"""Qwen LLM 客户端(OpenAI 兼容协议)。 包含:超时(asyncio.wait_for)、重试(指数退避)、并发上限(Semaphore)。 """ from __future__ import annotations import asyncio import json import logging import re import time from typing import Any try: from openai import AsyncOpenAI, OpenAI, APIError, APIConnectionError, RateLimitError _OPENAI_IMPORT_ERROR: Exception | None = None except ModuleNotFoundError as exc: # pragma: no cover - optional dependency AsyncOpenAI = None # type: ignore[assignment] OpenAI = None # type: ignore[assignment] _OPENAI_IMPORT_ERROR = exc class APIError(Exception): status_code: int | None = None class APIConnectionError(Exception): pass class RateLimitError(Exception): pass from fastapi_admin.config import ( LLM_API_KEY, LLM_BASE_URL, LLM_MODEL, LEAUDIT_LLM_MAX_CONCURRENCY, LEAUDIT_LLM_REQUEST_TIMEOUT, LEAUDIT_LLM_RETRY_MAX_ATTEMPTS, LEAUDIT_LLM_RETRY_BACKOFF_BASE_SECONDS, ) from fastapi_modules.fastapi_leaudit.govdoc_engine.llm.cache import LlmCache, make_key _log = logging.getLogger(__name__) _FENCE_RE = re.compile(r"```(?:json)?\s*([\s\S]+?)\s*```", re.MULTILINE) # 这些异常会触发重试;JSON 解析错误等业务错误不重试 _RETRYABLE = ( asyncio.TimeoutError, TimeoutError, APIConnectionError, RateLimitError, ) class LlmJsonError(Exception): """LLM 返回内容无法解析为 JSON。""" class LlmConfigError(Exception): """LLM 客户端缺少必要配置。""" def _parse_json_text(text: str) -> dict[str, Any]: text = text.strip() m = _FENCE_RE.search(text) if m: text = m.group(1) try: return json.loads(text) except json.JSONDecodeError: start = text.find("{") end = text.rfind("}") if start >= 0 and end > start: try: return json.loads(text[start : end + 1]) except json.JSONDecodeError as e: raise LlmJsonError(f"failed to parse LLM JSON: {text!r}") from e raise LlmJsonError(f"LLM returned non-JSON content: {text!r}") def _is_retryable_status(exc: Exception) -> bool: """APIError 中只重试 5xx 与 429。""" if isinstance(exc, RateLimitError): return True if isinstance(exc, APIError): status = getattr(exc, "status_code", None) return status is not None and (status >= 500 or status == 429) return False def _clip_text(value: Any, limit: int = 400) -> str: text = str(value).strip() if len(text) <= limit: return text return text[: limit - 3] + "..." def _format_exc(exc: Exception) -> str: text = str(exc).strip() parts = [exc.__class__.__name__] if text: parts.append(text) status = getattr(exc, "status_code", None) if status is not None: parts.append(f"status={status}") body = getattr(exc, "body", None) if body not in (None, "", b""): parts.append(f"body={_clip_text(body)}") response = getattr(exc, "response", None) if response is not None: try: request = getattr(response, "request", None) if request is not None and getattr(request, "url", None): parts.append(f"url={request.url}") except Exception: pass request = getattr(exc, "request", None) if request is not None and getattr(request, "url", None): parts.append(f"url={request.url}") return ": ".join(parts[:2]) + ("" if len(parts) <= 2 else " | " + " | ".join(parts[2:])) class LlmClient: def __init__( self, api_key: str | None = None, base_url: str | None = None, model: str | None = None, max_concurrency: int | None = None, timeout_seconds: float | None = None, max_retries: int | None = None, cache: LlmCache | None = None, cache_enabled: bool | None = None, ): key = api_key or LLM_API_KEY self._misconfigured_error: LlmConfigError | None = None if OpenAI is None or AsyncOpenAI is None: self._client = None self._aclient = None self._misconfigured_error = LlmConfigError( "python package 'openai' is not installed; govdoc LLM features are unavailable." ) elif not key: self._client = None self._aclient = None self._misconfigured_error = LlmConfigError( "LLM_API_KEY is not configured. Set LLM_API_KEY in platform config." ) else: self._client = OpenAI(api_key=key, base_url=base_url or LLM_BASE_URL) self._aclient = AsyncOpenAI(api_key=key, base_url=base_url or LLM_BASE_URL) self.model = model or LLM_MODEL self.timeout = timeout_seconds if timeout_seconds is not None else LEAUDIT_LLM_REQUEST_TIMEOUT self.max_retries = max_retries if max_retries is not None else LEAUDIT_LLM_RETRY_MAX_ATTEMPTS conc = max_concurrency if max_concurrency is not None else LEAUDIT_LLM_MAX_CONCURRENCY self._sem = asyncio.Semaphore(conc) # 缓存:cache 显式传入则用之;否则默认关闭。 if cache is not None: self.cache: LlmCache | None = cache elif cache_enabled is not False: self.cache = None else: self.cache = None def _ensure_ready(self) -> None: if self._misconfigured_error is not None: raise self._misconfigured_error @staticmethod def _prompt_text(messages: list[dict[str, str]]) -> str: return "\n\n".join( f"[{m.get('role', 'user')}]\n{m.get('content', '')}" for m in messages ) # -- 同步路径 ------------------------------------------------------- def chat(self, messages: list[dict[str, str]], **kwargs) -> str: self._ensure_ready() use_cache = kwargs.pop("use_cache", True) label = kwargs.pop("label", "llm_call") cache_kwargs = {k: kwargs.get(k) for k in ("temperature", "top_p", "max_tokens", "response_format")} cache_key: str | None = None prompt_text = self._prompt_text(messages) t0 = time.monotonic() if use_cache and self.cache is not None: cache_key = make_key(self.model, messages, **cache_kwargs) hit = self.cache.get(cache_key) if hit is not None: _log.debug("LLM cache HIT key=%s", cache_key[:12]) return hit kwargs.setdefault("timeout", self.timeout) last_exc: Exception | None = None for attempt in range(self.max_retries + 1): try: resp = self._client.chat.completions.create( model=self.model, messages=messages, **kwargs ) content = resp.choices[0].message.content or "" if cache_key is not None and content: self.cache.put(cache_key, self.model, content) return content except _RETRYABLE as e: last_exc = e except APIError as e: if not _is_retryable_status(e): raise last_exc = e if attempt < self.max_retries: wait = min(8.0, 2 ** attempt) _log.warning( "LLM call failed (%s); retry %d/%d after %.1fs", _format_exc(last_exc), attempt + 1, self.max_retries, wait, ) time.sleep(wait) assert last_exc is not None raise last_exc def chat_json(self, messages: list[dict[str, str]], **kwargs) -> dict[str, Any]: return _parse_json_text(self.chat(messages, **kwargs)) # -- 异步路径 ------------------------------------------------------- async def chat_async(self, messages: list[dict[str, str]], **kwargs) -> str: self._ensure_ready() use_cache = kwargs.pop("use_cache", True) label = kwargs.pop("label", "llm_call") cache_kwargs = {k: kwargs.get(k) for k in ("temperature", "top_p", "max_tokens", "response_format")} cache_key: str | None = None prompt_text = self._prompt_text(messages) t0 = time.monotonic() if use_cache and self.cache is not None: cache_key = make_key(self.model, messages, **cache_kwargs) hit = self.cache.get(cache_key) if hit is not None: _log.debug("LLM cache HIT key=%s", cache_key[:12]) return hit last_exc: Exception | None = None for attempt in range(self.max_retries + 1): try: async with self._sem: resp = await asyncio.wait_for( self._aclient.chat.completions.create( model=self.model, messages=messages, **kwargs, ), timeout=self.timeout, ) content = resp.choices[0].message.content or "" if cache_key is not None and content: self.cache.put(cache_key, self.model, content) return content except _RETRYABLE as e: last_exc = e except APIError as e: if not _is_retryable_status(e): raise last_exc = e if attempt < self.max_retries: wait = min(8.0, 2 ** attempt) _log.warning( "LLM async call failed (%s); retry %d/%d after %.1fs", _format_exc(last_exc), attempt + 1, self.max_retries, wait, ) await asyncio.sleep(wait) assert last_exc is not None raise last_exc async def chat_json_async( self, messages: list[dict[str, str]], **kwargs ) -> dict[str, Any]: return _parse_json_text(await self.chat_async(messages, **kwargs))