Files

282 lines
10 KiB
Python

"""Qwen LLM 客户端(OpenAI 兼容协议)。
包含:超时(asyncio.wait_for)、重试(指数退避)、并发上限(Semaphore)。
"""
from __future__ import annotations
import asyncio
import json
import logging
import re
import time
from typing import Any
try:
from openai import AsyncOpenAI, OpenAI, APIError, APIConnectionError, RateLimitError
_OPENAI_IMPORT_ERROR: Exception | None = None
except ModuleNotFoundError as exc: # pragma: no cover - optional dependency
AsyncOpenAI = None # type: ignore[assignment]
OpenAI = None # type: ignore[assignment]
_OPENAI_IMPORT_ERROR = exc
class APIError(Exception):
status_code: int | None = None
class APIConnectionError(Exception):
pass
class RateLimitError(Exception):
pass
from fastapi_admin.config import (
LLM_API_KEY,
LLM_BASE_URL,
LLM_MODEL,
LEAUDIT_LLM_MAX_CONCURRENCY,
LEAUDIT_LLM_REQUEST_TIMEOUT,
LEAUDIT_LLM_RETRY_MAX_ATTEMPTS,
LEAUDIT_LLM_RETRY_BACKOFF_BASE_SECONDS,
)
from fastapi_modules.fastapi_leaudit.govdoc_engine.llm.cache import LlmCache, make_key
from fastapi_modules.fastapi_leaudit.rag_engine.config import normalize_openai_base_url
_log = logging.getLogger(__name__)
_FENCE_RE = re.compile(r"```(?:json)?\s*([\s\S]+?)\s*```", re.MULTILINE)
# 这些异常会触发重试;JSON 解析错误等业务错误不重试
_RETRYABLE = (
asyncio.TimeoutError,
TimeoutError,
APIConnectionError,
RateLimitError,
)
class LlmJsonError(Exception):
"""LLM 返回内容无法解析为 JSON。"""
class LlmConfigError(Exception):
"""LLM 客户端缺少必要配置。"""
def _parse_json_text(text: str) -> dict[str, Any]:
text = text.strip()
m = _FENCE_RE.search(text)
if m:
text = m.group(1)
try:
return json.loads(text)
except json.JSONDecodeError:
start = text.find("{")
end = text.rfind("}")
if start >= 0 and end > start:
try:
return json.loads(text[start : end + 1])
except json.JSONDecodeError as e:
raise LlmJsonError(f"failed to parse LLM JSON: {text!r}") from e
raise LlmJsonError(f"LLM returned non-JSON content: {text!r}")
def _is_retryable_status(exc: Exception) -> bool:
"""APIError 中只重试 5xx 与 429。"""
if isinstance(exc, RateLimitError):
return True
if isinstance(exc, APIError):
status = getattr(exc, "status_code", None)
return status is not None and (status >= 500 or status == 429)
return False
def _clip_text(value: Any, limit: int = 400) -> str:
text = str(value).strip()
if len(text) <= limit:
return text
return text[: limit - 3] + "..."
def _format_exc(exc: Exception) -> str:
text = str(exc).strip()
parts = [exc.__class__.__name__]
if text:
parts.append(text)
status = getattr(exc, "status_code", None)
if status is not None:
parts.append(f"status={status}")
body = getattr(exc, "body", None)
if body not in (None, "", b""):
parts.append(f"body={_clip_text(body)}")
response = getattr(exc, "response", None)
if response is not None:
try:
request = getattr(response, "request", None)
if request is not None and getattr(request, "url", None):
parts.append(f"url={request.url}")
except Exception:
pass
request = getattr(exc, "request", None)
if request is not None and getattr(request, "url", None):
parts.append(f"url={request.url}")
return ": ".join(parts[:2]) + ("" if len(parts) <= 2 else " | " + " | ".join(parts[2:]))
class LlmClient:
def __init__(
self,
api_key: str | None = None,
base_url: str | None = None,
model: str | None = None,
max_concurrency: int | None = None,
timeout_seconds: float | None = None,
max_retries: int | None = None,
cache: LlmCache | None = None,
cache_enabled: bool | None = None,
):
key = api_key or LLM_API_KEY
self._misconfigured_error: LlmConfigError | None = None
if OpenAI is None or AsyncOpenAI is None:
self._client = None
self._aclient = None
self._misconfigured_error = LlmConfigError(
"python package 'openai' is not installed; govdoc LLM features are unavailable."
)
elif not key:
self._client = None
self._aclient = None
self._misconfigured_error = LlmConfigError(
"LLM_API_KEY is not configured. Set LLM_API_KEY in platform config."
)
else:
normalized_base_url = normalize_openai_base_url(base_url or LLM_BASE_URL)
self._client = OpenAI(api_key=key, base_url=normalized_base_url)
self._aclient = AsyncOpenAI(api_key=key, base_url=normalized_base_url)
self.model = model or LLM_MODEL
self.timeout = timeout_seconds if timeout_seconds is not None else LEAUDIT_LLM_REQUEST_TIMEOUT
self.max_retries = max_retries if max_retries is not None else LEAUDIT_LLM_RETRY_MAX_ATTEMPTS
conc = max_concurrency if max_concurrency is not None else LEAUDIT_LLM_MAX_CONCURRENCY
self._sem = asyncio.Semaphore(conc)
# 缓存:cache 显式传入则用之;否则默认关闭。
if cache is not None:
self.cache: LlmCache | None = cache
elif cache_enabled is not False:
self.cache = None
else:
self.cache = None
def _ensure_ready(self) -> None:
if self._misconfigured_error is not None:
raise self._misconfigured_error
@staticmethod
def _prompt_text(messages: list[dict[str, str]]) -> str:
return "\n\n".join(
f"[{m.get('role', 'user')}]\n{m.get('content', '')}"
for m in messages
)
# -- 同步路径 -------------------------------------------------------
def chat(self, messages: list[dict[str, str]], **kwargs) -> str:
self._ensure_ready()
use_cache = kwargs.pop("use_cache", True)
label = kwargs.pop("label", "llm_call")
cache_kwargs = {k: kwargs.get(k) for k in ("temperature", "top_p", "max_tokens", "response_format")}
cache_key: str | None = None
prompt_text = self._prompt_text(messages)
t0 = time.monotonic()
if use_cache and self.cache is not None:
cache_key = make_key(self.model, messages, **cache_kwargs)
hit = self.cache.get(cache_key)
if hit is not None:
_log.debug("LLM cache HIT key=%s", cache_key[:12])
return hit
kwargs.setdefault("timeout", self.timeout)
last_exc: Exception | None = None
for attempt in range(self.max_retries + 1):
try:
resp = self._client.chat.completions.create(
model=self.model, messages=messages, **kwargs
)
content = resp.choices[0].message.content or ""
if cache_key is not None and content:
self.cache.put(cache_key, self.model, content)
return content
except _RETRYABLE as e:
last_exc = e
except APIError as e:
if not _is_retryable_status(e):
raise
last_exc = e
if attempt < self.max_retries:
wait = min(8.0, 2 ** attempt)
_log.warning(
"LLM call failed (%s); retry %d/%d after %.1fs",
_format_exc(last_exc), attempt + 1, self.max_retries, wait,
)
time.sleep(wait)
assert last_exc is not None
raise last_exc
def chat_json(self, messages: list[dict[str, str]], **kwargs) -> dict[str, Any]:
return _parse_json_text(self.chat(messages, **kwargs))
# -- 异步路径 -------------------------------------------------------
async def chat_async(self, messages: list[dict[str, str]], **kwargs) -> str:
self._ensure_ready()
use_cache = kwargs.pop("use_cache", True)
label = kwargs.pop("label", "llm_call")
cache_kwargs = {k: kwargs.get(k) for k in ("temperature", "top_p", "max_tokens", "response_format")}
cache_key: str | None = None
prompt_text = self._prompt_text(messages)
t0 = time.monotonic()
if use_cache and self.cache is not None:
cache_key = make_key(self.model, messages, **cache_kwargs)
hit = self.cache.get(cache_key)
if hit is not None:
_log.debug("LLM cache HIT key=%s", cache_key[:12])
return hit
last_exc: Exception | None = None
for attempt in range(self.max_retries + 1):
try:
async with self._sem:
resp = await asyncio.wait_for(
self._aclient.chat.completions.create(
model=self.model, messages=messages, **kwargs,
),
timeout=self.timeout,
)
content = resp.choices[0].message.content or ""
if cache_key is not None and content:
self.cache.put(cache_key, self.model, content)
return content
except _RETRYABLE as e:
last_exc = e
except APIError as e:
if not _is_retryable_status(e):
raise
last_exc = e
if attempt < self.max_retries:
wait = min(8.0, 2 ** attempt)
_log.warning(
"LLM async call failed (%s); retry %d/%d after %.1fs",
_format_exc(last_exc), attempt + 1, self.max_retries, wait,
)
await asyncio.sleep(wait)
assert last_exc is not None
raise last_exc
async def chat_json_async(
self, messages: list[dict[str, str]], **kwargs
) -> dict[str, Any]:
return _parse_json_text(await self.chat_async(messages, **kwargs))