chore: add missing type hints across 10 files

This commit is contained in:
Nguyễn Công Thuận Huy
2026-04-05 13:30:20 +07:00
parent 3d063b08a9
commit 4d333acbbc
10 changed files with 49 additions and 49 deletions

View File

@@ -57,7 +57,7 @@ def _hash_password(password):
return dk.hex() return dk.hex()
def get_password_hash(): def get_password_hash() -> bool:
"""Return the active password hash, or None if auth is disabled. """Return the active password hash, or None if auth is disabled.
Priority: env var > settings.json.""" Priority: env var > settings.json."""
env_pw = os.getenv('HERMES_WEBUI_PASSWORD', '').strip() env_pw = os.getenv('HERMES_WEBUI_PASSWORD', '').strip()
@@ -67,12 +67,12 @@ def get_password_hash():
return settings.get('password_hash') or None return settings.get('password_hash') or None
def is_auth_enabled(): def is_auth_enabled() -> bool:
"""True if a password is configured (env var or settings).""" """True if a password is configured (env var or settings)."""
return get_password_hash() is not None return get_password_hash() is not None
def verify_password(plain): def verify_password(plain) -> bool:
"""Verify a plaintext password against the stored hash.""" """Verify a plaintext password against the stored hash."""
expected = get_password_hash() expected = get_password_hash()
if not expected: if not expected:
@@ -80,7 +80,7 @@ def verify_password(plain):
return hmac.compare_digest(_hash_password(plain), expected) return hmac.compare_digest(_hash_password(plain), expected)
def create_session(): def create_session() -> str:
"""Create a new auth session. Returns signed cookie value.""" """Create a new auth session. Returns signed cookie value."""
token = secrets.token_hex(32) token = secrets.token_hex(32)
_sessions[token] = time.time() + SESSION_TTL _sessions[token] = time.time() + SESSION_TTL
@@ -88,7 +88,7 @@ def create_session():
return f"{token}.{sig}" return f"{token}.{sig}"
def verify_session(cookie_value): def verify_session(cookie_value) -> bool:
"""Verify a signed session cookie. Returns True if valid and not expired.""" """Verify a signed session cookie. Returns True if valid and not expired."""
if not cookie_value or '.' not in cookie_value: if not cookie_value or '.' not in cookie_value:
return False return False
@@ -103,14 +103,14 @@ def verify_session(cookie_value):
return True return True
def invalidate_session(cookie_value): def invalidate_session(cookie_value) -> None:
"""Remove a session token.""" """Remove a session token."""
if cookie_value and '.' in cookie_value: if cookie_value and '.' in cookie_value:
token = cookie_value.rsplit('.', 1)[0] token = cookie_value.rsplit('.', 1)[0]
_sessions.pop(token, None) _sessions.pop(token, None)
def parse_cookie(handler): def parse_cookie(handler) -> None:
"""Extract the auth cookie from the request headers.""" """Extract the auth cookie from the request headers."""
cookie_header = handler.headers.get('Cookie', '') cookie_header = handler.headers.get('Cookie', '')
if not cookie_header: if not cookie_header:
@@ -124,7 +124,7 @@ def parse_cookie(handler):
return morsel.value if morsel else None return morsel.value if morsel else None
def check_auth(handler, parsed): def check_auth(handler, parsed) -> bool:
"""Check if request is authorized. Returns True if OK. """Check if request is authorized. Returns True if OK.
If not authorized, sends 401 (API) or 302 redirect (page) and returns False.""" If not authorized, sends 401 (API) or 302 redirect (page) and returns False."""
if not is_auth_enabled(): if not is_auth_enabled():
@@ -149,7 +149,7 @@ def check_auth(handler, parsed):
return False return False
def set_auth_cookie(handler, cookie_value): def set_auth_cookie(handler, cookie_value) -> None:
"""Set the auth cookie on the response.""" """Set the auth cookie on the response."""
cookie = http.cookies.SimpleCookie() cookie = http.cookies.SimpleCookie()
cookie[COOKIE_NAME] = cookie_value cookie[COOKIE_NAME] = cookie_value
@@ -160,7 +160,7 @@ def set_auth_cookie(handler, cookie_value):
handler.send_header('Set-Cookie', cookie[COOKIE_NAME].OutputString()) handler.send_header('Set-Cookie', cookie[COOKIE_NAME].OutputString())
def clear_auth_cookie(handler): def clear_auth_cookie(handler) -> None:
"""Clear the auth cookie on the response.""" """Clear the auth cookie on the response."""
cookie = http.cookies.SimpleCookie() cookie = http.cookies.SimpleCookie()
cookie[COOKIE_NAME] = '' cookie[COOKIE_NAME] = ''

View File

@@ -169,7 +169,7 @@ def get_config() -> dict:
reload_config() reload_config()
return _cfg_cache return _cfg_cache
def reload_config(): def reload_config() -> None:
"""Reload config.yaml from the active profile's directory.""" """Reload config.yaml from the active profile's directory."""
with _cfg_lock: with _cfg_lock:
_cfg_cache.clear() _cfg_cache.clear()
@@ -208,7 +208,7 @@ DEFAULT_WORKSPACE = _discover_default_workspace()
DEFAULT_MODEL = os.getenv('HERMES_WEBUI_DEFAULT_MODEL', 'openai/gpt-5.4-mini') DEFAULT_MODEL = os.getenv('HERMES_WEBUI_DEFAULT_MODEL', 'openai/gpt-5.4-mini')
# ── Startup diagnostics ─────────────────────────────────────────────────────── # ── Startup diagnostics ───────────────────────────────────────────────────────
def print_startup_config(): def print_startup_config() -> None:
"""Print detected configuration at startup so the user can verify what was found.""" """Print detected configuration at startup so the user can verify what was found."""
ok = '\033[32m[ok]\033[0m' ok = '\033[32m[ok]\033[0m'
warn = '\033[33m[!!]\033[0m' warn = '\033[33m[!!]\033[0m'
@@ -243,7 +243,7 @@ def print_startup_config():
flush=True flush=True
) )
def verify_hermes_imports(): def verify_hermes_imports() -> tuple:
""" """
Attempt to import the key Hermes modules. Attempt to import the key Hermes modules.
Returns (ok: bool, missing: list[str], errors: dict[str, str]). Returns (ok: bool, missing: list[str], errors: dict[str, str]).
@@ -366,7 +366,7 @@ _PROVIDER_MODELS = {
} }
def resolve_model_provider(model_id: str): def resolve_model_provider(model_id: str) -> tuple:
"""Resolve bare model name, provider, and base_url for AIAgent. """Resolve bare model name, provider, and base_url for AIAgent.
Model IDs from the dropdown may include a provider prefix Model IDs from the dropdown may include a provider prefix

View File

@@ -6,14 +6,14 @@ from pathlib import Path
from api.config import IMAGE_EXTS, MD_EXTS from api.config import IMAGE_EXTS, MD_EXTS
def require(body: dict, *fields): def require(body: dict, *fields) -> None:
"""Phase D: Validate required fields. Raises ValueError with clean message.""" """Phase D: Validate required fields. Raises ValueError with clean message."""
missing = [f for f in fields if not body.get(f) and body.get(f) != 0] missing = [f for f in fields if not body.get(f) and body.get(f) != 0]
if missing: if missing:
raise ValueError(f"Missing required field(s): {', '.join(missing)}") raise ValueError(f"Missing required field(s): {', '.join(missing)}")
def bad(handler, msg, status=400): def bad(handler, msg, status: int=400):
"""Return a clean JSON error response.""" """Return a clean JSON error response."""
return j(handler, {'error': msg}, status=status) return j(handler, {'error': msg}, status=status)
@@ -32,7 +32,7 @@ def _security_headers(handler):
handler.send_header('Referrer-Policy', 'same-origin') handler.send_header('Referrer-Policy', 'same-origin')
def j(handler, payload, status=200): def j(handler, payload, status: int=200) -> None:
"""Send a JSON response.""" """Send a JSON response."""
body = _json.dumps(payload, ensure_ascii=False, indent=2).encode('utf-8') body = _json.dumps(payload, ensure_ascii=False, indent=2).encode('utf-8')
handler.send_response(status) handler.send_response(status)
@@ -44,7 +44,7 @@ def j(handler, payload, status=200):
handler.wfile.write(body) handler.wfile.write(body)
def t(handler, payload, status=200, content_type='text/plain; charset=utf-8'): def t(handler, payload, status: int=200, content_type: str='text/plain; charset=utf-8') -> None:
"""Send a plain text or HTML response.""" """Send a plain text or HTML response."""
body = payload if isinstance(payload, bytes) else str(payload).encode('utf-8') body = payload if isinstance(payload, bytes) else str(payload).encode('utf-8')
handler.send_response(status) handler.send_response(status)
@@ -59,7 +59,7 @@ def t(handler, payload, status=200, content_type='text/plain; charset=utf-8'):
MAX_BODY_BYTES = 20 * 1024 * 1024 # 20MB limit for non-upload POST bodies MAX_BODY_BYTES = 20 * 1024 * 1024 # 20MB limit for non-upload POST bodies
def read_body(handler): def read_body(handler) -> dict:
"""Read and JSON-parse a POST request body (capped at 20MB).""" """Read and JSON-parse a POST request body (capped at 20MB)."""
length = int(handler.headers.get('Content-Length', 0)) length = int(handler.headers.get('Content-Length', 0))
if length > MAX_BODY_BYTES: if length > MAX_BODY_BYTES:

View File

@@ -34,13 +34,13 @@ def _write_session_index():
class Session: class Session:
def __init__(self, session_id=None, title='Untitled', def __init__(self, session_id: int=None, title: str='Untitled',
workspace=str(DEFAULT_WORKSPACE), model=DEFAULT_MODEL, workspace=str(DEFAULT_WORKSPACE), model=DEFAULT_MODEL,
messages=None, created_at=None, updated_at=None, messages=None, created_at=None, updated_at=None,
tool_calls=None, pinned=False, archived=False, tool_calls=None, pinned: bool=False, archived: bool=False,
project_id=None, profile=None, project_id: int=None, profile=None,
input_tokens=0, output_tokens=0, estimated_cost=None, input_tokens: int=0, output_tokens: int=0, estimated_cost=None,
**kwargs): **kwargs: dict):
self.session_id = session_id or uuid.uuid4().hex[:12] self.session_id = session_id or uuid.uuid4().hex[:12]
self.title = title self.title = title
self.workspace = str(Path(workspace).expanduser().resolve()) self.workspace = str(Path(workspace).expanduser().resolve())
@@ -61,7 +61,7 @@ class Session:
def path(self): def path(self):
return SESSION_DIR / f'{self.session_id}.json' return SESSION_DIR / f'{self.session_id}.json'
def save(self): def save(self) -> None:
self.updated_at = time.time() self.updated_at = time.time()
self.path.write_text( self.path.write_text(
json.dumps(self.__dict__, ensure_ascii=False, indent=2), json.dumps(self.__dict__, ensure_ascii=False, indent=2),
@@ -70,13 +70,13 @@ class Session:
_write_session_index() _write_session_index()
@classmethod @classmethod
def load(cls, sid): def load(cls, sid) -> None:
p = SESSION_DIR / f'{sid}.json' p = SESSION_DIR / f'{sid}.json'
if not p.exists(): if not p.exists():
return None return None
return cls(**json.loads(p.read_text(encoding='utf-8'))) return cls(**json.loads(p.read_text(encoding='utf-8')))
def compact(self): def compact(self) -> dict:
return { return {
'session_id': self.session_id, 'session_id': self.session_id,
'title': self.title, 'title': self.title,
@@ -165,7 +165,7 @@ def all_sessions():
return result return result
def title_from(messages, fallback='Untitled'): def title_from(messages, fallback: str='Untitled'):
"""Derive a session title from the first user message.""" """Derive a session title from the first user message."""
for m in messages: for m in messages:
if m.get('role') == 'user': if m.get('role') == 'user':
@@ -180,7 +180,7 @@ def title_from(messages, fallback='Untitled'):
# ── Project helpers ────────────────────────────────────────────────────────── # ── Project helpers ──────────────────────────────────────────────────────────
def load_projects(): def load_projects() -> list:
"""Load project list from disk. Returns list of project dicts.""" """Load project list from disk. Returns list of project dicts."""
if not PROJECTS_FILE.exists(): if not PROJECTS_FILE.exists():
return [] return []
@@ -189,12 +189,12 @@ def load_projects():
except Exception: except Exception:
return [] return []
def save_projects(projects): def save_projects(projects) -> None:
"""Write project list to disk.""" """Write project list to disk."""
PROJECTS_FILE.write_text(json.dumps(projects, ensure_ascii=False, indent=2), encoding='utf-8') PROJECTS_FILE.write_text(json.dumps(projects, ensure_ascii=False, indent=2), encoding='utf-8')
def import_cli_session(session_id, title, messages, model='unknown', profile=None): def import_cli_session(session_id: int, title: str, messages, model: str='unknown', profile=None):
"""Create a new WebUI session populated with CLI messages. """Create a new WebUI session populated with CLI messages.
Returns the Session object. Returns the Session object.
""" """
@@ -212,7 +212,7 @@ def import_cli_session(session_id, title, messages, model='unknown', profile=Non
# ── CLI session bridge ────────────────────────────────────────────────────── # ── CLI session bridge ──────────────────────────────────────────────────────
def get_cli_sessions(): def get_cli_sessions() -> list:
"""Read CLI sessions from the agent's SQLite store and return them as """Read CLI sessions from the agent's SQLite store and return them as
dicts in a format the WebUI sidebar can render alongside local sessions. dicts in a format the WebUI sidebar can render alongside local sessions.
@@ -296,7 +296,7 @@ def get_cli_sessions():
return cli_sessions return cli_sessions
def get_cli_session_messages(sid): def get_cli_session_messages(sid) -> list:
"""Read messages for a single CLI session from the SQLite store. """Read messages for a single CLI session from the SQLite store.
Returns a list of {role, content, timestamp} dicts. Returns a list of {role, content, timestamp} dicts.
Returns empty list on any error. Returns empty list on any error.
@@ -338,7 +338,7 @@ def get_cli_session_messages(sid):
return msgs return msgs
def delete_cli_session(sid): def delete_cli_session(sid) -> bool:
"""Delete a CLI session from state.db (messages + session row). """Delete a CLI session from state.db (messages + session row).
Returns True if deleted, False if not found or error. Returns True if deleted, False if not found or error.
""" """

View File

@@ -137,7 +137,7 @@ def _reload_dotenv(home: Path):
pass pass
def init_profile_state(): def init_profile_state() -> None:
"""Initialize profile state at server startup. """Initialize profile state at server startup.
Reads ~/.hermes/active_profile, sets HERMES_HOME env var, patches Reads ~/.hermes/active_profile, sets HERMES_HOME env var, patches

View File

@@ -107,7 +107,7 @@ async function doLogin(e){
# ── GET routes ──────────────────────────────────────────────────────────────── # ── GET routes ────────────────────────────────────────────────────────────────
def handle_get(handler, parsed): def handle_get(handler, parsed) -> bool:
"""Handle all GET routes. Returns True if handled, False for 404.""" """Handle all GET routes. Returns True if handled, False for 404."""
if parsed.path in ('/', '/index.html'): if parsed.path in ('/', '/index.html'):
@@ -318,7 +318,7 @@ def handle_get(handler, parsed):
# ── POST routes ─────────────────────────────────────────────────────────────── # ── POST routes ───────────────────────────────────────────────────────────────
def handle_post(handler, parsed): def handle_post(handler, parsed) -> bool:
"""Handle all POST routes. Returns True if handled, False for 404.""" """Handle all POST routes. Returns True if handled, False for 404."""
if parsed.path == '/api/upload': if parsed.path == '/api/upload':

View File

@@ -43,7 +43,7 @@ def _get_state_db():
return None return None
def sync_session_start(session_id, model=None): def sync_session_start(session_id: int, model=None) -> None:
"""Register a WebUI session in state.db (idempotent). """Register a WebUI session in state.db (idempotent).
Called when a session's first message is sent. Called when a session's first message is sent.
""" """
@@ -65,8 +65,8 @@ def sync_session_start(session_id, model=None):
pass pass
def sync_session_usage(session_id, input_tokens=0, output_tokens=0, def sync_session_usage(session_id: int, input_tokens: int=0, output_tokens: int=0,
estimated_cost=None, model=None, title=None): estimated_cost=None, model=None, title: str=None) -> None:
"""Update token usage and title for a WebUI session in state.db. """Update token usage and title for a WebUI session in state.db.
Called after each turn completes. Uses absolute=True to set totals Called after each turn completes. Uses absolute=True to set totals
(the WebUI Session already accumulates across turns). (the WebUI Session already accumulates across turns).

View File

@@ -11,7 +11,7 @@ from api.models import get_session
from api.workspace import safe_resolve_ws from api.workspace import safe_resolve_ws
def parse_multipart(rfile, content_type, content_length): def parse_multipart(rfile, content_type, content_length) -> tuple:
import re as _re, email.parser as _ep import re as _re, email.parser as _ep
m = _re.search(r'boundary=([^;\s]+)', content_type) m = _re.search(r'boundary=([^;\s]+)', content_type)
if not m: if not m:

View File

@@ -176,7 +176,7 @@ def load_workspaces() -> list:
return [{'path': _profile_default_workspace(), 'name': 'Home'}] return [{'path': _profile_default_workspace(), 'name': 'Home'}]
def save_workspaces(workspaces: list): def save_workspaces(workspaces: list) -> None:
ws_file = _workspaces_file() ws_file = _workspaces_file()
ws_file.parent.mkdir(parents=True, exist_ok=True) ws_file.parent.mkdir(parents=True, exist_ok=True)
ws_file.write_text(json.dumps(workspaces, ensure_ascii=False, indent=2), encoding='utf-8') ws_file.write_text(json.dumps(workspaces, ensure_ascii=False, indent=2), encoding='utf-8')
@@ -202,7 +202,7 @@ def get_last_workspace() -> str:
return _profile_default_workspace() return _profile_default_workspace()
def set_last_workspace(path: str): def set_last_workspace(path: str) -> None:
try: try:
lw_file = _last_workspace_file() lw_file = _last_workspace_file()
lw_file.parent.mkdir(parents=True, exist_ok=True) lw_file.parent.mkdir(parents=True, exist_ok=True)
@@ -218,7 +218,7 @@ def safe_resolve_ws(root: Path, requested: str) -> Path:
return resolved return resolved
def list_dir(workspace: Path, rel='.'): def list_dir(workspace: Path, rel: str='.'):
target = safe_resolve_ws(workspace, rel) target = safe_resolve_ws(workspace, rel)
if not target.is_dir(): if not target.is_dir():
raise FileNotFoundError(f"Not a directory: {rel}") raise FileNotFoundError(f"Not a directory: {rel}")
@@ -235,7 +235,7 @@ def list_dir(workspace: Path, rel='.'):
return entries return entries
def read_file_content(workspace: Path, rel: str): def read_file_content(workspace: Path, rel: str) -> dict:
target = safe_resolve_ws(workspace, rel) target = safe_resolve_ws(workspace, rel)
if not target.is_file(): if not target.is_file():
raise FileNotFoundError(f"Not a file: {rel}") raise FileNotFoundError(f"Not a file: {rel}")

View File

@@ -18,7 +18,7 @@ class Handler(BaseHTTPRequestHandler):
server_version = 'HermesWebUI/0.2' server_version = 'HermesWebUI/0.2'
def log_message(self, fmt, *args): pass # suppress default Apache-style log def log_message(self, fmt, *args): pass # suppress default Apache-style log
def log_request(self, code='-', size='-'): def log_request(self, code: str='-', size: str='-') -> None:
"""Structured JSON logs for each request.""" """Structured JSON logs for each request."""
import json as _json import json as _json
duration_ms = round((time.time() - getattr(self, '_req_t0', time.time())) * 1000, 1) duration_ms = round((time.time() - getattr(self, '_req_t0', time.time())) * 1000, 1)
@@ -31,7 +31,7 @@ class Handler(BaseHTTPRequestHandler):
}) })
print(f'[webui] {record}', flush=True) print(f'[webui] {record}', flush=True)
def do_GET(self): def do_GET(self) -> None:
self._req_t0 = time.time() self._req_t0 = time.time()
try: try:
parsed = urlparse(self.path) parsed = urlparse(self.path)
@@ -43,7 +43,7 @@ class Handler(BaseHTTPRequestHandler):
print(f'[webui] ERROR {self.command} {self.path}\n' + traceback.format_exc(), flush=True) print(f'[webui] ERROR {self.command} {self.path}\n' + traceback.format_exc(), flush=True)
return j(self, {'error': 'Internal server error'}, status=500) return j(self, {'error': 'Internal server error'}, status=500)
def do_POST(self): def do_POST(self) -> None:
self._req_t0 = time.time() self._req_t0 = time.time()
try: try:
parsed = urlparse(self.path) parsed = urlparse(self.path)
@@ -56,7 +56,7 @@ class Handler(BaseHTTPRequestHandler):
return j(self, {'error': 'Internal server error'}, status=500) return j(self, {'error': 'Internal server error'}, status=500)
def main(): def main() -> None:
from api.config import print_startup_config, verify_hermes_imports, _HERMES_FOUND from api.config import print_startup_config, verify_hermes_imports, _HERMES_FOUND
print_startup_config() print_startup_config()