"""Bridge-side resilient client wrappers. Keep retry policy configurable in leaudit-platform without changing the upstream leaudit package. """ from __future__ import annotations import asyncio import base64 import logging import time from pathlib import Path from typing import Any import httpx from leaudit.llm.base import LLMRequest, LLMResponse from leaudit.llm.openai_client import OpenAIClientError, OpenAICompatibleClient, _parse_response from leaudit.llm.qwen_vlm_client import QwenVLMClient, VisualClassification, _SYSTEM_PROMPT, _parse_json_loose from leaudit.ocr.chandra_client import ChandraOCRError, ChandraOCRClient, _guess_mime, _parse_response as parse_ocr_response log = logging.getLogger(__name__) _RETRYABLE_HTTP_STATUSES = {408, 429, 500, 502, 503, 504} def _retry_delay(base_seconds: float, attempt: int) -> float: """Return exponential backoff delay for retry attempt index.""" return max(0.0, float(base_seconds)) * (2 ** max(0, int(attempt))) def _should_retry_status(status_code: int) -> bool: return status_code in _RETRYABLE_HTTP_STATUSES or status_code >= 500 class ResilientOpenAICompatibleClient(OpenAICompatibleClient): """OpenAI-compatible client with configurable retry count/backoff.""" def __init__( self, *, retry_max_attempts: int = 3, retry_backoff_base_seconds: float = 1.0, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.retry_max_attempts = max(1, int(retry_max_attempts)) self.retry_backoff_base_seconds = max(0.0, float(retry_backoff_base_seconds)) async def complete(self, request: LLMRequest) -> LLMResponse: self.call_history.append(request) model = request.model if request.model != "default" else self.default_model payload: dict[str, Any] = { "model": model, "messages": [{"role": m.role, "content": m.content} for m in request.messages], "temperature": request.temperature, "max_tokens": request.max_tokens, } if request.return_logprobs: payload["logprobs"] = True payload["top_logprobs"] = 3 if request.response_format: payload["response_format"] = request.response_format thinking = request.enable_thinking if request.enable_thinking is not None else self.enable_thinking if not thinking: payload["enable_thinking"] = False req_timeout = min(self.timeout, request.timeout_ms / 1000) last_error: Exception | None = None for attempt in range(self.retry_max_attempts): t0 = time.monotonic() try: response = await self._client.post( f"{self.base_url}/chat/completions", json=payload, headers=self._headers, timeout=req_timeout, ) except httpx.TimeoutException as exc: last_error = TimeoutError(f"LLM request timed out: {exc}") if attempt < self.retry_max_attempts - 1: await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt)) continue raise last_error from exc except self._RETRYABLE_EXCEPTIONS as exc: last_error = OpenAIClientError(f"Network error: {exc}") if attempt < self.retry_max_attempts - 1: await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt)) continue raise last_error from exc duration_ms = int((time.monotonic() - t0) * 1000) if _should_retry_status(response.status_code) and attempt < self.retry_max_attempts - 1: await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt)) continue if response.status_code != 200: raise OpenAIClientError(f"API returned {response.status_code}: {response.text[:500]}") body = response.json() if "error" in body: raise OpenAIClientError(f"API error: {body['error']}") return _parse_response(body, duration_ms) raise last_error or OpenAIClientError("LLM request failed after retries") class ResilientQwenVLMClient(QwenVLMClient): """VLM client with configurable retry count/backoff.""" _RETRYABLE_EXCEPTIONS = ( httpx.ConnectError, httpx.RemoteProtocolError, httpx.ReadError, httpx.ReadTimeout, httpx.TimeoutException, ) def __init__( self, *, retry_max_attempts: int = 2, retry_backoff_base_seconds: float = 1.0, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.retry_max_attempts = max(1, int(retry_max_attempts)) self.retry_backoff_base_seconds = max(0.0, float(retry_backoff_base_seconds)) async def _post_with_retry(self, payload: dict[str, Any]) -> httpx.Response: last_error: Exception | None = None for attempt in range(self.retry_max_attempts): try: response = await self._client.post( f"{self.base_url}/chat/completions", json=payload, headers={k: v for k, v in self._headers.items() if v}, ) except self._RETRYABLE_EXCEPTIONS as exc: last_error = exc if attempt < self.retry_max_attempts - 1: await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt)) continue raise if _should_retry_status(response.status_code) and attempt < self.retry_max_attempts - 1: await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt)) continue return response raise last_error or RuntimeError("VLM request failed after retries") async def classify_visual( self, image_bytes: bytes, user_hint: str | None = None, max_tokens: int = 300, ) -> VisualClassification: b64 = base64.b64encode(image_bytes).decode() url = f"data:image/png;base64,{b64}" user_text = "请判断此区域内容并输出 JSON。" if user_hint: user_text += f"\n上下文:{user_hint}" payload: dict[str, Any] = { "model": self.model, "messages": [ {"role": "system", "content": _SYSTEM_PROMPT}, { "role": "user", "content": [ {"type": "image_url", "image_url": {"url": url}}, {"type": "text", "text": user_text}, ], }, ], "max_tokens": max_tokens, "temperature": 0.0, } try: response = await self._post_with_retry(payload) except Exception as exc: # noqa: BLE001 return VisualClassification(kind="other", confidence=0.0, raw=f"retry_failed: {exc}") if response.status_code != 200: return VisualClassification(kind="other", confidence=0.0, raw=f"HTTP {response.status_code}: {response.text[:300]}") body = response.json() content = (body.get("choices") or [{}])[0].get("message", {}).get("content", "") parsed = _parse_json_loose(content) if parsed is None: return VisualClassification(kind="other", confidence=0.0, raw=content) kind = parsed.get("kind", "other") if kind not in {"seal", "signature", "other"}: kind = "other" return VisualClassification( kind=kind, seal_type=parsed.get("seal_type"), text=parsed.get("text"), is_complete=parsed.get("is_complete"), confidence=0.9 if kind != "other" else 0.3, raw=content, ) async def extract_multifield( self, *, prompt: str, images_data_urls: list[str], max_tokens: int = 800, ) -> dict: content: list[dict] = [{"type": "text", "text": prompt}] for url in images_data_urls: content.append({"type": "image_url", "image_url": {"url": url}}) payload: dict[str, Any] = { "model": self.model, "messages": [{"role": "user", "content": content}], "max_tokens": max_tokens, "temperature": 0.0, } try: response = await self._post_with_retry(payload) except Exception: return {} if response.status_code != 200: return {} body = response.json() text = (body.get("choices") or [{}])[0].get("message", {}).get("content", "") parsed = _parse_json_loose(text) return parsed if isinstance(parsed, dict) else {"result": text, "reason": text} class ResilientChandraOCRClient(ChandraOCRClient): """OCR client with configurable retry count/backoff.""" _RETRYABLE_EXCEPTIONS = ( httpx.ConnectError, httpx.RemoteProtocolError, httpx.ReadError, httpx.ReadTimeout, httpx.TimeoutException, ) def __init__( self, *, retry_max_attempts: int = 3, retry_backoff_base_seconds: float = 1.0, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.retry_max_attempts = max(1, int(retry_max_attempts)) self.retry_backoff_base_seconds = max(0.0, float(retry_backoff_base_seconds)) async def _post_ocr(self, file_path: Path | str) -> dict: path = Path(file_path) if not path.exists(): raise FileNotFoundError(f"File not found: {path}") last_error: Exception | None = None for attempt in range(self.retry_max_attempts): try: async with httpx.AsyncClient(timeout=self.timeout, verify=False) as client: with open(path, "rb") as file_obj: files = {"file": (path.name, file_obj, _guess_mime(path))} data = {"include_images": str(self.include_images).lower()} response = await client.post( f"{self.base_url}/chandra/ocr", files=files, data=data, ) except self._RETRYABLE_EXCEPTIONS as exc: last_error = exc if attempt < self.retry_max_attempts - 1: await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt)) continue raise ChandraOCRError(f"OCR request failed after retries: {exc}") from exc if _should_retry_status(response.status_code) and attempt < self.retry_max_attempts - 1: await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt)) continue if response.status_code != 200: raise ChandraOCRError(f"OCR API returned {response.status_code}: {response.text[:500]}") body = response.json() if "error" in body: raise ChandraOCRError(f"OCR error: {body['error']}") return body raise ChandraOCRError(f"OCR request failed after retries: {last_error}") async def ocr(self, file_path: Path | str): body = await self._post_ocr(file_path) result = parse_ocr_response(body) if self.vlm_client is not None and ( result.visual_manifest.seals or result.visual_manifest.signatures or result.visual_manifest.cross_page_seals ): try: from leaudit.ocr.visual_classifier import refine_visual_manifest await refine_visual_manifest(result, self.vlm_client, concurrency=self.vlm_concurrency) except Exception as exc: # noqa: BLE001 log.warning("VLM refinement failed, keeping geometric manifest: %s", exc) return result async def ocr_raw(self, file_path: Path | str) -> dict: return await self._post_ocr(file_path)