diff --git a/api/streaming.py b/api/streaming.py index eeb8907..8ee5bc8 100644 --- a/api/streaming.py +++ b/api/streaming.py @@ -463,12 +463,29 @@ def _run_agent_streaming(session_id, msg_text, model, workspace, stream_id, atta # Detect rate limit errors specifically so the client can show a helpful card # rather than the generic "Connection lost" message is_rate_limit = 'rate limit' in err_str.lower() or '429' in err_str or 'RateLimitError' in type(e).__name__ + is_auth_error = ( + '401' in err_str + or 'AuthenticationError' in type(e).__name__ + or 'authentication' in err_str.lower() + or 'unauthorized' in err_str.lower() + or 'invalid api key' in err_str.lower() + or 'no cookie auth credentials' in err_str.lower() + ) if is_rate_limit: put('apperror', { 'message': err_str, 'type': 'rate_limit', 'hint': 'Rate limit reached. The fallback model (if configured) was also exhausted. Try again in a moment.', }) + elif is_auth_error: + put('apperror', { + 'message': err_str, + 'type': 'auth_mismatch', + 'hint': ( + 'The selected model may not be supported by your configured provider. ' + 'Run `hermes model` in your terminal to switch providers, then restart the WebUI.' + ), + }) else: put('apperror', {'message': err_str, 'type': 'error'}) finally: diff --git a/static/boot.js b/static/boot.js index 650a06b..13d42b1 100644 --- a/static/boot.js +++ b/static/boot.js @@ -209,6 +209,11 @@ $('modelSelect').onchange=async()=>{ localStorage.setItem('hermes-webui-model', selectedModel); await api('/api/session/update',{method:'POST',body:JSON.stringify({session_id:S.session.session_id,workspace:S.session.workspace,model:selectedModel})}); S.session.model=selectedModel;syncTopbar(); + // Warn if selected model belongs to a different provider than what Hermes is configured for + if(typeof _checkProviderMismatch==='function'){ + const warn=_checkProviderMismatch(selectedModel); + if(warn&&typeof showToast==='function') showToast(warn,4000); + } }; $('msg').addEventListener('input',()=>{ autoResize(); diff --git a/static/i18n.js b/static/i18n.js index b81255b..777d533 100644 --- a/static/i18n.js +++ b/static/i18n.js @@ -48,6 +48,8 @@ const LOCALES = { n_messages: (n) => `${n} messages`, model_unavailable: ' (unavailable)', model_unavailable_title: 'This model is no longer in your current provider list', + provider_mismatch_warning: (m,p)=>`"${m}" may not work with your configured provider (${p}). Send anyway, or run \`hermes model\` in your terminal to switch.`, + provider_mismatch_label: 'Provider mismatch', // commands.js cmd_help: 'List available commands', cmd_clear: 'Clear conversation messages', @@ -237,6 +239,8 @@ const LOCALES = { n_messages: (n) => `${n} mensajes`, model_unavailable: ' (no disponible)', model_unavailable_title: 'Este modelo ya no está en tu lista actual de proveedores', + provider_mismatch_warning: (m,p)=>`"${m}" puede no funcionar con tu proveedor configurado (${p}). Envía de todas formas, o ejecuta \`hermes model\` en la terminal para cambiar.`, + provider_mismatch_label: 'Proveedor incompatible', // commands.js cmd_help: 'Listar los comandos disponibles', cmd_clear: 'Borrar los mensajes de la conversación', @@ -426,6 +430,8 @@ const LOCALES = { n_messages: (n) => `${n} Nachrichten`, model_unavailable: ' (nicht verfügbar)', model_unavailable_title: 'Dieses Modell ist nicht mehr in Ihrer aktuellen Provider-Liste', + provider_mismatch_warning: (m,p)=>`"${m}" funktioniert möglicherweise nicht mit Ihrem konfigurierten Provider (${p}). Trotzdem senden, oder \`hermes model\` im Terminal ausführen.`, + provider_mismatch_label: 'Provider-Konflikt', // commands.js cmd_help: 'Verfügbare Befehle auflisten', cmd_clear: 'Konversationsverlauf löschen', @@ -615,6 +621,8 @@ const LOCALES = { n_messages: (n) => `${n} \u6761\u6d88\u606f`, model_unavailable: '\uff08\u4e0d\u53ef\u7528\uff09', model_unavailable_title: '\u8fd9\u4e2a\u6a21\u578b\u5df2\u7ecf\u4e0d\u5728\u5f53\u524d provider \u5217\u8868\u4e2d', + provider_mismatch_warning: (m,p)=>`\"${m}\" \u53ef\u80fd\u65e0\u6cd5\u5728\u5f53\u524d\u914d\u7f6e\u7684\u63d0\u4f9b\u5546 (${p}) \u4e0b\u5de5\u4f5c\u3002\u76f4\u63a5\u53d1\u9001\uff0c\u6216\u5728\u7ec8\u7aef\u8fd0\u884c \`hermes model\` \u5207\u6362\u3002`, + provider_mismatch_label: '\u63d0\u4f9b\u5546\u4e0d\u5339\u914d', // commands.js cmd_help: '\u67e5\u770b\u53ef\u7528\u547d\u4ee4', cmd_clear: '\u6e05\u7a7a\u5f53\u524d\u5bf9\u8bdd\u6d88\u606f', @@ -802,6 +810,8 @@ const LOCALES = { n_messages: (n) => `${n} \u689d\u8a0a\u606f`, model_unavailable: '\uff08\u4e0d\u53ef\u7528\uff09', model_unavailable_title: '\u6b64\u6a21\u578b\u5df2\u7d93\u4e0d\u5728\u7576\u524d provider \u5217\u8868\u4e2d', + provider_mismatch_warning: (m,p)=>`\"${m}\" \u53ef\u80fd\u7121\u6cd5\u5728\u7576\u524d\u914d\u7f6e\u7684\u63d0\u4f9b\u8005 (${p}) \u4e0b\u904b\u4f5c\u3002\u5c1a\u9001\uff0c\u6216\u5728\u7d42\u7aef\u57f7\u884c \`hermes model\` \u5207\u63db\u3002`, + provider_mismatch_label: '\u63d0\u4f9b\u8005\u4e0d\u76f8\u7b26', // commands.js cmd_help: '\u67e5\u770b\u53ef\u7528\u547d\u4ee4', cmd_clear: '\u6e05\u7a7a\u7576\u524d\u5c0d\u8a71\u8a0a\u606f', diff --git a/static/messages.js b/static/messages.js index a379be4..0f5ec33 100644 --- a/static/messages.js +++ b/static/messages.js @@ -234,7 +234,8 @@ async function send(){ try{ const d=JSON.parse(e.data); const isRateLimit=d.type==='rate_limit'; - const label=isRateLimit?'Rate limit reached':'Error'; + const isAuthMismatch=d.type==='auth_mismatch'; + const label=isRateLimit?'Rate limit reached':isAuthMismatch?(typeof t==='function'?t('provider_mismatch_label'):'Provider mismatch'):'Error'; const hint=d.hint?`\n\n*${d.hint}*`:''; S.messages.push({role:'assistant',content:`**${label}:** ${d.message}${hint}`}); }catch(_){ diff --git a/static/ui.js b/static/ui.js index 987d0f5..2396409 100644 --- a/static/ui.js +++ b/static/ui.js @@ -45,6 +45,8 @@ async function populateModelDropdown(){ try{ const data=await fetch(new URL('/api/models',location.origin).href,{credentials:'include'}).then(r=>r.json()); if(!data.groups||!data.groups.length) return; // keep HTML defaults + // Store active provider globally so the send path can warn on mismatch + window._activeProvider=data.active_provider||null; // Clear existing options sel.innerHTML=''; _dynamicModelLabels={}; @@ -70,6 +72,32 @@ async function populateModelDropdown(){ } } +/** + * Check if the given model ID belongs to a different provider than the one + * currently configured in Hermes. Returns a warning string if mismatched, + * or null if the selection looks compatible. + * + * Provider detection is intentionally loose — we compare the model's slash + * prefix (e.g. "openai/" from "openai/gpt-4o") against the active provider + * name. Custom/local endpoints report active_provider='custom' or the + * base_url hostname and we skip the check to avoid false positives. + */ +function _checkProviderMismatch(modelId){ + const ap=(window._activeProvider||'').toLowerCase(); + if(!ap||ap==='custom'||ap==='openrouter') return null; // can't reliably check + const slash=modelId.indexOf('/'); + if(slash<0) return null; // bare model name, no provider prefix + const modelProvider=modelId.substring(0,slash).toLowerCase(); + // Normalise common aliases + const aliases={'claude':'anthropic','gpt':'openai','gemini':'google'}; + const norm=p=>aliases[p]||p; + if(norm(modelProvider)!==norm(ap)){ + return (window.t?window.t('provider_mismatch_warning',modelId,ap): + `"${modelId}" may not work with your configured provider (${ap}). Send anyway or run \`hermes model\` to switch.`); + } + return null; +} + // ── Scroll pinning ────────────────────────────────────────────────────────── // When streaming, auto-scroll only if the user hasn't manually scrolled up. // Once the user scrolls back to within 80px of the bottom, re-pin. diff --git a/tests/test_provider_mismatch.py b/tests/test_provider_mismatch.py new file mode 100644 index 0000000..f0e9ef9 --- /dev/null +++ b/tests/test_provider_mismatch.py @@ -0,0 +1,266 @@ +""" +Tests for issue #266 — provider/model mismatch warning. + +Covers: + 1. streaming.py: auth errors detected and classified as 'auth_mismatch' + 2. static/ui.js: _checkProviderMismatch() helper exists and logic is correct + 3. static/messages.js: apperror handler has auth_mismatch branch + 4. static/i18n.js: provider_mismatch_warning and provider_mismatch_label keys + present in all 5 locales (en, es, de, zh, zh-Hant) + 5. static/boot.js: modelSelect.onchange calls _checkProviderMismatch + 6. /api/models: response includes active_provider field +""" +import json +import pathlib +import re +import urllib.request + +REPO_ROOT = pathlib.Path(__file__).parent.parent.resolve() +BASE = "http://127.0.0.1:8788" + + +def _read(rel_path: str) -> str: + return (REPO_ROOT / rel_path).read_text(encoding="utf-8") + + +# ── 1. streaming.py: auth error detection ─────────────────────────────────── + +class TestStreamingAuthErrorDetection: + """streaming.py must classify auth/401 errors as auth_mismatch.""" + + def test_auth_mismatch_type_defined_in_streaming(self): + """'auth_mismatch' type must be emitted for auth errors.""" + src = _read("api/streaming.py") + assert "auth_mismatch" in src, ( + "auth_mismatch type not found in streaming.py — " + "401/auth errors will not be surfaced with a helpful message" + ) + + def test_is_auth_error_flag_defined(self): + """is_auth_error variable must exist in the error handler.""" + src = _read("api/streaming.py") + assert "is_auth_error" in src, ( + "is_auth_error flag not found in streaming.py" + ) + + def test_auth_error_detects_401(self): + """'401' must be part of the auth error detection logic.""" + src = _read("api/streaming.py") + # Find the is_auth_error block + idx = src.find("is_auth_error") + assert idx != -1 + block = src[idx:idx + 400] + assert "'401'" in block or '"401"' in block, ( + "'401' not in is_auth_error detection block" + ) + + def test_auth_error_detects_unauthorized(self): + """'unauthorized' must be part of the auth error detection logic.""" + src = _read("api/streaming.py") + idx = src.find("is_auth_error") + block = src[idx:idx + 400] + assert "unauthorized" in block.lower(), ( + "'unauthorized' not in is_auth_error detection block" + ) + + def test_auth_error_hint_mentions_hermes_model(self): + """The auth_mismatch hint must mention 'hermes model' command.""" + src = _read("api/streaming.py") + # Find the auth_mismatch apperror block + idx = src.find("auth_mismatch") + block = src[idx:idx + 500] + assert "hermes model" in block, ( + "auth_mismatch hint must mention 'hermes model' command " + "so users know how to fix provider mismatch" + ) + + def test_auth_error_does_not_catch_rate_limit(self): + """Rate limit errors must not be reclassified as auth_mismatch.""" + src = _read("api/streaming.py") + # is_rate_limit must come before is_auth_error in the elif chain + rl_idx = src.find("is_rate_limit") + ae_idx = src.find("is_auth_error") + assert rl_idx < ae_idx, ( + "is_rate_limit check should precede is_auth_error — " + "rate limit errors must not be mistaken for auth errors" + ) + + +# ── 2. static/ui.js: _checkProviderMismatch() ─────────────────────────────── + +class TestCheckProviderMismatch: + """ui.js must expose _checkProviderMismatch() helper.""" + + def test_function_defined(self): + """_checkProviderMismatch function must be defined in ui.js.""" + src = _read("static/ui.js") + assert "function _checkProviderMismatch" in src, ( + "_checkProviderMismatch not defined in ui.js" + ) + + def test_uses_window_active_provider(self): + """Function must read window._activeProvider.""" + src = _read("static/ui.js") + idx = src.find("function _checkProviderMismatch") + block = src[idx:idx + 800] + assert "_activeProvider" in block, ( + "_checkProviderMismatch must read window._activeProvider" + ) + + def test_skips_check_for_openrouter(self): + """OpenRouter can route to any provider — skip the warning.""" + src = _read("static/ui.js") + idx = src.find("function _checkProviderMismatch") + block = src[idx:idx + 800] + assert "openrouter" in block.lower(), ( + "_checkProviderMismatch must skip the check for openrouter" + ) + + def test_skips_check_for_custom(self): + """Custom endpoints can serve any model — skip the warning.""" + src = _read("static/ui.js") + idx = src.find("function _checkProviderMismatch") + block = src[idx:idx + 800] + assert "custom" in block.lower(), ( + "_checkProviderMismatch must skip the check for custom provider" + ) + + def test_active_provider_stored_on_model_load(self): + """populateModelDropdown must store active_provider from /api/models.""" + src = _read("static/ui.js") + # Find the function definition (skip the comment that also mentions the name) + idx = src.find("async function populateModelDropdown") + assert idx != -1, "async function populateModelDropdown not found" + block = src[idx:idx + 800] + assert "_activeProvider" in block, ( + "populateModelDropdown must set window._activeProvider " + "from the /api/models response" + ) + + +# ── 3. static/messages.js: apperror handler ───────────────────────────────── + +class TestApperrorHandler: + """messages.js apperror handler must handle auth_mismatch type.""" + + def test_auth_mismatch_type_handled(self): + """apperror handler must check for type='auth_mismatch'.""" + src = _read("static/messages.js") + assert "auth_mismatch" in src, ( + "auth_mismatch type not handled in messages.js apperror handler" + ) + + def test_provider_mismatch_label(self): + """'Provider mismatch' label must appear in the error handling.""" + src = _read("static/messages.js") + assert "Provider mismatch" in src, ( + "'Provider mismatch' label not found in messages.js" + ) + + def test_is_auth_mismatch_variable(self): + """isAuthMismatch variable must be defined.""" + src = _read("static/messages.js") + assert "isAuthMismatch" in src, ( + "isAuthMismatch variable not found in messages.js apperror handler" + ) + + +# ── 4. static/i18n.js: all 5 locales ──────────────────────────────────────── + +class TestI18nProviderMismatch: + """All 5 locales must have provider_mismatch_warning and provider_mismatch_label.""" + + REQUIRED_KEYS = ["provider_mismatch_warning", "provider_mismatch_label"] + + def _count_key(self, src: str, key: str) -> int: + return len(re.findall(r'\b' + re.escape(key) + r'\b', src)) + + def test_all_locales_have_warning_key(self): + """provider_mismatch_warning must appear in all 5 locales.""" + src = _read("static/i18n.js") + count = self._count_key(src, "provider_mismatch_warning") + assert count >= 5, ( + f"provider_mismatch_warning found {count} times, expected >= 5 " + f"(one per locale: en, es, de, zh, zh-Hant)" + ) + + def test_all_locales_have_label_key(self): + """provider_mismatch_label must appear in all 5 locales.""" + src = _read("static/i18n.js") + count = self._count_key(src, "provider_mismatch_label") + assert count >= 5, ( + f"provider_mismatch_label found {count} times, expected >= 5" + ) + + def test_warning_is_function_in_en(self): + """English provider_mismatch_warning must be a function (m, p) => ...""" + src = _read("static/i18n.js") + # Find the en block + en_start = src.find("\n en: {") + es_start = src.find("\n es: {") + en_block = src[en_start:es_start] + assert "provider_mismatch_warning" in en_block, "Key not in en block" + idx = en_block.find("provider_mismatch_warning") + line = en_block[idx:idx + 200] + # Must be a function, not a plain string + assert "=>" in line, ( + "provider_mismatch_warning in en locale must be an arrow function " + "that takes (m, p) parameters for model and provider interpolation" + ) + + def test_spanish_locale_key_coverage(self): + """Spanish locale must have the new keys (parity with English).""" + src = _read("static/i18n.js") + es_start = src.find("\n es: {") + de_start = src.find("\n de: {") + es_block = src[es_start:de_start] + for key in self.REQUIRED_KEYS: + assert key in es_block, f"Key '{key}' missing from Spanish locale" + + +# ── 5. static/boot.js: dropdown change handler ────────────────────────────── + +class TestBootModelSelectChange: + """boot.js modelSelect.onchange must call _checkProviderMismatch.""" + + def test_onchange_calls_check_function(self): + """modelSelect.onchange must invoke _checkProviderMismatch.""" + src = _read("static/boot.js") + assert "_checkProviderMismatch" in src, ( + "boot.js modelSelect.onchange must call _checkProviderMismatch " + "to warn users about provider/model mismatches" + ) + # Verify it's called from the onchange handler (near modelSelect.onchange) + idx = src.find("'modelSelect').onchange") or src.find('"modelSelect").onchange') + if idx == -1: + # Try alternate patterns + idx = src.find("modelSelect") + block_start = src.rfind("\n", 0, src.find("_checkProviderMismatch")) or 0 + surrounding = src[max(0, block_start - 200):block_start + 400] + assert "modelSelect" in surrounding or "selectedModel" in surrounding, ( + "_checkProviderMismatch must be called in the context of model selection" + ) + + def test_onchange_shows_toast_on_mismatch(self): + """The warning must be shown via showToast, not alert().""" + src = _read("static/boot.js") + # Both _checkProviderMismatch call and showToast must be near each other + idx = src.find("_checkProviderMismatch") + assert idx != -1, "_checkProviderMismatch not found in boot.js" + block = src[idx:idx + 300] + assert "showToast" in block, ( + "Provider mismatch warning must be shown via showToast(), not alert()" + ) + + +# ── 6. /api/models: active_provider in response ────────────────────────────── + +def test_api_models_includes_active_provider(): + """/api/models must include 'active_provider' key in response.""" + with urllib.request.urlopen(BASE + "/api/models", timeout=10) as r: + data = json.loads(r.read()) + # active_provider can be None/null but the key must exist + assert "active_provider" in data, ( + "/api/models response missing 'active_provider' field — " + "frontend needs this to detect provider mismatches" + )