feat: complete M1-M3 infrastructure — OSS client, native execution chain, rule lifecycle API, system docs
- M1: unified OSS client (upload/download/presign) + path utils + config - M2: rule service with validate/create/publish/rollback + binding CRUD endpoints - M3: native AuditCtx runner, file/rule resolvers, storage adapter with full persistence - docs: SYSTEM_OVERVIEW.md as comprehensive architecture reference - fix: double finalize — terminal state now written once by finalize_run
This commit is contained in:
@@ -1,28 +1,38 @@
|
||||
"""Celery task for leaudit pipeline processing.
|
||||
|
||||
Activated when PIPELINE_MODE=leaudit in env.{port} config.
|
||||
Replaces the legacy OCR → extraction → evaluation pipeline with
|
||||
leaudit's YAML-rules-driven approach.
|
||||
"""
|
||||
"""LeAudit 任务入口。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from core.celery_app_limited import celery_app
|
||||
from core.postgrest.client import get_postgrest_client
|
||||
from core.logger import log
|
||||
from fastapi_common.fastapi_common_logger import logger
|
||||
|
||||
from leaudit_bridge import create_pipeline, RulesLoader
|
||||
from fastapi_admin.config import LEAUDIT_RULES_DIR
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.nativeRunner import (
|
||||
NativeRunRequest,
|
||||
NativeRunner,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.auditCtxBuilder import (
|
||||
NativeAuditMetadata,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.ruleVersionResolver import (
|
||||
RuleVersionResolver,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.rules_loader import RulesLoader
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
|
||||
|
||||
# Celery 集成待 P2 阶段实现,当前使用同步占位
|
||||
# from core.celery_app_limited import celery_app
|
||||
|
||||
log = logger
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="leaudit.process_document")
|
||||
# P2: Celery 集成后启用 @celery_app.task 装饰器
|
||||
def leaudit_process_document(
|
||||
self,
|
||||
document_id: int,
|
||||
file_content: bytes,
|
||||
filename: str,
|
||||
@@ -30,43 +40,41 @@ def leaudit_process_document(
|
||||
source_port: Optional[int] = None,
|
||||
rules_path: Optional[str] = None,
|
||||
):
|
||||
"""Process a document using leaudit's full pipeline.
|
||||
|
||||
Steps: OCR → Extraction → Evaluation → Store in docauditai DB.
|
||||
"""
|
||||
task_id = self.request.id
|
||||
log.task.info(f"[任务ID: {task_id}] leaudit管线开始处理: {filename}")
|
||||
"""处理单个文档的 LeAudit 任务。"""
|
||||
task_id = os.urandom(8).hex()
|
||||
log.info(f"[任务ID: {task_id}] leaudit管线开始处理: {filename}")
|
||||
|
||||
# 新平台:region 通过参数传递,不再依赖 os.environ 切换
|
||||
if source_port:
|
||||
from core.utils.instance_context import set_instance_environment
|
||||
instance_name = set_instance_environment(source_port)
|
||||
log.task.info(
|
||||
f"[任务ID: {task_id}] 实例环境: {instance_name} (端口: {source_port})"
|
||||
)
|
||||
log.info(f"[任务ID: {task_id}] 来源端口: {source_port}")
|
||||
|
||||
if upload_info is None:
|
||||
upload_info = {}
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
temp_paths: list[str] = []
|
||||
storage = StorageAdapter()
|
||||
|
||||
try:
|
||||
rules_path_resolved = rules_path or _resolve_rules_path(document_id, loop)
|
||||
run_id = _resolve_run_id(document_id, upload_info, loop)
|
||||
rules_resolution = _resolve_rules_runtime(document_id, run_id, rules_path, loop)
|
||||
loop.run_until_complete(_update_run_status_safe(run_id, "running"))
|
||||
rules_path_resolved = rules_resolution["rules_path"]
|
||||
|
||||
# For types with a known mapping (e.g. 行政处罚), pre-load rules_file.
|
||||
# For types that need content classification (e.g. 行政许可 sub-types),
|
||||
# rules_path will be None → adapter classifies after OCR → pipeline
|
||||
# loads rules from ocr_result.rules_file_path.
|
||||
rules_file = None
|
||||
if rules_path_resolved:
|
||||
loader = RulesLoader()
|
||||
rules_file = loader.load(rules_path_resolved)
|
||||
log.task.info(
|
||||
temp_rule_path = rules_resolution.get("temp_rule_path")
|
||||
if isinstance(temp_rule_path, str):
|
||||
temp_paths.append(temp_rule_path)
|
||||
log.info(
|
||||
f"[任务ID: {task_id}] RulesFile pre-loaded: {rules_path_resolved} "
|
||||
f"({len(rules_file.flat_rules)} rules, {len(rules_file.flat_extract)} fields)"
|
||||
)
|
||||
else:
|
||||
log.task.info(
|
||||
log.info(
|
||||
f"[任务ID: {task_id}] No fixed rules_path — "
|
||||
"will classify from document content after OCR"
|
||||
)
|
||||
@@ -75,47 +83,79 @@ def leaudit_process_document(
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp:
|
||||
temp.write(file_content)
|
||||
temp_path = temp.name
|
||||
temp_paths.append(temp_path)
|
||||
|
||||
pipeline = create_pipeline(rules_path=rules_path_resolved)
|
||||
runner = NativeRunner()
|
||||
|
||||
t0 = time.time()
|
||||
result = loop.run_until_complete(
|
||||
pipeline.run(
|
||||
document_id=document_id,
|
||||
file_path=temp_path,
|
||||
rules_file=rules_file,
|
||||
source_port=source_port or int(os.getenv("APP_PORT", "8000")),
|
||||
native_result = loop.run_until_complete(
|
||||
runner.run(
|
||||
NativeRunRequest(
|
||||
metadata=NativeAuditMetadata(
|
||||
run_id=run_id,
|
||||
document_id=document_id,
|
||||
rule_version_id=_optional_int(upload_info, "rule_version_id", "ruleVersionId"),
|
||||
extras={"taskId": task_id},
|
||||
),
|
||||
local_file_path=temp_path,
|
||||
rules_file=rules_file,
|
||||
rule_source_path=rules_path_resolved,
|
||||
)
|
||||
)
|
||||
)
|
||||
loop.run_until_complete(runner.persist_result(native_result))
|
||||
elapsed = round(time.time() - t0, 2)
|
||||
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
log.task.info(
|
||||
f"[任务ID: {task_id}] leaudit管线完成: phase={result.detected_phase}, "
|
||||
f"timing={result.timing}, 总耗时={elapsed:.1f}s"
|
||||
ctx = native_result.ctx
|
||||
loop.run_until_complete(_update_run_phase_safe(run_id, ctx.phase))
|
||||
loop.run_until_complete(_update_run_status_safe(run_id, "completed"))
|
||||
loop.run_until_complete(_update_status_safe(document_id, "completed"))
|
||||
log.info(
|
||||
f"[任务ID: {task_id}] leaudit管线完成: phase={ctx.phase}, "
|
||||
f"timing={dict(ctx.timing)}, 总耗时={elapsed:.1f}s"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"document_id": document_id,
|
||||
"phase": result.detected_phase,
|
||||
"timing": result.timing,
|
||||
"errors": result.errors,
|
||||
"run_id": run_id,
|
||||
"phase": ctx.phase,
|
||||
"timing": dict(ctx.timing),
|
||||
"errors": list(ctx.extraction.all_errors) if ctx.extraction is not None else list(ctx.extraction_errors),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
log.task.error(f"[任务ID: {task_id}] leaudit管线失败: {e}", exc_info=True)
|
||||
log.error(f"[任务ID: {task_id}] leaudit管线失败: {e}", exc_info=True)
|
||||
try:
|
||||
loop.run_until_complete(_update_status_safe(document_id, "Failed"))
|
||||
loop.run_until_complete(_update_status_safe(document_id, "failed"))
|
||||
if 'run_id' in locals():
|
||||
failed_phase = "persist"
|
||||
if "native_result" in locals():
|
||||
failed_phase = native_result.ctx.phase or failed_phase
|
||||
loop.run_until_complete(
|
||||
storage.fail_run(
|
||||
document_id,
|
||||
run_id=run_id,
|
||||
phase=failed_phase,
|
||||
message=str(e),
|
||||
detail_json={
|
||||
"taskId": task_id,
|
||||
"filename": filename,
|
||||
"errorType": type(e).__name__,
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
finally:
|
||||
for temp_path in temp_paths:
|
||||
try:
|
||||
if Path(temp_path).exists():
|
||||
os.remove(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
loop.close()
|
||||
|
||||
|
||||
@@ -129,49 +169,158 @@ _TYPE_ID_RULES_MAP: dict[int, str] = {
|
||||
|
||||
def _resolve_rules_path(document_id: int, loop: asyncio.AbstractEventLoop) -> str | None:
|
||||
"""Resolve rules_path: config override → document metadata → type_id mapping."""
|
||||
from core.config import LEAUDIT_CONFIG
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from sqlalchemy import text as sa_text
|
||||
|
||||
# 1. Config override (when explicitly set)
|
||||
config_path = LEAUDIT_CONFIG.get("RULES_PATH", "")
|
||||
if config_path:
|
||||
return config_path
|
||||
# 1. Config override (when explicitly set in app.toml)
|
||||
if LEAUDIT_RULES_DIR:
|
||||
return LEAUDIT_RULES_DIR
|
||||
|
||||
try:
|
||||
client = get_postgrest_client()
|
||||
doc = loop.run_until_complete(
|
||||
client.select(
|
||||
table="documents",
|
||||
filters={"id": f"eq.{document_id}"},
|
||||
single=True,
|
||||
)
|
||||
)
|
||||
if not doc:
|
||||
return None
|
||||
async def _fetch():
|
||||
async with GetAsyncSession() as session:
|
||||
result = await session.execute(
|
||||
sa_text("SELECT type_id FROM leaudit_documents WHERE id = :did"),
|
||||
{"did": document_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if row and row[0] and row[0] in _TYPE_ID_RULES_MAP:
|
||||
return f"{_TYPE_ID_RULES_MAP[row[0]]}/rules.yaml"
|
||||
return None
|
||||
|
||||
# 2. Document-level override
|
||||
rfp = doc.get("rules_file_path")
|
||||
if rfp:
|
||||
return rfp
|
||||
|
||||
# 3. type_id mapping
|
||||
type_id = doc.get("type_id")
|
||||
if type_id and type_id in _TYPE_ID_RULES_MAP:
|
||||
return f"{_TYPE_ID_RULES_MAP[type_id]}/rules.yaml"
|
||||
return loop.run_until_complete(_fetch())
|
||||
except Exception as e:
|
||||
log.task.warning(f"Failed to resolve rules_path from document: {e}")
|
||||
log.warning(f"Failed to resolve rules_path from document: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _update_status_safe(document_id: int, status: str) -> None:
|
||||
"""Safely update document status, ignoring errors."""
|
||||
def _resolve_rules_runtime(
|
||||
document_id: int,
|
||||
run_id: int,
|
||||
explicit_rules_path: str | None,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
) -> dict[str, str | None]:
|
||||
"""解析本次执行使用的规则来源。"""
|
||||
if explicit_rules_path:
|
||||
return {
|
||||
"rules_path": explicit_rules_path,
|
||||
"temp_rule_path": None,
|
||||
"source_type": "explicit",
|
||||
"source_path": explicit_rules_path,
|
||||
}
|
||||
|
||||
resolver = RuleVersionResolver()
|
||||
try:
|
||||
client = get_postgrest_client()
|
||||
await client.update(
|
||||
table="documents",
|
||||
filters={"id": f"eq.{document_id}"},
|
||||
data={"status": status},
|
||||
)
|
||||
payload = loop.run_until_complete(resolver.ResolveForRun(run_id))
|
||||
if payload:
|
||||
log.info(
|
||||
f"run_id={run_id} 规则来源已解析: sourceType={payload.sourceType}, "
|
||||
f"sourcePath={payload.sourcePath}, localPath={payload.localPath}"
|
||||
)
|
||||
return {
|
||||
"rules_path": payload.localPath,
|
||||
"temp_rule_path": payload.localPath if payload.sourceType == "oss" else None,
|
||||
"source_type": payload.sourceType,
|
||||
"source_path": payload.sourcePath,
|
||||
}
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to resolve rule version from run: run_id={run_id}, error={e}")
|
||||
|
||||
fallback_rules_path = _resolve_rules_path(document_id, loop)
|
||||
return {
|
||||
"rules_path": fallback_rules_path,
|
||||
"temp_rule_path": None,
|
||||
"source_type": "legacy_fallback",
|
||||
"source_path": fallback_rules_path,
|
||||
}
|
||||
|
||||
|
||||
def _resolve_run_id(
|
||||
document_id: int,
|
||||
upload_info: Dict[str, Any] | None,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
) -> int:
|
||||
"""解析本次任务对应的运行 ID。"""
|
||||
if upload_info:
|
||||
for key in ("run_id", "runId"):
|
||||
value = upload_info.get(key)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from sqlalchemy import text as sa_text
|
||||
|
||||
async def _fetch() -> int:
|
||||
async with GetAsyncSession() as session:
|
||||
result = await session.execute(
|
||||
sa_text("SELECT id FROM leaudit_audit_runs WHERE document_id = :did ORDER BY id DESC LIMIT 1"),
|
||||
{"did": document_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if not row:
|
||||
raise ValueError(f"未找到 document_id={document_id} 对应的 run 记录")
|
||||
return int(row[0])
|
||||
|
||||
return loop.run_until_complete(_fetch())
|
||||
|
||||
|
||||
def _optional_int(payload: Dict[str, Any] | None, *keys: str) -> int | None:
|
||||
"""从字典中按顺序取可用整数。"""
|
||||
if not payload:
|
||||
return None
|
||||
|
||||
for key in keys:
|
||||
value = payload.get(key)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
async def _update_status_safe(document_id: int, status: str) -> None:
|
||||
"""Safely update document status via SQLAlchemy, ignoring errors."""
|
||||
try:
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from sqlalchemy import text as sa_text
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
sa_text("UPDATE leaudit_documents SET processing_status = :s, update_time = now() WHERE id = :did"),
|
||||
{"s": status, "did": document_id},
|
||||
)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _update_run_status_safe(run_id: int, status: str) -> None:
|
||||
"""安全更新运行状态。"""
|
||||
try:
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from sqlalchemy import text as sa_text
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
sa_text("UPDATE leaudit_audit_runs SET status = :s, update_time = now() WHERE id = :rid"),
|
||||
{"s": status, "rid": run_id},
|
||||
)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _update_run_phase_safe(run_id: int, phase: str | None) -> None:
|
||||
"""安全更新运行阶段。"""
|
||||
try:
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from sqlalchemy import text as sa_text
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
sa_text("UPDATE leaudit_audit_runs SET phase = :p, update_time = now() WHERE id = :rid"),
|
||||
{"p": phase, "rid": run_id},
|
||||
)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -190,12 +339,16 @@ def dispatch_leaudit_task(
|
||||
source_port: Optional[int] = None,
|
||||
rules_path: Optional[str] = None,
|
||||
):
|
||||
"""Dispatch a leaudit processing task."""
|
||||
return leaudit_process_document.apply_async(
|
||||
args=[document_id, file_content, filename],
|
||||
kwargs={
|
||||
"upload_info": upload_info,
|
||||
"source_port": source_port or int(os.getenv("APP_PORT", "8000")),
|
||||
"rules_path": rules_path,
|
||||
},
|
||||
"""Dispatch a leaudit processing task.
|
||||
|
||||
P2: Celery 集成后改用 leaudit_process_document.apply_async(...)
|
||||
当前阶段直接同步调用。
|
||||
"""
|
||||
return leaudit_process_document(
|
||||
document_id=document_id,
|
||||
file_content=file_content,
|
||||
filename=filename,
|
||||
upload_info=upload_info,
|
||||
source_port=source_port or int(os.getenv("APP_PORT", "8000")),
|
||||
rules_path=rules_path,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user