"""LeAudit 任务入口。""" from __future__ import annotations import asyncio import os from pathlib import Path import tempfile import time from typing import Any, Dict, Optional from fastapi_common.fastapi_common_logger import logger from fastapi_admin.celery_app import celery_app from fastapi_admin.config import ( LEAUDIT_RULES_DIR, LEAUDIT_WORKER_QUEUE_NORMAL, LEAUDIT_WORKER_QUEUE_URGENT, ) from fastapi_modules.fastapi_leaudit.leaudit_bridge.nativeRunner import ( NativeRunRequest, NativeRunner, ) from fastapi_modules.fastapi_leaudit.leaudit_bridge.auditCtxBuilder import ( NativeAuditMetadata, ) from fastapi_modules.fastapi_leaudit.leaudit_bridge.fileSourceResolver import ( FileSourceResolver, ) from fastapi_modules.fastapi_leaudit.leaudit_bridge.ruleVersionResolver import ( RuleVersionResolver, ) from fastapi_modules.fastapi_leaudit.leaudit_bridge.rules_loader import RulesLoader from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter from fastapi_modules.fastapi_leaudit.models import ( LeauditAuditRun, LeauditDocument, LeauditDocumentFile, ) log = logger def leaudit_process_document( document_id: int, file_content: bytes, filename: str, upload_info: Optional[Dict[str, Any]] = None, source_port: Optional[int] = None, rules_path: Optional[str] = None, ): """处理单个文档的 LeAudit 任务。""" task_id = os.urandom(8).hex() log.info(f"[任务ID: {task_id}] leaudit管线开始处理: {filename}") # 新平台:region 通过参数传递,不再依赖 os.environ 切换 if source_port: log.info(f"[任务ID: {task_id}] 来源端口: {source_port}") if upload_info is None: upload_info = {} loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) temp_paths: list[str] = [] storage = StorageAdapter() try: run_id = _resolve_run_id(document_id, upload_info, loop) rules_resolution = _resolve_rules_runtime(document_id, run_id, rules_path, loop) loop.run_until_complete(_update_run_status_safe(run_id, "running")) rules_path_resolved = rules_resolution["rules_path"] rules_file = None if rules_path_resolved: loader = RulesLoader() rules_file = loader.load(rules_path_resolved) temp_rule_path = rules_resolution.get("temp_rule_path") if isinstance(temp_rule_path, str): temp_paths.append(temp_rule_path) log.info( f"[任务ID: {task_id}] RulesFile pre-loaded: {rules_path_resolved} " f"({len(rules_file.flat_rules)} rules, {len(rules_file.flat_extract)} fields)" ) else: log.info( f"[任务ID: {task_id}] No fixed rules_path — " "will classify from document content after OCR" ) suffix = _get_suffix(filename) with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp: temp.write(file_content) temp_path = temp.name temp_paths.append(temp_path) runner = NativeRunner() t0 = time.time() native_result = loop.run_until_complete( runner.run( NativeRunRequest( metadata=NativeAuditMetadata( run_id=run_id, document_id=document_id, rule_version_id=_optional_int(upload_info, "rule_version_id", "ruleVersionId"), extras={"taskId": task_id}, ), local_file_path=temp_path, rules_file=rules_file, rule_source_path=rules_path_resolved, ) ) ) loop.run_until_complete(runner.persist_result(native_result)) elapsed = round(time.time() - t0, 2) ctx = native_result.ctx loop.run_until_complete(_update_run_phase_safe(run_id, ctx.phase)) loop.run_until_complete(_update_run_status_safe(run_id, "completed")) loop.run_until_complete(_update_status_safe(document_id, "completed")) log.info( f"[任务ID: {task_id}] leaudit管线完成: phase={ctx.phase}, " f"timing={dict(ctx.timing)}, 总耗时={elapsed:.1f}s" ) return { "status": "success", "document_id": document_id, "run_id": run_id, "phase": ctx.phase, "timing": dict(ctx.timing), "errors": list(ctx.extraction.all_errors) if ctx.extraction is not None else list(ctx.extraction_errors), } except Exception as e: log.error(f"[任务ID: {task_id}] leaudit管线失败: {e}", exc_info=True) try: loop.run_until_complete(_update_status_safe(document_id, "failed")) if 'run_id' in locals(): failed_phase = "persist" if "native_result" in locals(): failed_phase = native_result.ctx.phase or failed_phase loop.run_until_complete( storage.fail_run( document_id, run_id=run_id, phase=failed_phase, message=str(e), detail_json={ "taskId": task_id, "filename": filename, "errorType": type(e).__name__, }, ) ) except Exception: pass raise finally: for temp_path in temp_paths: try: if Path(temp_path).exists(): os.remove(temp_path) except OSError: pass loop.close() def leaudit_process_document_by_run( run_id: int, *, task_id: str | None = None, rules_path: str | None = None, queue_name: str | None = None, ) -> dict[str, Any]: """按 runId 加载执行上下文并执行原生 leaudit。""" loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: claimed = loop.run_until_complete(_claim_run_safe(run_id, task_id)) if not claimed: log.warning("run_id=%s 未抢占成功,跳过重复消费", run_id) return {"status": "skipped", "run_id": run_id, "reason": "already_claimed"} context = loop.run_until_complete(_load_run_context(run_id)) log.info( "run_id=%s worker开始执行: queue=%s, speed=%s, filename=%s", run_id, queue_name or resolve_worker_queue(context.get("trigger_source")), _queue_label(queue_name or resolve_worker_queue(context.get("trigger_source"))), context["filename"], ) return leaudit_process_document( document_id=context["document_id"], file_content=context["file_content"], filename=context["filename"], upload_info={ "run_id": run_id, "rule_version_id": context["rule_version_id"], "rule_source_oss_url": context["rule_source_oss_url"], "source_type": context["source_type"], "source_path": context["source_path"], "trigger_source": context["trigger_source"], }, rules_path=rules_path, ) finally: loop.close() @celery_app.task( bind=True, name="leaudit.process_document", acks_late=True, ) def leaudit_process_document_task(self, run_id: int, rules_path: str | None = None) -> dict[str, Any]: """Celery worker 入口 —— 按 runId 执行评查。""" delivery_info = getattr(self.request, "delivery_info", {}) or {} queue_name = delivery_info.get("routing_key") or delivery_info.get("queue") return leaudit_process_document_by_run( run_id=run_id, task_id=self.request.id, rules_path=rules_path, queue_name=queue_name, ) # type_id → rules directory mapping (only fixed-mapping types) # 行政许可 (type_id=2) has 9 sub-types, NOT mapped here — # must come from document metadata (rules_file_path) or content classification. _TYPE_ID_RULES_MAP: dict[int, str] = { 3: "行政处罚", } def _resolve_rules_path(document_id: int, loop: asyncio.AbstractEventLoop) -> str | None: """Resolve rules_path: config override → document metadata → type_id mapping.""" from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession from sqlalchemy import text as sa_text # 1. Config override (when explicitly set in app.toml) if LEAUDIT_RULES_DIR: return LEAUDIT_RULES_DIR try: async def _fetch(): async with GetAsyncSession() as session: result = await session.execute( sa_text("SELECT type_id FROM leaudit_documents WHERE id = :did"), {"did": document_id}, ) row = result.fetchone() if row and row[0] and row[0] in _TYPE_ID_RULES_MAP: return f"{_TYPE_ID_RULES_MAP[row[0]]}/rules.yaml" return None return loop.run_until_complete(_fetch()) except Exception as e: log.warning(f"Failed to resolve rules_path from document: {e}") return None def _resolve_rules_runtime( document_id: int, run_id: int, explicit_rules_path: str | None, loop: asyncio.AbstractEventLoop, ) -> dict[str, str | None]: """解析本次执行使用的规则来源。""" if explicit_rules_path: return { "rules_path": explicit_rules_path, "temp_rule_path": None, "source_type": "explicit", "source_path": explicit_rules_path, } resolver = RuleVersionResolver() try: payload = loop.run_until_complete(resolver.ResolveForRun(run_id)) if payload: log.info( f"run_id={run_id} 规则来源已解析: sourceType={payload.sourceType}, " f"sourcePath={payload.sourcePath}, localPath={payload.localPath}" ) return { "rules_path": payload.localPath, "temp_rule_path": payload.localPath if payload.sourceType == "oss" else None, "source_type": payload.sourceType, "source_path": payload.sourcePath, } except Exception as e: log.warning(f"Failed to resolve rule version from run: run_id={run_id}, error={e}") fallback_rules_path = _resolve_rules_path(document_id, loop) return { "rules_path": fallback_rules_path, "temp_rule_path": None, "source_type": "legacy_fallback", "source_path": fallback_rules_path, } def _resolve_run_id( document_id: int, upload_info: Dict[str, Any] | None, loop: asyncio.AbstractEventLoop, ) -> int: """解析本次任务对应的运行 ID。""" if upload_info: for key in ("run_id", "runId"): value = upload_info.get(key) if isinstance(value, int): return value from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession from sqlalchemy import text as sa_text async def _fetch() -> int: async with GetAsyncSession() as session: result = await session.execute( sa_text("SELECT id FROM leaudit_audit_runs WHERE document_id = :did ORDER BY id DESC LIMIT 1"), {"did": document_id}, ) row = result.fetchone() if not row: raise ValueError(f"未找到 document_id={document_id} 对应的 run 记录") return int(row[0]) return loop.run_until_complete(_fetch()) def _optional_int(payload: Dict[str, Any] | None, *keys: str) -> int | None: """从字典中按顺序取可用整数。""" if not payload: return None for key in keys: value = payload.get(key) if isinstance(value, int): return value return None async def _update_status_safe(document_id: int, status: str) -> None: """Safely update document status via SQLAlchemy, ignoring errors.""" try: from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession from sqlalchemy import text as sa_text async with GetAsyncSession() as session: await session.execute( sa_text("UPDATE leaudit_documents SET processing_status = :s, updated_at = now() WHERE id = :did"), {"s": status, "did": document_id}, ) await session.commit() except Exception: pass async def _update_run_status_safe(run_id: int, status: str) -> None: """安全更新运行状态。""" try: from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession from sqlalchemy import text as sa_text async with GetAsyncSession() as session: await session.execute( sa_text("UPDATE leaudit_audit_runs SET status = :s, updated_at = now() WHERE id = :rid"), {"s": status, "rid": run_id}, ) await session.commit() except Exception: pass async def _update_run_phase_safe(run_id: int, phase: str | None) -> None: """安全更新运行阶段。""" try: from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession from sqlalchemy import text as sa_text async with GetAsyncSession() as session: await session.execute( sa_text("UPDATE leaudit_audit_runs SET phase = :p, updated_at = now() WHERE id = :rid"), {"p": phase, "rid": run_id}, ) await session.commit() except Exception: pass async def _claim_run_safe(run_id: int, task_id: str | None) -> bool: """原子抢占 queued/pending 运行,避免重复消费。""" try: from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession from sqlalchemy import text as sa_text async with GetAsyncSession() as session: result = await session.execute( sa_text( """ UPDATE leaudit_audit_runs SET status = 'running', phase = 'prepare', task_id = COALESCE(:task_id, task_id), started_at = COALESCE(started_at, now()), updated_at = now() WHERE id = :rid AND status IN ('queued', 'pending', 'retrying') RETURNING id """ ), {"rid": run_id, "task_id": task_id}, ) row = result.fetchone() await session.commit() return row is not None except Exception: log.exception("run_id=%s 抢占执行权失败", run_id) return False async def _load_run_context(run_id: int) -> dict[str, Any]: """按 runId 加载执行所需文档文件上下文。""" from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession async with GetAsyncSession() as session: run = await session.get(LeauditAuditRun, run_id) if not run: raise ValueError(f"未找到 run_id={run_id} 对应的运行记录") document = await session.get(LeauditDocument, run.documentId) if not document: raise ValueError(f"未找到 document_id={run.documentId} 对应的文档记录") document_file = await session.get(LeauditDocumentFile, run.documentFileId) if not document_file: raise ValueError(f"未找到 document_file_id={run.documentFileId} 对应的文件记录") resolver = FileSourceResolver() payload = await resolver.ResolvePayload(document_file) return { "document_id": document.Id, "filename": payload.fileName, "file_content": payload.fileContent, "source_type": payload.sourceType, "source_path": payload.sourcePath, "rule_version_id": run.ruleVersionId, "rule_source_oss_url": run.ruleSourceOssUrl, "trigger_source": run.triggerSource, } def _get_suffix(filename: str) -> str: """Extract file suffix from filename.""" _, ext = os.path.splitext(filename) return ext if ext else ".pdf" def dispatch_leaudit_task( run_id: int, *, queue_name: str | None = None, rules_path: Optional[str] = None, ) -> str: """投递 runId 到 Celery worker 队列。""" target_queue = queue_name or LEAUDIT_WORKER_QUEUE_NORMAL task = leaudit_process_document_task.apply_async( kwargs={"run_id": run_id, "rules_path": rules_path}, queue=target_queue, ) log.info( "run_id=%s 已投递到 worker 队列: queue=%s, speed=%s, task_id=%s", run_id, target_queue, _queue_label(target_queue), task.id, ) return task.id def resolve_worker_queue(trigger_source: str | None) -> str: """按触发来源选择 worker 队列。""" normalized = (trigger_source or "").strip().lower() if "urgent" in normalized or "high" in normalized: return LEAUDIT_WORKER_QUEUE_URGENT return LEAUDIT_WORKER_QUEUE_NORMAL def _queue_label(queue_name: str | None) -> str: """Map queue name to a user-facing speed label for logs.""" if queue_name == LEAUDIT_WORKER_QUEUE_URGENT: return "urgent" return "normal"