diff --git a/api/config.py b/api/config.py index 1ea94f7..d0490dd 100644 --- a/api/config.py +++ b/api/config.py @@ -389,14 +389,21 @@ def resolve_model_provider(model_id: str) -> tuple: """Resolve model name, provider, and base_url for AIAgent. Model IDs from the dropdown can be in several formats: - - 'claude-sonnet-4.6' (bare name, uses config default provider) - - 'anthropic/claude-sonnet-4.6' (OpenRouter format, provider/model) - - '@minimax:MiniMax-M2.7' (explicit provider hint from dropdown) + - 'claude-sonnet-4.6' (bare name, uses config default provider) + - 'anthropic/claude-sonnet-4.6' (OpenRouter-style provider/model) + - '@minimax:MiniMax-M2.7' (explicit provider hint from dropdown) The @provider:model format is used for models from non-default provider groups in the dropdown, so we can route them through the correct provider via resolve_runtime_provider(requested=provider) instead of the default. + Custom OpenAI-compatible endpoints are special: their model IDs often look + like provider/model (for example ``google/gemma-4-26b-a4b``), which would be + mistaken for an OpenRouter model if we only looked at the slash. To avoid + that, first check whether the selected model matches an entry in + config.yaml -> custom_providers and route it through that named custom + provider. + Returns (model, provider, base_url) where provider and base_url may be None. """ config_provider = None @@ -410,6 +417,20 @@ def resolve_model_provider(model_id: str) -> tuple: if not model_id: return model_id, config_provider, config_base_url + # Custom providers declared in config.yaml should win over slash-based + # OpenRouter heuristics. Their model IDs commonly contain '/' too. + custom_providers = cfg.get('custom_providers', []) + if isinstance(custom_providers, list): + for entry in custom_providers: + if not isinstance(entry, dict): + continue + entry_model = (entry.get('model') or '').strip() + entry_name = (entry.get('name') or '').strip() + entry_base_url = (entry.get('base_url') or '').strip() + if entry_model and entry_name and model_id == entry_model: + provider_hint = 'custom:' + entry_name.lower().replace(' ', '-') + return model_id, provider_hint, entry_base_url or None + # @provider:model format — explicit provider hint from the dropdown. # Route through that provider directly (resolve_runtime_provider will # resolve credentials in streaming.py). diff --git a/tests/test_model_resolver.py b/tests/test_model_resolver.py index ecad825..16bcf2a 100644 --- a/tests/test_model_resolver.py +++ b/tests/test_model_resolver.py @@ -6,8 +6,8 @@ tuples for different provider configurations. import api.config as config -def _resolve_with_config(model_id, provider=None, base_url=None, default=None): - """Helper: temporarily set config.cfg model section, call resolve, restore.""" +def _resolve_with_config(model_id, provider=None, base_url=None, default=None, custom_providers=None): + """Helper: temporarily set config.cfg model/custom provider sections, call resolve, restore.""" old_cfg = dict(config.cfg) model_cfg = {} if provider: @@ -17,6 +17,8 @@ def _resolve_with_config(model_id, provider=None, base_url=None, default=None): if default: model_cfg['default'] = default config.cfg['model'] = model_cfg if model_cfg else {} + if custom_providers is not None: + config.cfg['custom_providers'] = custom_providers try: return config.resolve_model_provider(model_id) finally: @@ -139,6 +141,23 @@ def test_slash_prefix_non_default_still_routes_openrouter(): assert provider == 'openrouter' +def test_custom_provider_model_with_slash_routes_to_named_custom_provider(): + """Slash-containing custom endpoint model IDs must not be mistaken for OpenRouter models.""" + model, provider, base_url = _resolve_with_config( + 'google/gemma-4-26b-a4b', + provider='openrouter', + base_url='https://openrouter.ai/api/v1', + custom_providers=[{ + 'name': 'Local LM Studio', + 'base_url': 'http://lmstudio.local:1234/v1', + 'model': 'google/gemma-4-26b-a4b', + }], + ) + assert model == 'google/gemma-4-26b-a4b' + assert provider == 'custom:local-lm-studio' + assert base_url == 'http://lmstudio.local:1234/v1' + + # ── get_available_models() @provider: hint behaviour ────────────────────── def _available_models_with_provider(provider):