323 lines
12 KiB
Python
323 lines
12 KiB
Python
"""Bridge-side resilient client wrappers.
|
|
|
|
Keep retry policy configurable in leaudit-platform without changing the
|
|
upstream leaudit package.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import base64
|
|
import logging
|
|
import time
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import httpx
|
|
|
|
from leaudit.llm.base import LLMRequest, LLMResponse
|
|
from leaudit.llm.openai_client import OpenAIClientError, OpenAICompatibleClient, _parse_response
|
|
from leaudit.llm.qwen_vlm_client import QwenVLMClient, VisualClassification, _SYSTEM_PROMPT, _parse_json_loose
|
|
from leaudit.ocr.chandra_client import ChandraOCRError, ChandraOCRClient, _guess_mime, _parse_response as parse_ocr_response
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
_RETRYABLE_HTTP_STATUSES = {408, 429, 500, 502, 503, 504}
|
|
|
|
|
|
def _retry_delay(base_seconds: float, attempt: int) -> float:
|
|
"""Return exponential backoff delay for retry attempt index."""
|
|
return max(0.0, float(base_seconds)) * (2 ** max(0, int(attempt)))
|
|
|
|
|
|
def _should_retry_status(status_code: int) -> bool:
|
|
return status_code in _RETRYABLE_HTTP_STATUSES or status_code >= 500
|
|
|
|
|
|
class ResilientOpenAICompatibleClient(OpenAICompatibleClient):
|
|
"""OpenAI-compatible client with configurable retry count/backoff."""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
retry_max_attempts: int = 3,
|
|
retry_backoff_base_seconds: float = 1.0,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
super().__init__(**kwargs)
|
|
self.retry_max_attempts = max(1, int(retry_max_attempts))
|
|
self.retry_backoff_base_seconds = max(0.0, float(retry_backoff_base_seconds))
|
|
|
|
async def complete(self, request: LLMRequest) -> LLMResponse:
|
|
self.call_history.append(request)
|
|
|
|
model = request.model if request.model != "default" else self.default_model
|
|
payload: dict[str, Any] = {
|
|
"model": model,
|
|
"messages": [{"role": m.role, "content": m.content} for m in request.messages],
|
|
"temperature": request.temperature,
|
|
"max_tokens": request.max_tokens,
|
|
}
|
|
if request.return_logprobs:
|
|
payload["logprobs"] = True
|
|
payload["top_logprobs"] = 3
|
|
if request.response_format:
|
|
payload["response_format"] = request.response_format
|
|
|
|
thinking = request.enable_thinking if request.enable_thinking is not None else self.enable_thinking
|
|
if not thinking:
|
|
payload["enable_thinking"] = False
|
|
|
|
req_timeout = min(self.timeout, request.timeout_ms / 1000)
|
|
last_error: Exception | None = None
|
|
|
|
for attempt in range(self.retry_max_attempts):
|
|
t0 = time.monotonic()
|
|
try:
|
|
response = await self._client.post(
|
|
f"{self.base_url}/chat/completions",
|
|
json=payload,
|
|
headers=self._headers,
|
|
timeout=req_timeout,
|
|
)
|
|
except httpx.TimeoutException as exc:
|
|
last_error = TimeoutError(f"LLM request timed out: {exc}")
|
|
if attempt < self.retry_max_attempts - 1:
|
|
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
|
|
continue
|
|
raise last_error from exc
|
|
except self._RETRYABLE_EXCEPTIONS as exc:
|
|
last_error = OpenAIClientError(f"Network error: {exc}")
|
|
if attempt < self.retry_max_attempts - 1:
|
|
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
|
|
continue
|
|
raise last_error from exc
|
|
|
|
duration_ms = int((time.monotonic() - t0) * 1000)
|
|
if _should_retry_status(response.status_code) and attempt < self.retry_max_attempts - 1:
|
|
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
|
|
continue
|
|
|
|
if response.status_code != 200:
|
|
raise OpenAIClientError(f"API returned {response.status_code}: {response.text[:500]}")
|
|
|
|
body = response.json()
|
|
if "error" in body:
|
|
raise OpenAIClientError(f"API error: {body['error']}")
|
|
return _parse_response(body, duration_ms)
|
|
|
|
raise last_error or OpenAIClientError("LLM request failed after retries")
|
|
|
|
|
|
class ResilientQwenVLMClient(QwenVLMClient):
|
|
"""VLM client with configurable retry count/backoff."""
|
|
|
|
_RETRYABLE_EXCEPTIONS = (
|
|
httpx.ConnectError,
|
|
httpx.RemoteProtocolError,
|
|
httpx.ReadError,
|
|
httpx.ReadTimeout,
|
|
httpx.TimeoutException,
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
retry_max_attempts: int = 2,
|
|
retry_backoff_base_seconds: float = 1.0,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
super().__init__(**kwargs)
|
|
self.retry_max_attempts = max(1, int(retry_max_attempts))
|
|
self.retry_backoff_base_seconds = max(0.0, float(retry_backoff_base_seconds))
|
|
|
|
async def _post_with_retry(self, payload: dict[str, Any]) -> httpx.Response:
|
|
last_error: Exception | None = None
|
|
for attempt in range(self.retry_max_attempts):
|
|
try:
|
|
response = await self._client.post(
|
|
f"{self.base_url}/chat/completions",
|
|
json=payload,
|
|
headers={k: v for k, v in self._headers.items() if v},
|
|
)
|
|
except self._RETRYABLE_EXCEPTIONS as exc:
|
|
last_error = exc
|
|
if attempt < self.retry_max_attempts - 1:
|
|
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
|
|
continue
|
|
raise
|
|
|
|
if _should_retry_status(response.status_code) and attempt < self.retry_max_attempts - 1:
|
|
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
|
|
continue
|
|
return response
|
|
|
|
raise last_error or RuntimeError("VLM request failed after retries")
|
|
|
|
async def classify_visual(
|
|
self,
|
|
image_bytes: bytes,
|
|
user_hint: str | None = None,
|
|
max_tokens: int = 300,
|
|
) -> VisualClassification:
|
|
b64 = base64.b64encode(image_bytes).decode()
|
|
url = f"data:image/png;base64,{b64}"
|
|
|
|
user_text = "请判断此区域内容并输出 JSON。"
|
|
if user_hint:
|
|
user_text += f"\n上下文:{user_hint}"
|
|
|
|
payload: dict[str, Any] = {
|
|
"model": self.model,
|
|
"messages": [
|
|
{"role": "system", "content": _SYSTEM_PROMPT},
|
|
{
|
|
"role": "user",
|
|
"content": [
|
|
{"type": "image_url", "image_url": {"url": url}},
|
|
{"type": "text", "text": user_text},
|
|
],
|
|
},
|
|
],
|
|
"max_tokens": max_tokens,
|
|
"temperature": 0.0,
|
|
}
|
|
|
|
try:
|
|
response = await self._post_with_retry(payload)
|
|
except Exception as exc: # noqa: BLE001
|
|
return VisualClassification(kind="other", confidence=0.0, raw=f"retry_failed: {exc}")
|
|
|
|
if response.status_code != 200:
|
|
return VisualClassification(kind="other", confidence=0.0, raw=f"HTTP {response.status_code}: {response.text[:300]}")
|
|
|
|
body = response.json()
|
|
content = (body.get("choices") or [{}])[0].get("message", {}).get("content", "")
|
|
parsed = _parse_json_loose(content)
|
|
if parsed is None:
|
|
return VisualClassification(kind="other", confidence=0.0, raw=content)
|
|
|
|
kind = parsed.get("kind", "other")
|
|
if kind not in {"seal", "signature", "other"}:
|
|
kind = "other"
|
|
|
|
return VisualClassification(
|
|
kind=kind,
|
|
seal_type=parsed.get("seal_type"),
|
|
text=parsed.get("text"),
|
|
is_complete=parsed.get("is_complete"),
|
|
confidence=0.9 if kind != "other" else 0.3,
|
|
raw=content,
|
|
)
|
|
|
|
async def extract_multifield(
|
|
self,
|
|
*,
|
|
prompt: str,
|
|
images_data_urls: list[str],
|
|
max_tokens: int = 800,
|
|
) -> dict:
|
|
content: list[dict] = [{"type": "text", "text": prompt}]
|
|
for url in images_data_urls:
|
|
content.append({"type": "image_url", "image_url": {"url": url}})
|
|
|
|
payload: dict[str, Any] = {
|
|
"model": self.model,
|
|
"messages": [{"role": "user", "content": content}],
|
|
"max_tokens": max_tokens,
|
|
"temperature": 0.0,
|
|
}
|
|
|
|
try:
|
|
response = await self._post_with_retry(payload)
|
|
except Exception:
|
|
return {}
|
|
|
|
if response.status_code != 200:
|
|
return {}
|
|
|
|
body = response.json()
|
|
text = (body.get("choices") or [{}])[0].get("message", {}).get("content", "")
|
|
parsed = _parse_json_loose(text)
|
|
return parsed if isinstance(parsed, dict) else {}
|
|
|
|
|
|
class ResilientChandraOCRClient(ChandraOCRClient):
|
|
"""OCR client with configurable retry count/backoff."""
|
|
|
|
_RETRYABLE_EXCEPTIONS = (
|
|
httpx.ConnectError,
|
|
httpx.RemoteProtocolError,
|
|
httpx.ReadError,
|
|
httpx.ReadTimeout,
|
|
httpx.TimeoutException,
|
|
)
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
retry_max_attempts: int = 3,
|
|
retry_backoff_base_seconds: float = 1.0,
|
|
**kwargs: Any,
|
|
) -> None:
|
|
super().__init__(**kwargs)
|
|
self.retry_max_attempts = max(1, int(retry_max_attempts))
|
|
self.retry_backoff_base_seconds = max(0.0, float(retry_backoff_base_seconds))
|
|
|
|
async def _post_ocr(self, file_path: Path | str) -> dict:
|
|
path = Path(file_path)
|
|
if not path.exists():
|
|
raise FileNotFoundError(f"File not found: {path}")
|
|
|
|
last_error: Exception | None = None
|
|
for attempt in range(self.retry_max_attempts):
|
|
try:
|
|
async with httpx.AsyncClient(timeout=self.timeout, verify=False) as client:
|
|
with open(path, "rb") as file_obj:
|
|
files = {"file": (path.name, file_obj, _guess_mime(path))}
|
|
data = {"include_images": str(self.include_images).lower()}
|
|
response = await client.post(
|
|
f"{self.base_url}/chandra/ocr",
|
|
files=files,
|
|
data=data,
|
|
)
|
|
except self._RETRYABLE_EXCEPTIONS as exc:
|
|
last_error = exc
|
|
if attempt < self.retry_max_attempts - 1:
|
|
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
|
|
continue
|
|
raise ChandraOCRError(f"OCR request failed after retries: {exc}") from exc
|
|
|
|
if _should_retry_status(response.status_code) and attempt < self.retry_max_attempts - 1:
|
|
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
|
|
continue
|
|
|
|
if response.status_code != 200:
|
|
raise ChandraOCRError(f"OCR API returned {response.status_code}: {response.text[:500]}")
|
|
|
|
body = response.json()
|
|
if "error" in body:
|
|
raise ChandraOCRError(f"OCR error: {body['error']}")
|
|
return body
|
|
|
|
raise ChandraOCRError(f"OCR request failed after retries: {last_error}")
|
|
|
|
async def ocr(self, file_path: Path | str):
|
|
body = await self._post_ocr(file_path)
|
|
result = parse_ocr_response(body)
|
|
if self.vlm_client is not None and (
|
|
result.visual_manifest.seals
|
|
or result.visual_manifest.signatures
|
|
or result.visual_manifest.cross_page_seals
|
|
):
|
|
try:
|
|
from leaudit.ocr.visual_classifier import refine_visual_manifest
|
|
|
|
await refine_visual_manifest(result, self.vlm_client, concurrency=self.vlm_concurrency)
|
|
except Exception as exc: # noqa: BLE001
|
|
log.warning("VLM refinement failed, keeping geometric manifest: %s", exc)
|
|
return result
|
|
|
|
async def ocr_raw(self, file_path: Path | str) -> dict:
|
|
return await self._post_ocr(file_path)
|