Files
leaudit-platform-backend/fastapi_modules/fastapi_leaudit/leaudit_bridge/tasks.py
T
wren 535d97a70c chore: initial commit — leaudit-platform project skeleton
17-table PostgreSQL schema with full Chinese column comments,
FastAPI project structure (admin/common/modules),
DSL rule files, and schema migration scripts.
2026-04-27 16:48:22 +08:00

202 lines
6.2 KiB
Python

"""Celery task for leaudit pipeline processing.
Activated when PIPELINE_MODE=leaudit in env.{port} config.
Replaces the legacy OCR → extraction → evaluation pipeline with
leaudit's YAML-rules-driven approach.
"""
from __future__ import annotations
import asyncio
import os
import tempfile
import time
from typing import Any, Dict, Optional
from core.celery_app_limited import celery_app
from core.postgrest.client import get_postgrest_client
from core.logger import log
from leaudit_bridge import create_pipeline, RulesLoader
@celery_app.task(bind=True, name="leaudit.process_document")
def leaudit_process_document(
self,
document_id: int,
file_content: bytes,
filename: str,
upload_info: Optional[Dict[str, Any]] = None,
source_port: Optional[int] = None,
rules_path: Optional[str] = None,
):
"""Process a document using leaudit's full pipeline.
Steps: OCR → Extraction → Evaluation → Store in docauditai DB.
"""
task_id = self.request.id
log.task.info(f"[任务ID: {task_id}] leaudit管线开始处理: {filename}")
if source_port:
from core.utils.instance_context import set_instance_environment
instance_name = set_instance_environment(source_port)
log.task.info(
f"[任务ID: {task_id}] 实例环境: {instance_name} (端口: {source_port})"
)
if upload_info is None:
upload_info = {}
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
rules_path_resolved = rules_path or _resolve_rules_path(document_id, loop)
# For types with a known mapping (e.g. 行政处罚), pre-load rules_file.
# For types that need content classification (e.g. 行政许可 sub-types),
# rules_path will be None → adapter classifies after OCR → pipeline
# loads rules from ocr_result.rules_file_path.
rules_file = None
if rules_path_resolved:
loader = RulesLoader()
rules_file = loader.load(rules_path_resolved)
log.task.info(
f"[任务ID: {task_id}] RulesFile pre-loaded: {rules_path_resolved} "
f"({len(rules_file.flat_rules)} rules, {len(rules_file.flat_extract)} fields)"
)
else:
log.task.info(
f"[任务ID: {task_id}] No fixed rules_path — "
"will classify from document content after OCR"
)
suffix = _get_suffix(filename)
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp:
temp.write(file_content)
temp_path = temp.name
pipeline = create_pipeline(rules_path=rules_path_resolved)
t0 = time.time()
result = loop.run_until_complete(
pipeline.run(
document_id=document_id,
file_path=temp_path,
rules_file=rules_file,
source_port=source_port or int(os.getenv("APP_PORT", "8000")),
)
)
elapsed = round(time.time() - t0, 2)
try:
os.remove(temp_path)
except OSError:
pass
log.task.info(
f"[任务ID: {task_id}] leaudit管线完成: phase={result.detected_phase}, "
f"timing={result.timing}, 总耗时={elapsed:.1f}s"
)
return {
"status": "success",
"document_id": document_id,
"phase": result.detected_phase,
"timing": result.timing,
"errors": result.errors,
}
except Exception as e:
log.task.error(f"[任务ID: {task_id}] leaudit管线失败: {e}", exc_info=True)
try:
loop.run_until_complete(_update_status_safe(document_id, "Failed"))
except Exception:
pass
raise
finally:
loop.close()
# type_id → rules directory mapping (only fixed-mapping types)
# 行政许可 (type_id=2) has 9 sub-types, NOT mapped here —
# must come from document metadata (rules_file_path) or content classification.
_TYPE_ID_RULES_MAP: dict[int, str] = {
3: "行政处罚",
}
def _resolve_rules_path(document_id: int, loop: asyncio.AbstractEventLoop) -> str | None:
"""Resolve rules_path: config override → document metadata → type_id mapping."""
from core.config import LEAUDIT_CONFIG
# 1. Config override (when explicitly set)
config_path = LEAUDIT_CONFIG.get("RULES_PATH", "")
if config_path:
return config_path
try:
client = get_postgrest_client()
doc = loop.run_until_complete(
client.select(
table="documents",
filters={"id": f"eq.{document_id}"},
single=True,
)
)
if not doc:
return None
# 2. Document-level override
rfp = doc.get("rules_file_path")
if rfp:
return rfp
# 3. type_id mapping
type_id = doc.get("type_id")
if type_id and type_id in _TYPE_ID_RULES_MAP:
return f"{_TYPE_ID_RULES_MAP[type_id]}/rules.yaml"
except Exception as e:
log.task.warning(f"Failed to resolve rules_path from document: {e}")
return None
async def _update_status_safe(document_id: int, status: str) -> None:
"""Safely update document status, ignoring errors."""
try:
client = get_postgrest_client()
await client.update(
table="documents",
filters={"id": f"eq.{document_id}"},
data={"status": status},
)
except Exception:
pass
def _get_suffix(filename: str) -> str:
"""Extract file suffix from filename."""
_, ext = os.path.splitext(filename)
return ext if ext else ".pdf"
def dispatch_leaudit_task(
document_id: int,
file_content: bytes,
filename: str,
upload_info: Optional[Dict[str, Any]] = None,
source_port: Optional[int] = None,
rules_path: Optional[str] = None,
):
"""Dispatch a leaudit processing task."""
return leaudit_process_document.apply_async(
args=[document_id, file_content, filename],
kwargs={
"upload_info": upload_info,
"source_port": source_port or int(os.getenv("APP_PORT", "8000")),
"rules_path": rules_path,
},
)