feat: add async worker queues and retry controls

This commit is contained in:
wren
2026-04-29 11:48:09 +08:00
parent e738398eb6
commit f3b83c9979
16 changed files with 1316 additions and 96 deletions
@@ -3,7 +3,6 @@
from __future__ import annotations
import asyncio
from concurrent.futures import ThreadPoolExecutor
import os
from pathlib import Path
import tempfile
@@ -12,7 +11,12 @@ from typing import Any, Dict, Optional
from fastapi_common.fastapi_common_logger import logger
from fastapi_admin.config import LEAUDIT_RULES_DIR
from fastapi_admin.celery_app import celery_app
from fastapi_admin.config import (
LEAUDIT_RULES_DIR,
LEAUDIT_WORKER_QUEUE_NORMAL,
LEAUDIT_WORKER_QUEUE_URGENT,
)
from fastapi_modules.fastapi_leaudit.leaudit_bridge.nativeRunner import (
NativeRunRequest,
NativeRunner,
@@ -20,19 +24,23 @@ from fastapi_modules.fastapi_leaudit.leaudit_bridge.nativeRunner import (
from fastapi_modules.fastapi_leaudit.leaudit_bridge.auditCtxBuilder import (
NativeAuditMetadata,
)
from fastapi_modules.fastapi_leaudit.leaudit_bridge.fileSourceResolver import (
FileSourceResolver,
)
from fastapi_modules.fastapi_leaudit.leaudit_bridge.ruleVersionResolver import (
RuleVersionResolver,
)
from fastapi_modules.fastapi_leaudit.leaudit_bridge.rules_loader import RulesLoader
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
# Celery 集成待 P2 阶段实现,当前使用同步占位
# from core.celery_app_limited import celery_app
from fastapi_modules.fastapi_leaudit.models import (
LeauditAuditRun,
LeauditDocument,
LeauditDocumentFile,
)
log = logger
# P2: Celery 集成后启用 @celery_app.task 装饰器
def leaudit_process_document(
document_id: int,
file_content: bytes,
@@ -160,6 +168,65 @@ def leaudit_process_document(
loop.close()
def leaudit_process_document_by_run(
run_id: int,
*,
task_id: str | None = None,
rules_path: str | None = None,
queue_name: str | None = None,
) -> dict[str, Any]:
"""按 runId 加载执行上下文并执行原生 leaudit。"""
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
claimed = loop.run_until_complete(_claim_run_safe(run_id, task_id))
if not claimed:
log.warning("run_id=%s 未抢占成功,跳过重复消费", run_id)
return {"status": "skipped", "run_id": run_id, "reason": "already_claimed"}
context = loop.run_until_complete(_load_run_context(run_id))
log.info(
"run_id=%s worker开始执行: queue=%s, speed=%s, filename=%s",
run_id,
queue_name or resolve_worker_queue(context.get("trigger_source")),
_queue_label(queue_name or resolve_worker_queue(context.get("trigger_source"))),
context["filename"],
)
return leaudit_process_document(
document_id=context["document_id"],
file_content=context["file_content"],
filename=context["filename"],
upload_info={
"run_id": run_id,
"rule_version_id": context["rule_version_id"],
"rule_source_oss_url": context["rule_source_oss_url"],
"source_type": context["source_type"],
"source_path": context["source_path"],
"trigger_source": context["trigger_source"],
},
rules_path=rules_path,
)
finally:
loop.close()
@celery_app.task(
bind=True,
name="leaudit.process_document",
acks_late=True,
)
def leaudit_process_document_task(self, run_id: int, rules_path: str | None = None) -> dict[str, Any]:
"""Celery worker 入口 —— 按 runId 执行评查。"""
delivery_info = getattr(self.request, "delivery_info", {}) or {}
queue_name = delivery_info.get("routing_key") or delivery_info.get("queue")
return leaudit_process_document_by_run(
run_id=run_id,
task_id=self.request.id,
rules_path=rules_path,
queue_name=queue_name,
)
# type_id → rules directory mapping (only fixed-mapping types)
# 行政许可 (type_id=2) has 9 sub-types, NOT mapped here —
# must come from document metadata (rules_file_path) or content classification.
@@ -326,6 +393,69 @@ async def _update_run_phase_safe(run_id: int, phase: str | None) -> None:
pass
async def _claim_run_safe(run_id: int, task_id: str | None) -> bool:
"""原子抢占 queued/pending 运行,避免重复消费。"""
try:
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from sqlalchemy import text as sa_text
async with GetAsyncSession() as session:
result = await session.execute(
sa_text(
"""
UPDATE leaudit_audit_runs
SET status = 'running',
phase = 'prepare',
task_id = COALESCE(:task_id, task_id),
started_at = COALESCE(started_at, now()),
updated_at = now()
WHERE id = :rid
AND status IN ('queued', 'pending', 'retrying')
RETURNING id
"""
),
{"rid": run_id, "task_id": task_id},
)
row = result.fetchone()
await session.commit()
return row is not None
except Exception:
log.exception("run_id=%s 抢占执行权失败", run_id)
return False
async def _load_run_context(run_id: int) -> dict[str, Any]:
"""按 runId 加载执行所需文档文件上下文。"""
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
async with GetAsyncSession() as session:
run = await session.get(LeauditAuditRun, run_id)
if not run:
raise ValueError(f"未找到 run_id={run_id} 对应的运行记录")
document = await session.get(LeauditDocument, run.documentId)
if not document:
raise ValueError(f"未找到 document_id={run.documentId} 对应的文档记录")
document_file = await session.get(LeauditDocumentFile, run.documentFileId)
if not document_file:
raise ValueError(f"未找到 document_file_id={run.documentFileId} 对应的文件记录")
resolver = FileSourceResolver()
payload = await resolver.ResolvePayload(document_file)
return {
"document_id": document.Id,
"filename": payload.fileName,
"file_content": payload.fileContent,
"source_type": payload.sourceType,
"source_path": payload.sourcePath,
"rule_version_id": run.ruleVersionId,
"rule_source_oss_url": run.ruleSourceOssUrl,
"trigger_source": run.triggerSource,
}
def _get_suffix(filename: str) -> str:
"""Extract file suffix from filename."""
_, ext = os.path.splitext(filename)
@@ -333,32 +463,37 @@ def _get_suffix(filename: str) -> str:
def dispatch_leaudit_task(
document_id: int,
file_content: bytes,
filename: str,
upload_info: Optional[Dict[str, Any]] = None,
source_port: Optional[int] = None,
run_id: int,
*,
queue_name: str | None = None,
rules_path: Optional[str] = None,
):
"""Dispatch a leaudit processing task.
) -> str:
"""投递 runId 到 Celery worker 队列。"""
target_queue = queue_name or LEAUDIT_WORKER_QUEUE_NORMAL
task = leaudit_process_document_task.apply_async(
kwargs={"run_id": run_id, "rules_path": rules_path},
queue=target_queue,
)
log.info(
"run_id=%s 已投递到 worker 队列: queue=%s, speed=%s, task_id=%s",
run_id,
target_queue,
_queue_label(target_queue),
task.id,
)
return task.id
P2: Celery 集成后改用 leaudit_process_document.apply_async(...)
当前阶段直接同步调用。
"""
kwargs = {
"document_id": document_id,
"file_content": file_content,
"filename": filename,
"upload_info": upload_info,
"source_port": source_port or int(os.getenv("APP_PORT", "8000")),
"rules_path": rules_path,
}
try:
asyncio.get_running_loop()
except RuntimeError:
return leaudit_process_document(**kwargs)
def resolve_worker_queue(trigger_source: str | None) -> str:
"""按触发来源选择 worker 队列。"""
normalized = (trigger_source or "").strip().lower()
if "urgent" in normalized or "high" in normalized:
return LEAUDIT_WORKER_QUEUE_URGENT
return LEAUDIT_WORKER_QUEUE_NORMAL
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(leaudit_process_document, **kwargs)
return future.result()
def _queue_label(queue_name: str | None) -> str:
"""Map queue name to a user-facing speed label for logs."""
if queue_name == LEAUDIT_WORKER_QUEUE_URGENT:
return "urgent"
return "normal"