chore: add missing type hints across 10 files
This commit is contained in:
20
api/auth.py
20
api/auth.py
@@ -57,7 +57,7 @@ def _hash_password(password):
|
||||
return dk.hex()
|
||||
|
||||
|
||||
def get_password_hash():
|
||||
def get_password_hash() -> bool:
|
||||
"""Return the active password hash, or None if auth is disabled.
|
||||
Priority: env var > settings.json."""
|
||||
env_pw = os.getenv('HERMES_WEBUI_PASSWORD', '').strip()
|
||||
@@ -67,12 +67,12 @@ def get_password_hash():
|
||||
return settings.get('password_hash') or None
|
||||
|
||||
|
||||
def is_auth_enabled():
|
||||
def is_auth_enabled() -> bool:
|
||||
"""True if a password is configured (env var or settings)."""
|
||||
return get_password_hash() is not None
|
||||
|
||||
|
||||
def verify_password(plain):
|
||||
def verify_password(plain) -> bool:
|
||||
"""Verify a plaintext password against the stored hash."""
|
||||
expected = get_password_hash()
|
||||
if not expected:
|
||||
@@ -80,7 +80,7 @@ def verify_password(plain):
|
||||
return hmac.compare_digest(_hash_password(plain), expected)
|
||||
|
||||
|
||||
def create_session():
|
||||
def create_session() -> str:
|
||||
"""Create a new auth session. Returns signed cookie value."""
|
||||
token = secrets.token_hex(32)
|
||||
_sessions[token] = time.time() + SESSION_TTL
|
||||
@@ -88,7 +88,7 @@ def create_session():
|
||||
return f"{token}.{sig}"
|
||||
|
||||
|
||||
def verify_session(cookie_value):
|
||||
def verify_session(cookie_value) -> bool:
|
||||
"""Verify a signed session cookie. Returns True if valid and not expired."""
|
||||
if not cookie_value or '.' not in cookie_value:
|
||||
return False
|
||||
@@ -103,14 +103,14 @@ def verify_session(cookie_value):
|
||||
return True
|
||||
|
||||
|
||||
def invalidate_session(cookie_value):
|
||||
def invalidate_session(cookie_value) -> None:
|
||||
"""Remove a session token."""
|
||||
if cookie_value and '.' in cookie_value:
|
||||
token = cookie_value.rsplit('.', 1)[0]
|
||||
_sessions.pop(token, None)
|
||||
|
||||
|
||||
def parse_cookie(handler):
|
||||
def parse_cookie(handler) -> None:
|
||||
"""Extract the auth cookie from the request headers."""
|
||||
cookie_header = handler.headers.get('Cookie', '')
|
||||
if not cookie_header:
|
||||
@@ -124,7 +124,7 @@ def parse_cookie(handler):
|
||||
return morsel.value if morsel else None
|
||||
|
||||
|
||||
def check_auth(handler, parsed):
|
||||
def check_auth(handler, parsed) -> bool:
|
||||
"""Check if request is authorized. Returns True if OK.
|
||||
If not authorized, sends 401 (API) or 302 redirect (page) and returns False."""
|
||||
if not is_auth_enabled():
|
||||
@@ -149,7 +149,7 @@ def check_auth(handler, parsed):
|
||||
return False
|
||||
|
||||
|
||||
def set_auth_cookie(handler, cookie_value):
|
||||
def set_auth_cookie(handler, cookie_value) -> None:
|
||||
"""Set the auth cookie on the response."""
|
||||
cookie = http.cookies.SimpleCookie()
|
||||
cookie[COOKIE_NAME] = cookie_value
|
||||
@@ -160,7 +160,7 @@ def set_auth_cookie(handler, cookie_value):
|
||||
handler.send_header('Set-Cookie', cookie[COOKIE_NAME].OutputString())
|
||||
|
||||
|
||||
def clear_auth_cookie(handler):
|
||||
def clear_auth_cookie(handler) -> None:
|
||||
"""Clear the auth cookie on the response."""
|
||||
cookie = http.cookies.SimpleCookie()
|
||||
cookie[COOKIE_NAME] = ''
|
||||
|
||||
@@ -169,7 +169,7 @@ def get_config() -> dict:
|
||||
reload_config()
|
||||
return _cfg_cache
|
||||
|
||||
def reload_config():
|
||||
def reload_config() -> None:
|
||||
"""Reload config.yaml from the active profile's directory."""
|
||||
with _cfg_lock:
|
||||
_cfg_cache.clear()
|
||||
@@ -208,7 +208,7 @@ DEFAULT_WORKSPACE = _discover_default_workspace()
|
||||
DEFAULT_MODEL = os.getenv('HERMES_WEBUI_DEFAULT_MODEL', 'openai/gpt-5.4-mini')
|
||||
|
||||
# ── Startup diagnostics ───────────────────────────────────────────────────────
|
||||
def print_startup_config():
|
||||
def print_startup_config() -> None:
|
||||
"""Print detected configuration at startup so the user can verify what was found."""
|
||||
ok = '\033[32m[ok]\033[0m'
|
||||
warn = '\033[33m[!!]\033[0m'
|
||||
@@ -243,7 +243,7 @@ def print_startup_config():
|
||||
flush=True
|
||||
)
|
||||
|
||||
def verify_hermes_imports():
|
||||
def verify_hermes_imports() -> tuple:
|
||||
"""
|
||||
Attempt to import the key Hermes modules.
|
||||
Returns (ok: bool, missing: list[str], errors: dict[str, str]).
|
||||
@@ -366,7 +366,7 @@ _PROVIDER_MODELS = {
|
||||
}
|
||||
|
||||
|
||||
def resolve_model_provider(model_id: str):
|
||||
def resolve_model_provider(model_id: str) -> tuple:
|
||||
"""Resolve bare model name, provider, and base_url for AIAgent.
|
||||
|
||||
Model IDs from the dropdown may include a provider prefix
|
||||
|
||||
@@ -6,14 +6,14 @@ from pathlib import Path
|
||||
from api.config import IMAGE_EXTS, MD_EXTS
|
||||
|
||||
|
||||
def require(body: dict, *fields):
|
||||
def require(body: dict, *fields) -> None:
|
||||
"""Phase D: Validate required fields. Raises ValueError with clean message."""
|
||||
missing = [f for f in fields if not body.get(f) and body.get(f) != 0]
|
||||
if missing:
|
||||
raise ValueError(f"Missing required field(s): {', '.join(missing)}")
|
||||
|
||||
|
||||
def bad(handler, msg, status=400):
|
||||
def bad(handler, msg, status: int=400):
|
||||
"""Return a clean JSON error response."""
|
||||
return j(handler, {'error': msg}, status=status)
|
||||
|
||||
@@ -32,7 +32,7 @@ def _security_headers(handler):
|
||||
handler.send_header('Referrer-Policy', 'same-origin')
|
||||
|
||||
|
||||
def j(handler, payload, status=200):
|
||||
def j(handler, payload, status: int=200) -> None:
|
||||
"""Send a JSON response."""
|
||||
body = _json.dumps(payload, ensure_ascii=False, indent=2).encode('utf-8')
|
||||
handler.send_response(status)
|
||||
@@ -44,7 +44,7 @@ def j(handler, payload, status=200):
|
||||
handler.wfile.write(body)
|
||||
|
||||
|
||||
def t(handler, payload, status=200, content_type='text/plain; charset=utf-8'):
|
||||
def t(handler, payload, status: int=200, content_type: str='text/plain; charset=utf-8') -> None:
|
||||
"""Send a plain text or HTML response."""
|
||||
body = payload if isinstance(payload, bytes) else str(payload).encode('utf-8')
|
||||
handler.send_response(status)
|
||||
@@ -59,7 +59,7 @@ def t(handler, payload, status=200, content_type='text/plain; charset=utf-8'):
|
||||
MAX_BODY_BYTES = 20 * 1024 * 1024 # 20MB limit for non-upload POST bodies
|
||||
|
||||
|
||||
def read_body(handler):
|
||||
def read_body(handler) -> dict:
|
||||
"""Read and JSON-parse a POST request body (capped at 20MB)."""
|
||||
length = int(handler.headers.get('Content-Length', 0))
|
||||
if length > MAX_BODY_BYTES:
|
||||
|
||||
@@ -34,13 +34,13 @@ def _write_session_index():
|
||||
|
||||
|
||||
class Session:
|
||||
def __init__(self, session_id=None, title='Untitled',
|
||||
def __init__(self, session_id: int=None, title: str='Untitled',
|
||||
workspace=str(DEFAULT_WORKSPACE), model=DEFAULT_MODEL,
|
||||
messages=None, created_at=None, updated_at=None,
|
||||
tool_calls=None, pinned=False, archived=False,
|
||||
project_id=None, profile=None,
|
||||
input_tokens=0, output_tokens=0, estimated_cost=None,
|
||||
**kwargs):
|
||||
tool_calls=None, pinned: bool=False, archived: bool=False,
|
||||
project_id: int=None, profile=None,
|
||||
input_tokens: int=0, output_tokens: int=0, estimated_cost=None,
|
||||
**kwargs: dict):
|
||||
self.session_id = session_id or uuid.uuid4().hex[:12]
|
||||
self.title = title
|
||||
self.workspace = str(Path(workspace).expanduser().resolve())
|
||||
@@ -61,7 +61,7 @@ class Session:
|
||||
def path(self):
|
||||
return SESSION_DIR / f'{self.session_id}.json'
|
||||
|
||||
def save(self):
|
||||
def save(self) -> None:
|
||||
self.updated_at = time.time()
|
||||
self.path.write_text(
|
||||
json.dumps(self.__dict__, ensure_ascii=False, indent=2),
|
||||
@@ -70,13 +70,13 @@ class Session:
|
||||
_write_session_index()
|
||||
|
||||
@classmethod
|
||||
def load(cls, sid):
|
||||
def load(cls, sid) -> None:
|
||||
p = SESSION_DIR / f'{sid}.json'
|
||||
if not p.exists():
|
||||
return None
|
||||
return cls(**json.loads(p.read_text(encoding='utf-8')))
|
||||
|
||||
def compact(self):
|
||||
def compact(self) -> dict:
|
||||
return {
|
||||
'session_id': self.session_id,
|
||||
'title': self.title,
|
||||
@@ -165,7 +165,7 @@ def all_sessions():
|
||||
return result
|
||||
|
||||
|
||||
def title_from(messages, fallback='Untitled'):
|
||||
def title_from(messages, fallback: str='Untitled'):
|
||||
"""Derive a session title from the first user message."""
|
||||
for m in messages:
|
||||
if m.get('role') == 'user':
|
||||
@@ -180,7 +180,7 @@ def title_from(messages, fallback='Untitled'):
|
||||
|
||||
# ── Project helpers ──────────────────────────────────────────────────────────
|
||||
|
||||
def load_projects():
|
||||
def load_projects() -> list:
|
||||
"""Load project list from disk. Returns list of project dicts."""
|
||||
if not PROJECTS_FILE.exists():
|
||||
return []
|
||||
@@ -189,12 +189,12 @@ def load_projects():
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def save_projects(projects):
|
||||
def save_projects(projects) -> None:
|
||||
"""Write project list to disk."""
|
||||
PROJECTS_FILE.write_text(json.dumps(projects, ensure_ascii=False, indent=2), encoding='utf-8')
|
||||
|
||||
|
||||
def import_cli_session(session_id, title, messages, model='unknown', profile=None):
|
||||
def import_cli_session(session_id: int, title: str, messages, model: str='unknown', profile=None):
|
||||
"""Create a new WebUI session populated with CLI messages.
|
||||
Returns the Session object.
|
||||
"""
|
||||
@@ -212,7 +212,7 @@ def import_cli_session(session_id, title, messages, model='unknown', profile=Non
|
||||
|
||||
# ── CLI session bridge ──────────────────────────────────────────────────────
|
||||
|
||||
def get_cli_sessions():
|
||||
def get_cli_sessions() -> list:
|
||||
"""Read CLI sessions from the agent's SQLite store and return them as
|
||||
dicts in a format the WebUI sidebar can render alongside local sessions.
|
||||
|
||||
@@ -296,7 +296,7 @@ def get_cli_sessions():
|
||||
return cli_sessions
|
||||
|
||||
|
||||
def get_cli_session_messages(sid):
|
||||
def get_cli_session_messages(sid) -> list:
|
||||
"""Read messages for a single CLI session from the SQLite store.
|
||||
Returns a list of {role, content, timestamp} dicts.
|
||||
Returns empty list on any error.
|
||||
@@ -338,7 +338,7 @@ def get_cli_session_messages(sid):
|
||||
return msgs
|
||||
|
||||
|
||||
def delete_cli_session(sid):
|
||||
def delete_cli_session(sid) -> bool:
|
||||
"""Delete a CLI session from state.db (messages + session row).
|
||||
Returns True if deleted, False if not found or error.
|
||||
"""
|
||||
|
||||
@@ -137,7 +137,7 @@ def _reload_dotenv(home: Path):
|
||||
pass
|
||||
|
||||
|
||||
def init_profile_state():
|
||||
def init_profile_state() -> None:
|
||||
"""Initialize profile state at server startup.
|
||||
|
||||
Reads ~/.hermes/active_profile, sets HERMES_HOME env var, patches
|
||||
|
||||
@@ -107,7 +107,7 @@ async function doLogin(e){
|
||||
|
||||
# ── GET routes ────────────────────────────────────────────────────────────────
|
||||
|
||||
def handle_get(handler, parsed):
|
||||
def handle_get(handler, parsed) -> bool:
|
||||
"""Handle all GET routes. Returns True if handled, False for 404."""
|
||||
|
||||
if parsed.path in ('/', '/index.html'):
|
||||
@@ -318,7 +318,7 @@ def handle_get(handler, parsed):
|
||||
|
||||
# ── POST routes ───────────────────────────────────────────────────────────────
|
||||
|
||||
def handle_post(handler, parsed):
|
||||
def handle_post(handler, parsed) -> bool:
|
||||
"""Handle all POST routes. Returns True if handled, False for 404."""
|
||||
|
||||
if parsed.path == '/api/upload':
|
||||
|
||||
@@ -43,7 +43,7 @@ def _get_state_db():
|
||||
return None
|
||||
|
||||
|
||||
def sync_session_start(session_id, model=None):
|
||||
def sync_session_start(session_id: int, model=None) -> None:
|
||||
"""Register a WebUI session in state.db (idempotent).
|
||||
Called when a session's first message is sent.
|
||||
"""
|
||||
@@ -65,8 +65,8 @@ def sync_session_start(session_id, model=None):
|
||||
pass
|
||||
|
||||
|
||||
def sync_session_usage(session_id, input_tokens=0, output_tokens=0,
|
||||
estimated_cost=None, model=None, title=None):
|
||||
def sync_session_usage(session_id: int, input_tokens: int=0, output_tokens: int=0,
|
||||
estimated_cost=None, model=None, title: str=None) -> None:
|
||||
"""Update token usage and title for a WebUI session in state.db.
|
||||
Called after each turn completes. Uses absolute=True to set totals
|
||||
(the WebUI Session already accumulates across turns).
|
||||
|
||||
@@ -11,7 +11,7 @@ from api.models import get_session
|
||||
from api.workspace import safe_resolve_ws
|
||||
|
||||
|
||||
def parse_multipart(rfile, content_type, content_length):
|
||||
def parse_multipart(rfile, content_type, content_length) -> tuple:
|
||||
import re as _re, email.parser as _ep
|
||||
m = _re.search(r'boundary=([^;\s]+)', content_type)
|
||||
if not m:
|
||||
|
||||
@@ -176,7 +176,7 @@ def load_workspaces() -> list:
|
||||
return [{'path': _profile_default_workspace(), 'name': 'Home'}]
|
||||
|
||||
|
||||
def save_workspaces(workspaces: list):
|
||||
def save_workspaces(workspaces: list) -> None:
|
||||
ws_file = _workspaces_file()
|
||||
ws_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
ws_file.write_text(json.dumps(workspaces, ensure_ascii=False, indent=2), encoding='utf-8')
|
||||
@@ -202,7 +202,7 @@ def get_last_workspace() -> str:
|
||||
return _profile_default_workspace()
|
||||
|
||||
|
||||
def set_last_workspace(path: str):
|
||||
def set_last_workspace(path: str) -> None:
|
||||
try:
|
||||
lw_file = _last_workspace_file()
|
||||
lw_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
@@ -218,7 +218,7 @@ def safe_resolve_ws(root: Path, requested: str) -> Path:
|
||||
return resolved
|
||||
|
||||
|
||||
def list_dir(workspace: Path, rel='.'):
|
||||
def list_dir(workspace: Path, rel: str='.'):
|
||||
target = safe_resolve_ws(workspace, rel)
|
||||
if not target.is_dir():
|
||||
raise FileNotFoundError(f"Not a directory: {rel}")
|
||||
@@ -235,7 +235,7 @@ def list_dir(workspace: Path, rel='.'):
|
||||
return entries
|
||||
|
||||
|
||||
def read_file_content(workspace: Path, rel: str):
|
||||
def read_file_content(workspace: Path, rel: str) -> dict:
|
||||
target = safe_resolve_ws(workspace, rel)
|
||||
if not target.is_file():
|
||||
raise FileNotFoundError(f"Not a file: {rel}")
|
||||
|
||||
Reference in New Issue
Block a user