535d97a70c
17-table PostgreSQL schema with full Chinese column comments, FastAPI project structure (admin/common/modules), DSL rule files, and schema migration scripts.
202 lines
6.2 KiB
Python
202 lines
6.2 KiB
Python
"""Celery task for leaudit pipeline processing.
|
|
|
|
Activated when PIPELINE_MODE=leaudit in env.{port} config.
|
|
Replaces the legacy OCR → extraction → evaluation pipeline with
|
|
leaudit's YAML-rules-driven approach.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import os
|
|
import tempfile
|
|
import time
|
|
from typing import Any, Dict, Optional
|
|
|
|
from core.celery_app_limited import celery_app
|
|
from core.postgrest.client import get_postgrest_client
|
|
from core.logger import log
|
|
|
|
from leaudit_bridge import create_pipeline, RulesLoader
|
|
|
|
|
|
@celery_app.task(bind=True, name="leaudit.process_document")
|
|
def leaudit_process_document(
|
|
self,
|
|
document_id: int,
|
|
file_content: bytes,
|
|
filename: str,
|
|
upload_info: Optional[Dict[str, Any]] = None,
|
|
source_port: Optional[int] = None,
|
|
rules_path: Optional[str] = None,
|
|
):
|
|
"""Process a document using leaudit's full pipeline.
|
|
|
|
Steps: OCR → Extraction → Evaluation → Store in docauditai DB.
|
|
"""
|
|
task_id = self.request.id
|
|
log.task.info(f"[任务ID: {task_id}] leaudit管线开始处理: {filename}")
|
|
|
|
if source_port:
|
|
from core.utils.instance_context import set_instance_environment
|
|
instance_name = set_instance_environment(source_port)
|
|
log.task.info(
|
|
f"[任务ID: {task_id}] 实例环境: {instance_name} (端口: {source_port})"
|
|
)
|
|
|
|
if upload_info is None:
|
|
upload_info = {}
|
|
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
|
|
try:
|
|
rules_path_resolved = rules_path or _resolve_rules_path(document_id, loop)
|
|
|
|
# For types with a known mapping (e.g. 行政处罚), pre-load rules_file.
|
|
# For types that need content classification (e.g. 行政许可 sub-types),
|
|
# rules_path will be None → adapter classifies after OCR → pipeline
|
|
# loads rules from ocr_result.rules_file_path.
|
|
rules_file = None
|
|
if rules_path_resolved:
|
|
loader = RulesLoader()
|
|
rules_file = loader.load(rules_path_resolved)
|
|
log.task.info(
|
|
f"[任务ID: {task_id}] RulesFile pre-loaded: {rules_path_resolved} "
|
|
f"({len(rules_file.flat_rules)} rules, {len(rules_file.flat_extract)} fields)"
|
|
)
|
|
else:
|
|
log.task.info(
|
|
f"[任务ID: {task_id}] No fixed rules_path — "
|
|
"will classify from document content after OCR"
|
|
)
|
|
|
|
suffix = _get_suffix(filename)
|
|
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp:
|
|
temp.write(file_content)
|
|
temp_path = temp.name
|
|
|
|
pipeline = create_pipeline(rules_path=rules_path_resolved)
|
|
|
|
t0 = time.time()
|
|
result = loop.run_until_complete(
|
|
pipeline.run(
|
|
document_id=document_id,
|
|
file_path=temp_path,
|
|
rules_file=rules_file,
|
|
source_port=source_port or int(os.getenv("APP_PORT", "8000")),
|
|
)
|
|
)
|
|
elapsed = round(time.time() - t0, 2)
|
|
|
|
try:
|
|
os.remove(temp_path)
|
|
except OSError:
|
|
pass
|
|
|
|
log.task.info(
|
|
f"[任务ID: {task_id}] leaudit管线完成: phase={result.detected_phase}, "
|
|
f"timing={result.timing}, 总耗时={elapsed:.1f}s"
|
|
)
|
|
|
|
return {
|
|
"status": "success",
|
|
"document_id": document_id,
|
|
"phase": result.detected_phase,
|
|
"timing": result.timing,
|
|
"errors": result.errors,
|
|
}
|
|
|
|
except Exception as e:
|
|
log.task.error(f"[任务ID: {task_id}] leaudit管线失败: {e}", exc_info=True)
|
|
try:
|
|
loop.run_until_complete(_update_status_safe(document_id, "Failed"))
|
|
except Exception:
|
|
pass
|
|
raise
|
|
|
|
finally:
|
|
loop.close()
|
|
|
|
|
|
# type_id → rules directory mapping (only fixed-mapping types)
|
|
# 行政许可 (type_id=2) has 9 sub-types, NOT mapped here —
|
|
# must come from document metadata (rules_file_path) or content classification.
|
|
_TYPE_ID_RULES_MAP: dict[int, str] = {
|
|
3: "行政处罚",
|
|
}
|
|
|
|
|
|
def _resolve_rules_path(document_id: int, loop: asyncio.AbstractEventLoop) -> str | None:
|
|
"""Resolve rules_path: config override → document metadata → type_id mapping."""
|
|
from core.config import LEAUDIT_CONFIG
|
|
|
|
# 1. Config override (when explicitly set)
|
|
config_path = LEAUDIT_CONFIG.get("RULES_PATH", "")
|
|
if config_path:
|
|
return config_path
|
|
|
|
try:
|
|
client = get_postgrest_client()
|
|
doc = loop.run_until_complete(
|
|
client.select(
|
|
table="documents",
|
|
filters={"id": f"eq.{document_id}"},
|
|
single=True,
|
|
)
|
|
)
|
|
if not doc:
|
|
return None
|
|
|
|
# 2. Document-level override
|
|
rfp = doc.get("rules_file_path")
|
|
if rfp:
|
|
return rfp
|
|
|
|
# 3. type_id mapping
|
|
type_id = doc.get("type_id")
|
|
if type_id and type_id in _TYPE_ID_RULES_MAP:
|
|
return f"{_TYPE_ID_RULES_MAP[type_id]}/rules.yaml"
|
|
except Exception as e:
|
|
log.task.warning(f"Failed to resolve rules_path from document: {e}")
|
|
|
|
return None
|
|
|
|
|
|
async def _update_status_safe(document_id: int, status: str) -> None:
|
|
"""Safely update document status, ignoring errors."""
|
|
try:
|
|
client = get_postgrest_client()
|
|
await client.update(
|
|
table="documents",
|
|
filters={"id": f"eq.{document_id}"},
|
|
data={"status": status},
|
|
)
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _get_suffix(filename: str) -> str:
|
|
"""Extract file suffix from filename."""
|
|
_, ext = os.path.splitext(filename)
|
|
return ext if ext else ".pdf"
|
|
|
|
|
|
def dispatch_leaudit_task(
|
|
document_id: int,
|
|
file_content: bytes,
|
|
filename: str,
|
|
upload_info: Optional[Dict[str, Any]] = None,
|
|
source_port: Optional[int] = None,
|
|
rules_path: Optional[str] = None,
|
|
):
|
|
"""Dispatch a leaudit processing task."""
|
|
return leaudit_process_document.apply_async(
|
|
args=[document_id, file_content, filename],
|
|
kwargs={
|
|
"upload_info": upload_info,
|
|
"source_port": source_port or int(os.getenv("APP_PORT", "8000")),
|
|
"rules_path": rules_path,
|
|
},
|
|
)
|