refactor: keep only model discovery, drop redundant routing changes
- Revert routes.py and streaming.py to master: resolve_model_provider() already handles provider routing and base_url passthrough for all models. - Fix indentation error in config.py (2-space indent on comment line). - Fix auto_detected_models scope: initialize before try block. - Remove unused urllib.parse import. - Simplify unknown-provider model group logic. - Remove verbose comments and redundant variable assignments. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
118
api/config.py
118
api/config.py
@@ -446,14 +446,13 @@ def get_available_models() -> dict:
|
|||||||
detected_providers.add('deepseek')
|
detected_providers.add('deepseek')
|
||||||
|
|
||||||
# 3. Fetch models from custom endpoint if base_url is configured
|
# 3. Fetch models from custom endpoint if base_url is configured
|
||||||
|
auto_detected_models = []
|
||||||
if cfg_base_url:
|
if cfg_base_url:
|
||||||
auto_detected_models = [] # Store models fetched from endpoint
|
|
||||||
try:
|
try:
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import urllib.request
|
import urllib.request
|
||||||
import urllib.parse
|
|
||||||
|
|
||||||
# Normalize the base_url
|
# Normalize the base_url and build models endpoint
|
||||||
base_url = cfg_base_url.strip()
|
base_url = cfg_base_url.strip()
|
||||||
if base_url.endswith('/v1'):
|
if base_url.endswith('/v1'):
|
||||||
endpoint_url = base_url[:-3] + '/models'
|
endpoint_url = base_url[:-3] + '/models'
|
||||||
@@ -462,16 +461,13 @@ def get_available_models() -> dict:
|
|||||||
|
|
||||||
# Detect provider from base_url
|
# Detect provider from base_url
|
||||||
provider = 'custom'
|
provider = 'custom'
|
||||||
normalized = base_url.strip('/')
|
parsed = urlparse(base_url if '://' in base_url else f'http://{base_url}')
|
||||||
parsed = urlparse(normalized if '://' in normalized else f'http://{normalized}')
|
host = (parsed.netloc or parsed.path).lower()
|
||||||
host = parsed.netloc.lower() or parsed.path.lower()
|
|
||||||
|
|
||||||
# Check if it's a local/private IP
|
|
||||||
if parsed.hostname:
|
if parsed.hostname:
|
||||||
try:
|
try:
|
||||||
addr = ipaddress.ip_address(parsed.hostname)
|
addr = ipaddress.ip_address(parsed.hostname)
|
||||||
if addr.is_private or addr.is_loopback or addr.is_link_local:
|
if addr.is_private or addr.is_loopback or addr.is_link_local:
|
||||||
# Detect specific local provider based on hostname
|
|
||||||
if 'ollama' in host or '127.0.0.1' in host or 'localhost' in host:
|
if 'ollama' in host or '127.0.0.1' in host or 'localhost' in host:
|
||||||
provider = 'ollama'
|
provider = 'ollama'
|
||||||
elif 'lmstudio' in host or 'lm-studio' in host:
|
elif 'lmstudio' in host or 'lm-studio' in host:
|
||||||
@@ -481,70 +477,40 @@ def get_available_models() -> dict:
|
|||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
# Get the API key for this provider
|
# Resolve API key from environment
|
||||||
headers = {}
|
headers = {}
|
||||||
|
api_key_vars = ('HERMES_API_KEY', 'HERMES_OPENAI_API_KEY', 'OPENAI_API_KEY',
|
||||||
|
'LOCAL_API_KEY', 'OPENROUTER_API_KEY', 'API_KEY')
|
||||||
|
for key in api_key_vars:
|
||||||
|
api_key = os.getenv(key)
|
||||||
|
if api_key:
|
||||||
|
headers['Authorization'] = f'Bearer {api_key}'
|
||||||
|
break
|
||||||
|
|
||||||
# Try hermes-agent style API key resolution
|
# Fetch model list from endpoint
|
||||||
if provider == 'local':
|
req = urllib.request.Request(endpoint_url, method='GET')
|
||||||
# For local endpoints, check common API key env vars
|
for k, v in headers.items():
|
||||||
for key in ('HERMES_API_KEY', 'HERMES_OPENAI_API_KEY', 'OPENAI_API_KEY',
|
req.add_header(k, v)
|
||||||
'LOCAL_API_KEY', 'OPENROUTER_API_KEY', 'API_KEY'):
|
with urllib.request.urlopen(req, timeout=10) as response:
|
||||||
api_key = os.getenv(key)
|
data = json.loads(response.read().decode('utf-8'))
|
||||||
if api_key:
|
|
||||||
headers['Authorization'] = f'Bearer {api_key}'
|
|
||||||
break
|
|
||||||
else:
|
|
||||||
# For known providers, use their specific key env vars
|
|
||||||
for key in ('OPENROUTER_API_KEY', 'OPENAI_API_KEY', 'API_KEY'):
|
|
||||||
api_key = os.getenv(key)
|
|
||||||
if api_key:
|
|
||||||
headers['Authorization'] = f'Bearer {api_key}'
|
|
||||||
break
|
|
||||||
|
|
||||||
# Make the request using urllib.request
|
# Handle both OpenAI-compatible and llama.cpp response formats
|
||||||
try:
|
models_list = []
|
||||||
# Build request URL
|
if 'data' in data and isinstance(data['data'], list):
|
||||||
url = endpoint_url
|
models_list = data['data']
|
||||||
|
elif 'models' in data and isinstance(data['models'], list):
|
||||||
|
models_list = data['models']
|
||||||
|
|
||||||
# Prepare request
|
for model in models_list:
|
||||||
req = urllib.request.Request(url, method='GET')
|
if not isinstance(model, dict):
|
||||||
for key, value in headers.items():
|
continue
|
||||||
req.add_header(key, value)
|
model_id = model.get('id', '') or model.get('name', '') or model.get('model', '')
|
||||||
|
model_name = model.get('name', '') or model.get('model', '') or model_id
|
||||||
# Send request with timeout
|
if model_id and model_name:
|
||||||
with urllib.request.urlopen(req, timeout=10) as response:
|
auto_detected_models.append({'id': model_id, 'label': model_name})
|
||||||
data = json.loads(response.read().decode('utf-8'))
|
detected_providers.add(provider.lower())
|
||||||
|
except Exception as e:
|
||||||
# Parse the response - handle both OpenAI-compatible and llama.cpp formats
|
logger.debug(f"Failed to fetch models from custom endpoint: {e}")
|
||||||
models_list = []
|
|
||||||
|
|
||||||
# OpenAI-compatible format: data is in 'data' key
|
|
||||||
if 'data' in data and isinstance(data['data'], list):
|
|
||||||
models_list = data['data']
|
|
||||||
|
|
||||||
# llama.cpp format: data is 'models' array at root
|
|
||||||
elif 'models' in data and isinstance(data['models'], list):
|
|
||||||
models_list = data['models']
|
|
||||||
|
|
||||||
for model in models_list:
|
|
||||||
if not isinstance(model, dict):
|
|
||||||
continue
|
|
||||||
model_id = model.get('id', '') or model.get('name', '') or model.get('model', '')
|
|
||||||
model_name = model.get('name', '') or model.get('model', '') or model_id
|
|
||||||
if model_id and model_name:
|
|
||||||
# Store model in auto_detected_models for later use
|
|
||||||
auto_detected_models.append({
|
|
||||||
'id': model_id,
|
|
||||||
'label': model_name
|
|
||||||
})
|
|
||||||
detected_providers.add(provider.lower())
|
|
||||||
except Exception as e:
|
|
||||||
# Endpoint unavailable, fall through to fallback list
|
|
||||||
logger.debug(f"Failed to fetch models from {endpoint_url}: {e}")
|
|
||||||
pass
|
|
||||||
except Exception:
|
|
||||||
# Import failed, fall through to fallback list
|
|
||||||
pass
|
|
||||||
|
|
||||||
# 5. Build model groups
|
# 5. Build model groups
|
||||||
if detected_providers:
|
if detected_providers:
|
||||||
@@ -562,22 +528,14 @@ def get_available_models() -> dict:
|
|||||||
'models': _PROVIDER_MODELS[pid],
|
'models': _PROVIDER_MODELS[pid],
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
# Unknown provider with key
|
# Unknown provider -- use auto-detected models if available,
|
||||||
# If we have auto-detected models from base_url, use those instead of hardcoded default
|
# otherwise fall back to default model placeholder
|
||||||
if cfg_base_url and cfg_default:
|
if auto_detected_models:
|
||||||
# Use the default model from config
|
|
||||||
groups.append({
|
groups.append({
|
||||||
'provider': provider_name,
|
'provider': provider_name,
|
||||||
'models': [{'id': default_model, 'label': default_model.split('/')[-1]}],
|
'models': auto_detected_models,
|
||||||
})
|
|
||||||
elif cfg_base_url:
|
|
||||||
# Use auto-detected models from the endpoint
|
|
||||||
groups.append({
|
|
||||||
'provider': provider_name,
|
|
||||||
'models': [{'id': model['id'], 'label': model['label']} for model in auto_detected_models],
|
|
||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
# Fallback to placeholder with default model
|
|
||||||
groups.append({
|
groups.append({
|
||||||
'provider': provider_name,
|
'provider': provider_name,
|
||||||
'models': [{'id': default_model, 'label': default_model.split('/')[-1]}],
|
'models': [{'id': default_model, 'label': default_model.split('/')[-1]}],
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ from api.config import (
|
|||||||
SESSIONS, SESSIONS_MAX, LOCK, STREAMS, STREAMS_LOCK, CANCEL_FLAGS,
|
SESSIONS, SESSIONS_MAX, LOCK, STREAMS, STREAMS_LOCK, CANCEL_FLAGS,
|
||||||
SERVER_START_TIME, CLI_TOOLSETS, _INDEX_HTML_PATH, get_available_models,
|
SERVER_START_TIME, CLI_TOOLSETS, _INDEX_HTML_PATH, get_available_models,
|
||||||
IMAGE_EXTS, MD_EXTS, MIME_MAP, MAX_FILE_BYTES, MAX_UPLOAD_BYTES,
|
IMAGE_EXTS, MD_EXTS, MIME_MAP, MAX_FILE_BYTES, MAX_UPLOAD_BYTES,
|
||||||
CHAT_LOCK, load_settings, save_settings, cfg,
|
CHAT_LOCK, load_settings, save_settings,
|
||||||
)
|
)
|
||||||
from api.helpers import require, bad, safe_resolve, j, t, read_body
|
from api.helpers import require, bad, safe_resolve, j, t, read_body
|
||||||
from api.models import (
|
from api.models import (
|
||||||
@@ -679,28 +679,12 @@ def _handle_chat_start(handler, body):
|
|||||||
model = body.get('model') or s.model
|
model = body.get('model') or s.model
|
||||||
s.workspace = workspace; s.model = model; s.save()
|
s.workspace = workspace; s.model = model; s.save()
|
||||||
set_last_workspace(workspace)
|
set_last_workspace(workspace)
|
||||||
|
|
||||||
# Read base_url from config.yaml for this model
|
|
||||||
# cfg is a global variable loaded at module load time
|
|
||||||
model_cfg = cfg.get('model', {})
|
|
||||||
base_url = model_cfg.get('base_url', '')
|
|
||||||
|
|
||||||
# Use resolve_model_provider to get the correct model, provider, and base_url
|
|
||||||
# This handles all providers including local Ollama/LM Studio endpoints
|
|
||||||
from api.config import resolve_model_provider
|
|
||||||
resolved_model, resolved_provider, resolved_base_url = resolve_model_provider(model)
|
|
||||||
|
|
||||||
stream_id = uuid.uuid4().hex
|
stream_id = uuid.uuid4().hex
|
||||||
q = queue.Queue()
|
q = queue.Queue()
|
||||||
with STREAMS_LOCK: STREAMS[stream_id] = q
|
with STREAMS_LOCK: STREAMS[stream_id] = q
|
||||||
kwargs = {}
|
|
||||||
# Pass resolved provider and base_url to the streaming handler
|
|
||||||
kwargs['provider'] = resolved_provider
|
|
||||||
kwargs['base_url'] = resolved_base_url
|
|
||||||
thr = threading.Thread(
|
thr = threading.Thread(
|
||||||
target=_run_agent_streaming,
|
target=_run_agent_streaming,
|
||||||
args=(s.session_id, msg, resolved_model, workspace, stream_id, attachments),
|
args=(s.session_id, msg, model, workspace, stream_id, attachments),
|
||||||
kwargs=kwargs,
|
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
thr.start()
|
thr.start()
|
||||||
|
|||||||
@@ -32,7 +32,7 @@ def _sse(handler, event, data):
|
|||||||
handler.wfile.flush()
|
handler.wfile.flush()
|
||||||
|
|
||||||
|
|
||||||
def _run_agent_streaming(session_id, msg_text, model, workspace, stream_id, attachments=None, base_url=None, provider=None):
|
def _run_agent_streaming(session_id, msg_text, model, workspace, stream_id, attachments=None):
|
||||||
"""Run agent in background thread, writing SSE events to STREAMS[stream_id]."""
|
"""Run agent in background thread, writing SSE events to STREAMS[stream_id]."""
|
||||||
q = STREAMS.get(stream_id)
|
q = STREAMS.get(stream_id)
|
||||||
if q is None:
|
if q is None:
|
||||||
|
|||||||
Reference in New Issue
Block a user