diff --git a/api/config.py b/api/config.py index 84e023c..f33bfe5 100644 --- a/api/config.py +++ b/api/config.py @@ -446,32 +446,28 @@ def get_available_models() -> dict: detected_providers.add('deepseek') # 3. Fetch models from custom endpoint if base_url is configured + auto_detected_models = [] if cfg_base_url: - auto_detected_models = [] # Store models fetched from endpoint try: import ipaddress import urllib.request - import urllib.parse - - # Normalize the base_url + + # Normalize the base_url and build models endpoint base_url = cfg_base_url.strip() if base_url.endswith('/v1'): endpoint_url = base_url[:-3] + '/models' else: endpoint_url = base_url + '/v1/models' - + # Detect provider from base_url provider = 'custom' - normalized = base_url.strip('/') - parsed = urlparse(normalized if '://' in normalized else f'http://{normalized}') - host = parsed.netloc.lower() or parsed.path.lower() - - # Check if it's a local/private IP + parsed = urlparse(base_url if '://' in base_url else f'http://{base_url}') + host = (parsed.netloc or parsed.path).lower() + if parsed.hostname: try: addr = ipaddress.ip_address(parsed.hostname) if addr.is_private or addr.is_loopback or addr.is_link_local: - # Detect specific local provider based on hostname if 'ollama' in host or '127.0.0.1' in host or 'localhost' in host: provider = 'ollama' elif 'lmstudio' in host or 'lm-studio' in host: @@ -480,71 +476,41 @@ def get_available_models() -> dict: provider = 'local' except ValueError: pass - - # Get the API key for this provider + + # Resolve API key from environment headers = {} - - # Try hermes-agent style API key resolution - if provider == 'local': - # For local endpoints, check common API key env vars - for key in ('HERMES_API_KEY', 'HERMES_OPENAI_API_KEY', 'OPENAI_API_KEY', - 'LOCAL_API_KEY', 'OPENROUTER_API_KEY', 'API_KEY'): - api_key = os.getenv(key) - if api_key: - headers['Authorization'] = f'Bearer {api_key}' - break - else: - # For known providers, use their specific key env vars - for key in ('OPENROUTER_API_KEY', 'OPENAI_API_KEY', 'API_KEY'): - api_key = os.getenv(key) - if api_key: - headers['Authorization'] = f'Bearer {api_key}' - break - - # Make the request using urllib.request - try: - # Build request URL - url = endpoint_url - - # Prepare request - req = urllib.request.Request(url, method='GET') - for key, value in headers.items(): - req.add_header(key, value) - - # Send request with timeout - with urllib.request.urlopen(req, timeout=10) as response: - data = json.loads(response.read().decode('utf-8')) - - # Parse the response - handle both OpenAI-compatible and llama.cpp formats - models_list = [] - - # OpenAI-compatible format: data is in 'data' key - if 'data' in data and isinstance(data['data'], list): - models_list = data['data'] - - # llama.cpp format: data is 'models' array at root - elif 'models' in data and isinstance(data['models'], list): - models_list = data['models'] - - for model in models_list: - if not isinstance(model, dict): - continue - model_id = model.get('id', '') or model.get('name', '') or model.get('model', '') - model_name = model.get('name', '') or model.get('model', '') or model_id - if model_id and model_name: - # Store model in auto_detected_models for later use - auto_detected_models.append({ - 'id': model_id, - 'label': model_name - }) - detected_providers.add(provider.lower()) - except Exception as e: - # Endpoint unavailable, fall through to fallback list - logger.debug(f"Failed to fetch models from {endpoint_url}: {e}") - pass - except Exception: - # Import failed, fall through to fallback list - pass + api_key_vars = ('HERMES_API_KEY', 'HERMES_OPENAI_API_KEY', 'OPENAI_API_KEY', + 'LOCAL_API_KEY', 'OPENROUTER_API_KEY', 'API_KEY') + for key in api_key_vars: + api_key = os.getenv(key) + if api_key: + headers['Authorization'] = f'Bearer {api_key}' + break + + # Fetch model list from endpoint + req = urllib.request.Request(endpoint_url, method='GET') + for k, v in headers.items(): + req.add_header(k, v) + with urllib.request.urlopen(req, timeout=10) as response: + data = json.loads(response.read().decode('utf-8')) + + # Handle both OpenAI-compatible and llama.cpp response formats + models_list = [] + if 'data' in data and isinstance(data['data'], list): + models_list = data['data'] + elif 'models' in data and isinstance(data['models'], list): + models_list = data['models'] + + for model in models_list: + if not isinstance(model, dict): + continue + model_id = model.get('id', '') or model.get('name', '') or model.get('model', '') + model_name = model.get('name', '') or model.get('model', '') or model_id + if model_id and model_name: + auto_detected_models.append({'id': model_id, 'label': model_name}) + detected_providers.add(provider.lower()) + except Exception as e: + logger.debug(f"Failed to fetch models from custom endpoint: {e}") # 5. Build model groups if detected_providers: @@ -562,22 +528,14 @@ def get_available_models() -> dict: 'models': _PROVIDER_MODELS[pid], }) else: - # Unknown provider with key - # If we have auto-detected models from base_url, use those instead of hardcoded default - if cfg_base_url and cfg_default: - # Use the default model from config + # Unknown provider -- use auto-detected models if available, + # otherwise fall back to default model placeholder + if auto_detected_models: groups.append({ 'provider': provider_name, - 'models': [{'id': default_model, 'label': default_model.split('/')[-1]}], - }) - elif cfg_base_url: - # Use auto-detected models from the endpoint - groups.append({ - 'provider': provider_name, - 'models': [{'id': model['id'], 'label': model['label']} for model in auto_detected_models], + 'models': auto_detected_models, }) else: - # Fallback to placeholder with default model groups.append({ 'provider': provider_name, 'models': [{'id': default_model, 'label': default_model.split('/')[-1]}], diff --git a/api/routes.py b/api/routes.py index 153cff7..ee7ccea 100644 --- a/api/routes.py +++ b/api/routes.py @@ -17,7 +17,7 @@ from api.config import ( SESSIONS, SESSIONS_MAX, LOCK, STREAMS, STREAMS_LOCK, CANCEL_FLAGS, SERVER_START_TIME, CLI_TOOLSETS, _INDEX_HTML_PATH, get_available_models, IMAGE_EXTS, MD_EXTS, MIME_MAP, MAX_FILE_BYTES, MAX_UPLOAD_BYTES, - CHAT_LOCK, load_settings, save_settings, cfg, + CHAT_LOCK, load_settings, save_settings, ) from api.helpers import require, bad, safe_resolve, j, t, read_body from api.models import ( @@ -679,28 +679,12 @@ def _handle_chat_start(handler, body): model = body.get('model') or s.model s.workspace = workspace; s.model = model; s.save() set_last_workspace(workspace) - - # Read base_url from config.yaml for this model - # cfg is a global variable loaded at module load time - model_cfg = cfg.get('model', {}) - base_url = model_cfg.get('base_url', '') - - # Use resolve_model_provider to get the correct model, provider, and base_url - # This handles all providers including local Ollama/LM Studio endpoints - from api.config import resolve_model_provider - resolved_model, resolved_provider, resolved_base_url = resolve_model_provider(model) - stream_id = uuid.uuid4().hex q = queue.Queue() with STREAMS_LOCK: STREAMS[stream_id] = q - kwargs = {} - # Pass resolved provider and base_url to the streaming handler - kwargs['provider'] = resolved_provider - kwargs['base_url'] = resolved_base_url thr = threading.Thread( target=_run_agent_streaming, - args=(s.session_id, msg, resolved_model, workspace, stream_id, attachments), - kwargs=kwargs, + args=(s.session_id, msg, model, workspace, stream_id, attachments), daemon=True, ) thr.start() diff --git a/api/streaming.py b/api/streaming.py index 34dcadb..3bf8bd9 100644 --- a/api/streaming.py +++ b/api/streaming.py @@ -32,7 +32,7 @@ def _sse(handler, event, data): handler.wfile.flush() -def _run_agent_streaming(session_id, msg_text, model, workspace, stream_id, attachments=None, base_url=None, provider=None): +def _run_agent_streaming(session_id, msg_text, model, workspace, stream_id, attachments=None): """Run agent in background thread, writing SSE events to STREAMS[stream_id].""" q = STREAMS.get(stream_id) if q is None: