"""LeAudit 任务入口。""" from __future__ import annotations import asyncio import os from pathlib import Path import tempfile import time from typing import Any, Dict, Optional from fastapi_common.fastapi_common_logger import logger from fastapi_admin.config import LEAUDIT_RULES_DIR from fastapi_modules.fastapi_leaudit.leaudit_bridge.nativeRunner import ( NativeRunRequest, NativeRunner, ) from fastapi_modules.fastapi_leaudit.leaudit_bridge.auditCtxBuilder import ( NativeAuditMetadata, ) from fastapi_modules.fastapi_leaudit.leaudit_bridge.ruleVersionResolver import ( RuleVersionResolver, ) from fastapi_modules.fastapi_leaudit.leaudit_bridge.rules_loader import RulesLoader from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter # Celery 集成待 P2 阶段实现,当前使用同步占位 # from core.celery_app_limited import celery_app log = logger # P2: Celery 集成后启用 @celery_app.task 装饰器 def leaudit_process_document( document_id: int, file_content: bytes, filename: str, upload_info: Optional[Dict[str, Any]] = None, source_port: Optional[int] = None, rules_path: Optional[str] = None, ): """处理单个文档的 LeAudit 任务。""" task_id = os.urandom(8).hex() log.info(f"[任务ID: {task_id}] leaudit管线开始处理: {filename}") # 新平台:region 通过参数传递,不再依赖 os.environ 切换 if source_port: log.info(f"[任务ID: {task_id}] 来源端口: {source_port}") if upload_info is None: upload_info = {} loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) temp_paths: list[str] = [] storage = StorageAdapter() try: run_id = _resolve_run_id(document_id, upload_info, loop) rules_resolution = _resolve_rules_runtime(document_id, run_id, rules_path, loop) loop.run_until_complete(_update_run_status_safe(run_id, "running")) rules_path_resolved = rules_resolution["rules_path"] rules_file = None if rules_path_resolved: loader = RulesLoader() rules_file = loader.load(rules_path_resolved) temp_rule_path = rules_resolution.get("temp_rule_path") if isinstance(temp_rule_path, str): temp_paths.append(temp_rule_path) log.info( f"[任务ID: {task_id}] RulesFile pre-loaded: {rules_path_resolved} " f"({len(rules_file.flat_rules)} rules, {len(rules_file.flat_extract)} fields)" ) else: log.info( f"[任务ID: {task_id}] No fixed rules_path — " "will classify from document content after OCR" ) suffix = _get_suffix(filename) with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp: temp.write(file_content) temp_path = temp.name temp_paths.append(temp_path) runner = NativeRunner() t0 = time.time() native_result = loop.run_until_complete( runner.run( NativeRunRequest( metadata=NativeAuditMetadata( run_id=run_id, document_id=document_id, rule_version_id=_optional_int(upload_info, "rule_version_id", "ruleVersionId"), extras={"taskId": task_id}, ), local_file_path=temp_path, rules_file=rules_file, rule_source_path=rules_path_resolved, ) ) ) loop.run_until_complete(runner.persist_result(native_result)) elapsed = round(time.time() - t0, 2) ctx = native_result.ctx loop.run_until_complete(_update_run_phase_safe(run_id, ctx.phase)) loop.run_until_complete(_update_run_status_safe(run_id, "completed")) loop.run_until_complete(_update_status_safe(document_id, "completed")) log.info( f"[任务ID: {task_id}] leaudit管线完成: phase={ctx.phase}, " f"timing={dict(ctx.timing)}, 总耗时={elapsed:.1f}s" ) return { "status": "success", "document_id": document_id, "run_id": run_id, "phase": ctx.phase, "timing": dict(ctx.timing), "errors": list(ctx.extraction.all_errors) if ctx.extraction is not None else list(ctx.extraction_errors), } except Exception as e: log.error(f"[任务ID: {task_id}] leaudit管线失败: {e}", exc_info=True) try: loop.run_until_complete(_update_status_safe(document_id, "failed")) if 'run_id' in locals(): failed_phase = "persist" if "native_result" in locals(): failed_phase = native_result.ctx.phase or failed_phase loop.run_until_complete( storage.fail_run( document_id, run_id=run_id, phase=failed_phase, message=str(e), detail_json={ "taskId": task_id, "filename": filename, "errorType": type(e).__name__, }, ) ) except Exception: pass raise finally: for temp_path in temp_paths: try: if Path(temp_path).exists(): os.remove(temp_path) except OSError: pass loop.close() # type_id → rules directory mapping (only fixed-mapping types) # 行政许可 (type_id=2) has 9 sub-types, NOT mapped here — # must come from document metadata (rules_file_path) or content classification. _TYPE_ID_RULES_MAP: dict[int, str] = { 3: "行政处罚", } def _resolve_rules_path(document_id: int, loop: asyncio.AbstractEventLoop) -> str | None: """Resolve rules_path: config override → document metadata → type_id mapping.""" from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession from sqlalchemy import text as sa_text # 1. Config override (when explicitly set in app.toml) if LEAUDIT_RULES_DIR: return LEAUDIT_RULES_DIR try: async def _fetch(): async with GetAsyncSession() as session: result = await session.execute( sa_text("SELECT type_id FROM leaudit_documents WHERE id = :did"), {"did": document_id}, ) row = result.fetchone() if row and row[0] and row[0] in _TYPE_ID_RULES_MAP: return f"{_TYPE_ID_RULES_MAP[row[0]]}/rules.yaml" return None return loop.run_until_complete(_fetch()) except Exception as e: log.warning(f"Failed to resolve rules_path from document: {e}") return None def _resolve_rules_runtime( document_id: int, run_id: int, explicit_rules_path: str | None, loop: asyncio.AbstractEventLoop, ) -> dict[str, str | None]: """解析本次执行使用的规则来源。""" if explicit_rules_path: return { "rules_path": explicit_rules_path, "temp_rule_path": None, "source_type": "explicit", "source_path": explicit_rules_path, } resolver = RuleVersionResolver() try: payload = loop.run_until_complete(resolver.ResolveForRun(run_id)) if payload: log.info( f"run_id={run_id} 规则来源已解析: sourceType={payload.sourceType}, " f"sourcePath={payload.sourcePath}, localPath={payload.localPath}" ) return { "rules_path": payload.localPath, "temp_rule_path": payload.localPath if payload.sourceType == "oss" else None, "source_type": payload.sourceType, "source_path": payload.sourcePath, } except Exception as e: log.warning(f"Failed to resolve rule version from run: run_id={run_id}, error={e}") fallback_rules_path = _resolve_rules_path(document_id, loop) return { "rules_path": fallback_rules_path, "temp_rule_path": None, "source_type": "legacy_fallback", "source_path": fallback_rules_path, } def _resolve_run_id( document_id: int, upload_info: Dict[str, Any] | None, loop: asyncio.AbstractEventLoop, ) -> int: """解析本次任务对应的运行 ID。""" if upload_info: for key in ("run_id", "runId"): value = upload_info.get(key) if isinstance(value, int): return value from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession from sqlalchemy import text as sa_text async def _fetch() -> int: async with GetAsyncSession() as session: result = await session.execute( sa_text("SELECT id FROM leaudit_audit_runs WHERE document_id = :did ORDER BY id DESC LIMIT 1"), {"did": document_id}, ) row = result.fetchone() if not row: raise ValueError(f"未找到 document_id={document_id} 对应的 run 记录") return int(row[0]) return loop.run_until_complete(_fetch()) def _optional_int(payload: Dict[str, Any] | None, *keys: str) -> int | None: """从字典中按顺序取可用整数。""" if not payload: return None for key in keys: value = payload.get(key) if isinstance(value, int): return value return None async def _update_status_safe(document_id: int, status: str) -> None: """Safely update document status via SQLAlchemy, ignoring errors.""" try: from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession from sqlalchemy import text as sa_text async with GetAsyncSession() as session: await session.execute( sa_text("UPDATE leaudit_documents SET processing_status = :s, update_time = now() WHERE id = :did"), {"s": status, "did": document_id}, ) await session.commit() except Exception: pass async def _update_run_status_safe(run_id: int, status: str) -> None: """安全更新运行状态。""" try: from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession from sqlalchemy import text as sa_text async with GetAsyncSession() as session: await session.execute( sa_text("UPDATE leaudit_audit_runs SET status = :s, update_time = now() WHERE id = :rid"), {"s": status, "rid": run_id}, ) await session.commit() except Exception: pass async def _update_run_phase_safe(run_id: int, phase: str | None) -> None: """安全更新运行阶段。""" try: from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession from sqlalchemy import text as sa_text async with GetAsyncSession() as session: await session.execute( sa_text("UPDATE leaudit_audit_runs SET phase = :p, update_time = now() WHERE id = :rid"), {"p": phase, "rid": run_id}, ) await session.commit() except Exception: pass def _get_suffix(filename: str) -> str: """Extract file suffix from filename.""" _, ext = os.path.splitext(filename) return ext if ext else ".pdf" def dispatch_leaudit_task( document_id: int, file_content: bytes, filename: str, upload_info: Optional[Dict[str, Any]] = None, source_port: Optional[int] = None, rules_path: Optional[str] = None, ): """Dispatch a leaudit processing task. P2: Celery 集成后改用 leaudit_process_document.apply_async(...) 当前阶段直接同步调用。 """ return leaudit_process_document( document_id=document_id, file_content=file_content, filename=filename, upload_info=upload_info, source_port=source_port or int(os.getenv("APP_PORT", "8000")), rules_path=rules_path, )