102 lines
3.2 KiB
Python
102 lines
3.2 KiB
Python
"""LLM 响应缓存(SQLite)。
|
|
|
|
缓存键 = sha256(model + canonical_json(messages, temperature, top_p, max_tokens))。
|
|
仅缓存成功返回的文本;JSON 解析失败、API 错误、超时一律不入库。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import sqlite3
|
|
import time
|
|
from pathlib import Path
|
|
from threading import Lock
|
|
from typing import Any
|
|
|
|
_log = logging.getLogger(__name__)
|
|
|
|
_SCHEMA = """
|
|
CREATE TABLE IF NOT EXISTS llm_cache (
|
|
cache_key TEXT PRIMARY KEY,
|
|
model TEXT NOT NULL,
|
|
response_text TEXT NOT NULL,
|
|
created_at REAL NOT NULL,
|
|
hit_count INTEGER NOT NULL DEFAULT 0,
|
|
last_hit_at REAL
|
|
);
|
|
CREATE INDEX IF NOT EXISTS idx_llm_cache_created ON llm_cache(created_at);
|
|
"""
|
|
|
|
# 影响响应的关键参数。其他 kwargs 不入 hash(如 stream/timeout)。
|
|
_KEY_PARAMS = ("temperature", "top_p", "max_tokens", "response_format")
|
|
|
|
|
|
def _canonical(messages: list[dict[str, str]], **kwargs: Any) -> str:
|
|
payload = {
|
|
"messages": messages,
|
|
"params": {k: kwargs.get(k) for k in _KEY_PARAMS},
|
|
}
|
|
return json.dumps(payload, sort_keys=True, ensure_ascii=False)
|
|
|
|
|
|
def make_key(model: str, messages: list[dict[str, str]], **kwargs: Any) -> str:
|
|
raw = model + "\x00" + _canonical(messages, **kwargs)
|
|
return hashlib.sha256(raw.encode("utf-8")).hexdigest()
|
|
|
|
|
|
class LlmCache:
|
|
def __init__(self, path: str | Path):
|
|
self.path = Path(path)
|
|
self.path.parent.mkdir(parents=True, exist_ok=True)
|
|
self._lock = Lock()
|
|
self._conn = sqlite3.connect(str(self.path), check_same_thread=False)
|
|
self._conn.executescript(_SCHEMA)
|
|
self._conn.commit()
|
|
|
|
def get(self, key: str) -> str | None:
|
|
with self._lock:
|
|
row = self._conn.execute(
|
|
"SELECT response_text FROM llm_cache WHERE cache_key = ?",
|
|
(key,),
|
|
).fetchone()
|
|
if row is None:
|
|
return None
|
|
self._conn.execute(
|
|
"UPDATE llm_cache "
|
|
"SET hit_count = hit_count + 1, last_hit_at = ? "
|
|
"WHERE cache_key = ?",
|
|
(time.time(), key),
|
|
)
|
|
self._conn.commit()
|
|
return row[0]
|
|
|
|
def put(self, key: str, model: str, response_text: str) -> None:
|
|
if not response_text:
|
|
return
|
|
with self._lock:
|
|
self._conn.execute(
|
|
"INSERT OR IGNORE INTO llm_cache "
|
|
"(cache_key, model, response_text, created_at) "
|
|
"VALUES (?, ?, ?, ?)",
|
|
(key, model, response_text, time.time()),
|
|
)
|
|
self._conn.commit()
|
|
|
|
def stats(self) -> dict[str, int]:
|
|
with self._lock:
|
|
row = self._conn.execute(
|
|
"SELECT COUNT(*), COALESCE(SUM(hit_count), 0) FROM llm_cache"
|
|
).fetchone()
|
|
return {"entries": int(row[0] or 0), "total_hits": int(row[1] or 0)}
|
|
|
|
def clear(self) -> int:
|
|
with self._lock:
|
|
cur = self._conn.execute("DELETE FROM llm_cache")
|
|
self._conn.commit()
|
|
return cur.rowcount
|
|
|
|
def close(self) -> None:
|
|
with self._lock:
|
|
self._conn.close()
|