feat: add async worker queues and retry controls
This commit is contained in:
@@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
@@ -12,7 +11,12 @@ from typing import Any, Dict, Optional
|
||||
|
||||
from fastapi_common.fastapi_common_logger import logger
|
||||
|
||||
from fastapi_admin.config import LEAUDIT_RULES_DIR
|
||||
from fastapi_admin.celery_app import celery_app
|
||||
from fastapi_admin.config import (
|
||||
LEAUDIT_RULES_DIR,
|
||||
LEAUDIT_WORKER_QUEUE_NORMAL,
|
||||
LEAUDIT_WORKER_QUEUE_URGENT,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.nativeRunner import (
|
||||
NativeRunRequest,
|
||||
NativeRunner,
|
||||
@@ -20,19 +24,23 @@ from fastapi_modules.fastapi_leaudit.leaudit_bridge.nativeRunner import (
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.auditCtxBuilder import (
|
||||
NativeAuditMetadata,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.fileSourceResolver import (
|
||||
FileSourceResolver,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.ruleVersionResolver import (
|
||||
RuleVersionResolver,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.rules_loader import RulesLoader
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
|
||||
|
||||
# Celery 集成待 P2 阶段实现,当前使用同步占位
|
||||
# from core.celery_app_limited import celery_app
|
||||
from fastapi_modules.fastapi_leaudit.models import (
|
||||
LeauditAuditRun,
|
||||
LeauditDocument,
|
||||
LeauditDocumentFile,
|
||||
)
|
||||
|
||||
log = logger
|
||||
|
||||
|
||||
# P2: Celery 集成后启用 @celery_app.task 装饰器
|
||||
def leaudit_process_document(
|
||||
document_id: int,
|
||||
file_content: bytes,
|
||||
@@ -160,6 +168,65 @@ def leaudit_process_document(
|
||||
loop.close()
|
||||
|
||||
|
||||
def leaudit_process_document_by_run(
|
||||
run_id: int,
|
||||
*,
|
||||
task_id: str | None = None,
|
||||
rules_path: str | None = None,
|
||||
queue_name: str | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""按 runId 加载执行上下文并执行原生 leaudit。"""
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
try:
|
||||
claimed = loop.run_until_complete(_claim_run_safe(run_id, task_id))
|
||||
if not claimed:
|
||||
log.warning("run_id=%s 未抢占成功,跳过重复消费", run_id)
|
||||
return {"status": "skipped", "run_id": run_id, "reason": "already_claimed"}
|
||||
|
||||
context = loop.run_until_complete(_load_run_context(run_id))
|
||||
log.info(
|
||||
"run_id=%s worker开始执行: queue=%s, speed=%s, filename=%s",
|
||||
run_id,
|
||||
queue_name or resolve_worker_queue(context.get("trigger_source")),
|
||||
_queue_label(queue_name or resolve_worker_queue(context.get("trigger_source"))),
|
||||
context["filename"],
|
||||
)
|
||||
return leaudit_process_document(
|
||||
document_id=context["document_id"],
|
||||
file_content=context["file_content"],
|
||||
filename=context["filename"],
|
||||
upload_info={
|
||||
"run_id": run_id,
|
||||
"rule_version_id": context["rule_version_id"],
|
||||
"rule_source_oss_url": context["rule_source_oss_url"],
|
||||
"source_type": context["source_type"],
|
||||
"source_path": context["source_path"],
|
||||
"trigger_source": context["trigger_source"],
|
||||
},
|
||||
rules_path=rules_path,
|
||||
)
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
bind=True,
|
||||
name="leaudit.process_document",
|
||||
acks_late=True,
|
||||
)
|
||||
def leaudit_process_document_task(self, run_id: int, rules_path: str | None = None) -> dict[str, Any]:
|
||||
"""Celery worker 入口 —— 按 runId 执行评查。"""
|
||||
delivery_info = getattr(self.request, "delivery_info", {}) or {}
|
||||
queue_name = delivery_info.get("routing_key") or delivery_info.get("queue")
|
||||
return leaudit_process_document_by_run(
|
||||
run_id=run_id,
|
||||
task_id=self.request.id,
|
||||
rules_path=rules_path,
|
||||
queue_name=queue_name,
|
||||
)
|
||||
|
||||
|
||||
# type_id → rules directory mapping (only fixed-mapping types)
|
||||
# 行政许可 (type_id=2) has 9 sub-types, NOT mapped here —
|
||||
# must come from document metadata (rules_file_path) or content classification.
|
||||
@@ -326,6 +393,69 @@ async def _update_run_phase_safe(run_id: int, phase: str | None) -> None:
|
||||
pass
|
||||
|
||||
|
||||
async def _claim_run_safe(run_id: int, task_id: str | None) -> bool:
|
||||
"""原子抢占 queued/pending 运行,避免重复消费。"""
|
||||
try:
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from sqlalchemy import text as sa_text
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
result = await session.execute(
|
||||
sa_text(
|
||||
"""
|
||||
UPDATE leaudit_audit_runs
|
||||
SET status = 'running',
|
||||
phase = 'prepare',
|
||||
task_id = COALESCE(:task_id, task_id),
|
||||
started_at = COALESCE(started_at, now()),
|
||||
updated_at = now()
|
||||
WHERE id = :rid
|
||||
AND status IN ('queued', 'pending', 'retrying')
|
||||
RETURNING id
|
||||
"""
|
||||
),
|
||||
{"rid": run_id, "task_id": task_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
await session.commit()
|
||||
return row is not None
|
||||
except Exception:
|
||||
log.exception("run_id=%s 抢占执行权失败", run_id)
|
||||
return False
|
||||
|
||||
|
||||
async def _load_run_context(run_id: int) -> dict[str, Any]:
|
||||
"""按 runId 加载执行所需文档文件上下文。"""
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
run = await session.get(LeauditAuditRun, run_id)
|
||||
if not run:
|
||||
raise ValueError(f"未找到 run_id={run_id} 对应的运行记录")
|
||||
|
||||
document = await session.get(LeauditDocument, run.documentId)
|
||||
if not document:
|
||||
raise ValueError(f"未找到 document_id={run.documentId} 对应的文档记录")
|
||||
|
||||
document_file = await session.get(LeauditDocumentFile, run.documentFileId)
|
||||
if not document_file:
|
||||
raise ValueError(f"未找到 document_file_id={run.documentFileId} 对应的文件记录")
|
||||
|
||||
resolver = FileSourceResolver()
|
||||
payload = await resolver.ResolvePayload(document_file)
|
||||
|
||||
return {
|
||||
"document_id": document.Id,
|
||||
"filename": payload.fileName,
|
||||
"file_content": payload.fileContent,
|
||||
"source_type": payload.sourceType,
|
||||
"source_path": payload.sourcePath,
|
||||
"rule_version_id": run.ruleVersionId,
|
||||
"rule_source_oss_url": run.ruleSourceOssUrl,
|
||||
"trigger_source": run.triggerSource,
|
||||
}
|
||||
|
||||
|
||||
def _get_suffix(filename: str) -> str:
|
||||
"""Extract file suffix from filename."""
|
||||
_, ext = os.path.splitext(filename)
|
||||
@@ -333,32 +463,37 @@ def _get_suffix(filename: str) -> str:
|
||||
|
||||
|
||||
def dispatch_leaudit_task(
|
||||
document_id: int,
|
||||
file_content: bytes,
|
||||
filename: str,
|
||||
upload_info: Optional[Dict[str, Any]] = None,
|
||||
source_port: Optional[int] = None,
|
||||
run_id: int,
|
||||
*,
|
||||
queue_name: str | None = None,
|
||||
rules_path: Optional[str] = None,
|
||||
):
|
||||
"""Dispatch a leaudit processing task.
|
||||
) -> str:
|
||||
"""投递 runId 到 Celery worker 队列。"""
|
||||
target_queue = queue_name or LEAUDIT_WORKER_QUEUE_NORMAL
|
||||
task = leaudit_process_document_task.apply_async(
|
||||
kwargs={"run_id": run_id, "rules_path": rules_path},
|
||||
queue=target_queue,
|
||||
)
|
||||
log.info(
|
||||
"run_id=%s 已投递到 worker 队列: queue=%s, speed=%s, task_id=%s",
|
||||
run_id,
|
||||
target_queue,
|
||||
_queue_label(target_queue),
|
||||
task.id,
|
||||
)
|
||||
return task.id
|
||||
|
||||
P2: Celery 集成后改用 leaudit_process_document.apply_async(...)
|
||||
当前阶段直接同步调用。
|
||||
"""
|
||||
kwargs = {
|
||||
"document_id": document_id,
|
||||
"file_content": file_content,
|
||||
"filename": filename,
|
||||
"upload_info": upload_info,
|
||||
"source_port": source_port or int(os.getenv("APP_PORT", "8000")),
|
||||
"rules_path": rules_path,
|
||||
}
|
||||
|
||||
try:
|
||||
asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
return leaudit_process_document(**kwargs)
|
||||
def resolve_worker_queue(trigger_source: str | None) -> str:
|
||||
"""按触发来源选择 worker 队列。"""
|
||||
normalized = (trigger_source or "").strip().lower()
|
||||
if "urgent" in normalized or "high" in normalized:
|
||||
return LEAUDIT_WORKER_QUEUE_URGENT
|
||||
return LEAUDIT_WORKER_QUEUE_NORMAL
|
||||
|
||||
with ThreadPoolExecutor(max_workers=1) as executor:
|
||||
future = executor.submit(leaudit_process_document, **kwargs)
|
||||
return future.result()
|
||||
|
||||
def _queue_label(queue_name: str | None) -> str:
|
||||
"""Map queue name to a user-facing speed label for logs."""
|
||||
if queue_name == LEAUDIT_WORKER_QUEUE_URGENT:
|
||||
return "urgent"
|
||||
return "normal"
|
||||
|
||||
Reference in New Issue
Block a user