Files
leaudit-platform-backend/fastapi_modules/fastapi_leaudit/leaudit_bridge/tasks.py
T
2026-04-28 16:53:16 +08:00

365 lines
12 KiB
Python

"""LeAudit 任务入口。"""
from __future__ import annotations
import asyncio
from concurrent.futures import ThreadPoolExecutor
import os
from pathlib import Path
import tempfile
import time
from typing import Any, Dict, Optional
from fastapi_common.fastapi_common_logger import logger
from fastapi_admin.config import LEAUDIT_RULES_DIR
from fastapi_modules.fastapi_leaudit.leaudit_bridge.nativeRunner import (
NativeRunRequest,
NativeRunner,
)
from fastapi_modules.fastapi_leaudit.leaudit_bridge.auditCtxBuilder import (
NativeAuditMetadata,
)
from fastapi_modules.fastapi_leaudit.leaudit_bridge.ruleVersionResolver import (
RuleVersionResolver,
)
from fastapi_modules.fastapi_leaudit.leaudit_bridge.rules_loader import RulesLoader
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
# Celery 集成待 P2 阶段实现,当前使用同步占位
# from core.celery_app_limited import celery_app
log = logger
# P2: Celery 集成后启用 @celery_app.task 装饰器
def leaudit_process_document(
document_id: int,
file_content: bytes,
filename: str,
upload_info: Optional[Dict[str, Any]] = None,
source_port: Optional[int] = None,
rules_path: Optional[str] = None,
):
"""处理单个文档的 LeAudit 任务。"""
task_id = os.urandom(8).hex()
log.info(f"[任务ID: {task_id}] leaudit管线开始处理: {filename}")
# 新平台:region 通过参数传递,不再依赖 os.environ 切换
if source_port:
log.info(f"[任务ID: {task_id}] 来源端口: {source_port}")
if upload_info is None:
upload_info = {}
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
temp_paths: list[str] = []
storage = StorageAdapter()
try:
run_id = _resolve_run_id(document_id, upload_info, loop)
rules_resolution = _resolve_rules_runtime(document_id, run_id, rules_path, loop)
loop.run_until_complete(_update_run_status_safe(run_id, "running"))
rules_path_resolved = rules_resolution["rules_path"]
rules_file = None
if rules_path_resolved:
loader = RulesLoader()
rules_file = loader.load(rules_path_resolved)
temp_rule_path = rules_resolution.get("temp_rule_path")
if isinstance(temp_rule_path, str):
temp_paths.append(temp_rule_path)
log.info(
f"[任务ID: {task_id}] RulesFile pre-loaded: {rules_path_resolved} "
f"({len(rules_file.flat_rules)} rules, {len(rules_file.flat_extract)} fields)"
)
else:
log.info(
f"[任务ID: {task_id}] No fixed rules_path — "
"will classify from document content after OCR"
)
suffix = _get_suffix(filename)
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp:
temp.write(file_content)
temp_path = temp.name
temp_paths.append(temp_path)
runner = NativeRunner()
t0 = time.time()
native_result = loop.run_until_complete(
runner.run(
NativeRunRequest(
metadata=NativeAuditMetadata(
run_id=run_id,
document_id=document_id,
rule_version_id=_optional_int(upload_info, "rule_version_id", "ruleVersionId"),
extras={"taskId": task_id},
),
local_file_path=temp_path,
rules_file=rules_file,
rule_source_path=rules_path_resolved,
)
)
)
loop.run_until_complete(runner.persist_result(native_result))
elapsed = round(time.time() - t0, 2)
ctx = native_result.ctx
loop.run_until_complete(_update_run_phase_safe(run_id, ctx.phase))
loop.run_until_complete(_update_run_status_safe(run_id, "completed"))
loop.run_until_complete(_update_status_safe(document_id, "completed"))
log.info(
f"[任务ID: {task_id}] leaudit管线完成: phase={ctx.phase}, "
f"timing={dict(ctx.timing)}, 总耗时={elapsed:.1f}s"
)
return {
"status": "success",
"document_id": document_id,
"run_id": run_id,
"phase": ctx.phase,
"timing": dict(ctx.timing),
"errors": list(ctx.extraction.all_errors) if ctx.extraction is not None else list(ctx.extraction_errors),
}
except Exception as e:
log.error(f"[任务ID: {task_id}] leaudit管线失败: {e}", exc_info=True)
try:
loop.run_until_complete(_update_status_safe(document_id, "failed"))
if 'run_id' in locals():
failed_phase = "persist"
if "native_result" in locals():
failed_phase = native_result.ctx.phase or failed_phase
loop.run_until_complete(
storage.fail_run(
document_id,
run_id=run_id,
phase=failed_phase,
message=str(e),
detail_json={
"taskId": task_id,
"filename": filename,
"errorType": type(e).__name__,
},
)
)
except Exception:
pass
raise
finally:
for temp_path in temp_paths:
try:
if Path(temp_path).exists():
os.remove(temp_path)
except OSError:
pass
loop.close()
# type_id → rules directory mapping (only fixed-mapping types)
# 行政许可 (type_id=2) has 9 sub-types, NOT mapped here —
# must come from document metadata (rules_file_path) or content classification.
_TYPE_ID_RULES_MAP: dict[int, str] = {
3: "行政处罚",
}
def _resolve_rules_path(document_id: int, loop: asyncio.AbstractEventLoop) -> str | None:
"""Resolve rules_path: config override → document metadata → type_id mapping."""
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from sqlalchemy import text as sa_text
# 1. Config override (when explicitly set in app.toml)
if LEAUDIT_RULES_DIR:
return LEAUDIT_RULES_DIR
try:
async def _fetch():
async with GetAsyncSession() as session:
result = await session.execute(
sa_text("SELECT type_id FROM leaudit_documents WHERE id = :did"),
{"did": document_id},
)
row = result.fetchone()
if row and row[0] and row[0] in _TYPE_ID_RULES_MAP:
return f"{_TYPE_ID_RULES_MAP[row[0]]}/rules.yaml"
return None
return loop.run_until_complete(_fetch())
except Exception as e:
log.warning(f"Failed to resolve rules_path from document: {e}")
return None
def _resolve_rules_runtime(
document_id: int,
run_id: int,
explicit_rules_path: str | None,
loop: asyncio.AbstractEventLoop,
) -> dict[str, str | None]:
"""解析本次执行使用的规则来源。"""
if explicit_rules_path:
return {
"rules_path": explicit_rules_path,
"temp_rule_path": None,
"source_type": "explicit",
"source_path": explicit_rules_path,
}
resolver = RuleVersionResolver()
try:
payload = loop.run_until_complete(resolver.ResolveForRun(run_id))
if payload:
log.info(
f"run_id={run_id} 规则来源已解析: sourceType={payload.sourceType}, "
f"sourcePath={payload.sourcePath}, localPath={payload.localPath}"
)
return {
"rules_path": payload.localPath,
"temp_rule_path": payload.localPath if payload.sourceType == "oss" else None,
"source_type": payload.sourceType,
"source_path": payload.sourcePath,
}
except Exception as e:
log.warning(f"Failed to resolve rule version from run: run_id={run_id}, error={e}")
fallback_rules_path = _resolve_rules_path(document_id, loop)
return {
"rules_path": fallback_rules_path,
"temp_rule_path": None,
"source_type": "legacy_fallback",
"source_path": fallback_rules_path,
}
def _resolve_run_id(
document_id: int,
upload_info: Dict[str, Any] | None,
loop: asyncio.AbstractEventLoop,
) -> int:
"""解析本次任务对应的运行 ID。"""
if upload_info:
for key in ("run_id", "runId"):
value = upload_info.get(key)
if isinstance(value, int):
return value
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from sqlalchemy import text as sa_text
async def _fetch() -> int:
async with GetAsyncSession() as session:
result = await session.execute(
sa_text("SELECT id FROM leaudit_audit_runs WHERE document_id = :did ORDER BY id DESC LIMIT 1"),
{"did": document_id},
)
row = result.fetchone()
if not row:
raise ValueError(f"未找到 document_id={document_id} 对应的 run 记录")
return int(row[0])
return loop.run_until_complete(_fetch())
def _optional_int(payload: Dict[str, Any] | None, *keys: str) -> int | None:
"""从字典中按顺序取可用整数。"""
if not payload:
return None
for key in keys:
value = payload.get(key)
if isinstance(value, int):
return value
return None
async def _update_status_safe(document_id: int, status: str) -> None:
"""Safely update document status via SQLAlchemy, ignoring errors."""
try:
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from sqlalchemy import text as sa_text
async with GetAsyncSession() as session:
await session.execute(
sa_text("UPDATE leaudit_documents SET processing_status = :s, updated_at = now() WHERE id = :did"),
{"s": status, "did": document_id},
)
await session.commit()
except Exception:
pass
async def _update_run_status_safe(run_id: int, status: str) -> None:
"""安全更新运行状态。"""
try:
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from sqlalchemy import text as sa_text
async with GetAsyncSession() as session:
await session.execute(
sa_text("UPDATE leaudit_audit_runs SET status = :s, updated_at = now() WHERE id = :rid"),
{"s": status, "rid": run_id},
)
await session.commit()
except Exception:
pass
async def _update_run_phase_safe(run_id: int, phase: str | None) -> None:
"""安全更新运行阶段。"""
try:
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from sqlalchemy import text as sa_text
async with GetAsyncSession() as session:
await session.execute(
sa_text("UPDATE leaudit_audit_runs SET phase = :p, updated_at = now() WHERE id = :rid"),
{"p": phase, "rid": run_id},
)
await session.commit()
except Exception:
pass
def _get_suffix(filename: str) -> str:
"""Extract file suffix from filename."""
_, ext = os.path.splitext(filename)
return ext if ext else ".pdf"
def dispatch_leaudit_task(
document_id: int,
file_content: bytes,
filename: str,
upload_info: Optional[Dict[str, Any]] = None,
source_port: Optional[int] = None,
rules_path: Optional[str] = None,
):
"""Dispatch a leaudit processing task.
P2: Celery 集成后改用 leaudit_process_document.apply_async(...)
当前阶段直接同步调用。
"""
kwargs = {
"document_id": document_id,
"file_content": file_content,
"filename": filename,
"upload_info": upload_info,
"source_port": source_port or int(os.getenv("APP_PORT", "8000")),
"rules_path": rules_path,
}
try:
asyncio.get_running_loop()
except RuntimeError:
return leaudit_process_document(**kwargs)
with ThreadPoolExecutor(max_workers=1) as executor:
future = executor.submit(leaudit_process_document, **kwargs)
return future.result()