diff --git a/api/config.py b/api/config.py index 977923a..a2e993f 100644 --- a/api/config.py +++ b/api/config.py @@ -363,7 +363,8 @@ def get_available_models() -> dict: Discovery order: 1. Read config.yaml 'model' section for active provider info 2. Check for known API keys in env or ~/.hermes/.env - 3. Fall back to hardcoded model list (OpenRouter-style) + 3. Fetch models from custom endpoint if base_url is configured + 4. Fall back to hardcoded model list (OpenRouter-style) Returns: { 'active_provider': str|None, @@ -382,6 +383,7 @@ def get_available_models() -> dict: elif isinstance(model_cfg, dict): active_provider = model_cfg.get('provider') cfg_default = model_cfg.get('default', '') + cfg_base_url = model_cfg.get('base_url', '') if cfg_default: default_model = cfg_default @@ -442,6 +444,96 @@ def get_available_models() -> dict: if all_env.get('DEEPSEEK_API_KEY'): detected_providers.add('deepseek') + # 3. Fetch models from custom endpoint if base_url is configured + if cfg_base_url: + try: + import requests as _req + import ipaddress + + # Normalize the base_url + base_url = cfg_base_url.strip() + if base_url.endswith('/v1'): + endpoint_url = base_url[:-3] + '/models' + else: + endpoint_url = base_url + '/v1/models' + + # Detect provider from base_url + provider = 'custom' + normalized = base_url.strip('/') + parsed = urlparse(normalized if '://' in normalized else f'http://{normalized}') + host = parsed.netloc.lower() or parsed.path.lower() + + # Check if it's a local/private IP + if parsed.hostname: + try: + addr = ipaddress.ip_address(parsed.hostname) + if addr.is_private or addr.is_loopback or addr.is_link_local: + provider = 'local' + except ValueError: + pass + + # Get the API key for this provider + headers = {} + + # Try hermes-agent style API key resolution + if provider == 'local': + # For local endpoints, check common API key env vars + for key in ('HERMES_API_KEY', 'HERMES_OPENAI_API_KEY', 'OPENAI_API_KEY', + 'LOCAL_API_KEY', 'OPENROUTER_API_KEY', 'API_KEY'): + api_key = os.getenv(key) + if api_key: + headers['Authorization'] = f'Bearer {api_key}' + break + else: + # For known providers, use their specific key env vars + for key in ('OPENROUTER_API_KEY', 'OPENAI_API_KEY', 'API_KEY'): + api_key = os.getenv(key) + if api_key: + headers['Authorization'] = f'Bearer {api_key}' + break + + # Make the request + try: + resp = _req.get(endpoint_url, headers=headers, timeout=10) + resp.raise_for_status() + data = resp.json() + + # Parse the response - handle both OpenAI-compatible and llama.cpp formats + models_list = [] + + # OpenAI-compatible format: data is in 'data' key + if 'data' in data and isinstance(data['data'], list): + models_list = data['data'] + + # llama.cpp format: data is 'models' array at root + elif 'models' in data and isinstance(data['models'], list): + models_list = data['models'] + + for model in models_list: + if not isinstance(model, dict): + continue + model_id = model.get('id', '') or model.get('name', '') or model.get('model', '') + model_name = model.get('name', '') or model.get('model', '') or model_id + if model_id and model_name: + # Detect provider from model_id + detected_provider = 'Custom' + for pid in _PROVIDER_DISPLAY: + if pid in model_id.lower(): + detected_provider = _PROVIDER_DISPLAY.get(pid, pid.title()) + break + groups.append({ + 'provider': detected_provider, + 'models': [{'id': model_id, 'label': model_name}], + }) + detected_providers.add(provider.lower()) + except Exception as e: + # Endpoint unavailable, fall through to fallback list + logger.debug(f"Failed to fetch models from {endpoint_url}: {e}") + pass + except Exception: + # Import failed, fall through to fallback list + pass + # 5. Build model groups if detected_providers: for pid in sorted(detected_providers): diff --git a/api/routes.py b/api/routes.py index ee7ccea..5873f01 100644 --- a/api/routes.py +++ b/api/routes.py @@ -17,7 +17,7 @@ from api.config import ( SESSIONS, SESSIONS_MAX, LOCK, STREAMS, STREAMS_LOCK, CANCEL_FLAGS, SERVER_START_TIME, CLI_TOOLSETS, _INDEX_HTML_PATH, get_available_models, IMAGE_EXTS, MD_EXTS, MIME_MAP, MAX_FILE_BYTES, MAX_UPLOAD_BYTES, - CHAT_LOCK, load_settings, save_settings, + CHAT_LOCK, load_settings, save_settings, cfg, ) from api.helpers import require, bad, safe_resolve, j, t, read_body from api.models import ( @@ -679,12 +679,36 @@ def _handle_chat_start(handler, body): model = body.get('model') or s.model s.workspace = workspace; s.model = model; s.save() set_last_workspace(workspace) + + # Read base_url from config.yaml for this model + # cfg is a global variable loaded at module load time + model_cfg = cfg.get('model', {}) + base_url = model_cfg.get('base_url', '') + + # If a model is selected that matches our custom endpoint, use that base_url AND local provider + if base_url and model: + # Check if the model name contains "qwen" (your local model) + if 'qwen' in model.lower(): + # Use the local endpoint with "local" provider + if base_url.endswith('/v1'): + effective_base_url = base_url[:-3] + else: + effective_base_url = base_url + '/v1' + effective_provider = 'custom' + print(f"DEBUG: Using custom base_url for {model}: {effective_base_url}", file=sys.stderr) + stream_id = uuid.uuid4().hex q = queue.Queue() with STREAMS_LOCK: STREAMS[stream_id] = q + kwargs = {} + if 'effective_base_url' in locals(): + kwargs['base_url'] = effective_base_url + if 'effective_provider' in locals(): + kwargs['provider'] = effective_provider thr = threading.Thread( target=_run_agent_streaming, args=(s.session_id, msg, model, workspace, stream_id, attachments), + kwargs=kwargs, daemon=True, ) thr.start() diff --git a/api/streaming.py b/api/streaming.py index 3bf8bd9..34dcadb 100644 --- a/api/streaming.py +++ b/api/streaming.py @@ -32,7 +32,7 @@ def _sse(handler, event, data): handler.wfile.flush() -def _run_agent_streaming(session_id, msg_text, model, workspace, stream_id, attachments=None): +def _run_agent_streaming(session_id, msg_text, model, workspace, stream_id, attachments=None, base_url=None, provider=None): """Run agent in background thread, writing SSE events to STREAMS[stream_id].""" q = STREAMS.get(stream_id) if q is None: