feat: complete M1-M3 infrastructure — OSS client, native execution chain, rule lifecycle API, system docs
- M1: unified OSS client (upload/download/presign) + path utils + config - M2: rule service with validate/create/publish/rollback + binding CRUD endpoints - M3: native AuditCtx runner, file/rule resolvers, storage adapter with full persistence - docs: SYSTEM_OVERVIEW.md as comprehensive architecture reference - fix: double finalize — terminal state now written once by finalize_run
This commit is contained in:
@@ -1,44 +1,30 @@
|
||||
"""leaudit bridge — use leaudit's full pipeline with docauditai's database storage.
|
||||
"""LeAudit Bridge 模块。
|
||||
|
||||
Directly calls leaudit's OCR → extraction → evaluation pipeline
|
||||
and persists results into docauditai's PostgreSQL via PostgREST.
|
||||
|
||||
Configuration switch (in env.{port}):
|
||||
PIPELINE_MODE=leaudit → use leaudit pipeline
|
||||
对平台暴露统一桥接入口,内部逐步从旧的手写 pipeline
|
||||
迁移到原生 ``AuditCtx`` + ``AuditService`` 路线。
|
||||
"""
|
||||
|
||||
from leaudit_bridge.client_factory import (
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.client_factory import (
|
||||
create_ocr_client,
|
||||
create_llm_client,
|
||||
create_vlm_client,
|
||||
)
|
||||
from leaudit_bridge.ocr_bridge import BridgeOCRClient
|
||||
from leaudit_bridge.pipeline import LauditPipeline, PipelineResult
|
||||
from leaudit_bridge.rules_loader import RulesLoader
|
||||
from leaudit_bridge.storage_adapter import StorageAdapter
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.ocr_bridge import BridgeOCRClient
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline import LauditPipeline, PipelineResult
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.rules_loader import RulesLoader
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
|
||||
|
||||
|
||||
def is_leaudit_mode() -> bool:
|
||||
"""Check if the system is configured to use the leaudit pipeline."""
|
||||
from core.config import PIPELINE_MODE
|
||||
return PIPELINE_MODE == "leaudit"
|
||||
"""新平台始终使用 leaudit pipeline。"""
|
||||
return True
|
||||
|
||||
|
||||
def create_pipeline(rules_path: str | None = None) -> LauditPipeline:
|
||||
"""Create a fully configured LauditPipeline from current config.
|
||||
"""创建旧版兼容 LauditPipeline。
|
||||
|
||||
Wraps the raw OCR client with DocNormalizationAdapter so that a single
|
||||
``.ocr()`` call produces a fully enriched OcrResult with:
|
||||
- Document classification (type_id + rules_file_path)
|
||||
- Dossier segmentation (sub-document page mapping)
|
||||
- Seal/signature enrichment (text, seal_id, party_id)
|
||||
- Normalized markdown (seal blocks + page separators)
|
||||
|
||||
Args:
|
||||
rules_path: If provided, forces the adapter to use this rules file
|
||||
for classification and segmentation. When None, the adapter
|
||||
uses the RulesFileRegistry to classify from document content,
|
||||
enabling auto-detection of sub-types (e.g. 行政许可 variants).
|
||||
当前仍保留该入口兼容旧调用方,后续正式执行链应逐步切到
|
||||
``NativeRunner``。
|
||||
"""
|
||||
from pathlib import Path
|
||||
from leaudit.doc_normalization.adapter import DocNormalizationAdapter
|
||||
@@ -51,7 +37,7 @@ def create_pipeline(rules_path: str | None = None) -> LauditPipeline:
|
||||
# Build registry from rules/ directory for content-based classification
|
||||
registry = None
|
||||
if rules_path is None:
|
||||
rules_dir = Path(__file__).resolve().parents[1] / "rules"
|
||||
rules_dir = Path(__file__).resolve().parents[3] / "rules"
|
||||
if rules_dir.is_dir():
|
||||
registry = RulesFileRegistry.from_directory(rules_dir)
|
||||
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Build native leaudit ``AuditCtx`` instances from platform-side inputs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from leaudit.config.audit_config import AuditConfig
|
||||
from leaudit.services.audit_ctx import AuditCtx
|
||||
from leaudit.services.audit_services import AuditServices
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NativeAuditMetadata:
|
||||
"""Platform-side metadata kept outside the native ``AuditCtx`` model."""
|
||||
|
||||
run_id: int
|
||||
document_id: int
|
||||
document_file_id: int | None = None
|
||||
rule_set_id: int | None = None
|
||||
rule_version_id: int | None = None
|
||||
trigger_user_id: int | None = None
|
||||
extras: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NativeAuditBuildInput:
|
||||
"""Everything the bridge knows before constructing a native ``AuditCtx``."""
|
||||
|
||||
metadata: NativeAuditMetadata
|
||||
file_path: str
|
||||
services: AuditServices
|
||||
rules_file: Any | None = None
|
||||
page_range: tuple[int, ...] | None = None
|
||||
rule_source_path: str | None = None
|
||||
force_rules_path: str | None = None
|
||||
config_overrides: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class AuditCtxBuilder:
|
||||
"""Translate platform-side run inputs into leaudit's native ``AuditCtx``."""
|
||||
|
||||
def build(self, payload: NativeAuditBuildInput) -> AuditCtx:
|
||||
"""Create a native ``AuditCtx`` ready for ``AuditService.audit``."""
|
||||
config = self.build_config(
|
||||
force_rules_path=payload.force_rules_path or payload.rule_source_path,
|
||||
overrides=payload.config_overrides,
|
||||
)
|
||||
return AuditCtx(
|
||||
document_id=str(payload.metadata.document_id),
|
||||
rules_file=payload.rules_file,
|
||||
services=payload.services,
|
||||
file_path=payload.file_path,
|
||||
page_range=payload.page_range,
|
||||
config=config,
|
||||
)
|
||||
|
||||
def build_config(
|
||||
self,
|
||||
*,
|
||||
force_rules_path: str | None = None,
|
||||
overrides: dict[str, Any] | None = None,
|
||||
) -> AuditConfig:
|
||||
"""Build native ``AuditConfig`` from platform-side overrides."""
|
||||
raw = dict(overrides or {})
|
||||
if force_rules_path and "force_rules_path" not in raw:
|
||||
raw["force_rules_path"] = force_rules_path
|
||||
return AuditConfig(**raw)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AuditCtxBuilder",
|
||||
"NativeAuditBuildInput",
|
||||
"NativeAuditMetadata",
|
||||
]
|
||||
@@ -0,0 +1,132 @@
|
||||
"""Factory helpers for leaudit's native service-layer orchestration.
|
||||
|
||||
This module is the bridge-side assembly point for native leaudit services:
|
||||
|
||||
- ``AuditServices``
|
||||
- ``DocNormalizationService``
|
||||
- ``ExtractionService``
|
||||
- ``EvaluationService``
|
||||
- ``RescueService``
|
||||
- ``AuditService``
|
||||
|
||||
The platform should not construct these objects in controllers/services
|
||||
directly. Keep all leaudit-native wiring inside ``leaudit_bridge/``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from leaudit.services.audit_service import AuditService
|
||||
from leaudit.services.audit_services import AuditServices
|
||||
from leaudit.services.doc_normalization_service import DocNormalizationService
|
||||
from leaudit.services.evaluation_service import EvaluationService
|
||||
from leaudit.services.extraction_service import ExtractionService
|
||||
from leaudit.services.rescue_service import RescueService
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.client_factory import (
|
||||
create_llm_client,
|
||||
create_ocr_client,
|
||||
create_vlm_client,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.ocr_bridge import BridgeOCRClient
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NativeServiceBundle:
|
||||
"""Fully assembled native leaudit service bundle."""
|
||||
|
||||
audit_service: AuditService
|
||||
audit_services: AuditServices
|
||||
normalization_service: DocNormalizationService
|
||||
extraction_service: ExtractionService
|
||||
evaluation_service: EvaluationService
|
||||
rescue_service: RescueService | None
|
||||
|
||||
|
||||
class AuditServiceFactory:
|
||||
"""Build native leaudit services for one platform-side run."""
|
||||
|
||||
def create_bundle(self, rules_path: str | None = None) -> NativeServiceBundle:
|
||||
"""Create a fully wired native leaudit service bundle.
|
||||
|
||||
``rules_path`` is only used to force the normalization adapter's
|
||||
classification path when the caller wants a fixed rules file.
|
||||
"""
|
||||
normalization_service, audit_services = self._create_normalization_services(
|
||||
rules_path=rules_path
|
||||
)
|
||||
extraction_service = ExtractionService(
|
||||
session=None,
|
||||
llm_client=audit_services.llm_client,
|
||||
)
|
||||
evaluation_service = EvaluationService(session=None)
|
||||
rescue_service = RescueService(
|
||||
session=None,
|
||||
llm_client=audit_services.llm_client,
|
||||
vlm_client=audit_services.vlm_client,
|
||||
extraction_service=extraction_service,
|
||||
)
|
||||
|
||||
audit_service = AuditService(
|
||||
document_service=None,
|
||||
normalization_service=normalization_service,
|
||||
extraction_service=extraction_service,
|
||||
evaluation_service=evaluation_service,
|
||||
rescue_service=rescue_service,
|
||||
services=AuditServices(
|
||||
llm_client=audit_services.llm_client,
|
||||
vlm_client=audit_services.vlm_client,
|
||||
ocr_client=audit_services.ocr_client,
|
||||
normalization=normalization_service,
|
||||
extraction=extraction_service,
|
||||
evaluation=evaluation_service,
|
||||
),
|
||||
)
|
||||
|
||||
return NativeServiceBundle(
|
||||
audit_service=audit_service,
|
||||
audit_services=audit_service.services,
|
||||
normalization_service=normalization_service,
|
||||
extraction_service=extraction_service,
|
||||
evaluation_service=evaluation_service,
|
||||
rescue_service=rescue_service,
|
||||
)
|
||||
|
||||
def _create_normalization_services(
|
||||
self, rules_path: str | None = None
|
||||
) -> tuple[DocNormalizationService, AuditServices]:
|
||||
"""Create normalization service plus low-level shared clients."""
|
||||
from leaudit.doc_normalization.adapter import DocNormalizationAdapter
|
||||
from leaudit.doc_normalization.doc_classifier import RulesFileRegistry
|
||||
|
||||
raw_ocr = create_ocr_client()
|
||||
llm_client = create_llm_client()
|
||||
vlm_client = create_vlm_client()
|
||||
|
||||
registry = None
|
||||
if rules_path is None:
|
||||
rules_dir = Path(__file__).resolve().parents[3] / "rules"
|
||||
if rules_dir.is_dir():
|
||||
registry = RulesFileRegistry.from_directory(rules_dir)
|
||||
|
||||
adapter = DocNormalizationAdapter(
|
||||
ocr_client=raw_ocr,
|
||||
registry=registry,
|
||||
llm_client=llm_client,
|
||||
vlm_client=vlm_client,
|
||||
force_rules_path=rules_path,
|
||||
)
|
||||
ocr_client = BridgeOCRClient(adapter, vlm_client=vlm_client)
|
||||
normalization_service = DocNormalizationService(ocr_client)
|
||||
audit_services = AuditServices(
|
||||
llm_client=llm_client,
|
||||
vlm_client=vlm_client,
|
||||
ocr_client=raw_ocr,
|
||||
normalization=normalization_service,
|
||||
)
|
||||
return normalization_service, audit_services
|
||||
|
||||
|
||||
__all__ = ["AuditServiceFactory", "NativeServiceBundle"]
|
||||
@@ -5,13 +5,15 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.config import (
|
||||
OCR_CONFIG,
|
||||
DEFAULT_BASE_URL,
|
||||
DEFAULT_LLM_MODEL,
|
||||
DEFAULT_API_KEY,
|
||||
DEFAULT_VLM_BASE_URL,
|
||||
DEFAULT_VLM_MODEL,
|
||||
from fastapi_admin.config import (
|
||||
OCR_BASE_URL,
|
||||
OCR_TIMEOUT,
|
||||
LLM_BASE_URL,
|
||||
LLM_MODEL,
|
||||
LLM_API_KEY,
|
||||
VLM_BASE_URL,
|
||||
VLM_MODEL,
|
||||
VLM_API_KEY,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -29,8 +31,8 @@ def create_ocr_client() -> BaseOCRClient:
|
||||
|
||||
base_url = os.getenv("LEAUDIT_OCR_URL", "").rstrip("/")
|
||||
if not base_url:
|
||||
base_url = OCR_CONFIG["API_URL"].rsplit("/api/v1/ocr", 1)[0]
|
||||
timeout = float(OCR_CONFIG["TIMEOUT"])
|
||||
base_url = OCR_BASE_URL.rstrip("/")
|
||||
timeout = float(OCR_TIMEOUT)
|
||||
|
||||
client = ChandraOCRClient(
|
||||
base_url=base_url,
|
||||
@@ -45,9 +47,9 @@ def create_llm_client() -> BaseLLMClient:
|
||||
"""Create a leaudit OpenAICompatibleClient from docauditai's LLM config."""
|
||||
from leaudit.llm.openai_client import OpenAICompatibleClient
|
||||
|
||||
base_url = DEFAULT_BASE_URL
|
||||
model = DEFAULT_LLM_MODEL
|
||||
api_key = DEFAULT_API_KEY or "no-key"
|
||||
base_url = LLM_BASE_URL
|
||||
model = LLM_MODEL
|
||||
api_key = LLM_API_KEY or "no-key"
|
||||
|
||||
client = OpenAICompatibleClient(
|
||||
api_key=api_key,
|
||||
@@ -63,9 +65,9 @@ def create_vlm_client() -> BaseVLMClient | None:
|
||||
"""Create a leaudit QwenVLMClient from docauditai's VLM config."""
|
||||
from leaudit.llm.qwen_vlm_client import QwenVLMClient
|
||||
|
||||
base_url = DEFAULT_VLM_BASE_URL
|
||||
model = DEFAULT_VLM_MODEL
|
||||
api_key = DEFAULT_API_KEY or "no-key"
|
||||
base_url = VLM_BASE_URL
|
||||
model = VLM_MODEL
|
||||
api_key = VLM_API_KEY or LLM_API_KEY or "no-key"
|
||||
|
||||
if not base_url or not model:
|
||||
log.info("leaudit VLM client skipped: no VLM config")
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
"""文档文件来源解析器。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi_common.fastapi_common_logger import logger
|
||||
from fastapi_common.fastapi_common_storage.oss_client import OssClient
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.models.leauditDocumentFile import LeauditDocumentFile
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FileSourcePayload:
|
||||
"""可供执行链消费的文件载荷。"""
|
||||
|
||||
fileName: str
|
||||
fileContent: bytes
|
||||
sourceType: str
|
||||
sourcePath: str | None = None
|
||||
|
||||
|
||||
class FileSourceResolver:
|
||||
"""解析文档文件来源。"""
|
||||
|
||||
def __init__(self, Oss: OssClient | None = None) -> None:
|
||||
self.Oss = Oss or OssClient()
|
||||
|
||||
async def ResolvePayload(self, DocumentFile: LeauditDocumentFile) -> FileSourcePayload:
|
||||
"""解析文档文件,返回任务入口可直接消费的字节载荷。"""
|
||||
if DocumentFile.localPath:
|
||||
LocalPath = Path(DocumentFile.localPath)
|
||||
if LocalPath.is_file():
|
||||
return FileSourcePayload(
|
||||
fileName=DocumentFile.fileName,
|
||||
fileContent=LocalPath.read_bytes(),
|
||||
sourceType="local",
|
||||
sourcePath=str(LocalPath),
|
||||
)
|
||||
|
||||
if DocumentFile.ossUrl:
|
||||
return await self._DownloadFromUrl(
|
||||
FileName=DocumentFile.fileName,
|
||||
Url=DocumentFile.ossUrl,
|
||||
)
|
||||
|
||||
raise ValueError("当前文档文件既无可用 localPath,也无可用 ossUrl")
|
||||
|
||||
async def _DownloadFromUrl(self, FileName: str, Url: str) -> FileSourcePayload:
|
||||
"""从 OSS 或 URL 下载文件内容。"""
|
||||
try:
|
||||
Content = self.Oss.DownloadBytes(Url)
|
||||
except Exception as Error:
|
||||
logger.error(f"下载 OSS 文件失败: url={Url}, error={Error}")
|
||||
raise
|
||||
|
||||
return FileSourcePayload(
|
||||
fileName=FileName,
|
||||
fileContent=Content,
|
||||
sourceType="oss",
|
||||
sourcePath=Url,
|
||||
)
|
||||
@@ -0,0 +1,185 @@
|
||||
"""Native AuditCtx runner for the platform bridge.
|
||||
|
||||
This module is the target execution path after deprecating the old
|
||||
platform-side hand-written pipeline orchestration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from leaudit.services.audit_ctx import AuditCtx
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.auditCtxBuilder import (
|
||||
AuditCtxBuilder,
|
||||
NativeAuditBuildInput,
|
||||
NativeAuditMetadata,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.auditServiceFactory import (
|
||||
AuditServiceFactory,
|
||||
NativeServiceBundle,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NativeRunRequest:
|
||||
"""Platform-side request payload for one native leaudit run."""
|
||||
|
||||
metadata: NativeAuditMetadata
|
||||
local_file_path: str
|
||||
rules_file: Any | None = None
|
||||
rule_source_path: str | None = None
|
||||
rules_path_override: str | None = None
|
||||
page_range: tuple[int, ...] | None = None
|
||||
config_overrides: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NativeRunResult:
|
||||
"""Result of one native ``AuditService.audit(ctx)`` execution."""
|
||||
|
||||
ctx: AuditCtx
|
||||
service_bundle: NativeServiceBundle
|
||||
metadata: NativeAuditMetadata
|
||||
|
||||
|
||||
class NativeRunner:
|
||||
"""Bridge-side runner that delegates orchestration to native leaudit."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
service_factory: AuditServiceFactory | None = None,
|
||||
ctx_builder: AuditCtxBuilder | None = None,
|
||||
storage_adapter: StorageAdapter | None = None,
|
||||
) -> None:
|
||||
self.service_factory = service_factory or AuditServiceFactory()
|
||||
self.ctx_builder = ctx_builder or AuditCtxBuilder()
|
||||
self.storage = storage_adapter or StorageAdapter()
|
||||
|
||||
async def run(self, request: NativeRunRequest) -> NativeRunResult:
|
||||
"""Execute one native leaudit run and return the populated ctx.
|
||||
|
||||
Persistence is intentionally not mixed into the orchestration step.
|
||||
The caller can choose when to persist the final ctx to platform tables.
|
||||
"""
|
||||
bundle = self.service_factory.create_bundle(
|
||||
rules_path=request.rules_path_override or request.rule_source_path
|
||||
)
|
||||
ctx = self.ctx_builder.build(
|
||||
NativeAuditBuildInput(
|
||||
metadata=request.metadata,
|
||||
file_path=request.local_file_path,
|
||||
services=bundle.audit_services,
|
||||
rules_file=request.rules_file,
|
||||
page_range=request.page_range,
|
||||
rule_source_path=request.rule_source_path,
|
||||
force_rules_path=request.rules_path_override,
|
||||
config_overrides=request.config_overrides,
|
||||
)
|
||||
)
|
||||
ctx = await bundle.audit_service.audit(ctx)
|
||||
return NativeRunResult(
|
||||
ctx=ctx,
|
||||
service_bundle=bundle,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
async def persist_result(self, result: NativeRunResult) -> None:
|
||||
"""Persist a native run into platform-owned ``leaudit_*`` tables.
|
||||
"""
|
||||
document_id = result.metadata.document_id
|
||||
run_id = result.metadata.run_id
|
||||
ctx = result.ctx
|
||||
extraction_errors = list(ctx.extraction_errors)
|
||||
if not extraction_errors and ctx.extraction is not None:
|
||||
extraction_errors = list(ctx.extraction.all_errors)
|
||||
|
||||
if ctx.normalized_doc is not None:
|
||||
await self.storage.save_ocr_result(
|
||||
document_id,
|
||||
ctx.normalized_doc,
|
||||
run_id=run_id,
|
||||
)
|
||||
if ctx.extraction is not None:
|
||||
await self.storage.save_extraction_result(
|
||||
document_id,
|
||||
ctx.extraction,
|
||||
run_id=run_id,
|
||||
)
|
||||
if ctx.evaluation is not None and ctx.rules_file is not None and ctx.extraction is not None:
|
||||
await self.storage.save_evaluation_results(
|
||||
document_id,
|
||||
ctx.rules_file,
|
||||
ctx.evaluation,
|
||||
ctx.extraction,
|
||||
run_id=run_id,
|
||||
rule_version_id=result.metadata.rule_version_id,
|
||||
)
|
||||
if extraction_errors:
|
||||
await self.storage.save_run_errors(
|
||||
document_id,
|
||||
run_id=run_id,
|
||||
stage=ctx.phase or "extract",
|
||||
messages=extraction_errors,
|
||||
level="warning",
|
||||
error_code="EXTRACTION_WARNING",
|
||||
)
|
||||
if ctx.fallback_tasks:
|
||||
await self.storage.save_rescue_outcomes(
|
||||
document_id,
|
||||
run_id=run_id,
|
||||
tasks=ctx.fallback_tasks,
|
||||
)
|
||||
|
||||
await self.storage.save_run_metrics(
|
||||
document_id,
|
||||
run_id=run_id,
|
||||
timing=dict(ctx.timing),
|
||||
page_count=len(ctx.normalized_doc.pages) if ctx.normalized_doc is not None else None,
|
||||
sub_document_count=len(ctx.extraction.sub_documents) if ctx.extraction is not None and getattr(ctx.extraction, "sub_documents", None) else 0,
|
||||
field_count=len(ctx.extraction.fields) if ctx.extraction is not None else 0,
|
||||
rule_count=len(ctx.evaluation.rules) if ctx.evaluation is not None else (len(ctx.rules_file.flat_rules) if ctx.rules_file is not None else 0),
|
||||
rescue_rule_count=len(ctx.fallback_tasks),
|
||||
artifact_count=self._estimate_artifact_count(ctx),
|
||||
)
|
||||
|
||||
result_status = "review" if any(task.requires_human_review for task in ctx.fallback_tasks) else self._resolve_result_status(ctx)
|
||||
await self.storage.finalize_run(
|
||||
document_id,
|
||||
run_id=run_id,
|
||||
result_status=result_status,
|
||||
rescue_applied=bool(ctx.fallback_tasks),
|
||||
phase=ctx.phase,
|
||||
finished=True,
|
||||
)
|
||||
|
||||
def _estimate_artifact_count(self, ctx: AuditCtx) -> int:
|
||||
"""粗略估算当前运行已经产出的平台产物数。"""
|
||||
count = 0
|
||||
if ctx.normalized_doc is not None:
|
||||
count += 1
|
||||
if ctx.extraction is not None:
|
||||
count += 1
|
||||
if ctx.evaluation is not None:
|
||||
count += 1
|
||||
if ctx.fallback_tasks:
|
||||
count += len(ctx.fallback_tasks)
|
||||
return count
|
||||
|
||||
def _resolve_result_status(self, ctx: AuditCtx) -> str:
|
||||
"""按原生 AuditCtx 结果推导运行状态。"""
|
||||
if ctx.evaluation is None:
|
||||
return "error"
|
||||
if ctx.evaluation.errors:
|
||||
return "error"
|
||||
if ctx.evaluation.failed_count == 0 and ctx.evaluation.skipped_count == 0:
|
||||
return "pass"
|
||||
if ctx.evaluation.failed_count > 0:
|
||||
return "fail"
|
||||
return "partial"
|
||||
|
||||
|
||||
__all__ = ["NativeRunRequest", "NativeRunResult", "NativeRunner"]
|
||||
@@ -23,7 +23,7 @@ from leaudit.llm.base import BaseLLMClient
|
||||
from leaudit.ocr.base import BaseOCRClient
|
||||
from leaudit.ocr.models import OcrResult
|
||||
|
||||
from leaudit_bridge.storage_adapter import StorageAdapter
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -229,7 +229,7 @@ class LauditPipeline:
|
||||
self, document_id: int, ocr_result: OcrResult,
|
||||
) -> None:
|
||||
"""Extract case number from OCR and write to database."""
|
||||
from leaudit_bridge.case_number_extractor import (
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.case_number_extractor import (
|
||||
extract_case_number_with_llm,
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
"""规则 YAML 校验器。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
|
||||
from leaudit.dsl.loader import parse_rules_yaml_text
|
||||
from leaudit.dsl.validator import DSLValidationError, validate as validate_rules
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RuleValidationPayload:
|
||||
"""规则校验结果。"""
|
||||
|
||||
valid: bool
|
||||
ruleType: str | None = None
|
||||
ruleName: str | None = None
|
||||
versionNo: str | None = None
|
||||
ruleCount: int = 0
|
||||
extractCount: int = 0
|
||||
errors: list[str] | None = None
|
||||
|
||||
|
||||
class RuleValidator:
|
||||
"""负责规则 YAML 的语法与 DSL 语义校验。"""
|
||||
|
||||
_CHECK_MODULES = (
|
||||
"leaudit.engine.checks.required",
|
||||
"leaudit.engine.checks.compare",
|
||||
"leaudit.engine.checks.format_check",
|
||||
"leaudit.engine.checks.text",
|
||||
"leaudit.engine.checks.multi_entity",
|
||||
"leaudit.engine.checks.visual",
|
||||
"leaudit.engine.checks.external",
|
||||
"leaudit.engine.checks.assert_check",
|
||||
"leaudit.engine.checks.code_check",
|
||||
"leaudit.engine.checks.ai_check",
|
||||
)
|
||||
|
||||
def ValidateYaml(self, YamlText: str) -> RuleValidationPayload:
|
||||
"""校验 YAML 并返回摘要结果。"""
|
||||
try:
|
||||
RulesFile = parse_rules_yaml_text(YamlText)
|
||||
self._EnsureChecksImported()
|
||||
validate_rules(RulesFile, registered_primitives=None)
|
||||
except yaml.YAMLError as Error:
|
||||
return RuleValidationPayload(valid=False, errors=[f"YAML 语法错误: {Error}"])
|
||||
except ValidationError as Error:
|
||||
return RuleValidationPayload(valid=False, errors=[f"Schema 校验失败: {Error}"])
|
||||
except DSLValidationError as Error:
|
||||
return RuleValidationPayload(valid=False, errors=[f"DSL 校验失败: {Error}"])
|
||||
except Exception as Error:
|
||||
return RuleValidationPayload(valid=False, errors=[f"规则校验失败: {Error}"])
|
||||
|
||||
return RuleValidationPayload(
|
||||
valid=True,
|
||||
ruleType=RulesFile.metadata.type_id,
|
||||
ruleName=RulesFile.metadata.name,
|
||||
versionNo=RulesFile.metadata.version,
|
||||
ruleCount=len(RulesFile.flat_rules),
|
||||
extractCount=len(RulesFile.flat_extract),
|
||||
errors=[],
|
||||
)
|
||||
|
||||
def ParseValidated(self, YamlText: str):
|
||||
"""解析并返回已通过完整校验的 RulesFile。"""
|
||||
Validation = self.ValidateYaml(YamlText)
|
||||
if not Validation.valid:
|
||||
raise ValueError("; ".join(Validation.errors or ["规则校验失败"]))
|
||||
return parse_rules_yaml_text(YamlText)
|
||||
|
||||
def _EnsureChecksImported(self) -> None:
|
||||
"""确保所有检查器模块已注册。"""
|
||||
for ModuleName in self._CHECK_MODULES:
|
||||
try:
|
||||
if ModuleName in sys.modules:
|
||||
importlib.reload(sys.modules[ModuleName])
|
||||
else:
|
||||
importlib.import_module(ModuleName)
|
||||
except Exception:
|
||||
continue
|
||||
@@ -0,0 +1,125 @@
|
||||
"""规则版本来源解析器。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi_common.fastapi_common_logger import logger
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from fastapi_common.fastapi_common_storage.oss_client import OssClient
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RuleVersionPayload:
|
||||
"""规则文件解析结果。"""
|
||||
|
||||
localPath: str
|
||||
sourceType: str
|
||||
sourcePath: str
|
||||
ruleVersionId: int | None = None
|
||||
ruleTypeId: str | None = None
|
||||
fileSha256: str | None = None
|
||||
|
||||
|
||||
class RuleVersionResolver:
|
||||
"""按运行记录解析规则 YAML 文件来源。"""
|
||||
|
||||
def __init__(self, Oss: OssClient | None = None) -> None:
|
||||
self.Oss = Oss or OssClient()
|
||||
|
||||
async def ResolveForRun(self, RunId: int) -> RuleVersionPayload | None:
|
||||
"""根据运行记录解析规则文件来源。"""
|
||||
RunInfo = await self._LoadRunInfo(RunId)
|
||||
if not RunInfo:
|
||||
return None
|
||||
|
||||
LocalCachePath = RunInfo["rule_local_cache_path"]
|
||||
if LocalCachePath:
|
||||
CachePath = Path(LocalCachePath)
|
||||
if CachePath.is_file():
|
||||
return RuleVersionPayload(
|
||||
localPath=str(CachePath),
|
||||
sourceType="local_cache",
|
||||
sourcePath=str(CachePath),
|
||||
ruleVersionId=RunInfo["rule_version_id"],
|
||||
ruleTypeId=RunInfo["rule_type_id"],
|
||||
fileSha256=RunInfo["rule_source_sha256"],
|
||||
)
|
||||
|
||||
SourceUrl = RunInfo["rule_source_oss_url"]
|
||||
if not SourceUrl:
|
||||
return None
|
||||
|
||||
return await self._DownloadFromUrl(
|
||||
Url=SourceUrl,
|
||||
RuleVersionId=RunInfo["rule_version_id"],
|
||||
RuleTypeId=RunInfo["rule_type_id"],
|
||||
ExpectedSha256=RunInfo["rule_source_sha256"],
|
||||
)
|
||||
|
||||
async def _LoadRunInfo(self, RunId: int) -> dict[str, object] | None:
|
||||
"""读取运行记录中的规则来源信息。"""
|
||||
async with GetAsyncSession() as Session:
|
||||
Result = await Session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT
|
||||
rule_version_id,
|
||||
rule_type_id,
|
||||
rule_source_oss_url,
|
||||
rule_source_sha256,
|
||||
rule_local_cache_path
|
||||
FROM leaudit_audit_runs
|
||||
WHERE id = :run_id
|
||||
LIMIT 1
|
||||
"""
|
||||
),
|
||||
{"run_id": RunId},
|
||||
)
|
||||
Row = Result.mappings().first()
|
||||
return dict(Row) if Row else None
|
||||
|
||||
async def _DownloadFromUrl(
|
||||
self,
|
||||
*,
|
||||
Url: str,
|
||||
RuleVersionId: int | None,
|
||||
RuleTypeId: str | None,
|
||||
ExpectedSha256: str | None,
|
||||
) -> RuleVersionPayload:
|
||||
"""从 OSS 下载规则 YAML 到本地临时文件。"""
|
||||
try:
|
||||
Content = self.Oss.DownloadBytes(Url)
|
||||
except Exception as Error:
|
||||
logger.error(f"下载规则 YAML 失败: url={Url}, error={Error}")
|
||||
raise
|
||||
|
||||
ActualSha256 = hashlib.sha256(Content).hexdigest()
|
||||
if ExpectedSha256 and ActualSha256.lower() != ExpectedSha256.lower():
|
||||
raise ValueError(
|
||||
"规则 YAML SHA256 校验失败: "
|
||||
f"expected={ExpectedSha256}, actual={ActualSha256}"
|
||||
)
|
||||
|
||||
FilePrefix = "leaudit-rule-"
|
||||
if RuleTypeId:
|
||||
SafeTypeId = RuleTypeId.replace("/", "_").replace(".", "_")
|
||||
FilePrefix = f"{FilePrefix}{SafeTypeId}-"
|
||||
|
||||
LocalPath = self.Oss.WriteTempBytes(
|
||||
Content=Content,
|
||||
Suffix=".yaml",
|
||||
Prefix=FilePrefix,
|
||||
)
|
||||
|
||||
return RuleVersionPayload(
|
||||
localPath=LocalPath,
|
||||
sourceType="oss",
|
||||
sourcePath=Url,
|
||||
ruleVersionId=RuleVersionId,
|
||||
ruleTypeId=RuleTypeId,
|
||||
fileSha256=ActualSha256,
|
||||
)
|
||||
@@ -1,28 +1,38 @@
|
||||
"""Celery task for leaudit pipeline processing.
|
||||
|
||||
Activated when PIPELINE_MODE=leaudit in env.{port} config.
|
||||
Replaces the legacy OCR → extraction → evaluation pipeline with
|
||||
leaudit's YAML-rules-driven approach.
|
||||
"""
|
||||
"""LeAudit 任务入口。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from core.celery_app_limited import celery_app
|
||||
from core.postgrest.client import get_postgrest_client
|
||||
from core.logger import log
|
||||
from fastapi_common.fastapi_common_logger import logger
|
||||
|
||||
from leaudit_bridge import create_pipeline, RulesLoader
|
||||
from fastapi_admin.config import LEAUDIT_RULES_DIR
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.nativeRunner import (
|
||||
NativeRunRequest,
|
||||
NativeRunner,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.auditCtxBuilder import (
|
||||
NativeAuditMetadata,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.ruleVersionResolver import (
|
||||
RuleVersionResolver,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.rules_loader import RulesLoader
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
|
||||
|
||||
# Celery 集成待 P2 阶段实现,当前使用同步占位
|
||||
# from core.celery_app_limited import celery_app
|
||||
|
||||
log = logger
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="leaudit.process_document")
|
||||
# P2: Celery 集成后启用 @celery_app.task 装饰器
|
||||
def leaudit_process_document(
|
||||
self,
|
||||
document_id: int,
|
||||
file_content: bytes,
|
||||
filename: str,
|
||||
@@ -30,43 +40,41 @@ def leaudit_process_document(
|
||||
source_port: Optional[int] = None,
|
||||
rules_path: Optional[str] = None,
|
||||
):
|
||||
"""Process a document using leaudit's full pipeline.
|
||||
|
||||
Steps: OCR → Extraction → Evaluation → Store in docauditai DB.
|
||||
"""
|
||||
task_id = self.request.id
|
||||
log.task.info(f"[任务ID: {task_id}] leaudit管线开始处理: {filename}")
|
||||
"""处理单个文档的 LeAudit 任务。"""
|
||||
task_id = os.urandom(8).hex()
|
||||
log.info(f"[任务ID: {task_id}] leaudit管线开始处理: {filename}")
|
||||
|
||||
# 新平台:region 通过参数传递,不再依赖 os.environ 切换
|
||||
if source_port:
|
||||
from core.utils.instance_context import set_instance_environment
|
||||
instance_name = set_instance_environment(source_port)
|
||||
log.task.info(
|
||||
f"[任务ID: {task_id}] 实例环境: {instance_name} (端口: {source_port})"
|
||||
)
|
||||
log.info(f"[任务ID: {task_id}] 来源端口: {source_port}")
|
||||
|
||||
if upload_info is None:
|
||||
upload_info = {}
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
temp_paths: list[str] = []
|
||||
storage = StorageAdapter()
|
||||
|
||||
try:
|
||||
rules_path_resolved = rules_path or _resolve_rules_path(document_id, loop)
|
||||
run_id = _resolve_run_id(document_id, upload_info, loop)
|
||||
rules_resolution = _resolve_rules_runtime(document_id, run_id, rules_path, loop)
|
||||
loop.run_until_complete(_update_run_status_safe(run_id, "running"))
|
||||
rules_path_resolved = rules_resolution["rules_path"]
|
||||
|
||||
# For types with a known mapping (e.g. 行政处罚), pre-load rules_file.
|
||||
# For types that need content classification (e.g. 行政许可 sub-types),
|
||||
# rules_path will be None → adapter classifies after OCR → pipeline
|
||||
# loads rules from ocr_result.rules_file_path.
|
||||
rules_file = None
|
||||
if rules_path_resolved:
|
||||
loader = RulesLoader()
|
||||
rules_file = loader.load(rules_path_resolved)
|
||||
log.task.info(
|
||||
temp_rule_path = rules_resolution.get("temp_rule_path")
|
||||
if isinstance(temp_rule_path, str):
|
||||
temp_paths.append(temp_rule_path)
|
||||
log.info(
|
||||
f"[任务ID: {task_id}] RulesFile pre-loaded: {rules_path_resolved} "
|
||||
f"({len(rules_file.flat_rules)} rules, {len(rules_file.flat_extract)} fields)"
|
||||
)
|
||||
else:
|
||||
log.task.info(
|
||||
log.info(
|
||||
f"[任务ID: {task_id}] No fixed rules_path — "
|
||||
"will classify from document content after OCR"
|
||||
)
|
||||
@@ -75,47 +83,79 @@ def leaudit_process_document(
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp:
|
||||
temp.write(file_content)
|
||||
temp_path = temp.name
|
||||
temp_paths.append(temp_path)
|
||||
|
||||
pipeline = create_pipeline(rules_path=rules_path_resolved)
|
||||
runner = NativeRunner()
|
||||
|
||||
t0 = time.time()
|
||||
result = loop.run_until_complete(
|
||||
pipeline.run(
|
||||
document_id=document_id,
|
||||
file_path=temp_path,
|
||||
rules_file=rules_file,
|
||||
source_port=source_port or int(os.getenv("APP_PORT", "8000")),
|
||||
native_result = loop.run_until_complete(
|
||||
runner.run(
|
||||
NativeRunRequest(
|
||||
metadata=NativeAuditMetadata(
|
||||
run_id=run_id,
|
||||
document_id=document_id,
|
||||
rule_version_id=_optional_int(upload_info, "rule_version_id", "ruleVersionId"),
|
||||
extras={"taskId": task_id},
|
||||
),
|
||||
local_file_path=temp_path,
|
||||
rules_file=rules_file,
|
||||
rule_source_path=rules_path_resolved,
|
||||
)
|
||||
)
|
||||
)
|
||||
loop.run_until_complete(runner.persist_result(native_result))
|
||||
elapsed = round(time.time() - t0, 2)
|
||||
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
log.task.info(
|
||||
f"[任务ID: {task_id}] leaudit管线完成: phase={result.detected_phase}, "
|
||||
f"timing={result.timing}, 总耗时={elapsed:.1f}s"
|
||||
ctx = native_result.ctx
|
||||
loop.run_until_complete(_update_run_phase_safe(run_id, ctx.phase))
|
||||
loop.run_until_complete(_update_run_status_safe(run_id, "completed"))
|
||||
loop.run_until_complete(_update_status_safe(document_id, "completed"))
|
||||
log.info(
|
||||
f"[任务ID: {task_id}] leaudit管线完成: phase={ctx.phase}, "
|
||||
f"timing={dict(ctx.timing)}, 总耗时={elapsed:.1f}s"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"document_id": document_id,
|
||||
"phase": result.detected_phase,
|
||||
"timing": result.timing,
|
||||
"errors": result.errors,
|
||||
"run_id": run_id,
|
||||
"phase": ctx.phase,
|
||||
"timing": dict(ctx.timing),
|
||||
"errors": list(ctx.extraction.all_errors) if ctx.extraction is not None else list(ctx.extraction_errors),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
log.task.error(f"[任务ID: {task_id}] leaudit管线失败: {e}", exc_info=True)
|
||||
log.error(f"[任务ID: {task_id}] leaudit管线失败: {e}", exc_info=True)
|
||||
try:
|
||||
loop.run_until_complete(_update_status_safe(document_id, "Failed"))
|
||||
loop.run_until_complete(_update_status_safe(document_id, "failed"))
|
||||
if 'run_id' in locals():
|
||||
failed_phase = "persist"
|
||||
if "native_result" in locals():
|
||||
failed_phase = native_result.ctx.phase or failed_phase
|
||||
loop.run_until_complete(
|
||||
storage.fail_run(
|
||||
document_id,
|
||||
run_id=run_id,
|
||||
phase=failed_phase,
|
||||
message=str(e),
|
||||
detail_json={
|
||||
"taskId": task_id,
|
||||
"filename": filename,
|
||||
"errorType": type(e).__name__,
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
finally:
|
||||
for temp_path in temp_paths:
|
||||
try:
|
||||
if Path(temp_path).exists():
|
||||
os.remove(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
loop.close()
|
||||
|
||||
|
||||
@@ -129,49 +169,158 @@ _TYPE_ID_RULES_MAP: dict[int, str] = {
|
||||
|
||||
def _resolve_rules_path(document_id: int, loop: asyncio.AbstractEventLoop) -> str | None:
|
||||
"""Resolve rules_path: config override → document metadata → type_id mapping."""
|
||||
from core.config import LEAUDIT_CONFIG
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from sqlalchemy import text as sa_text
|
||||
|
||||
# 1. Config override (when explicitly set)
|
||||
config_path = LEAUDIT_CONFIG.get("RULES_PATH", "")
|
||||
if config_path:
|
||||
return config_path
|
||||
# 1. Config override (when explicitly set in app.toml)
|
||||
if LEAUDIT_RULES_DIR:
|
||||
return LEAUDIT_RULES_DIR
|
||||
|
||||
try:
|
||||
client = get_postgrest_client()
|
||||
doc = loop.run_until_complete(
|
||||
client.select(
|
||||
table="documents",
|
||||
filters={"id": f"eq.{document_id}"},
|
||||
single=True,
|
||||
)
|
||||
)
|
||||
if not doc:
|
||||
return None
|
||||
async def _fetch():
|
||||
async with GetAsyncSession() as session:
|
||||
result = await session.execute(
|
||||
sa_text("SELECT type_id FROM leaudit_documents WHERE id = :did"),
|
||||
{"did": document_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if row and row[0] and row[0] in _TYPE_ID_RULES_MAP:
|
||||
return f"{_TYPE_ID_RULES_MAP[row[0]]}/rules.yaml"
|
||||
return None
|
||||
|
||||
# 2. Document-level override
|
||||
rfp = doc.get("rules_file_path")
|
||||
if rfp:
|
||||
return rfp
|
||||
|
||||
# 3. type_id mapping
|
||||
type_id = doc.get("type_id")
|
||||
if type_id and type_id in _TYPE_ID_RULES_MAP:
|
||||
return f"{_TYPE_ID_RULES_MAP[type_id]}/rules.yaml"
|
||||
return loop.run_until_complete(_fetch())
|
||||
except Exception as e:
|
||||
log.task.warning(f"Failed to resolve rules_path from document: {e}")
|
||||
log.warning(f"Failed to resolve rules_path from document: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _update_status_safe(document_id: int, status: str) -> None:
|
||||
"""Safely update document status, ignoring errors."""
|
||||
def _resolve_rules_runtime(
|
||||
document_id: int,
|
||||
run_id: int,
|
||||
explicit_rules_path: str | None,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
) -> dict[str, str | None]:
|
||||
"""解析本次执行使用的规则来源。"""
|
||||
if explicit_rules_path:
|
||||
return {
|
||||
"rules_path": explicit_rules_path,
|
||||
"temp_rule_path": None,
|
||||
"source_type": "explicit",
|
||||
"source_path": explicit_rules_path,
|
||||
}
|
||||
|
||||
resolver = RuleVersionResolver()
|
||||
try:
|
||||
client = get_postgrest_client()
|
||||
await client.update(
|
||||
table="documents",
|
||||
filters={"id": f"eq.{document_id}"},
|
||||
data={"status": status},
|
||||
)
|
||||
payload = loop.run_until_complete(resolver.ResolveForRun(run_id))
|
||||
if payload:
|
||||
log.info(
|
||||
f"run_id={run_id} 规则来源已解析: sourceType={payload.sourceType}, "
|
||||
f"sourcePath={payload.sourcePath}, localPath={payload.localPath}"
|
||||
)
|
||||
return {
|
||||
"rules_path": payload.localPath,
|
||||
"temp_rule_path": payload.localPath if payload.sourceType == "oss" else None,
|
||||
"source_type": payload.sourceType,
|
||||
"source_path": payload.sourcePath,
|
||||
}
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to resolve rule version from run: run_id={run_id}, error={e}")
|
||||
|
||||
fallback_rules_path = _resolve_rules_path(document_id, loop)
|
||||
return {
|
||||
"rules_path": fallback_rules_path,
|
||||
"temp_rule_path": None,
|
||||
"source_type": "legacy_fallback",
|
||||
"source_path": fallback_rules_path,
|
||||
}
|
||||
|
||||
|
||||
def _resolve_run_id(
|
||||
document_id: int,
|
||||
upload_info: Dict[str, Any] | None,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
) -> int:
|
||||
"""解析本次任务对应的运行 ID。"""
|
||||
if upload_info:
|
||||
for key in ("run_id", "runId"):
|
||||
value = upload_info.get(key)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from sqlalchemy import text as sa_text
|
||||
|
||||
async def _fetch() -> int:
|
||||
async with GetAsyncSession() as session:
|
||||
result = await session.execute(
|
||||
sa_text("SELECT id FROM leaudit_audit_runs WHERE document_id = :did ORDER BY id DESC LIMIT 1"),
|
||||
{"did": document_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if not row:
|
||||
raise ValueError(f"未找到 document_id={document_id} 对应的 run 记录")
|
||||
return int(row[0])
|
||||
|
||||
return loop.run_until_complete(_fetch())
|
||||
|
||||
|
||||
def _optional_int(payload: Dict[str, Any] | None, *keys: str) -> int | None:
|
||||
"""从字典中按顺序取可用整数。"""
|
||||
if not payload:
|
||||
return None
|
||||
|
||||
for key in keys:
|
||||
value = payload.get(key)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
async def _update_status_safe(document_id: int, status: str) -> None:
|
||||
"""Safely update document status via SQLAlchemy, ignoring errors."""
|
||||
try:
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from sqlalchemy import text as sa_text
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
sa_text("UPDATE leaudit_documents SET processing_status = :s, update_time = now() WHERE id = :did"),
|
||||
{"s": status, "did": document_id},
|
||||
)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _update_run_status_safe(run_id: int, status: str) -> None:
|
||||
"""安全更新运行状态。"""
|
||||
try:
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from sqlalchemy import text as sa_text
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
sa_text("UPDATE leaudit_audit_runs SET status = :s, update_time = now() WHERE id = :rid"),
|
||||
{"s": status, "rid": run_id},
|
||||
)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _update_run_phase_safe(run_id: int, phase: str | None) -> None:
|
||||
"""安全更新运行阶段。"""
|
||||
try:
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from sqlalchemy import text as sa_text
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
sa_text("UPDATE leaudit_audit_runs SET phase = :p, update_time = now() WHERE id = :rid"),
|
||||
{"p": phase, "rid": run_id},
|
||||
)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -190,12 +339,16 @@ def dispatch_leaudit_task(
|
||||
source_port: Optional[int] = None,
|
||||
rules_path: Optional[str] = None,
|
||||
):
|
||||
"""Dispatch a leaudit processing task."""
|
||||
return leaudit_process_document.apply_async(
|
||||
args=[document_id, file_content, filename],
|
||||
kwargs={
|
||||
"upload_info": upload_info,
|
||||
"source_port": source_port or int(os.getenv("APP_PORT", "8000")),
|
||||
"rules_path": rules_path,
|
||||
},
|
||||
"""Dispatch a leaudit processing task.
|
||||
|
||||
P2: Celery 集成后改用 leaudit_process_document.apply_async(...)
|
||||
当前阶段直接同步调用。
|
||||
"""
|
||||
return leaudit_process_document(
|
||||
document_id=document_id,
|
||||
file_content=file_content,
|
||||
filename=filename,
|
||||
upload_info=upload_info,
|
||||
source_port=source_port or int(os.getenv("APP_PORT", "8000")),
|
||||
rules_path=rules_path,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user