280 lines
9.8 KiB
Python
280 lines
9.8 KiB
Python
"""Qwen LLM 客户端(OpenAI 兼容协议)。
|
|
|
|
包含:超时(asyncio.wait_for)、重试(指数退避)、并发上限(Semaphore)。
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
import asyncio
|
|
import json
|
|
import logging
|
|
import re
|
|
import time
|
|
from typing import Any
|
|
|
|
try:
|
|
from openai import AsyncOpenAI, OpenAI, APIError, APIConnectionError, RateLimitError
|
|
_OPENAI_IMPORT_ERROR: Exception | None = None
|
|
except ModuleNotFoundError as exc: # pragma: no cover - optional dependency
|
|
AsyncOpenAI = None # type: ignore[assignment]
|
|
OpenAI = None # type: ignore[assignment]
|
|
_OPENAI_IMPORT_ERROR = exc
|
|
|
|
class APIError(Exception):
|
|
status_code: int | None = None
|
|
|
|
class APIConnectionError(Exception):
|
|
pass
|
|
|
|
class RateLimitError(Exception):
|
|
pass
|
|
|
|
from fastapi_admin.config import (
|
|
LLM_API_KEY,
|
|
LLM_BASE_URL,
|
|
LLM_MODEL,
|
|
LEAUDIT_LLM_MAX_CONCURRENCY,
|
|
LEAUDIT_LLM_REQUEST_TIMEOUT,
|
|
LEAUDIT_LLM_RETRY_MAX_ATTEMPTS,
|
|
LEAUDIT_LLM_RETRY_BACKOFF_BASE_SECONDS,
|
|
)
|
|
from fastapi_modules.fastapi_leaudit.govdoc_engine.llm.cache import LlmCache, make_key
|
|
|
|
_log = logging.getLogger(__name__)
|
|
|
|
|
|
_FENCE_RE = re.compile(r"```(?:json)?\s*([\s\S]+?)\s*```", re.MULTILINE)
|
|
|
|
# 这些异常会触发重试;JSON 解析错误等业务错误不重试
|
|
_RETRYABLE = (
|
|
asyncio.TimeoutError,
|
|
TimeoutError,
|
|
APIConnectionError,
|
|
RateLimitError,
|
|
)
|
|
|
|
|
|
class LlmJsonError(Exception):
|
|
"""LLM 返回内容无法解析为 JSON。"""
|
|
|
|
|
|
class LlmConfigError(Exception):
|
|
"""LLM 客户端缺少必要配置。"""
|
|
|
|
|
|
def _parse_json_text(text: str) -> dict[str, Any]:
|
|
text = text.strip()
|
|
m = _FENCE_RE.search(text)
|
|
if m:
|
|
text = m.group(1)
|
|
try:
|
|
return json.loads(text)
|
|
except json.JSONDecodeError:
|
|
start = text.find("{")
|
|
end = text.rfind("}")
|
|
if start >= 0 and end > start:
|
|
try:
|
|
return json.loads(text[start : end + 1])
|
|
except json.JSONDecodeError as e:
|
|
raise LlmJsonError(f"failed to parse LLM JSON: {text!r}") from e
|
|
raise LlmJsonError(f"LLM returned non-JSON content: {text!r}")
|
|
|
|
|
|
def _is_retryable_status(exc: Exception) -> bool:
|
|
"""APIError 中只重试 5xx 与 429。"""
|
|
if isinstance(exc, RateLimitError):
|
|
return True
|
|
if isinstance(exc, APIError):
|
|
status = getattr(exc, "status_code", None)
|
|
return status is not None and (status >= 500 or status == 429)
|
|
return False
|
|
|
|
|
|
def _clip_text(value: Any, limit: int = 400) -> str:
|
|
text = str(value).strip()
|
|
if len(text) <= limit:
|
|
return text
|
|
return text[: limit - 3] + "..."
|
|
|
|
|
|
def _format_exc(exc: Exception) -> str:
|
|
text = str(exc).strip()
|
|
parts = [exc.__class__.__name__]
|
|
if text:
|
|
parts.append(text)
|
|
|
|
status = getattr(exc, "status_code", None)
|
|
if status is not None:
|
|
parts.append(f"status={status}")
|
|
|
|
body = getattr(exc, "body", None)
|
|
if body not in (None, "", b""):
|
|
parts.append(f"body={_clip_text(body)}")
|
|
|
|
response = getattr(exc, "response", None)
|
|
if response is not None:
|
|
try:
|
|
request = getattr(response, "request", None)
|
|
if request is not None and getattr(request, "url", None):
|
|
parts.append(f"url={request.url}")
|
|
except Exception:
|
|
pass
|
|
|
|
request = getattr(exc, "request", None)
|
|
if request is not None and getattr(request, "url", None):
|
|
parts.append(f"url={request.url}")
|
|
|
|
return ": ".join(parts[:2]) + ("" if len(parts) <= 2 else " | " + " | ".join(parts[2:]))
|
|
|
|
|
|
class LlmClient:
|
|
def __init__(
|
|
self,
|
|
api_key: str | None = None,
|
|
base_url: str | None = None,
|
|
model: str | None = None,
|
|
max_concurrency: int | None = None,
|
|
timeout_seconds: float | None = None,
|
|
max_retries: int | None = None,
|
|
cache: LlmCache | None = None,
|
|
cache_enabled: bool | None = None,
|
|
):
|
|
key = api_key or LLM_API_KEY
|
|
self._misconfigured_error: LlmConfigError | None = None
|
|
if OpenAI is None or AsyncOpenAI is None:
|
|
self._client = None
|
|
self._aclient = None
|
|
self._misconfigured_error = LlmConfigError(
|
|
"python package 'openai' is not installed; govdoc LLM features are unavailable."
|
|
)
|
|
elif not key:
|
|
self._client = None
|
|
self._aclient = None
|
|
self._misconfigured_error = LlmConfigError(
|
|
"LLM_API_KEY is not configured. Set LLM_API_KEY in platform config."
|
|
)
|
|
else:
|
|
self._client = OpenAI(api_key=key, base_url=base_url or LLM_BASE_URL)
|
|
self._aclient = AsyncOpenAI(api_key=key, base_url=base_url or LLM_BASE_URL)
|
|
self.model = model or LLM_MODEL
|
|
self.timeout = timeout_seconds if timeout_seconds is not None else LEAUDIT_LLM_REQUEST_TIMEOUT
|
|
self.max_retries = max_retries if max_retries is not None else LEAUDIT_LLM_RETRY_MAX_ATTEMPTS
|
|
conc = max_concurrency if max_concurrency is not None else LEAUDIT_LLM_MAX_CONCURRENCY
|
|
self._sem = asyncio.Semaphore(conc)
|
|
# 缓存:cache 显式传入则用之;否则默认关闭。
|
|
if cache is not None:
|
|
self.cache: LlmCache | None = cache
|
|
elif cache_enabled is not False:
|
|
self.cache = None
|
|
else:
|
|
self.cache = None
|
|
|
|
def _ensure_ready(self) -> None:
|
|
if self._misconfigured_error is not None:
|
|
raise self._misconfigured_error
|
|
|
|
@staticmethod
|
|
def _prompt_text(messages: list[dict[str, str]]) -> str:
|
|
return "\n\n".join(
|
|
f"[{m.get('role', 'user')}]\n{m.get('content', '')}"
|
|
for m in messages
|
|
)
|
|
|
|
# -- 同步路径 -------------------------------------------------------
|
|
def chat(self, messages: list[dict[str, str]], **kwargs) -> str:
|
|
self._ensure_ready()
|
|
use_cache = kwargs.pop("use_cache", True)
|
|
label = kwargs.pop("label", "llm_call")
|
|
cache_kwargs = {k: kwargs.get(k) for k in ("temperature", "top_p", "max_tokens", "response_format")}
|
|
cache_key: str | None = None
|
|
prompt_text = self._prompt_text(messages)
|
|
t0 = time.monotonic()
|
|
|
|
if use_cache and self.cache is not None:
|
|
cache_key = make_key(self.model, messages, **cache_kwargs)
|
|
hit = self.cache.get(cache_key)
|
|
if hit is not None:
|
|
_log.debug("LLM cache HIT key=%s", cache_key[:12])
|
|
return hit
|
|
|
|
kwargs.setdefault("timeout", self.timeout)
|
|
last_exc: Exception | None = None
|
|
for attempt in range(self.max_retries + 1):
|
|
try:
|
|
resp = self._client.chat.completions.create(
|
|
model=self.model, messages=messages, **kwargs
|
|
)
|
|
content = resp.choices[0].message.content or ""
|
|
if cache_key is not None and content:
|
|
self.cache.put(cache_key, self.model, content)
|
|
return content
|
|
except _RETRYABLE as e:
|
|
last_exc = e
|
|
except APIError as e:
|
|
if not _is_retryable_status(e):
|
|
raise
|
|
last_exc = e
|
|
if attempt < self.max_retries:
|
|
wait = min(8.0, 2 ** attempt)
|
|
_log.warning(
|
|
"LLM call failed (%s); retry %d/%d after %.1fs",
|
|
_format_exc(last_exc), attempt + 1, self.max_retries, wait,
|
|
)
|
|
time.sleep(wait)
|
|
assert last_exc is not None
|
|
raise last_exc
|
|
|
|
def chat_json(self, messages: list[dict[str, str]], **kwargs) -> dict[str, Any]:
|
|
return _parse_json_text(self.chat(messages, **kwargs))
|
|
|
|
# -- 异步路径 -------------------------------------------------------
|
|
async def chat_async(self, messages: list[dict[str, str]], **kwargs) -> str:
|
|
self._ensure_ready()
|
|
use_cache = kwargs.pop("use_cache", True)
|
|
label = kwargs.pop("label", "llm_call")
|
|
cache_kwargs = {k: kwargs.get(k) for k in ("temperature", "top_p", "max_tokens", "response_format")}
|
|
cache_key: str | None = None
|
|
prompt_text = self._prompt_text(messages)
|
|
t0 = time.monotonic()
|
|
|
|
if use_cache and self.cache is not None:
|
|
cache_key = make_key(self.model, messages, **cache_kwargs)
|
|
hit = self.cache.get(cache_key)
|
|
if hit is not None:
|
|
_log.debug("LLM cache HIT key=%s", cache_key[:12])
|
|
return hit
|
|
|
|
last_exc: Exception | None = None
|
|
for attempt in range(self.max_retries + 1):
|
|
try:
|
|
async with self._sem:
|
|
resp = await asyncio.wait_for(
|
|
self._aclient.chat.completions.create(
|
|
model=self.model, messages=messages, **kwargs,
|
|
),
|
|
timeout=self.timeout,
|
|
)
|
|
content = resp.choices[0].message.content or ""
|
|
if cache_key is not None and content:
|
|
self.cache.put(cache_key, self.model, content)
|
|
return content
|
|
except _RETRYABLE as e:
|
|
last_exc = e
|
|
except APIError as e:
|
|
if not _is_retryable_status(e):
|
|
raise
|
|
last_exc = e
|
|
if attempt < self.max_retries:
|
|
wait = min(8.0, 2 ** attempt)
|
|
_log.warning(
|
|
"LLM async call failed (%s); retry %d/%d after %.1fs",
|
|
_format_exc(last_exc), attempt + 1, self.max_retries, wait,
|
|
)
|
|
await asyncio.sleep(wait)
|
|
assert last_exc is not None
|
|
raise last_exc
|
|
|
|
async def chat_json_async(
|
|
self, messages: list[dict[str, str]], **kwargs
|
|
) -> dict[str, Any]:
|
|
return _parse_json_text(await self.chat_async(messages, **kwargs))
|