"""Celery task for leaudit pipeline processing. Activated when PIPELINE_MODE=leaudit in env.{port} config. Replaces the legacy OCR → extraction → evaluation pipeline with leaudit's YAML-rules-driven approach. """ from __future__ import annotations import asyncio import os import tempfile import time from typing import Any, Dict, Optional from core.celery_app_limited import celery_app from core.postgrest.client import get_postgrest_client from core.logger import log from leaudit_bridge import create_pipeline, RulesLoader @celery_app.task(bind=True, name="leaudit.process_document") def leaudit_process_document( self, document_id: int, file_content: bytes, filename: str, upload_info: Optional[Dict[str, Any]] = None, source_port: Optional[int] = None, rules_path: Optional[str] = None, ): """Process a document using leaudit's full pipeline. Steps: OCR → Extraction → Evaluation → Store in docauditai DB. """ task_id = self.request.id log.task.info(f"[任务ID: {task_id}] leaudit管线开始处理: {filename}") if source_port: from core.utils.instance_context import set_instance_environment instance_name = set_instance_environment(source_port) log.task.info( f"[任务ID: {task_id}] 实例环境: {instance_name} (端口: {source_port})" ) if upload_info is None: upload_info = {} loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) try: rules_path_resolved = rules_path or _resolve_rules_path(document_id, loop) # For types with a known mapping (e.g. 行政处罚), pre-load rules_file. # For types that need content classification (e.g. 行政许可 sub-types), # rules_path will be None → adapter classifies after OCR → pipeline # loads rules from ocr_result.rules_file_path. rules_file = None if rules_path_resolved: loader = RulesLoader() rules_file = loader.load(rules_path_resolved) log.task.info( f"[任务ID: {task_id}] RulesFile pre-loaded: {rules_path_resolved} " f"({len(rules_file.flat_rules)} rules, {len(rules_file.flat_extract)} fields)" ) else: log.task.info( f"[任务ID: {task_id}] No fixed rules_path — " "will classify from document content after OCR" ) suffix = _get_suffix(filename) with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp: temp.write(file_content) temp_path = temp.name pipeline = create_pipeline(rules_path=rules_path_resolved) t0 = time.time() result = loop.run_until_complete( pipeline.run( document_id=document_id, file_path=temp_path, rules_file=rules_file, source_port=source_port or int(os.getenv("APP_PORT", "8000")), ) ) elapsed = round(time.time() - t0, 2) try: os.remove(temp_path) except OSError: pass log.task.info( f"[任务ID: {task_id}] leaudit管线完成: phase={result.detected_phase}, " f"timing={result.timing}, 总耗时={elapsed:.1f}s" ) return { "status": "success", "document_id": document_id, "phase": result.detected_phase, "timing": result.timing, "errors": result.errors, } except Exception as e: log.task.error(f"[任务ID: {task_id}] leaudit管线失败: {e}", exc_info=True) try: loop.run_until_complete(_update_status_safe(document_id, "Failed")) except Exception: pass raise finally: loop.close() # type_id → rules directory mapping (only fixed-mapping types) # 行政许可 (type_id=2) has 9 sub-types, NOT mapped here — # must come from document metadata (rules_file_path) or content classification. _TYPE_ID_RULES_MAP: dict[int, str] = { 3: "行政处罚", } def _resolve_rules_path(document_id: int, loop: asyncio.AbstractEventLoop) -> str | None: """Resolve rules_path: config override → document metadata → type_id mapping.""" from core.config import LEAUDIT_CONFIG # 1. Config override (when explicitly set) config_path = LEAUDIT_CONFIG.get("RULES_PATH", "") if config_path: return config_path try: client = get_postgrest_client() doc = loop.run_until_complete( client.select( table="documents", filters={"id": f"eq.{document_id}"}, single=True, ) ) if not doc: return None # 2. Document-level override rfp = doc.get("rules_file_path") if rfp: return rfp # 3. type_id mapping type_id = doc.get("type_id") if type_id and type_id in _TYPE_ID_RULES_MAP: return f"{_TYPE_ID_RULES_MAP[type_id]}/rules.yaml" except Exception as e: log.task.warning(f"Failed to resolve rules_path from document: {e}") return None async def _update_status_safe(document_id: int, status: str) -> None: """Safely update document status, ignoring errors.""" try: client = get_postgrest_client() await client.update( table="documents", filters={"id": f"eq.{document_id}"}, data={"status": status}, ) except Exception: pass def _get_suffix(filename: str) -> str: """Extract file suffix from filename.""" _, ext = os.path.splitext(filename) return ext if ext else ".pdf" def dispatch_leaudit_task( document_id: int, file_content: bytes, filename: str, upload_info: Optional[Dict[str, Any]] = None, source_port: Optional[int] = None, rules_path: Optional[str] = None, ): """Dispatch a leaudit processing task.""" return leaudit_process_document.apply_async( args=[document_id, file_content, filename], kwargs={ "upload_info": upload_info, "source_port": source_port or int(os.getenv("APP_PORT", "8000")), "rules_path": rules_path, }, )