Files
leaudit-platform-backend/fastapi_modules/fastapi_leaudit/govdoc_engine/llm/cache.py
T

102 lines
3.2 KiB
Python

"""LLM 响应缓存(SQLite)。
缓存键 = sha256(model + canonical_json(messages, temperature, top_p, max_tokens))。
仅缓存成功返回的文本;JSON 解析失败、API 错误、超时一律不入库。
"""
from __future__ import annotations
import hashlib
import json
import logging
import sqlite3
import time
from pathlib import Path
from threading import Lock
from typing import Any
_log = logging.getLogger(__name__)
_SCHEMA = """
CREATE TABLE IF NOT EXISTS llm_cache (
cache_key TEXT PRIMARY KEY,
model TEXT NOT NULL,
response_text TEXT NOT NULL,
created_at REAL NOT NULL,
hit_count INTEGER NOT NULL DEFAULT 0,
last_hit_at REAL
);
CREATE INDEX IF NOT EXISTS idx_llm_cache_created ON llm_cache(created_at);
"""
# 影响响应的关键参数。其他 kwargs 不入 hash(如 stream/timeout)。
_KEY_PARAMS = ("temperature", "top_p", "max_tokens", "response_format")
def _canonical(messages: list[dict[str, str]], **kwargs: Any) -> str:
payload = {
"messages": messages,
"params": {k: kwargs.get(k) for k in _KEY_PARAMS},
}
return json.dumps(payload, sort_keys=True, ensure_ascii=False)
def make_key(model: str, messages: list[dict[str, str]], **kwargs: Any) -> str:
raw = model + "\x00" + _canonical(messages, **kwargs)
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
class LlmCache:
def __init__(self, path: str | Path):
self.path = Path(path)
self.path.parent.mkdir(parents=True, exist_ok=True)
self._lock = Lock()
self._conn = sqlite3.connect(str(self.path), check_same_thread=False)
self._conn.executescript(_SCHEMA)
self._conn.commit()
def get(self, key: str) -> str | None:
with self._lock:
row = self._conn.execute(
"SELECT response_text FROM llm_cache WHERE cache_key = ?",
(key,),
).fetchone()
if row is None:
return None
self._conn.execute(
"UPDATE llm_cache "
"SET hit_count = hit_count + 1, last_hit_at = ? "
"WHERE cache_key = ?",
(time.time(), key),
)
self._conn.commit()
return row[0]
def put(self, key: str, model: str, response_text: str) -> None:
if not response_text:
return
with self._lock:
self._conn.execute(
"INSERT OR IGNORE INTO llm_cache "
"(cache_key, model, response_text, created_at) "
"VALUES (?, ?, ?, ?)",
(key, model, response_text, time.time()),
)
self._conn.commit()
def stats(self) -> dict[str, int]:
with self._lock:
row = self._conn.execute(
"SELECT COUNT(*), COALESCE(SUM(hit_count), 0) FROM llm_cache"
).fetchone()
return {"entries": int(row[0] or 0), "total_hits": int(row[1] or 0)}
def clear(self) -> int:
with self._lock:
cur = self._conn.execute("DELETE FROM llm_cache")
self._conn.commit()
return cur.rowcount
def close(self) -> None:
with self._lock:
self._conn.close()