Files
leaudit-platform-backend/fastapi_modules/fastapi_leaudit/leaudit_bridge/resilient_clients.py
T

323 lines
12 KiB
Python

"""Bridge-side resilient client wrappers.
Keep retry policy configurable in leaudit-platform without changing the
upstream leaudit package.
"""
from __future__ import annotations
import asyncio
import base64
import logging
import time
from pathlib import Path
from typing import Any
import httpx
from leaudit.llm.base import LLMRequest, LLMResponse
from leaudit.llm.openai_client import OpenAIClientError, OpenAICompatibleClient, _parse_response
from leaudit.llm.qwen_vlm_client import QwenVLMClient, VisualClassification, _SYSTEM_PROMPT, _parse_json_loose
from leaudit.ocr.chandra_client import ChandraOCRError, ChandraOCRClient, _guess_mime, _parse_response as parse_ocr_response
log = logging.getLogger(__name__)
_RETRYABLE_HTTP_STATUSES = {408, 429, 500, 502, 503, 504}
def _retry_delay(base_seconds: float, attempt: int) -> float:
"""Return exponential backoff delay for retry attempt index."""
return max(0.0, float(base_seconds)) * (2 ** max(0, int(attempt)))
def _should_retry_status(status_code: int) -> bool:
return status_code in _RETRYABLE_HTTP_STATUSES or status_code >= 500
class ResilientOpenAICompatibleClient(OpenAICompatibleClient):
"""OpenAI-compatible client with configurable retry count/backoff."""
def __init__(
self,
*,
retry_max_attempts: int = 3,
retry_backoff_base_seconds: float = 1.0,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.retry_max_attempts = max(1, int(retry_max_attempts))
self.retry_backoff_base_seconds = max(0.0, float(retry_backoff_base_seconds))
async def complete(self, request: LLMRequest) -> LLMResponse:
self.call_history.append(request)
model = request.model if request.model != "default" else self.default_model
payload: dict[str, Any] = {
"model": model,
"messages": [{"role": m.role, "content": m.content} for m in request.messages],
"temperature": request.temperature,
"max_tokens": request.max_tokens,
}
if request.return_logprobs:
payload["logprobs"] = True
payload["top_logprobs"] = 3
if request.response_format:
payload["response_format"] = request.response_format
thinking = request.enable_thinking if request.enable_thinking is not None else self.enable_thinking
if not thinking:
payload["enable_thinking"] = False
req_timeout = min(self.timeout, request.timeout_ms / 1000)
last_error: Exception | None = None
for attempt in range(self.retry_max_attempts):
t0 = time.monotonic()
try:
response = await self._client.post(
f"{self.base_url}/chat/completions",
json=payload,
headers=self._headers,
timeout=req_timeout,
)
except httpx.TimeoutException as exc:
last_error = TimeoutError(f"LLM request timed out: {exc}")
if attempt < self.retry_max_attempts - 1:
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
continue
raise last_error from exc
except self._RETRYABLE_EXCEPTIONS as exc:
last_error = OpenAIClientError(f"Network error: {exc}")
if attempt < self.retry_max_attempts - 1:
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
continue
raise last_error from exc
duration_ms = int((time.monotonic() - t0) * 1000)
if _should_retry_status(response.status_code) and attempt < self.retry_max_attempts - 1:
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
continue
if response.status_code != 200:
raise OpenAIClientError(f"API returned {response.status_code}: {response.text[:500]}")
body = response.json()
if "error" in body:
raise OpenAIClientError(f"API error: {body['error']}")
return _parse_response(body, duration_ms)
raise last_error or OpenAIClientError("LLM request failed after retries")
class ResilientQwenVLMClient(QwenVLMClient):
"""VLM client with configurable retry count/backoff."""
_RETRYABLE_EXCEPTIONS = (
httpx.ConnectError,
httpx.RemoteProtocolError,
httpx.ReadError,
httpx.ReadTimeout,
httpx.TimeoutException,
)
def __init__(
self,
*,
retry_max_attempts: int = 2,
retry_backoff_base_seconds: float = 1.0,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.retry_max_attempts = max(1, int(retry_max_attempts))
self.retry_backoff_base_seconds = max(0.0, float(retry_backoff_base_seconds))
async def _post_with_retry(self, payload: dict[str, Any]) -> httpx.Response:
last_error: Exception | None = None
for attempt in range(self.retry_max_attempts):
try:
response = await self._client.post(
f"{self.base_url}/chat/completions",
json=payload,
headers={k: v for k, v in self._headers.items() if v},
)
except self._RETRYABLE_EXCEPTIONS as exc:
last_error = exc
if attempt < self.retry_max_attempts - 1:
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
continue
raise
if _should_retry_status(response.status_code) and attempt < self.retry_max_attempts - 1:
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
continue
return response
raise last_error or RuntimeError("VLM request failed after retries")
async def classify_visual(
self,
image_bytes: bytes,
user_hint: str | None = None,
max_tokens: int = 300,
) -> VisualClassification:
b64 = base64.b64encode(image_bytes).decode()
url = f"data:image/png;base64,{b64}"
user_text = "请判断此区域内容并输出 JSON。"
if user_hint:
user_text += f"\n上下文:{user_hint}"
payload: dict[str, Any] = {
"model": self.model,
"messages": [
{"role": "system", "content": _SYSTEM_PROMPT},
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": url}},
{"type": "text", "text": user_text},
],
},
],
"max_tokens": max_tokens,
"temperature": 0.0,
}
try:
response = await self._post_with_retry(payload)
except Exception as exc: # noqa: BLE001
return VisualClassification(kind="other", confidence=0.0, raw=f"retry_failed: {exc}")
if response.status_code != 200:
return VisualClassification(kind="other", confidence=0.0, raw=f"HTTP {response.status_code}: {response.text[:300]}")
body = response.json()
content = (body.get("choices") or [{}])[0].get("message", {}).get("content", "")
parsed = _parse_json_loose(content)
if parsed is None:
return VisualClassification(kind="other", confidence=0.0, raw=content)
kind = parsed.get("kind", "other")
if kind not in {"seal", "signature", "other"}:
kind = "other"
return VisualClassification(
kind=kind,
seal_type=parsed.get("seal_type"),
text=parsed.get("text"),
is_complete=parsed.get("is_complete"),
confidence=0.9 if kind != "other" else 0.3,
raw=content,
)
async def extract_multifield(
self,
*,
prompt: str,
images_data_urls: list[str],
max_tokens: int = 800,
) -> dict:
content: list[dict] = [{"type": "text", "text": prompt}]
for url in images_data_urls:
content.append({"type": "image_url", "image_url": {"url": url}})
payload: dict[str, Any] = {
"model": self.model,
"messages": [{"role": "user", "content": content}],
"max_tokens": max_tokens,
"temperature": 0.0,
}
try:
response = await self._post_with_retry(payload)
except Exception:
return {}
if response.status_code != 200:
return {}
body = response.json()
text = (body.get("choices") or [{}])[0].get("message", {}).get("content", "")
parsed = _parse_json_loose(text)
return parsed if isinstance(parsed, dict) else {}
class ResilientChandraOCRClient(ChandraOCRClient):
"""OCR client with configurable retry count/backoff."""
_RETRYABLE_EXCEPTIONS = (
httpx.ConnectError,
httpx.RemoteProtocolError,
httpx.ReadError,
httpx.ReadTimeout,
httpx.TimeoutException,
)
def __init__(
self,
*,
retry_max_attempts: int = 3,
retry_backoff_base_seconds: float = 1.0,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.retry_max_attempts = max(1, int(retry_max_attempts))
self.retry_backoff_base_seconds = max(0.0, float(retry_backoff_base_seconds))
async def _post_ocr(self, file_path: Path | str) -> dict:
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
last_error: Exception | None = None
for attempt in range(self.retry_max_attempts):
try:
async with httpx.AsyncClient(timeout=self.timeout, verify=False) as client:
with open(path, "rb") as file_obj:
files = {"file": (path.name, file_obj, _guess_mime(path))}
data = {"include_images": str(self.include_images).lower()}
response = await client.post(
f"{self.base_url}/chandra/ocr",
files=files,
data=data,
)
except self._RETRYABLE_EXCEPTIONS as exc:
last_error = exc
if attempt < self.retry_max_attempts - 1:
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
continue
raise ChandraOCRError(f"OCR request failed after retries: {exc}") from exc
if _should_retry_status(response.status_code) and attempt < self.retry_max_attempts - 1:
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
continue
if response.status_code != 200:
raise ChandraOCRError(f"OCR API returned {response.status_code}: {response.text[:500]}")
body = response.json()
if "error" in body:
raise ChandraOCRError(f"OCR error: {body['error']}")
return body
raise ChandraOCRError(f"OCR request failed after retries: {last_error}")
async def ocr(self, file_path: Path | str):
body = await self._post_ocr(file_path)
result = parse_ocr_response(body)
if self.vlm_client is not None and (
result.visual_manifest.seals
or result.visual_manifest.signatures
or result.visual_manifest.cross_page_seals
):
try:
from leaudit.ocr.visual_classifier import refine_visual_manifest
await refine_visual_manifest(result, self.vlm_client, concurrency=self.vlm_concurrency)
except Exception as exc: # noqa: BLE001
log.warning("VLM refinement failed, keeping geometric manifest: %s", exc)
return result
async def ocr_raw(self, file_path: Path | str) -> dict:
return await self._post_ocr(file_path)