feat: add async worker queues and retry controls

This commit is contained in:
wren
2026-04-29 11:48:09 +08:00
parent e738398eb6
commit f3b83c9979
16 changed files with 1316 additions and 96 deletions
@@ -26,6 +26,7 @@ class AuditController(BaseController):
DocumentId=body.documentId,
RuleType=body.ruleType,
Force=body.force,
Speed=body.speed,
)
return Result.success(data=run)
@@ -9,3 +9,4 @@ class AuditRunDTO(BaseModel):
documentId: int = Field(..., description="文档ID")
ruleType: str | None = Field(None, description="指定规则类型编码")
force: bool = Field(False, description="是否强制重跑")
speed: str = Field("normal", description="执行速度档位:urgent/normal")
@@ -25,6 +25,7 @@ from leaudit.services.evaluation_service import EvaluationService
from leaudit.services.extraction_service import ExtractionService
from leaudit.services.rescue_service import RescueService
from fastapi_admin.config import LEAUDIT_OCR_VLM_CONCURRENCY
from fastapi_modules.fastapi_leaudit.leaudit_bridge.client_factory import (
create_llm_client,
create_ocr_client,
@@ -118,7 +119,11 @@ class AuditServiceFactory:
vlm_client=vlm_client,
force_rules_path=rules_path,
)
ocr_client = BridgeOCRClient(adapter, vlm_client=vlm_client)
ocr_client = BridgeOCRClient(
adapter,
vlm_client=vlm_client,
vlm_concurrency=LEAUDIT_OCR_VLM_CONCURRENCY,
)
normalization_service = DocNormalizationService(ocr_client)
audit_services = AuditServices(
llm_client=llm_client,
@@ -8,6 +8,15 @@ from typing import TYPE_CHECKING
from fastapi_admin.config import (
OCR_BASE_URL,
OCR_TIMEOUT,
LEAUDIT_LLM_REQUEST_TIMEOUT,
LEAUDIT_LLM_RETRY_BACKOFF_BASE_SECONDS,
LEAUDIT_LLM_RETRY_MAX_ATTEMPTS,
LEAUDIT_OCR_VLM_CONCURRENCY,
LEAUDIT_OCR_RETRY_BACKOFF_BASE_SECONDS,
LEAUDIT_OCR_RETRY_MAX_ATTEMPTS,
LEAUDIT_VLM_REQUEST_TIMEOUT,
LEAUDIT_VLM_RETRY_BACKOFF_BASE_SECONDS,
LEAUDIT_VLM_RETRY_MAX_ATTEMPTS,
LLM_BASE_URL,
LLM_MODEL,
LLM_API_KEY,
@@ -15,6 +24,11 @@ from fastapi_admin.config import (
VLM_MODEL,
VLM_API_KEY,
)
from fastapi_modules.fastapi_leaudit.leaudit_bridge.resilient_clients import (
ResilientChandraOCRClient,
ResilientOpenAICompatibleClient,
ResilientQwenVLMClient,
)
if TYPE_CHECKING:
from leaudit.llm.base import BaseLLMClient
@@ -27,44 +41,58 @@ log = logging.getLogger(__name__)
def create_ocr_client() -> BaseOCRClient:
"""Create a leaudit ChandraOCRClient from LEAUDIT_OCR_URL config."""
import os
from leaudit.ocr.chandra_client import ChandraOCRClient
base_url = os.getenv("LEAUDIT_OCR_URL", "").rstrip("/")
if not base_url:
base_url = OCR_BASE_URL.rstrip("/")
timeout = float(OCR_TIMEOUT)
client = ChandraOCRClient(
client = ResilientChandraOCRClient(
base_url=base_url,
timeout=timeout,
include_images=True,
vlm_concurrency=LEAUDIT_OCR_VLM_CONCURRENCY,
retry_max_attempts=LEAUDIT_OCR_RETRY_MAX_ATTEMPTS,
retry_backoff_base_seconds=LEAUDIT_OCR_RETRY_BACKOFF_BASE_SECONDS,
)
log.info(
"leaudit OCR client created: %s (timeout=%ss, vlm_concurrency=%s, retry_max_attempts=%s, retry_backoff_base_seconds=%s)",
base_url,
timeout,
LEAUDIT_OCR_VLM_CONCURRENCY,
LEAUDIT_OCR_RETRY_MAX_ATTEMPTS,
LEAUDIT_OCR_RETRY_BACKOFF_BASE_SECONDS,
)
log.info("leaudit OCR client created: %s (timeout=%ss)", base_url, timeout)
return client
def create_llm_client() -> BaseLLMClient:
"""Create a leaudit OpenAICompatibleClient from docauditai's LLM config."""
from leaudit.llm.openai_client import OpenAICompatibleClient
base_url = LLM_BASE_URL
model = LLM_MODEL
api_key = LLM_API_KEY or "no-key"
client = OpenAICompatibleClient(
client = ResilientOpenAICompatibleClient(
api_key=api_key,
base_url=base_url,
default_model=model,
timeout=120.0,
timeout=float(LEAUDIT_LLM_REQUEST_TIMEOUT),
retry_max_attempts=LEAUDIT_LLM_RETRY_MAX_ATTEMPTS,
retry_backoff_base_seconds=LEAUDIT_LLM_RETRY_BACKOFF_BASE_SECONDS,
)
log.info(
"leaudit LLM client created: %s (model=%s, timeout=%ss, retry_max_attempts=%s, retry_backoff_base_seconds=%s)",
base_url,
model,
LEAUDIT_LLM_REQUEST_TIMEOUT,
LEAUDIT_LLM_RETRY_MAX_ATTEMPTS,
LEAUDIT_LLM_RETRY_BACKOFF_BASE_SECONDS,
)
log.info("leaudit LLM client created: %s (model=%s)", base_url, model)
return client
def create_vlm_client() -> BaseVLMClient | None:
"""Create a leaudit QwenVLMClient from docauditai's VLM config."""
from leaudit.llm.qwen_vlm_client import QwenVLMClient
base_url = VLM_BASE_URL
model = VLM_MODEL
api_key = VLM_API_KEY or LLM_API_KEY or "no-key"
@@ -73,10 +101,20 @@ def create_vlm_client() -> BaseVLMClient | None:
log.info("leaudit VLM client skipped: no VLM config")
return None
client = QwenVLMClient(
client = ResilientQwenVLMClient(
base_url=base_url,
api_key=api_key,
model=model,
timeout=float(LEAUDIT_VLM_REQUEST_TIMEOUT),
retry_max_attempts=LEAUDIT_VLM_RETRY_MAX_ATTEMPTS,
retry_backoff_base_seconds=LEAUDIT_VLM_RETRY_BACKOFF_BASE_SECONDS,
)
log.info(
"leaudit VLM client created: %s (model=%s, timeout=%ss, retry_max_attempts=%s, retry_backoff_base_seconds=%s)",
base_url,
model,
LEAUDIT_VLM_REQUEST_TIMEOUT,
LEAUDIT_VLM_RETRY_MAX_ATTEMPTS,
LEAUDIT_VLM_RETRY_BACKOFF_BASE_SECONDS,
)
log.info("leaudit VLM client created: %s (model=%s)", base_url, model)
return client
@@ -9,10 +9,17 @@ Keeps docauditai-specific fixes outside ``services/leaudit/**``:
from __future__ import annotations
import asyncio
import logging
from io import BytesIO
from pathlib import Path
from fastapi_admin.config import (
LEAUDIT_SIGNATURE_PROBE_CONCURRENCY,
LEAUDIT_SIGNATURE_PROBE_RETRY_BACKOFF_BASE_SECONDS,
LEAUDIT_SIGNATURE_PROBE_RETRY_MAX_ATTEMPTS,
LEAUDIT_SIGNATURE_PROBE_TIMEOUT,
)
from leaudit.ocr.base import BaseOCRClient
from leaudit.ocr.models import OcrResult, VisualManifestItem
@@ -120,18 +127,20 @@ async def _inject_docx_signature_candidates(
if parent_key:
parent_to_items.setdefault(parent_key, []).append(item)
for parent_key, items in parent_to_items.items():
sem = asyncio.Semaphore(max(1, int(LEAUDIT_SIGNATURE_PROBE_CONCURRENCY)))
async def _probe_parent(parent_key: str, items: list[VisualManifestItem]) -> None:
if any((it.label or "") == "signature" for it in items):
continue
return
parent_bytes = ocr_result.get_image_bytes(parent_key)
if not parent_bytes:
continue
return
try:
image = Image.open(BytesIO(parent_bytes))
except Exception as exc:
log.warning("failed to open parent image %s: %s", parent_key, exc)
continue
return
width, height = image.size
for candidate_bbox in _signature_candidate_boxes(items, width, height):
@@ -139,11 +148,13 @@ async def _inject_docx_signature_candidates(
crop = image.crop(tuple(candidate_bbox))
buf = BytesIO()
crop.save(buf, format="PNG")
result = await _classify_signature_candidate(
vlm_client,
buf.getvalue(),
"这是合同签章页里疑似法人签名的候选区域,请优先判断是否为手写签名。",
)
async with sem:
result = await _classify_signature_candidate(
vlm_client,
buf.getvalue(),
"这是合同签章页里疑似法人签名的候选区域,请优先判断是否为手写签名。",
parent_key=parent_key,
)
except Exception as exc:
log.warning("signature probe failed for %s: %s", parent_key, exc)
continue
@@ -166,33 +177,66 @@ async def _inject_docx_signature_candidates(
)
break
if parent_to_items:
await asyncio.gather(
*(_probe_parent(parent_key, items) for parent_key, items in parent_to_items.items()),
return_exceptions=False,
)
async def _classify_signature_candidate(
vlm_client: object,
image_bytes: bytes,
user_hint: str,
*,
parent_key: str | None = None,
) -> object:
"""Classify with one retry using a fresh VLM client when needed."""
try:
return await vlm_client.classify_visual(image_bytes, user_hint=user_hint)
except Exception as exc:
log.warning("signature probe primary VLM failed, retrying fresh client: %s", exc)
"""Classify with configurable retry using a fresh VLM client when needed."""
timeout = max(1, int(LEAUDIT_SIGNATURE_PROBE_TIMEOUT))
max_attempts = max(1, int(LEAUDIT_SIGNATURE_PROBE_RETRY_MAX_ATTEMPTS))
backoff_base = max(0.0, float(LEAUDIT_SIGNATURE_PROBE_RETRY_BACKOFF_BASE_SECONDS))
last_error: Exception | None = None
try:
from leaudit.llm.qwen_vlm_client import QwenVLMClient
fresh = QwenVLMClient(
base_url=getattr(vlm_client, "base_url"),
api_key=getattr(vlm_client, "api_key", ""),
model=getattr(vlm_client, "model"),
timeout=getattr(vlm_client, "timeout", 90.0),
)
for attempt in range(max_attempts):
current_client = vlm_client
fresh = None
try:
return await fresh.classify_visual(image_bytes, user_hint=user_hint)
if attempt > 0:
from fastapi_modules.fastapi_leaudit.leaudit_bridge.resilient_clients import ResilientQwenVLMClient
fresh = ResilientQwenVLMClient(
base_url=getattr(vlm_client, "base_url"),
api_key=getattr(vlm_client, "api_key", ""),
model=getattr(vlm_client, "model"),
timeout=getattr(vlm_client, "timeout", 90.0),
retry_max_attempts=1,
retry_backoff_base_seconds=0.0,
)
current_client = fresh
return await asyncio.wait_for(
current_client.classify_visual(image_bytes, user_hint=user_hint),
timeout=timeout,
)
except Exception as exc:
last_error = exc
if attempt < max_attempts - 1:
log.warning(
"signature probe attempt %s/%s failed for %s, retrying after %.2fs (timeout=%ss): %s",
attempt + 1,
max_attempts,
parent_key or "-",
backoff_base * (2 ** attempt),
timeout,
exc,
)
await asyncio.sleep(backoff_base * (2 ** attempt))
continue
finally:
await fresh.close()
except Exception as exc:
raise RuntimeError(exc) from exc
if fresh is not None:
await fresh.close()
raise RuntimeError(last_error) from last_error
def _signature_candidate_boxes(
@@ -0,0 +1,322 @@
"""Bridge-side resilient client wrappers.
Keep retry policy configurable in leaudit-platform without changing the
upstream leaudit package.
"""
from __future__ import annotations
import asyncio
import base64
import logging
import time
from pathlib import Path
from typing import Any
import httpx
from leaudit.llm.base import LLMRequest, LLMResponse
from leaudit.llm.openai_client import OpenAIClientError, OpenAICompatibleClient, _parse_response
from leaudit.llm.qwen_vlm_client import QwenVLMClient, VisualClassification, _SYSTEM_PROMPT, _parse_json_loose
from leaudit.ocr.chandra_client import ChandraOCRError, ChandraOCRClient, _guess_mime, _parse_response as parse_ocr_response
log = logging.getLogger(__name__)
_RETRYABLE_HTTP_STATUSES = {408, 429, 500, 502, 503, 504}
def _retry_delay(base_seconds: float, attempt: int) -> float:
"""Return exponential backoff delay for retry attempt index."""
return max(0.0, float(base_seconds)) * (2 ** max(0, int(attempt)))
def _should_retry_status(status_code: int) -> bool:
return status_code in _RETRYABLE_HTTP_STATUSES or status_code >= 500
class ResilientOpenAICompatibleClient(OpenAICompatibleClient):
"""OpenAI-compatible client with configurable retry count/backoff."""
def __init__(
self,
*,
retry_max_attempts: int = 3,
retry_backoff_base_seconds: float = 1.0,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.retry_max_attempts = max(1, int(retry_max_attempts))
self.retry_backoff_base_seconds = max(0.0, float(retry_backoff_base_seconds))
async def complete(self, request: LLMRequest) -> LLMResponse:
self.call_history.append(request)
model = request.model if request.model != "default" else self.default_model
payload: dict[str, Any] = {
"model": model,
"messages": [{"role": m.role, "content": m.content} for m in request.messages],
"temperature": request.temperature,
"max_tokens": request.max_tokens,
}
if request.return_logprobs:
payload["logprobs"] = True
payload["top_logprobs"] = 3
if request.response_format:
payload["response_format"] = request.response_format
thinking = request.enable_thinking if request.enable_thinking is not None else self.enable_thinking
if not thinking:
payload["enable_thinking"] = False
req_timeout = min(self.timeout, request.timeout_ms / 1000)
last_error: Exception | None = None
for attempt in range(self.retry_max_attempts):
t0 = time.monotonic()
try:
response = await self._client.post(
f"{self.base_url}/chat/completions",
json=payload,
headers=self._headers,
timeout=req_timeout,
)
except httpx.TimeoutException as exc:
last_error = TimeoutError(f"LLM request timed out: {exc}")
if attempt < self.retry_max_attempts - 1:
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
continue
raise last_error from exc
except self._RETRYABLE_EXCEPTIONS as exc:
last_error = OpenAIClientError(f"Network error: {exc}")
if attempt < self.retry_max_attempts - 1:
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
continue
raise last_error from exc
duration_ms = int((time.monotonic() - t0) * 1000)
if _should_retry_status(response.status_code) and attempt < self.retry_max_attempts - 1:
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
continue
if response.status_code != 200:
raise OpenAIClientError(f"API returned {response.status_code}: {response.text[:500]}")
body = response.json()
if "error" in body:
raise OpenAIClientError(f"API error: {body['error']}")
return _parse_response(body, duration_ms)
raise last_error or OpenAIClientError("LLM request failed after retries")
class ResilientQwenVLMClient(QwenVLMClient):
"""VLM client with configurable retry count/backoff."""
_RETRYABLE_EXCEPTIONS = (
httpx.ConnectError,
httpx.RemoteProtocolError,
httpx.ReadError,
httpx.ReadTimeout,
httpx.TimeoutException,
)
def __init__(
self,
*,
retry_max_attempts: int = 2,
retry_backoff_base_seconds: float = 1.0,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.retry_max_attempts = max(1, int(retry_max_attempts))
self.retry_backoff_base_seconds = max(0.0, float(retry_backoff_base_seconds))
async def _post_with_retry(self, payload: dict[str, Any]) -> httpx.Response:
last_error: Exception | None = None
for attempt in range(self.retry_max_attempts):
try:
response = await self._client.post(
f"{self.base_url}/chat/completions",
json=payload,
headers={k: v for k, v in self._headers.items() if v},
)
except self._RETRYABLE_EXCEPTIONS as exc:
last_error = exc
if attempt < self.retry_max_attempts - 1:
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
continue
raise
if _should_retry_status(response.status_code) and attempt < self.retry_max_attempts - 1:
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
continue
return response
raise last_error or RuntimeError("VLM request failed after retries")
async def classify_visual(
self,
image_bytes: bytes,
user_hint: str | None = None,
max_tokens: int = 300,
) -> VisualClassification:
b64 = base64.b64encode(image_bytes).decode()
url = f"data:image/png;base64,{b64}"
user_text = "请判断此区域内容并输出 JSON。"
if user_hint:
user_text += f"\n上下文:{user_hint}"
payload: dict[str, Any] = {
"model": self.model,
"messages": [
{"role": "system", "content": _SYSTEM_PROMPT},
{
"role": "user",
"content": [
{"type": "image_url", "image_url": {"url": url}},
{"type": "text", "text": user_text},
],
},
],
"max_tokens": max_tokens,
"temperature": 0.0,
}
try:
response = await self._post_with_retry(payload)
except Exception as exc: # noqa: BLE001
return VisualClassification(kind="other", confidence=0.0, raw=f"retry_failed: {exc}")
if response.status_code != 200:
return VisualClassification(kind="other", confidence=0.0, raw=f"HTTP {response.status_code}: {response.text[:300]}")
body = response.json()
content = (body.get("choices") or [{}])[0].get("message", {}).get("content", "")
parsed = _parse_json_loose(content)
if parsed is None:
return VisualClassification(kind="other", confidence=0.0, raw=content)
kind = parsed.get("kind", "other")
if kind not in {"seal", "signature", "other"}:
kind = "other"
return VisualClassification(
kind=kind,
seal_type=parsed.get("seal_type"),
text=parsed.get("text"),
is_complete=parsed.get("is_complete"),
confidence=0.9 if kind != "other" else 0.3,
raw=content,
)
async def extract_multifield(
self,
*,
prompt: str,
images_data_urls: list[str],
max_tokens: int = 800,
) -> dict:
content: list[dict] = [{"type": "text", "text": prompt}]
for url in images_data_urls:
content.append({"type": "image_url", "image_url": {"url": url}})
payload: dict[str, Any] = {
"model": self.model,
"messages": [{"role": "user", "content": content}],
"max_tokens": max_tokens,
"temperature": 0.0,
}
try:
response = await self._post_with_retry(payload)
except Exception:
return {}
if response.status_code != 200:
return {}
body = response.json()
text = (body.get("choices") or [{}])[0].get("message", {}).get("content", "")
parsed = _parse_json_loose(text)
return parsed if isinstance(parsed, dict) else {}
class ResilientChandraOCRClient(ChandraOCRClient):
"""OCR client with configurable retry count/backoff."""
_RETRYABLE_EXCEPTIONS = (
httpx.ConnectError,
httpx.RemoteProtocolError,
httpx.ReadError,
httpx.ReadTimeout,
httpx.TimeoutException,
)
def __init__(
self,
*,
retry_max_attempts: int = 3,
retry_backoff_base_seconds: float = 1.0,
**kwargs: Any,
) -> None:
super().__init__(**kwargs)
self.retry_max_attempts = max(1, int(retry_max_attempts))
self.retry_backoff_base_seconds = max(0.0, float(retry_backoff_base_seconds))
async def _post_ocr(self, file_path: Path | str) -> dict:
path = Path(file_path)
if not path.exists():
raise FileNotFoundError(f"File not found: {path}")
last_error: Exception | None = None
for attempt in range(self.retry_max_attempts):
try:
async with httpx.AsyncClient(timeout=self.timeout, verify=False) as client:
with open(path, "rb") as file_obj:
files = {"file": (path.name, file_obj, _guess_mime(path))}
data = {"include_images": str(self.include_images).lower()}
response = await client.post(
f"{self.base_url}/ocr",
files=files,
data=data,
)
except self._RETRYABLE_EXCEPTIONS as exc:
last_error = exc
if attempt < self.retry_max_attempts - 1:
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
continue
raise ChandraOCRError(f"OCR request failed after retries: {exc}") from exc
if _should_retry_status(response.status_code) and attempt < self.retry_max_attempts - 1:
await asyncio.sleep(_retry_delay(self.retry_backoff_base_seconds, attempt))
continue
if response.status_code != 200:
raise ChandraOCRError(f"OCR API returned {response.status_code}: {response.text[:500]}")
body = response.json()
if "error" in body:
raise ChandraOCRError(f"OCR error: {body['error']}")
return body
raise ChandraOCRError(f"OCR request failed after retries: {last_error}")
async def ocr(self, file_path: Path | str):
body = await self._post_ocr(file_path)
result = parse_ocr_response(body)
if self.vlm_client is not None and (
result.visual_manifest.seals
or result.visual_manifest.signatures
or result.visual_manifest.cross_page_seals
):
try:
from leaudit.ocr.visual_classifier import refine_visual_manifest
await refine_visual_manifest(result, self.vlm_client, concurrency=self.vlm_concurrency)
except Exception as exc: # noqa: BLE001
log.warning("VLM refinement failed, keeping geometric manifest: %s", exc)
return result
async def ocr_raw(self, file_path: Path | str) -> dict:
return await self._post_ocr(file_path)
@@ -3,7 +3,6 @@
from __future__ import annotations
import asyncio
from concurrent.futures import ThreadPoolExecutor
import os
from pathlib import Path
import tempfile
@@ -12,7 +11,12 @@ from typing import Any, Dict, Optional
from fastapi_common.fastapi_common_logger import logger
from fastapi_admin.config import LEAUDIT_RULES_DIR
from fastapi_admin.celery_app import celery_app
from fastapi_admin.config import (
LEAUDIT_RULES_DIR,
LEAUDIT_WORKER_QUEUE_NORMAL,
LEAUDIT_WORKER_QUEUE_URGENT,
)
from fastapi_modules.fastapi_leaudit.leaudit_bridge.nativeRunner import (
NativeRunRequest,
NativeRunner,
@@ -20,19 +24,23 @@ from fastapi_modules.fastapi_leaudit.leaudit_bridge.nativeRunner import (
from fastapi_modules.fastapi_leaudit.leaudit_bridge.auditCtxBuilder import (
NativeAuditMetadata,
)
from fastapi_modules.fastapi_leaudit.leaudit_bridge.fileSourceResolver import (
FileSourceResolver,
)
from fastapi_modules.fastapi_leaudit.leaudit_bridge.ruleVersionResolver import (
RuleVersionResolver,
)
from fastapi_modules.fastapi_leaudit.leaudit_bridge.rules_loader import RulesLoader
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
# Celery 集成待 P2 阶段实现,当前使用同步占位
# from core.celery_app_limited import celery_app
from fastapi_modules.fastapi_leaudit.models import (
LeauditAuditRun,
LeauditDocument,
LeauditDocumentFile,
)
log = logger
# P2: Celery 集成后启用 @celery_app.task 装饰器
def leaudit_process_document(
document_id: int,
file_content: bytes,
@@ -160,6 +168,65 @@ def leaudit_process_document(
loop.close()
def leaudit_process_document_by_run(
run_id: int,
*,
task_id: str | None = None,
rules_path: str | None = None,
queue_name: str | None = None,
) -> dict[str, Any]:
"""按 runId 加载执行上下文并执行原生 leaudit。"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
claimed = loop.run_until_complete(_claim_run_safe(run_id, task_id))
if not claimed:
log.warning("run_id=%s 未抢占成功,跳过重复消费", run_id)
return {"status": "skipped", "run_id": run_id, "reason": "already_claimed"}
context = loop.run_until_complete(_load_run_context(run_id))
log.info(
"run_id=%s worker开始执行: queue=%s, speed=%s, filename=%s",
run_id,
queue_name or resolve_worker_queue(context.get("trigger_source")),
_queue_label(queue_name or resolve_worker_queue(context.get("trigger_source"))),
context["filename"],
)
return leaudit_process_document(
document_id=context["document_id"],
file_content=context["file_content"],
filename=context["filename"],
upload_info={
"run_id": run_id,
"rule_version_id": context["rule_version_id"],
"rule_source_oss_url": context["rule_source_oss_url"],
"source_type": context["source_type"],
"source_path": context["source_path"],
"trigger_source": context["trigger_source"],
},
rules_path=rules_path,
)
finally:
loop.close()
@celery_app.task(
bind=True,
name="leaudit.process_document",
acks_late=True,
)
def leaudit_process_document_task(self, run_id: int, rules_path: str | None = None) -> dict[str, Any]:
"""Celery worker 入口 —— 按 runId 执行评查。"""
delivery_info = getattr(self.request, "delivery_info", {}) or {}
queue_name = delivery_info.get("routing_key") or delivery_info.get("queue")
return leaudit_process_document_by_run(
run_id=run_id,
task_id=self.request.id,
rules_path=rules_path,
queue_name=queue_name,
)
# type_id → rules directory mapping (only fixed-mapping types)
# 行政许可 (type_id=2) has 9 sub-types, NOT mapped here —
# must come from document metadata (rules_file_path) or content classification.
@@ -326,6 +393,69 @@ async def _update_run_phase_safe(run_id: int, phase: str | None) -> None:
pass
async def _claim_run_safe(run_id: int, task_id: str | None) -> bool:
"""原子抢占 queued/pending 运行,避免重复消费。"""
try:
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from sqlalchemy import text as sa_text
async with GetAsyncSession() as session:
result = await session.execute(
sa_text(
"""
UPDATE leaudit_audit_runs
SET status = 'running',
phase = 'prepare',
task_id = COALESCE(:task_id, task_id),
started_at = COALESCE(started_at, now()),
updated_at = now()
WHERE id = :rid
AND status IN ('queued', 'pending', 'retrying')
RETURNING id
"""
),
{"rid": run_id, "task_id": task_id},
)
row = result.fetchone()
await session.commit()
return row is not None
except Exception:
log.exception("run_id=%s 抢占执行权失败", run_id)
return False
async def _load_run_context(run_id: int) -> dict[str, Any]:
"""按 runId 加载执行所需文档文件上下文。"""
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
async with GetAsyncSession() as session:
run = await session.get(LeauditAuditRun, run_id)
if not run:
raise ValueError(f"未找到 run_id={run_id} 对应的运行记录")
document = await session.get(LeauditDocument, run.documentId)
if not document:
raise ValueError(f"未找到 document_id={run.documentId} 对应的文档记录")
document_file = await session.get(LeauditDocumentFile, run.documentFileId)
if not document_file:
raise ValueError(f"未找到 document_file_id={run.documentFileId} 对应的文件记录")
resolver = FileSourceResolver()
payload = await resolver.ResolvePayload(document_file)
return {
"document_id": document.Id,
"filename": payload.fileName,
"file_content": payload.fileContent,
"source_type": payload.sourceType,
"source_path": payload.sourcePath,
"rule_version_id": run.ruleVersionId,
"rule_source_oss_url": run.ruleSourceOssUrl,
"trigger_source": run.triggerSource,
}
def _get_suffix(filename: str) -> str:
"""Extract file suffix from filename."""
_, ext = os.path.splitext(filename)
@@ -333,32 +463,37 @@ def _get_suffix(filename: str) -> str:
def dispatch_leaudit_task(
document_id: int,
file_content: bytes,
filename: str,
upload_info: Optional[Dict[str, Any]] = None,
source_port: Optional[int] = None,
run_id: int,
*,
queue_name: str | None = None,
rules_path: Optional[str] = None,
):
"""Dispatch a leaudit processing task.
) -> str:
"""投递 runId 到 Celery worker 队列。"""
target_queue = queue_name or LEAUDIT_WORKER_QUEUE_NORMAL
task = leaudit_process_document_task.apply_async(
kwargs={"run_id": run_id, "rules_path": rules_path},
queue=target_queue,
)
log.info(
"run_id=%s 已投递到 worker 队列: queue=%s, speed=%s, task_id=%s",
run_id,
target_queue,
_queue_label(target_queue),
task.id,
)
return task.id
P2: Celery 集成后改用 leaudit_process_document.apply_async(...)
当前阶段直接同步调用。
"""
kwargs = {
"document_id": document_id,
"file_content": file_content,
"filename": filename,
"upload_info": upload_info,
"source_port": source_port or int(os.getenv("APP_PORT", "8000")),
"rules_path": rules_path,
}
try:
asyncio.get_running_loop()
except RuntimeError:
return leaudit_process_document(**kwargs)
def resolve_worker_queue(trigger_source: str | None) -> str:
"""按触发来源选择 worker 队列。"""
normalized = (trigger_source or "").strip().lower()
if "urgent" in normalized or "high" in normalized:
return LEAUDIT_WORKER_QUEUE_URGENT
return LEAUDIT_WORKER_QUEUE_NORMAL
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(leaudit_process_document, **kwargs)
return future.result()
def _queue_label(queue_name: str | None) -> str:
"""Map queue name to a user-facing speed label for logs."""
if queue_name == LEAUDIT_WORKER_QUEUE_URGENT:
return "urgent"
return "normal"
@@ -9,7 +9,13 @@ class IAuditService(ABC):
"""评查服务接口。"""
@abstractmethod
async def Run(self, DocumentId: int, RuleType: str | None = None, Force: bool = False) -> AuditRunVO:
async def Run(
self,
DocumentId: int,
RuleType: str | None = None,
Force: bool = False,
Speed: str = "normal",
) -> AuditRunVO:
"""触发文档评查。"""
...
@@ -21,8 +21,10 @@ from fastapi_modules.fastapi_leaudit.domian.vo.auditVo import (
AuditRunErrorVO,
AuditRunVO,
)
from fastapi_modules.fastapi_leaudit.leaudit_bridge.fileSourceResolver import FileSourceResolver
from fastapi_modules.fastapi_leaudit.leaudit_bridge.tasks import dispatch_leaudit_task
from fastapi_modules.fastapi_leaudit.leaudit_bridge.tasks import (
dispatch_leaudit_task,
resolve_worker_queue,
)
from fastapi_modules.fastapi_leaudit.models import (
LeauditAuditRun,
LeauditDocument,
@@ -31,20 +33,67 @@ from fastapi_modules.fastapi_leaudit.models import (
from fastapi_modules.fastapi_leaudit.services import IAuditService
def _normalize_speed(speed: str | None) -> str:
"""Normalize front-end speed selection to urgent/normal."""
normalized = (speed or "").strip().lower()
if normalized in {"urgent", "high", "fast", "emergency", "紧急"}:
return "urgent"
return "normal"
class AuditServiceImpl(IAuditService):
"""评查服务实现。"""
async def Run(self, DocumentId: int, RuleType: str | None = None, Force: bool = False) -> AuditRunVO:
async def Run(
self,
DocumentId: int,
RuleType: str | None = None,
Force: bool = False,
Speed: str = "normal",
) -> AuditRunVO:
"""触发文档评查。
当前阶段同步触发 bridge 执行链,后续再切换为 Celery 异步分发
当前阶段只负责创建 run 并投递 worker,不在 HTTP 请求内同步执行
"""
async with GetAsyncSession() as session:
logger.info(f"触发评查: documentId={DocumentId}, ruleType={RuleType}")
normalizedSpeed = _normalize_speed(Speed)
document = await session.get(LeauditDocument, DocumentId)
if not document:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "评查文档不存在")
if not Force:
activeRunResult = await session.execute(
select(LeauditAuditRun)
.where(
LeauditAuditRun.documentId == DocumentId,
LeauditAuditRun.status.in_(("queued", "running", "retrying")),
)
.order_by(LeauditAuditRun.Id.desc())
.limit(1)
)
activeRun = activeRunResult.scalar_one_or_none()
if activeRun:
return AuditRunVO(
runId=activeRun.Id,
documentId=activeRun.documentId,
runNo=activeRun.runNo,
documentFileId=activeRun.documentFileId,
status=activeRun.status,
phase=activeRun.phase,
resultStatus=activeRun.resultStatus,
ruleSetId=activeRun.ruleSetId,
ruleVersionId=activeRun.ruleVersionId,
ruleTypeId=activeRun.ruleTypeId,
rescueApplied=activeRun.rescueApplied or False,
totalScore=float(activeRun.totalScore) if activeRun.totalScore else None,
passedCount=activeRun.passedCount,
failedCount=activeRun.failedCount,
skippedCount=activeRun.skippedCount,
startedAt=activeRun.startedAt,
finishedAt=activeRun.finishedAt,
)
fileResult = await session.execute(
select(LeauditDocumentFile)
.where(
@@ -91,49 +140,48 @@ class AuditServiceImpl(IAuditService):
if not binding or not binding["rule_set_id"] or not binding["rule_version_id"]:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前文档类型未绑定可用规则版本")
triggerSource = f"{'retry' if Force else 'upload'}:{normalizedSpeed}"
run = LeauditAuditRun(
documentId=DocumentId,
documentFileId=documentFile.Id,
runNo=int(latestRunNo) + 1,
triggerSource="manual" if not Force else "retry",
status="pending",
triggerSource=triggerSource,
status="queued",
phase="dispatch",
ruleSetId=int(binding["rule_set_id"]),
ruleVersionId=int(binding["rule_version_id"]),
ruleTypeId=binding["rule_type_id"],
ruleSourceOssUrl=binding["rule_source_oss_url"],
ruleSourceSha256=binding["rule_source_sha256"],
startedAt=datetime.now(),
)
session.add(run)
await session.flush()
document.currentRunId = run.Id
document.processingStatus = "running"
document.processingStatus = "queued"
await session.commit()
await session.refresh(run)
try:
Resolver = FileSourceResolver()
Payload = await Resolver.ResolvePayload(documentFile)
taskId = dispatch_leaudit_task(
run_id=run.Id,
queue_name=resolve_worker_queue(triggerSource),
rules_path=RuleType,
)
except Exception as Error:
run.status = "failed"
run.phase = "dispatch"
run.finishedAt = datetime.now()
document.processingStatus = "failed"
await session.commit()
raise LeauditException(
StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR,
f"读取评查文件失败: {Error}",
f"投递评查任务失败: {Error}",
) from Error
dispatch_leaudit_task(
document_id=DocumentId,
file_content=Payload.fileContent,
filename=Payload.fileName,
upload_info={
"run_id": run.Id,
"rule_version_id": run.ruleVersionId,
"rule_source_oss_url": run.ruleSourceOssUrl,
"source_type": Payload.sourceType,
"source_path": Payload.sourcePath,
},
rules_path=RuleType,
)
run.taskId = taskId
await session.commit()
await session.refresh(run)
return AuditRunVO(