feat: add async worker queues and retry controls
This commit is contained in:
@@ -0,0 +1,322 @@
|
||||
"""Bridge-side resilient client wrappers.
|
||||
|
||||
Keep retry policy configurable in leaudit-platform without changing the
|
||||
upstream leaudit package.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from leaudit.llm.base import LLMRequest, LLMResponse
|
||||
from leaudit.llm.openai_client import OpenAIClientError, OpenAICompatibleClient, _parse_response
|
||||
from leaudit.llm.qwen_vlm_client import QwenVLMClient, VisualClassification, _SYSTEM_PROMPT, _parse_json_loose
|
||||
from leaudit.ocr.chandra_client import ChandraOCRError, ChandraOCRClient, _guess_mime, _parse_response as parse_ocr_response
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_RETRYABLE_HTTP_STATUSES = {408, 429, 500, 502, 503, 504}
|
||||
|
||||
|
||||
def _retry_delay(base_seconds: float, attempt: int) -> float:
|
||||
"""Return exponential backoff delay for retry attempt index."""
|
||||
return max(0.0, float(base_seconds)) * (2 ** max(0, int(attempt)))
|
||||
|
||||
|
||||
def _should_retry_status(status_code: int) -> bool:
|
||||
return status_code in _RETRYABLE_HTTP_STATUSES or status_code >= 500
|
||||
|
||||
|
||||
class ResilientOpenAICompatibleClient(OpenAICompatibleClient):
|
||||
"""OpenAI-compatible client with configurable retry count/backoff."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
retry_max_attempts: int = 3,
|
||||
retry_backoff_base_seconds: float = 1.0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.retry_max_attempts = max(1, int(retry_max_attempts))
|
||||
self.retry_backoff_base_seconds = max(0.0, float(retry_backoff_base_seconds))
|
||||
|
||||
async def complete(self, request: LLMRequest) -> LLMResponse:
|
||||
self.call_history.append(request)
|
||||
|
||||
model = request.model if request.model != "default" else self.default_model
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": [{"role": m.role, "content": m.content} for m in request.messages],
|
||||
"temperature": request.temperature,
|
||||
"max_tokens": request.max_tokens,
|
||||
}
|
||||
if request.return_logprobs:
|
||||
payload["logprobs"] = True
|
||||
payload["top_logprobs"] = 3
|
||||
if request.response_format:
|
||||
payload["response_format"] = request.response_format
|
||||
|
||||
thinking = request.enable_thinking if request.enable_thinking is not None else self.enable_thinking
|
||||
if not thinking:
|
||||
payload["enable_thinking"] = False
|
||||
|
||||
req_timeout = min(self.timeout, request.timeout_ms / 1000)
|
||||
last_error: Exception | None = None
|
||||
|
||||
for attempt in range(self.retry_max_attempts):
|
||||
t0 = time.monotonic()
|
||||
try:
|
||||
response = await self._client.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
json=payload,
|
||||
headers=self._headers,
|
||||
timeout=req_timeout,
|
||||
)
|
||||
except httpx.TimeoutException as exc:
|
||||
last_error = TimeoutError(f"LLM request timed out: {exc}")
|
||||
if attempt < self.retry_max_attempts - 1:
|
||||
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
|
||||
continue
|
||||
raise last_error from exc
|
||||
except self._RETRYABLE_EXCEPTIONS as exc:
|
||||
last_error = OpenAIClientError(f"Network error: {exc}")
|
||||
if attempt < self.retry_max_attempts - 1:
|
||||
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
|
||||
continue
|
||||
raise last_error from exc
|
||||
|
||||
duration_ms = int((time.monotonic() - t0) * 1000)
|
||||
if _should_retry_status(response.status_code) and attempt < self.retry_max_attempts - 1:
|
||||
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
|
||||
continue
|
||||
|
||||
if response.status_code != 200:
|
||||
raise OpenAIClientError(f"API returned {response.status_code}: {response.text[:500]}")
|
||||
|
||||
body = response.json()
|
||||
if "error" in body:
|
||||
raise OpenAIClientError(f"API error: {body['error']}")
|
||||
return _parse_response(body, duration_ms)
|
||||
|
||||
raise last_error or OpenAIClientError("LLM request failed after retries")
|
||||
|
||||
|
||||
class ResilientQwenVLMClient(QwenVLMClient):
|
||||
"""VLM client with configurable retry count/backoff."""
|
||||
|
||||
_RETRYABLE_EXCEPTIONS = (
|
||||
httpx.ConnectError,
|
||||
httpx.RemoteProtocolError,
|
||||
httpx.ReadError,
|
||||
httpx.ReadTimeout,
|
||||
httpx.TimeoutException,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
retry_max_attempts: int = 2,
|
||||
retry_backoff_base_seconds: float = 1.0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.retry_max_attempts = max(1, int(retry_max_attempts))
|
||||
self.retry_backoff_base_seconds = max(0.0, float(retry_backoff_base_seconds))
|
||||
|
||||
async def _post_with_retry(self, payload: dict[str, Any]) -> httpx.Response:
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(self.retry_max_attempts):
|
||||
try:
|
||||
response = await self._client.post(
|
||||
f"{self.base_url}/chat/completions",
|
||||
json=payload,
|
||||
headers={k: v for k, v in self._headers.items() if v},
|
||||
)
|
||||
except self._RETRYABLE_EXCEPTIONS as exc:
|
||||
last_error = exc
|
||||
if attempt < self.retry_max_attempts - 1:
|
||||
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
|
||||
continue
|
||||
raise
|
||||
|
||||
if _should_retry_status(response.status_code) and attempt < self.retry_max_attempts - 1:
|
||||
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
|
||||
continue
|
||||
return response
|
||||
|
||||
raise last_error or RuntimeError("VLM request failed after retries")
|
||||
|
||||
async def classify_visual(
|
||||
self,
|
||||
image_bytes: bytes,
|
||||
user_hint: str | None = None,
|
||||
max_tokens: int = 300,
|
||||
) -> VisualClassification:
|
||||
b64 = base64.b64encode(image_bytes).decode()
|
||||
url = f"data:image/png;base64,{b64}"
|
||||
|
||||
user_text = "请判断此区域内容并输出 JSON。"
|
||||
if user_hint:
|
||||
user_text += f"\n上下文:{user_hint}"
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": [
|
||||
{"role": "system", "content": _SYSTEM_PROMPT},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "image_url", "image_url": {"url": url}},
|
||||
{"type": "text", "text": user_text},
|
||||
],
|
||||
},
|
||||
],
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": 0.0,
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self._post_with_retry(payload)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
return VisualClassification(kind="other", confidence=0.0, raw=f"retry_failed: {exc}")
|
||||
|
||||
if response.status_code != 200:
|
||||
return VisualClassification(kind="other", confidence=0.0, raw=f"HTTP {response.status_code}: {response.text[:300]}")
|
||||
|
||||
body = response.json()
|
||||
content = (body.get("choices") or [{}])[0].get("message", {}).get("content", "")
|
||||
parsed = _parse_json_loose(content)
|
||||
if parsed is None:
|
||||
return VisualClassification(kind="other", confidence=0.0, raw=content)
|
||||
|
||||
kind = parsed.get("kind", "other")
|
||||
if kind not in {"seal", "signature", "other"}:
|
||||
kind = "other"
|
||||
|
||||
return VisualClassification(
|
||||
kind=kind,
|
||||
seal_type=parsed.get("seal_type"),
|
||||
text=parsed.get("text"),
|
||||
is_complete=parsed.get("is_complete"),
|
||||
confidence=0.9 if kind != "other" else 0.3,
|
||||
raw=content,
|
||||
)
|
||||
|
||||
async def extract_multifield(
|
||||
self,
|
||||
*,
|
||||
prompt: str,
|
||||
images_data_urls: list[str],
|
||||
max_tokens: int = 800,
|
||||
) -> dict:
|
||||
content: list[dict] = [{"type": "text", "text": prompt}]
|
||||
for url in images_data_urls:
|
||||
content.append({"type": "image_url", "image_url": {"url": url}})
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": [{"role": "user", "content": content}],
|
||||
"max_tokens": max_tokens,
|
||||
"temperature": 0.0,
|
||||
}
|
||||
|
||||
try:
|
||||
response = await self._post_with_retry(payload)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
if response.status_code != 200:
|
||||
return {}
|
||||
|
||||
body = response.json()
|
||||
text = (body.get("choices") or [{}])[0].get("message", {}).get("content", "")
|
||||
parsed = _parse_json_loose(text)
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
|
||||
|
||||
class ResilientChandraOCRClient(ChandraOCRClient):
|
||||
"""OCR client with configurable retry count/backoff."""
|
||||
|
||||
_RETRYABLE_EXCEPTIONS = (
|
||||
httpx.ConnectError,
|
||||
httpx.RemoteProtocolError,
|
||||
httpx.ReadError,
|
||||
httpx.ReadTimeout,
|
||||
httpx.TimeoutException,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
retry_max_attempts: int = 3,
|
||||
retry_backoff_base_seconds: float = 1.0,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.retry_max_attempts = max(1, int(retry_max_attempts))
|
||||
self.retry_backoff_base_seconds = max(0.0, float(retry_backoff_base_seconds))
|
||||
|
||||
async def _post_ocr(self, file_path: Path | str) -> dict:
|
||||
path = Path(file_path)
|
||||
if not path.exists():
|
||||
raise FileNotFoundError(f"File not found: {path}")
|
||||
|
||||
last_error: Exception | None = None
|
||||
for attempt in range(self.retry_max_attempts):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=self.timeout, verify=False) as client:
|
||||
with open(path, "rb") as file_obj:
|
||||
files = {"file": (path.name, file_obj, _guess_mime(path))}
|
||||
data = {"include_images": str(self.include_images).lower()}
|
||||
response = await client.post(
|
||||
f"{self.base_url}/ocr",
|
||||
files=files,
|
||||
data=data,
|
||||
)
|
||||
except self._RETRYABLE_EXCEPTIONS as exc:
|
||||
last_error = exc
|
||||
if attempt < self.retry_max_attempts - 1:
|
||||
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
|
||||
continue
|
||||
raise ChandraOCRError(f"OCR request failed after retries: {exc}") from exc
|
||||
|
||||
if _should_retry_status(response.status_code) and attempt < self.retry_max_attempts - 1:
|
||||
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
|
||||
continue
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ChandraOCRError(f"OCR API returned {response.status_code}: {response.text[:500]}")
|
||||
|
||||
body = response.json()
|
||||
if "error" in body:
|
||||
raise ChandraOCRError(f"OCR error: {body['error']}")
|
||||
return body
|
||||
|
||||
raise ChandraOCRError(f"OCR request failed after retries: {last_error}")
|
||||
|
||||
async def ocr(self, file_path: Path | str):
|
||||
body = await self._post_ocr(file_path)
|
||||
result = parse_ocr_response(body)
|
||||
if self.vlm_client is not None and (
|
||||
result.visual_manifest.seals
|
||||
or result.visual_manifest.signatures
|
||||
or result.visual_manifest.cross_page_seals
|
||||
):
|
||||
try:
|
||||
from leaudit.ocr.visual_classifier import refine_visual_manifest
|
||||
|
||||
await refine_visual_manifest(result, self.vlm_client, concurrency=self.vlm_concurrency)
|
||||
except Exception as exc: # noqa: BLE001
|
||||
log.warning("VLM refinement failed, keeping geometric manifest: %s", exc)
|
||||
return result
|
||||
|
||||
async def ocr_raw(self, file_path: Path | str) -> dict:
|
||||
return await self._post_ocr(file_path)
|
||||
Reference in New Issue
Block a user