feat(govdoc): 新增内部公文模块全链路(后端58+前端11文件)
This commit is contained in:
@@ -0,0 +1,258 @@
|
||||
"""Qwen LLM 客户端(OpenAI 兼容协议)。
|
||||
|
||||
包含:超时(asyncio.wait_for)、重试(指数退避)、并发上限(Semaphore)。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from openai import AsyncOpenAI, OpenAI, APIError, APIConnectionError, RateLimitError
|
||||
|
||||
from fastapi_admin.config import (
|
||||
LLM_API_KEY,
|
||||
LLM_BASE_URL,
|
||||
LLM_MODEL,
|
||||
LEAUDIT_LLM_MAX_CONCURRENCY,
|
||||
LEAUDIT_LLM_REQUEST_TIMEOUT,
|
||||
LEAUDIT_LLM_RETRY_MAX_ATTEMPTS,
|
||||
LEAUDIT_LLM_RETRY_BACKOFF_BASE_SECONDS,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.govdoc_engine.llm.cache import LlmCache, make_key
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_FENCE_RE = re.compile(r"```(?:json)?\s*([\s\S]+?)\s*```", re.MULTILINE)
|
||||
|
||||
# 这些异常会触发重试;JSON 解析错误等业务错误不重试
|
||||
_RETRYABLE = (
|
||||
asyncio.TimeoutError,
|
||||
TimeoutError,
|
||||
APIConnectionError,
|
||||
RateLimitError,
|
||||
)
|
||||
|
||||
|
||||
class LlmJsonError(Exception):
|
||||
"""LLM 返回内容无法解析为 JSON。"""
|
||||
|
||||
|
||||
class LlmConfigError(Exception):
|
||||
"""LLM 客户端缺少必要配置。"""
|
||||
|
||||
|
||||
def _parse_json_text(text: str) -> dict[str, Any]:
|
||||
text = text.strip()
|
||||
m = _FENCE_RE.search(text)
|
||||
if m:
|
||||
text = m.group(1)
|
||||
try:
|
||||
return json.loads(text)
|
||||
except json.JSONDecodeError:
|
||||
start = text.find("{")
|
||||
end = text.rfind("}")
|
||||
if start >= 0 and end > start:
|
||||
try:
|
||||
return json.loads(text[start : end + 1])
|
||||
except json.JSONDecodeError as e:
|
||||
raise LlmJsonError(f"failed to parse LLM JSON: {text!r}") from e
|
||||
raise LlmJsonError(f"LLM returned non-JSON content: {text!r}")
|
||||
|
||||
|
||||
def _is_retryable_status(exc: Exception) -> bool:
|
||||
"""APIError 中只重试 5xx 与 429。"""
|
||||
if isinstance(exc, RateLimitError):
|
||||
return True
|
||||
if isinstance(exc, APIError):
|
||||
status = getattr(exc, "status_code", None)
|
||||
return status is not None and (status >= 500 or status == 429)
|
||||
return False
|
||||
|
||||
|
||||
def _clip_text(value: Any, limit: int = 400) -> str:
|
||||
text = str(value).strip()
|
||||
if len(text) <= limit:
|
||||
return text
|
||||
return text[: limit - 3] + "..."
|
||||
|
||||
|
||||
def _format_exc(exc: Exception) -> str:
|
||||
text = str(exc).strip()
|
||||
parts = [exc.__class__.__name__]
|
||||
if text:
|
||||
parts.append(text)
|
||||
|
||||
status = getattr(exc, "status_code", None)
|
||||
if status is not None:
|
||||
parts.append(f"status={status}")
|
||||
|
||||
body = getattr(exc, "body", None)
|
||||
if body not in (None, "", b""):
|
||||
parts.append(f"body={_clip_text(body)}")
|
||||
|
||||
response = getattr(exc, "response", None)
|
||||
if response is not None:
|
||||
try:
|
||||
request = getattr(response, "request", None)
|
||||
if request is not None and getattr(request, "url", None):
|
||||
parts.append(f"url={request.url}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
request = getattr(exc, "request", None)
|
||||
if request is not None and getattr(request, "url", None):
|
||||
parts.append(f"url={request.url}")
|
||||
|
||||
return ": ".join(parts[:2]) + ("" if len(parts) <= 2 else " | " + " | ".join(parts[2:]))
|
||||
|
||||
|
||||
class LlmClient:
|
||||
def __init__(
|
||||
self,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
model: str | None = None,
|
||||
max_concurrency: int | None = None,
|
||||
timeout_seconds: float | None = None,
|
||||
max_retries: int | None = None,
|
||||
cache: LlmCache | None = None,
|
||||
cache_enabled: bool | None = None,
|
||||
):
|
||||
key = api_key or LLM_API_KEY
|
||||
self._misconfigured_error: LlmConfigError | None = None
|
||||
if not key:
|
||||
self._client = None
|
||||
self._aclient = None
|
||||
self._misconfigured_error = LlmConfigError(
|
||||
"LLM_API_KEY is not configured. Set LLM_API_KEY in platform config."
|
||||
)
|
||||
else:
|
||||
self._client = OpenAI(api_key=key, base_url=base_url or LLM_BASE_URL)
|
||||
self._aclient = AsyncOpenAI(api_key=key, base_url=base_url or LLM_BASE_URL)
|
||||
self.model = model or LLM_MODEL
|
||||
self.timeout = timeout_seconds if timeout_seconds is not None else LEAUDIT_LLM_REQUEST_TIMEOUT
|
||||
self.max_retries = max_retries if max_retries is not None else LEAUDIT_LLM_RETRY_MAX_ATTEMPTS
|
||||
conc = max_concurrency if max_concurrency is not None else LEAUDIT_LLM_MAX_CONCURRENCY
|
||||
self._sem = asyncio.Semaphore(conc)
|
||||
# 缓存:cache 显式传入则用之;否则默认关闭。
|
||||
if cache is not None:
|
||||
self.cache: LlmCache | None = cache
|
||||
elif cache_enabled is not False:
|
||||
self.cache = None
|
||||
else:
|
||||
self.cache = None
|
||||
|
||||
def _ensure_ready(self) -> None:
|
||||
if self._misconfigured_error is not None:
|
||||
raise self._misconfigured_error
|
||||
|
||||
@staticmethod
|
||||
def _prompt_text(messages: list[dict[str, str]]) -> str:
|
||||
return "\n\n".join(
|
||||
f"[{m.get('role', 'user')}]\n{m.get('content', '')}"
|
||||
for m in messages
|
||||
)
|
||||
|
||||
# -- 同步路径 -------------------------------------------------------
|
||||
def chat(self, messages: list[dict[str, str]], **kwargs) -> str:
|
||||
self._ensure_ready()
|
||||
use_cache = kwargs.pop("use_cache", True)
|
||||
label = kwargs.pop("label", "llm_call")
|
||||
cache_kwargs = {k: kwargs.get(k) for k in ("temperature", "top_p", "max_tokens", "response_format")}
|
||||
cache_key: str | None = None
|
||||
prompt_text = self._prompt_text(messages)
|
||||
t0 = time.monotonic()
|
||||
|
||||
if use_cache and self.cache is not None:
|
||||
cache_key = make_key(self.model, messages, **cache_kwargs)
|
||||
hit = self.cache.get(cache_key)
|
||||
if hit is not None:
|
||||
_log.debug("LLM cache HIT key=%s", cache_key[:12])
|
||||
return hit
|
||||
|
||||
kwargs.setdefault("timeout", self.timeout)
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
resp = self._client.chat.completions.create(
|
||||
model=self.model, messages=messages, **kwargs
|
||||
)
|
||||
content = resp.choices[0].message.content or ""
|
||||
if cache_key is not None and content:
|
||||
self.cache.put(cache_key, self.model, content)
|
||||
return content
|
||||
except _RETRYABLE as e:
|
||||
last_exc = e
|
||||
except APIError as e:
|
||||
if not _is_retryable_status(e):
|
||||
raise
|
||||
last_exc = e
|
||||
if attempt < self.max_retries:
|
||||
wait = min(8.0, 2 ** attempt)
|
||||
_log.warning(
|
||||
"LLM call failed (%s); retry %d/%d after %.1fs",
|
||||
_format_exc(last_exc), attempt + 1, self.max_retries, wait,
|
||||
)
|
||||
time.sleep(wait)
|
||||
assert last_exc is not None
|
||||
raise last_exc
|
||||
|
||||
def chat_json(self, messages: list[dict[str, str]], **kwargs) -> dict[str, Any]:
|
||||
return _parse_json_text(self.chat(messages, **kwargs))
|
||||
|
||||
# -- 异步路径 -------------------------------------------------------
|
||||
async def chat_async(self, messages: list[dict[str, str]], **kwargs) -> str:
|
||||
self._ensure_ready()
|
||||
use_cache = kwargs.pop("use_cache", True)
|
||||
label = kwargs.pop("label", "llm_call")
|
||||
cache_kwargs = {k: kwargs.get(k) for k in ("temperature", "top_p", "max_tokens", "response_format")}
|
||||
cache_key: str | None = None
|
||||
prompt_text = self._prompt_text(messages)
|
||||
t0 = time.monotonic()
|
||||
|
||||
if use_cache and self.cache is not None:
|
||||
cache_key = make_key(self.model, messages, **cache_kwargs)
|
||||
hit = self.cache.get(cache_key)
|
||||
if hit is not None:
|
||||
_log.debug("LLM cache HIT key=%s", cache_key[:12])
|
||||
return hit
|
||||
|
||||
last_exc: Exception | None = None
|
||||
for attempt in range(self.max_retries + 1):
|
||||
try:
|
||||
async with self._sem:
|
||||
resp = await asyncio.wait_for(
|
||||
self._aclient.chat.completions.create(
|
||||
model=self.model, messages=messages, **kwargs,
|
||||
),
|
||||
timeout=self.timeout,
|
||||
)
|
||||
content = resp.choices[0].message.content or ""
|
||||
if cache_key is not None and content:
|
||||
self.cache.put(cache_key, self.model, content)
|
||||
return content
|
||||
except _RETRYABLE as e:
|
||||
last_exc = e
|
||||
except APIError as e:
|
||||
if not _is_retryable_status(e):
|
||||
raise
|
||||
last_exc = e
|
||||
if attempt < self.max_retries:
|
||||
wait = min(8.0, 2 ** attempt)
|
||||
_log.warning(
|
||||
"LLM async call failed (%s); retry %d/%d after %.1fs",
|
||||
_format_exc(last_exc), attempt + 1, self.max_retries, wait,
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
assert last_exc is not None
|
||||
raise last_exc
|
||||
|
||||
async def chat_json_async(
|
||||
self, messages: list[dict[str, str]], **kwargs
|
||||
) -> dict[str, Any]:
|
||||
return _parse_json_text(await self.chat_async(messages, **kwargs))
|
||||
Reference in New Issue
Block a user