246c0e5ded
- M1: unified OSS client (upload/download/presign) + path utils + config - M2: rule service with validate/create/publish/rollback + binding CRUD endpoints - M3: native AuditCtx runner, file/rule resolvers, storage adapter with full persistence - docs: SYSTEM_OVERVIEW.md as comprehensive architecture reference - fix: double finalize — terminal state now written once by finalize_run
355 lines
12 KiB
Python
355 lines
12 KiB
Python
"""LeAudit 任务入口。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import os
|
|
from pathlib import Path
|
|
import tempfile
|
|
import time
|
|
from typing import Any, Dict, Optional
|
|
|
|
from fastapi_common.fastapi_common_logger import logger
|
|
|
|
from fastapi_admin.config import LEAUDIT_RULES_DIR
|
|
from fastapi_modules.fastapi_leaudit.leaudit_bridge.nativeRunner import (
|
|
NativeRunRequest,
|
|
NativeRunner,
|
|
)
|
|
from fastapi_modules.fastapi_leaudit.leaudit_bridge.auditCtxBuilder import (
|
|
NativeAuditMetadata,
|
|
)
|
|
from fastapi_modules.fastapi_leaudit.leaudit_bridge.ruleVersionResolver import (
|
|
RuleVersionResolver,
|
|
)
|
|
from fastapi_modules.fastapi_leaudit.leaudit_bridge.rules_loader import RulesLoader
|
|
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
|
|
|
|
# Celery 集成待 P2 阶段实现,当前使用同步占位
|
|
# from core.celery_app_limited import celery_app
|
|
|
|
log = logger
|
|
|
|
|
|
# P2: Celery 集成后启用 @celery_app.task 装饰器
|
|
def leaudit_process_document(
|
|
document_id: int,
|
|
file_content: bytes,
|
|
filename: str,
|
|
upload_info: Optional[Dict[str, Any]] = None,
|
|
source_port: Optional[int] = None,
|
|
rules_path: Optional[str] = None,
|
|
):
|
|
"""处理单个文档的 LeAudit 任务。"""
|
|
task_id = os.urandom(8).hex()
|
|
log.info(f"[任务ID: {task_id}] leaudit管线开始处理: {filename}")
|
|
|
|
# 新平台:region 通过参数传递,不再依赖 os.environ 切换
|
|
if source_port:
|
|
log.info(f"[任务ID: {task_id}] 来源端口: {source_port}")
|
|
|
|
if upload_info is None:
|
|
upload_info = {}
|
|
|
|
loop = asyncio.new_event_loop()
|
|
asyncio.set_event_loop(loop)
|
|
temp_paths: list[str] = []
|
|
storage = StorageAdapter()
|
|
|
|
try:
|
|
run_id = _resolve_run_id(document_id, upload_info, loop)
|
|
rules_resolution = _resolve_rules_runtime(document_id, run_id, rules_path, loop)
|
|
loop.run_until_complete(_update_run_status_safe(run_id, "running"))
|
|
rules_path_resolved = rules_resolution["rules_path"]
|
|
|
|
rules_file = None
|
|
if rules_path_resolved:
|
|
loader = RulesLoader()
|
|
rules_file = loader.load(rules_path_resolved)
|
|
temp_rule_path = rules_resolution.get("temp_rule_path")
|
|
if isinstance(temp_rule_path, str):
|
|
temp_paths.append(temp_rule_path)
|
|
log.info(
|
|
f"[任务ID: {task_id}] RulesFile pre-loaded: {rules_path_resolved} "
|
|
f"({len(rules_file.flat_rules)} rules, {len(rules_file.flat_extract)} fields)"
|
|
)
|
|
else:
|
|
log.info(
|
|
f"[任务ID: {task_id}] No fixed rules_path — "
|
|
"will classify from document content after OCR"
|
|
)
|
|
|
|
suffix = _get_suffix(filename)
|
|
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp:
|
|
temp.write(file_content)
|
|
temp_path = temp.name
|
|
temp_paths.append(temp_path)
|
|
|
|
runner = NativeRunner()
|
|
|
|
t0 = time.time()
|
|
native_result = loop.run_until_complete(
|
|
runner.run(
|
|
NativeRunRequest(
|
|
metadata=NativeAuditMetadata(
|
|
run_id=run_id,
|
|
document_id=document_id,
|
|
rule_version_id=_optional_int(upload_info, "rule_version_id", "ruleVersionId"),
|
|
extras={"taskId": task_id},
|
|
),
|
|
local_file_path=temp_path,
|
|
rules_file=rules_file,
|
|
rule_source_path=rules_path_resolved,
|
|
)
|
|
)
|
|
)
|
|
loop.run_until_complete(runner.persist_result(native_result))
|
|
elapsed = round(time.time() - t0, 2)
|
|
|
|
ctx = native_result.ctx
|
|
loop.run_until_complete(_update_run_phase_safe(run_id, ctx.phase))
|
|
loop.run_until_complete(_update_run_status_safe(run_id, "completed"))
|
|
loop.run_until_complete(_update_status_safe(document_id, "completed"))
|
|
log.info(
|
|
f"[任务ID: {task_id}] leaudit管线完成: phase={ctx.phase}, "
|
|
f"timing={dict(ctx.timing)}, 总耗时={elapsed:.1f}s"
|
|
)
|
|
|
|
return {
|
|
"status": "success",
|
|
"document_id": document_id,
|
|
"run_id": run_id,
|
|
"phase": ctx.phase,
|
|
"timing": dict(ctx.timing),
|
|
"errors": list(ctx.extraction.all_errors) if ctx.extraction is not None else list(ctx.extraction_errors),
|
|
}
|
|
|
|
except Exception as e:
|
|
log.error(f"[任务ID: {task_id}] leaudit管线失败: {e}", exc_info=True)
|
|
try:
|
|
loop.run_until_complete(_update_status_safe(document_id, "failed"))
|
|
if 'run_id' in locals():
|
|
failed_phase = "persist"
|
|
if "native_result" in locals():
|
|
failed_phase = native_result.ctx.phase or failed_phase
|
|
loop.run_until_complete(
|
|
storage.fail_run(
|
|
document_id,
|
|
run_id=run_id,
|
|
phase=failed_phase,
|
|
message=str(e),
|
|
detail_json={
|
|
"taskId": task_id,
|
|
"filename": filename,
|
|
"errorType": type(e).__name__,
|
|
},
|
|
)
|
|
)
|
|
except Exception:
|
|
pass
|
|
raise
|
|
|
|
finally:
|
|
for temp_path in temp_paths:
|
|
try:
|
|
if Path(temp_path).exists():
|
|
os.remove(temp_path)
|
|
except OSError:
|
|
pass
|
|
loop.close()
|
|
|
|
|
|
# type_id → rules directory mapping (only fixed-mapping types)
|
|
# 行政许可 (type_id=2) has 9 sub-types, NOT mapped here —
|
|
# must come from document metadata (rules_file_path) or content classification.
|
|
_TYPE_ID_RULES_MAP: dict[int, str] = {
|
|
3: "行政处罚",
|
|
}
|
|
|
|
|
|
def _resolve_rules_path(document_id: int, loop: asyncio.AbstractEventLoop) -> str | None:
|
|
"""Resolve rules_path: config override → document metadata → type_id mapping."""
|
|
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
|
from sqlalchemy import text as sa_text
|
|
|
|
# 1. Config override (when explicitly set in app.toml)
|
|
if LEAUDIT_RULES_DIR:
|
|
return LEAUDIT_RULES_DIR
|
|
|
|
try:
|
|
async def _fetch():
|
|
async with GetAsyncSession() as session:
|
|
result = await session.execute(
|
|
sa_text("SELECT type_id FROM leaudit_documents WHERE id = :did"),
|
|
{"did": document_id},
|
|
)
|
|
row = result.fetchone()
|
|
if row and row[0] and row[0] in _TYPE_ID_RULES_MAP:
|
|
return f"{_TYPE_ID_RULES_MAP[row[0]]}/rules.yaml"
|
|
return None
|
|
|
|
return loop.run_until_complete(_fetch())
|
|
except Exception as e:
|
|
log.warning(f"Failed to resolve rules_path from document: {e}")
|
|
|
|
return None
|
|
|
|
|
|
def _resolve_rules_runtime(
|
|
document_id: int,
|
|
run_id: int,
|
|
explicit_rules_path: str | None,
|
|
loop: asyncio.AbstractEventLoop,
|
|
) -> dict[str, str | None]:
|
|
"""解析本次执行使用的规则来源。"""
|
|
if explicit_rules_path:
|
|
return {
|
|
"rules_path": explicit_rules_path,
|
|
"temp_rule_path": None,
|
|
"source_type": "explicit",
|
|
"source_path": explicit_rules_path,
|
|
}
|
|
|
|
resolver = RuleVersionResolver()
|
|
try:
|
|
payload = loop.run_until_complete(resolver.ResolveForRun(run_id))
|
|
if payload:
|
|
log.info(
|
|
f"run_id={run_id} 规则来源已解析: sourceType={payload.sourceType}, "
|
|
f"sourcePath={payload.sourcePath}, localPath={payload.localPath}"
|
|
)
|
|
return {
|
|
"rules_path": payload.localPath,
|
|
"temp_rule_path": payload.localPath if payload.sourceType == "oss" else None,
|
|
"source_type": payload.sourceType,
|
|
"source_path": payload.sourcePath,
|
|
}
|
|
except Exception as e:
|
|
log.warning(f"Failed to resolve rule version from run: run_id={run_id}, error={e}")
|
|
|
|
fallback_rules_path = _resolve_rules_path(document_id, loop)
|
|
return {
|
|
"rules_path": fallback_rules_path,
|
|
"temp_rule_path": None,
|
|
"source_type": "legacy_fallback",
|
|
"source_path": fallback_rules_path,
|
|
}
|
|
|
|
|
|
def _resolve_run_id(
|
|
document_id: int,
|
|
upload_info: Dict[str, Any] | None,
|
|
loop: asyncio.AbstractEventLoop,
|
|
) -> int:
|
|
"""解析本次任务对应的运行 ID。"""
|
|
if upload_info:
|
|
for key in ("run_id", "runId"):
|
|
value = upload_info.get(key)
|
|
if isinstance(value, int):
|
|
return value
|
|
|
|
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
|
from sqlalchemy import text as sa_text
|
|
|
|
async def _fetch() -> int:
|
|
async with GetAsyncSession() as session:
|
|
result = await session.execute(
|
|
sa_text("SELECT id FROM leaudit_audit_runs WHERE document_id = :did ORDER BY id DESC LIMIT 1"),
|
|
{"did": document_id},
|
|
)
|
|
row = result.fetchone()
|
|
if not row:
|
|
raise ValueError(f"未找到 document_id={document_id} 对应的 run 记录")
|
|
return int(row[0])
|
|
|
|
return loop.run_until_complete(_fetch())
|
|
|
|
|
|
def _optional_int(payload: Dict[str, Any] | None, *keys: str) -> int | None:
|
|
"""从字典中按顺序取可用整数。"""
|
|
if not payload:
|
|
return None
|
|
|
|
for key in keys:
|
|
value = payload.get(key)
|
|
if isinstance(value, int):
|
|
return value
|
|
return None
|
|
|
|
|
|
async def _update_status_safe(document_id: int, status: str) -> None:
|
|
"""Safely update document status via SQLAlchemy, ignoring errors."""
|
|
try:
|
|
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
|
from sqlalchemy import text as sa_text
|
|
|
|
async with GetAsyncSession() as session:
|
|
await session.execute(
|
|
sa_text("UPDATE leaudit_documents SET processing_status = :s, update_time = now() WHERE id = :did"),
|
|
{"s": status, "did": document_id},
|
|
)
|
|
await session.commit()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
async def _update_run_status_safe(run_id: int, status: str) -> None:
|
|
"""安全更新运行状态。"""
|
|
try:
|
|
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
|
from sqlalchemy import text as sa_text
|
|
|
|
async with GetAsyncSession() as session:
|
|
await session.execute(
|
|
sa_text("UPDATE leaudit_audit_runs SET status = :s, update_time = now() WHERE id = :rid"),
|
|
{"s": status, "rid": run_id},
|
|
)
|
|
await session.commit()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
async def _update_run_phase_safe(run_id: int, phase: str | None) -> None:
|
|
"""安全更新运行阶段。"""
|
|
try:
|
|
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
|
from sqlalchemy import text as sa_text
|
|
|
|
async with GetAsyncSession() as session:
|
|
await session.execute(
|
|
sa_text("UPDATE leaudit_audit_runs SET phase = :p, update_time = now() WHERE id = :rid"),
|
|
{"p": phase, "rid": run_id},
|
|
)
|
|
await session.commit()
|
|
except Exception:
|
|
pass
|
|
|
|
|
|
def _get_suffix(filename: str) -> str:
|
|
"""Extract file suffix from filename."""
|
|
_, ext = os.path.splitext(filename)
|
|
return ext if ext else ".pdf"
|
|
|
|
|
|
def dispatch_leaudit_task(
|
|
document_id: int,
|
|
file_content: bytes,
|
|
filename: str,
|
|
upload_info: Optional[Dict[str, Any]] = None,
|
|
source_port: Optional[int] = None,
|
|
rules_path: Optional[str] = None,
|
|
):
|
|
"""Dispatch a leaudit processing task.
|
|
|
|
P2: Celery 集成后改用 leaudit_process_document.apply_async(...)
|
|
当前阶段直接同步调用。
|
|
"""
|
|
return leaudit_process_document(
|
|
document_id=document_id,
|
|
file_content=file_content,
|
|
filename=filename,
|
|
upload_info=upload_info,
|
|
source_port=source_port or int(os.getenv("APP_PORT", "8000")),
|
|
rules_path=rules_path,
|
|
)
|