feat: complete M1-M3 infrastructure — OSS client, native execution chain, rule lifecycle API, system docs
- M1: unified OSS client (upload/download/presign) + path utils + config - M2: rule service with validate/create/publish/rollback + binding CRUD endpoints - M3: native AuditCtx runner, file/rule resolvers, storage adapter with full persistence - docs: SYSTEM_OVERVIEW.md as comprehensive architecture reference - fix: double finalize — terminal state now written once by finalize_run
This commit is contained in:
@@ -0,0 +1,10 @@
|
||||
"""规则发布 DTO。"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RulePublishDTO(BaseModel):
|
||||
"""规则版本发布/回滚请求。"""
|
||||
|
||||
versionId: int = Field(..., description="规则版本ID")
|
||||
operatorUserId: int | None = Field(None, description="操作用户ID")
|
||||
@@ -0,0 +1,9 @@
|
||||
"""规则校验 DTO。"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RuleValidateDTO(BaseModel):
|
||||
"""规则 YAML 校验请求。"""
|
||||
|
||||
yamlText: str = Field(..., description="规则 YAML 正文")
|
||||
@@ -0,0 +1,11 @@
|
||||
"""规则版本创建 DTO。"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RuleVersionCreateDTO(BaseModel):
|
||||
"""创建规则版本请求。"""
|
||||
|
||||
yamlText: str = Field(..., description="规则 YAML 正文")
|
||||
changeNote: str | None = Field(None, description="版本变更说明")
|
||||
editorUserId: int | None = Field(None, description="编辑者用户ID")
|
||||
@@ -1,44 +1,30 @@
|
||||
"""leaudit bridge — use leaudit's full pipeline with docauditai's database storage.
|
||||
"""LeAudit Bridge 模块。
|
||||
|
||||
Directly calls leaudit's OCR → extraction → evaluation pipeline
|
||||
and persists results into docauditai's PostgreSQL via PostgREST.
|
||||
|
||||
Configuration switch (in env.{port}):
|
||||
PIPELINE_MODE=leaudit → use leaudit pipeline
|
||||
对平台暴露统一桥接入口,内部逐步从旧的手写 pipeline
|
||||
迁移到原生 ``AuditCtx`` + ``AuditService`` 路线。
|
||||
"""
|
||||
|
||||
from leaudit_bridge.client_factory import (
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.client_factory import (
|
||||
create_ocr_client,
|
||||
create_llm_client,
|
||||
create_vlm_client,
|
||||
)
|
||||
from leaudit_bridge.ocr_bridge import BridgeOCRClient
|
||||
from leaudit_bridge.pipeline import LauditPipeline, PipelineResult
|
||||
from leaudit_bridge.rules_loader import RulesLoader
|
||||
from leaudit_bridge.storage_adapter import StorageAdapter
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.ocr_bridge import BridgeOCRClient
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline import LauditPipeline, PipelineResult
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.rules_loader import RulesLoader
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
|
||||
|
||||
|
||||
def is_leaudit_mode() -> bool:
|
||||
"""Check if the system is configured to use the leaudit pipeline."""
|
||||
from core.config import PIPELINE_MODE
|
||||
return PIPELINE_MODE == "leaudit"
|
||||
"""新平台始终使用 leaudit pipeline。"""
|
||||
return True
|
||||
|
||||
|
||||
def create_pipeline(rules_path: str | None = None) -> LauditPipeline:
|
||||
"""Create a fully configured LauditPipeline from current config.
|
||||
"""创建旧版兼容 LauditPipeline。
|
||||
|
||||
Wraps the raw OCR client with DocNormalizationAdapter so that a single
|
||||
``.ocr()`` call produces a fully enriched OcrResult with:
|
||||
- Document classification (type_id + rules_file_path)
|
||||
- Dossier segmentation (sub-document page mapping)
|
||||
- Seal/signature enrichment (text, seal_id, party_id)
|
||||
- Normalized markdown (seal blocks + page separators)
|
||||
|
||||
Args:
|
||||
rules_path: If provided, forces the adapter to use this rules file
|
||||
for classification and segmentation. When None, the adapter
|
||||
uses the RulesFileRegistry to classify from document content,
|
||||
enabling auto-detection of sub-types (e.g. 行政许可 variants).
|
||||
当前仍保留该入口兼容旧调用方,后续正式执行链应逐步切到
|
||||
``NativeRunner``。
|
||||
"""
|
||||
from pathlib import Path
|
||||
from leaudit.doc_normalization.adapter import DocNormalizationAdapter
|
||||
@@ -51,7 +37,7 @@ def create_pipeline(rules_path: str | None = None) -> LauditPipeline:
|
||||
# Build registry from rules/ directory for content-based classification
|
||||
registry = None
|
||||
if rules_path is None:
|
||||
rules_dir = Path(__file__).resolve().parents[1] / "rules"
|
||||
rules_dir = Path(__file__).resolve().parents[3] / "rules"
|
||||
if rules_dir.is_dir():
|
||||
registry = RulesFileRegistry.from_directory(rules_dir)
|
||||
|
||||
|
||||
@@ -0,0 +1,75 @@
|
||||
"""Build native leaudit ``AuditCtx`` instances from platform-side inputs."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from leaudit.config.audit_config import AuditConfig
|
||||
from leaudit.services.audit_ctx import AuditCtx
|
||||
from leaudit.services.audit_services import AuditServices
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NativeAuditMetadata:
|
||||
"""Platform-side metadata kept outside the native ``AuditCtx`` model."""
|
||||
|
||||
run_id: int
|
||||
document_id: int
|
||||
document_file_id: int | None = None
|
||||
rule_set_id: int | None = None
|
||||
rule_version_id: int | None = None
|
||||
trigger_user_id: int | None = None
|
||||
extras: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NativeAuditBuildInput:
|
||||
"""Everything the bridge knows before constructing a native ``AuditCtx``."""
|
||||
|
||||
metadata: NativeAuditMetadata
|
||||
file_path: str
|
||||
services: AuditServices
|
||||
rules_file: Any | None = None
|
||||
page_range: tuple[int, ...] | None = None
|
||||
rule_source_path: str | None = None
|
||||
force_rules_path: str | None = None
|
||||
config_overrides: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class AuditCtxBuilder:
|
||||
"""Translate platform-side run inputs into leaudit's native ``AuditCtx``."""
|
||||
|
||||
def build(self, payload: NativeAuditBuildInput) -> AuditCtx:
|
||||
"""Create a native ``AuditCtx`` ready for ``AuditService.audit``."""
|
||||
config = self.build_config(
|
||||
force_rules_path=payload.force_rules_path or payload.rule_source_path,
|
||||
overrides=payload.config_overrides,
|
||||
)
|
||||
return AuditCtx(
|
||||
document_id=str(payload.metadata.document_id),
|
||||
rules_file=payload.rules_file,
|
||||
services=payload.services,
|
||||
file_path=payload.file_path,
|
||||
page_range=payload.page_range,
|
||||
config=config,
|
||||
)
|
||||
|
||||
def build_config(
|
||||
self,
|
||||
*,
|
||||
force_rules_path: str | None = None,
|
||||
overrides: dict[str, Any] | None = None,
|
||||
) -> AuditConfig:
|
||||
"""Build native ``AuditConfig`` from platform-side overrides."""
|
||||
raw = dict(overrides or {})
|
||||
if force_rules_path and "force_rules_path" not in raw:
|
||||
raw["force_rules_path"] = force_rules_path
|
||||
return AuditConfig(**raw)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AuditCtxBuilder",
|
||||
"NativeAuditBuildInput",
|
||||
"NativeAuditMetadata",
|
||||
]
|
||||
@@ -0,0 +1,132 @@
|
||||
"""Factory helpers for leaudit's native service-layer orchestration.
|
||||
|
||||
This module is the bridge-side assembly point for native leaudit services:
|
||||
|
||||
- ``AuditServices``
|
||||
- ``DocNormalizationService``
|
||||
- ``ExtractionService``
|
||||
- ``EvaluationService``
|
||||
- ``RescueService``
|
||||
- ``AuditService``
|
||||
|
||||
The platform should not construct these objects in controllers/services
|
||||
directly. Keep all leaudit-native wiring inside ``leaudit_bridge/``.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from leaudit.services.audit_service import AuditService
|
||||
from leaudit.services.audit_services import AuditServices
|
||||
from leaudit.services.doc_normalization_service import DocNormalizationService
|
||||
from leaudit.services.evaluation_service import EvaluationService
|
||||
from leaudit.services.extraction_service import ExtractionService
|
||||
from leaudit.services.rescue_service import RescueService
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.client_factory import (
|
||||
create_llm_client,
|
||||
create_ocr_client,
|
||||
create_vlm_client,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.ocr_bridge import BridgeOCRClient
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NativeServiceBundle:
|
||||
"""Fully assembled native leaudit service bundle."""
|
||||
|
||||
audit_service: AuditService
|
||||
audit_services: AuditServices
|
||||
normalization_service: DocNormalizationService
|
||||
extraction_service: ExtractionService
|
||||
evaluation_service: EvaluationService
|
||||
rescue_service: RescueService | None
|
||||
|
||||
|
||||
class AuditServiceFactory:
|
||||
"""Build native leaudit services for one platform-side run."""
|
||||
|
||||
def create_bundle(self, rules_path: str | None = None) -> NativeServiceBundle:
|
||||
"""Create a fully wired native leaudit service bundle.
|
||||
|
||||
``rules_path`` is only used to force the normalization adapter's
|
||||
classification path when the caller wants a fixed rules file.
|
||||
"""
|
||||
normalization_service, audit_services = self._create_normalization_services(
|
||||
rules_path=rules_path
|
||||
)
|
||||
extraction_service = ExtractionService(
|
||||
session=None,
|
||||
llm_client=audit_services.llm_client,
|
||||
)
|
||||
evaluation_service = EvaluationService(session=None)
|
||||
rescue_service = RescueService(
|
||||
session=None,
|
||||
llm_client=audit_services.llm_client,
|
||||
vlm_client=audit_services.vlm_client,
|
||||
extraction_service=extraction_service,
|
||||
)
|
||||
|
||||
audit_service = AuditService(
|
||||
document_service=None,
|
||||
normalization_service=normalization_service,
|
||||
extraction_service=extraction_service,
|
||||
evaluation_service=evaluation_service,
|
||||
rescue_service=rescue_service,
|
||||
services=AuditServices(
|
||||
llm_client=audit_services.llm_client,
|
||||
vlm_client=audit_services.vlm_client,
|
||||
ocr_client=audit_services.ocr_client,
|
||||
normalization=normalization_service,
|
||||
extraction=extraction_service,
|
||||
evaluation=evaluation_service,
|
||||
),
|
||||
)
|
||||
|
||||
return NativeServiceBundle(
|
||||
audit_service=audit_service,
|
||||
audit_services=audit_service.services,
|
||||
normalization_service=normalization_service,
|
||||
extraction_service=extraction_service,
|
||||
evaluation_service=evaluation_service,
|
||||
rescue_service=rescue_service,
|
||||
)
|
||||
|
||||
def _create_normalization_services(
|
||||
self, rules_path: str | None = None
|
||||
) -> tuple[DocNormalizationService, AuditServices]:
|
||||
"""Create normalization service plus low-level shared clients."""
|
||||
from leaudit.doc_normalization.adapter import DocNormalizationAdapter
|
||||
from leaudit.doc_normalization.doc_classifier import RulesFileRegistry
|
||||
|
||||
raw_ocr = create_ocr_client()
|
||||
llm_client = create_llm_client()
|
||||
vlm_client = create_vlm_client()
|
||||
|
||||
registry = None
|
||||
if rules_path is None:
|
||||
rules_dir = Path(__file__).resolve().parents[3] / "rules"
|
||||
if rules_dir.is_dir():
|
||||
registry = RulesFileRegistry.from_directory(rules_dir)
|
||||
|
||||
adapter = DocNormalizationAdapter(
|
||||
ocr_client=raw_ocr,
|
||||
registry=registry,
|
||||
llm_client=llm_client,
|
||||
vlm_client=vlm_client,
|
||||
force_rules_path=rules_path,
|
||||
)
|
||||
ocr_client = BridgeOCRClient(adapter, vlm_client=vlm_client)
|
||||
normalization_service = DocNormalizationService(ocr_client)
|
||||
audit_services = AuditServices(
|
||||
llm_client=llm_client,
|
||||
vlm_client=vlm_client,
|
||||
ocr_client=raw_ocr,
|
||||
normalization=normalization_service,
|
||||
)
|
||||
return normalization_service, audit_services
|
||||
|
||||
|
||||
__all__ = ["AuditServiceFactory", "NativeServiceBundle"]
|
||||
@@ -5,13 +5,15 @@ from __future__ import annotations
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.config import (
|
||||
OCR_CONFIG,
|
||||
DEFAULT_BASE_URL,
|
||||
DEFAULT_LLM_MODEL,
|
||||
DEFAULT_API_KEY,
|
||||
DEFAULT_VLM_BASE_URL,
|
||||
DEFAULT_VLM_MODEL,
|
||||
from fastapi_admin.config import (
|
||||
OCR_BASE_URL,
|
||||
OCR_TIMEOUT,
|
||||
LLM_BASE_URL,
|
||||
LLM_MODEL,
|
||||
LLM_API_KEY,
|
||||
VLM_BASE_URL,
|
||||
VLM_MODEL,
|
||||
VLM_API_KEY,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -29,8 +31,8 @@ def create_ocr_client() -> BaseOCRClient:
|
||||
|
||||
base_url = os.getenv("LEAUDIT_OCR_URL", "").rstrip("/")
|
||||
if not base_url:
|
||||
base_url = OCR_CONFIG["API_URL"].rsplit("/api/v1/ocr", 1)[0]
|
||||
timeout = float(OCR_CONFIG["TIMEOUT"])
|
||||
base_url = OCR_BASE_URL.rstrip("/")
|
||||
timeout = float(OCR_TIMEOUT)
|
||||
|
||||
client = ChandraOCRClient(
|
||||
base_url=base_url,
|
||||
@@ -45,9 +47,9 @@ def create_llm_client() -> BaseLLMClient:
|
||||
"""Create a leaudit OpenAICompatibleClient from docauditai's LLM config."""
|
||||
from leaudit.llm.openai_client import OpenAICompatibleClient
|
||||
|
||||
base_url = DEFAULT_BASE_URL
|
||||
model = DEFAULT_LLM_MODEL
|
||||
api_key = DEFAULT_API_KEY or "no-key"
|
||||
base_url = LLM_BASE_URL
|
||||
model = LLM_MODEL
|
||||
api_key = LLM_API_KEY or "no-key"
|
||||
|
||||
client = OpenAICompatibleClient(
|
||||
api_key=api_key,
|
||||
@@ -63,9 +65,9 @@ def create_vlm_client() -> BaseVLMClient | None:
|
||||
"""Create a leaudit QwenVLMClient from docauditai's VLM config."""
|
||||
from leaudit.llm.qwen_vlm_client import QwenVLMClient
|
||||
|
||||
base_url = DEFAULT_VLM_BASE_URL
|
||||
model = DEFAULT_VLM_MODEL
|
||||
api_key = DEFAULT_API_KEY or "no-key"
|
||||
base_url = VLM_BASE_URL
|
||||
model = VLM_MODEL
|
||||
api_key = VLM_API_KEY or LLM_API_KEY or "no-key"
|
||||
|
||||
if not base_url or not model:
|
||||
log.info("leaudit VLM client skipped: no VLM config")
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
"""文档文件来源解析器。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi_common.fastapi_common_logger import logger
|
||||
from fastapi_common.fastapi_common_storage.oss_client import OssClient
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.models.leauditDocumentFile import LeauditDocumentFile
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class FileSourcePayload:
|
||||
"""可供执行链消费的文件载荷。"""
|
||||
|
||||
fileName: str
|
||||
fileContent: bytes
|
||||
sourceType: str
|
||||
sourcePath: str | None = None
|
||||
|
||||
|
||||
class FileSourceResolver:
|
||||
"""解析文档文件来源。"""
|
||||
|
||||
def __init__(self, Oss: OssClient | None = None) -> None:
|
||||
self.Oss = Oss or OssClient()
|
||||
|
||||
async def ResolvePayload(self, DocumentFile: LeauditDocumentFile) -> FileSourcePayload:
|
||||
"""解析文档文件,返回任务入口可直接消费的字节载荷。"""
|
||||
if DocumentFile.localPath:
|
||||
LocalPath = Path(DocumentFile.localPath)
|
||||
if LocalPath.is_file():
|
||||
return FileSourcePayload(
|
||||
fileName=DocumentFile.fileName,
|
||||
fileContent=LocalPath.read_bytes(),
|
||||
sourceType="local",
|
||||
sourcePath=str(LocalPath),
|
||||
)
|
||||
|
||||
if DocumentFile.ossUrl:
|
||||
return await self._DownloadFromUrl(
|
||||
FileName=DocumentFile.fileName,
|
||||
Url=DocumentFile.ossUrl,
|
||||
)
|
||||
|
||||
raise ValueError("当前文档文件既无可用 localPath,也无可用 ossUrl")
|
||||
|
||||
async def _DownloadFromUrl(self, FileName: str, Url: str) -> FileSourcePayload:
|
||||
"""从 OSS 或 URL 下载文件内容。"""
|
||||
try:
|
||||
Content = self.Oss.DownloadBytes(Url)
|
||||
except Exception as Error:
|
||||
logger.error(f"下载 OSS 文件失败: url={Url}, error={Error}")
|
||||
raise
|
||||
|
||||
return FileSourcePayload(
|
||||
fileName=FileName,
|
||||
fileContent=Content,
|
||||
sourceType="oss",
|
||||
sourcePath=Url,
|
||||
)
|
||||
@@ -0,0 +1,185 @@
|
||||
"""Native AuditCtx runner for the platform bridge.
|
||||
|
||||
This module is the target execution path after deprecating the old
|
||||
platform-side hand-written pipeline orchestration.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
from leaudit.services.audit_ctx import AuditCtx
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.auditCtxBuilder import (
|
||||
AuditCtxBuilder,
|
||||
NativeAuditBuildInput,
|
||||
NativeAuditMetadata,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.auditServiceFactory import (
|
||||
AuditServiceFactory,
|
||||
NativeServiceBundle,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NativeRunRequest:
|
||||
"""Platform-side request payload for one native leaudit run."""
|
||||
|
||||
metadata: NativeAuditMetadata
|
||||
local_file_path: str
|
||||
rules_file: Any | None = None
|
||||
rule_source_path: str | None = None
|
||||
rules_path_override: str | None = None
|
||||
page_range: tuple[int, ...] | None = None
|
||||
config_overrides: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class NativeRunResult:
|
||||
"""Result of one native ``AuditService.audit(ctx)`` execution."""
|
||||
|
||||
ctx: AuditCtx
|
||||
service_bundle: NativeServiceBundle
|
||||
metadata: NativeAuditMetadata
|
||||
|
||||
|
||||
class NativeRunner:
|
||||
"""Bridge-side runner that delegates orchestration to native leaudit."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
service_factory: AuditServiceFactory | None = None,
|
||||
ctx_builder: AuditCtxBuilder | None = None,
|
||||
storage_adapter: StorageAdapter | None = None,
|
||||
) -> None:
|
||||
self.service_factory = service_factory or AuditServiceFactory()
|
||||
self.ctx_builder = ctx_builder or AuditCtxBuilder()
|
||||
self.storage = storage_adapter or StorageAdapter()
|
||||
|
||||
async def run(self, request: NativeRunRequest) -> NativeRunResult:
|
||||
"""Execute one native leaudit run and return the populated ctx.
|
||||
|
||||
Persistence is intentionally not mixed into the orchestration step.
|
||||
The caller can choose when to persist the final ctx to platform tables.
|
||||
"""
|
||||
bundle = self.service_factory.create_bundle(
|
||||
rules_path=request.rules_path_override or request.rule_source_path
|
||||
)
|
||||
ctx = self.ctx_builder.build(
|
||||
NativeAuditBuildInput(
|
||||
metadata=request.metadata,
|
||||
file_path=request.local_file_path,
|
||||
services=bundle.audit_services,
|
||||
rules_file=request.rules_file,
|
||||
page_range=request.page_range,
|
||||
rule_source_path=request.rule_source_path,
|
||||
force_rules_path=request.rules_path_override,
|
||||
config_overrides=request.config_overrides,
|
||||
)
|
||||
)
|
||||
ctx = await bundle.audit_service.audit(ctx)
|
||||
return NativeRunResult(
|
||||
ctx=ctx,
|
||||
service_bundle=bundle,
|
||||
metadata=request.metadata,
|
||||
)
|
||||
|
||||
async def persist_result(self, result: NativeRunResult) -> None:
|
||||
"""Persist a native run into platform-owned ``leaudit_*`` tables.
|
||||
"""
|
||||
document_id = result.metadata.document_id
|
||||
run_id = result.metadata.run_id
|
||||
ctx = result.ctx
|
||||
extraction_errors = list(ctx.extraction_errors)
|
||||
if not extraction_errors and ctx.extraction is not None:
|
||||
extraction_errors = list(ctx.extraction.all_errors)
|
||||
|
||||
if ctx.normalized_doc is not None:
|
||||
await self.storage.save_ocr_result(
|
||||
document_id,
|
||||
ctx.normalized_doc,
|
||||
run_id=run_id,
|
||||
)
|
||||
if ctx.extraction is not None:
|
||||
await self.storage.save_extraction_result(
|
||||
document_id,
|
||||
ctx.extraction,
|
||||
run_id=run_id,
|
||||
)
|
||||
if ctx.evaluation is not None and ctx.rules_file is not None and ctx.extraction is not None:
|
||||
await self.storage.save_evaluation_results(
|
||||
document_id,
|
||||
ctx.rules_file,
|
||||
ctx.evaluation,
|
||||
ctx.extraction,
|
||||
run_id=run_id,
|
||||
rule_version_id=result.metadata.rule_version_id,
|
||||
)
|
||||
if extraction_errors:
|
||||
await self.storage.save_run_errors(
|
||||
document_id,
|
||||
run_id=run_id,
|
||||
stage=ctx.phase or "extract",
|
||||
messages=extraction_errors,
|
||||
level="warning",
|
||||
error_code="EXTRACTION_WARNING",
|
||||
)
|
||||
if ctx.fallback_tasks:
|
||||
await self.storage.save_rescue_outcomes(
|
||||
document_id,
|
||||
run_id=run_id,
|
||||
tasks=ctx.fallback_tasks,
|
||||
)
|
||||
|
||||
await self.storage.save_run_metrics(
|
||||
document_id,
|
||||
run_id=run_id,
|
||||
timing=dict(ctx.timing),
|
||||
page_count=len(ctx.normalized_doc.pages) if ctx.normalized_doc is not None else None,
|
||||
sub_document_count=len(ctx.extraction.sub_documents) if ctx.extraction is not None and getattr(ctx.extraction, "sub_documents", None) else 0,
|
||||
field_count=len(ctx.extraction.fields) if ctx.extraction is not None else 0,
|
||||
rule_count=len(ctx.evaluation.rules) if ctx.evaluation is not None else (len(ctx.rules_file.flat_rules) if ctx.rules_file is not None else 0),
|
||||
rescue_rule_count=len(ctx.fallback_tasks),
|
||||
artifact_count=self._estimate_artifact_count(ctx),
|
||||
)
|
||||
|
||||
result_status = "review" if any(task.requires_human_review for task in ctx.fallback_tasks) else self._resolve_result_status(ctx)
|
||||
await self.storage.finalize_run(
|
||||
document_id,
|
||||
run_id=run_id,
|
||||
result_status=result_status,
|
||||
rescue_applied=bool(ctx.fallback_tasks),
|
||||
phase=ctx.phase,
|
||||
finished=True,
|
||||
)
|
||||
|
||||
def _estimate_artifact_count(self, ctx: AuditCtx) -> int:
|
||||
"""粗略估算当前运行已经产出的平台产物数。"""
|
||||
count = 0
|
||||
if ctx.normalized_doc is not None:
|
||||
count += 1
|
||||
if ctx.extraction is not None:
|
||||
count += 1
|
||||
if ctx.evaluation is not None:
|
||||
count += 1
|
||||
if ctx.fallback_tasks:
|
||||
count += len(ctx.fallback_tasks)
|
||||
return count
|
||||
|
||||
def _resolve_result_status(self, ctx: AuditCtx) -> str:
|
||||
"""按原生 AuditCtx 结果推导运行状态。"""
|
||||
if ctx.evaluation is None:
|
||||
return "error"
|
||||
if ctx.evaluation.errors:
|
||||
return "error"
|
||||
if ctx.evaluation.failed_count == 0 and ctx.evaluation.skipped_count == 0:
|
||||
return "pass"
|
||||
if ctx.evaluation.failed_count > 0:
|
||||
return "fail"
|
||||
return "partial"
|
||||
|
||||
|
||||
__all__ = ["NativeRunRequest", "NativeRunResult", "NativeRunner"]
|
||||
@@ -23,7 +23,7 @@ from leaudit.llm.base import BaseLLMClient
|
||||
from leaudit.ocr.base import BaseOCRClient
|
||||
from leaudit.ocr.models import OcrResult
|
||||
|
||||
from leaudit_bridge.storage_adapter import StorageAdapter
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -229,7 +229,7 @@ class LauditPipeline:
|
||||
self, document_id: int, ocr_result: OcrResult,
|
||||
) -> None:
|
||||
"""Extract case number from OCR and write to database."""
|
||||
from leaudit_bridge.case_number_extractor import (
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.case_number_extractor import (
|
||||
extract_case_number_with_llm,
|
||||
)
|
||||
|
||||
|
||||
@@ -0,0 +1,86 @@
|
||||
"""规则 YAML 校验器。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import sys
|
||||
from dataclasses import dataclass
|
||||
|
||||
import yaml
|
||||
from pydantic import ValidationError
|
||||
|
||||
from leaudit.dsl.loader import parse_rules_yaml_text
|
||||
from leaudit.dsl.validator import DSLValidationError, validate as validate_rules
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RuleValidationPayload:
|
||||
"""规则校验结果。"""
|
||||
|
||||
valid: bool
|
||||
ruleType: str | None = None
|
||||
ruleName: str | None = None
|
||||
versionNo: str | None = None
|
||||
ruleCount: int = 0
|
||||
extractCount: int = 0
|
||||
errors: list[str] | None = None
|
||||
|
||||
|
||||
class RuleValidator:
|
||||
"""负责规则 YAML 的语法与 DSL 语义校验。"""
|
||||
|
||||
_CHECK_MODULES = (
|
||||
"leaudit.engine.checks.required",
|
||||
"leaudit.engine.checks.compare",
|
||||
"leaudit.engine.checks.format_check",
|
||||
"leaudit.engine.checks.text",
|
||||
"leaudit.engine.checks.multi_entity",
|
||||
"leaudit.engine.checks.visual",
|
||||
"leaudit.engine.checks.external",
|
||||
"leaudit.engine.checks.assert_check",
|
||||
"leaudit.engine.checks.code_check",
|
||||
"leaudit.engine.checks.ai_check",
|
||||
)
|
||||
|
||||
def ValidateYaml(self, YamlText: str) -> RuleValidationPayload:
|
||||
"""校验 YAML 并返回摘要结果。"""
|
||||
try:
|
||||
RulesFile = parse_rules_yaml_text(YamlText)
|
||||
self._EnsureChecksImported()
|
||||
validate_rules(RulesFile, registered_primitives=None)
|
||||
except yaml.YAMLError as Error:
|
||||
return RuleValidationPayload(valid=False, errors=[f"YAML 语法错误: {Error}"])
|
||||
except ValidationError as Error:
|
||||
return RuleValidationPayload(valid=False, errors=[f"Schema 校验失败: {Error}"])
|
||||
except DSLValidationError as Error:
|
||||
return RuleValidationPayload(valid=False, errors=[f"DSL 校验失败: {Error}"])
|
||||
except Exception as Error:
|
||||
return RuleValidationPayload(valid=False, errors=[f"规则校验失败: {Error}"])
|
||||
|
||||
return RuleValidationPayload(
|
||||
valid=True,
|
||||
ruleType=RulesFile.metadata.type_id,
|
||||
ruleName=RulesFile.metadata.name,
|
||||
versionNo=RulesFile.metadata.version,
|
||||
ruleCount=len(RulesFile.flat_rules),
|
||||
extractCount=len(RulesFile.flat_extract),
|
||||
errors=[],
|
||||
)
|
||||
|
||||
def ParseValidated(self, YamlText: str):
|
||||
"""解析并返回已通过完整校验的 RulesFile。"""
|
||||
Validation = self.ValidateYaml(YamlText)
|
||||
if not Validation.valid:
|
||||
raise ValueError("; ".join(Validation.errors or ["规则校验失败"]))
|
||||
return parse_rules_yaml_text(YamlText)
|
||||
|
||||
def _EnsureChecksImported(self) -> None:
|
||||
"""确保所有检查器模块已注册。"""
|
||||
for ModuleName in self._CHECK_MODULES:
|
||||
try:
|
||||
if ModuleName in sys.modules:
|
||||
importlib.reload(sys.modules[ModuleName])
|
||||
else:
|
||||
importlib.import_module(ModuleName)
|
||||
except Exception:
|
||||
continue
|
||||
@@ -0,0 +1,125 @@
|
||||
"""规则版本来源解析器。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
from fastapi_common.fastapi_common_logger import logger
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from fastapi_common.fastapi_common_storage.oss_client import OssClient
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RuleVersionPayload:
|
||||
"""规则文件解析结果。"""
|
||||
|
||||
localPath: str
|
||||
sourceType: str
|
||||
sourcePath: str
|
||||
ruleVersionId: int | None = None
|
||||
ruleTypeId: str | None = None
|
||||
fileSha256: str | None = None
|
||||
|
||||
|
||||
class RuleVersionResolver:
|
||||
"""按运行记录解析规则 YAML 文件来源。"""
|
||||
|
||||
def __init__(self, Oss: OssClient | None = None) -> None:
|
||||
self.Oss = Oss or OssClient()
|
||||
|
||||
async def ResolveForRun(self, RunId: int) -> RuleVersionPayload | None:
|
||||
"""根据运行记录解析规则文件来源。"""
|
||||
RunInfo = await self._LoadRunInfo(RunId)
|
||||
if not RunInfo:
|
||||
return None
|
||||
|
||||
LocalCachePath = RunInfo["rule_local_cache_path"]
|
||||
if LocalCachePath:
|
||||
CachePath = Path(LocalCachePath)
|
||||
if CachePath.is_file():
|
||||
return RuleVersionPayload(
|
||||
localPath=str(CachePath),
|
||||
sourceType="local_cache",
|
||||
sourcePath=str(CachePath),
|
||||
ruleVersionId=RunInfo["rule_version_id"],
|
||||
ruleTypeId=RunInfo["rule_type_id"],
|
||||
fileSha256=RunInfo["rule_source_sha256"],
|
||||
)
|
||||
|
||||
SourceUrl = RunInfo["rule_source_oss_url"]
|
||||
if not SourceUrl:
|
||||
return None
|
||||
|
||||
return await self._DownloadFromUrl(
|
||||
Url=SourceUrl,
|
||||
RuleVersionId=RunInfo["rule_version_id"],
|
||||
RuleTypeId=RunInfo["rule_type_id"],
|
||||
ExpectedSha256=RunInfo["rule_source_sha256"],
|
||||
)
|
||||
|
||||
async def _LoadRunInfo(self, RunId: int) -> dict[str, object] | None:
|
||||
"""读取运行记录中的规则来源信息。"""
|
||||
async with GetAsyncSession() as Session:
|
||||
Result = await Session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT
|
||||
rule_version_id,
|
||||
rule_type_id,
|
||||
rule_source_oss_url,
|
||||
rule_source_sha256,
|
||||
rule_local_cache_path
|
||||
FROM leaudit_audit_runs
|
||||
WHERE id = :run_id
|
||||
LIMIT 1
|
||||
"""
|
||||
),
|
||||
{"run_id": RunId},
|
||||
)
|
||||
Row = Result.mappings().first()
|
||||
return dict(Row) if Row else None
|
||||
|
||||
async def _DownloadFromUrl(
|
||||
self,
|
||||
*,
|
||||
Url: str,
|
||||
RuleVersionId: int | None,
|
||||
RuleTypeId: str | None,
|
||||
ExpectedSha256: str | None,
|
||||
) -> RuleVersionPayload:
|
||||
"""从 OSS 下载规则 YAML 到本地临时文件。"""
|
||||
try:
|
||||
Content = self.Oss.DownloadBytes(Url)
|
||||
except Exception as Error:
|
||||
logger.error(f"下载规则 YAML 失败: url={Url}, error={Error}")
|
||||
raise
|
||||
|
||||
ActualSha256 = hashlib.sha256(Content).hexdigest()
|
||||
if ExpectedSha256 and ActualSha256.lower() != ExpectedSha256.lower():
|
||||
raise ValueError(
|
||||
"规则 YAML SHA256 校验失败: "
|
||||
f"expected={ExpectedSha256}, actual={ActualSha256}"
|
||||
)
|
||||
|
||||
FilePrefix = "leaudit-rule-"
|
||||
if RuleTypeId:
|
||||
SafeTypeId = RuleTypeId.replace("/", "_").replace(".", "_")
|
||||
FilePrefix = f"{FilePrefix}{SafeTypeId}-"
|
||||
|
||||
LocalPath = self.Oss.WriteTempBytes(
|
||||
Content=Content,
|
||||
Suffix=".yaml",
|
||||
Prefix=FilePrefix,
|
||||
)
|
||||
|
||||
return RuleVersionPayload(
|
||||
localPath=LocalPath,
|
||||
sourceType="oss",
|
||||
sourcePath=Url,
|
||||
ruleVersionId=RuleVersionId,
|
||||
ruleTypeId=RuleTypeId,
|
||||
fileSha256=ActualSha256,
|
||||
)
|
||||
@@ -1,28 +1,38 @@
|
||||
"""Celery task for leaudit pipeline processing.
|
||||
|
||||
Activated when PIPELINE_MODE=leaudit in env.{port} config.
|
||||
Replaces the legacy OCR → extraction → evaluation pipeline with
|
||||
leaudit's YAML-rules-driven approach.
|
||||
"""
|
||||
"""LeAudit 任务入口。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from core.celery_app_limited import celery_app
|
||||
from core.postgrest.client import get_postgrest_client
|
||||
from core.logger import log
|
||||
from fastapi_common.fastapi_common_logger import logger
|
||||
|
||||
from leaudit_bridge import create_pipeline, RulesLoader
|
||||
from fastapi_admin.config import LEAUDIT_RULES_DIR
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.nativeRunner import (
|
||||
NativeRunRequest,
|
||||
NativeRunner,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.auditCtxBuilder import (
|
||||
NativeAuditMetadata,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.ruleVersionResolver import (
|
||||
RuleVersionResolver,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.rules_loader import RulesLoader
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
|
||||
|
||||
# Celery 集成待 P2 阶段实现,当前使用同步占位
|
||||
# from core.celery_app_limited import celery_app
|
||||
|
||||
log = logger
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="leaudit.process_document")
|
||||
# P2: Celery 集成后启用 @celery_app.task 装饰器
|
||||
def leaudit_process_document(
|
||||
self,
|
||||
document_id: int,
|
||||
file_content: bytes,
|
||||
filename: str,
|
||||
@@ -30,43 +40,41 @@ def leaudit_process_document(
|
||||
source_port: Optional[int] = None,
|
||||
rules_path: Optional[str] = None,
|
||||
):
|
||||
"""Process a document using leaudit's full pipeline.
|
||||
|
||||
Steps: OCR → Extraction → Evaluation → Store in docauditai DB.
|
||||
"""
|
||||
task_id = self.request.id
|
||||
log.task.info(f"[任务ID: {task_id}] leaudit管线开始处理: {filename}")
|
||||
"""处理单个文档的 LeAudit 任务。"""
|
||||
task_id = os.urandom(8).hex()
|
||||
log.info(f"[任务ID: {task_id}] leaudit管线开始处理: {filename}")
|
||||
|
||||
# 新平台:region 通过参数传递,不再依赖 os.environ 切换
|
||||
if source_port:
|
||||
from core.utils.instance_context import set_instance_environment
|
||||
instance_name = set_instance_environment(source_port)
|
||||
log.task.info(
|
||||
f"[任务ID: {task_id}] 实例环境: {instance_name} (端口: {source_port})"
|
||||
)
|
||||
log.info(f"[任务ID: {task_id}] 来源端口: {source_port}")
|
||||
|
||||
if upload_info is None:
|
||||
upload_info = {}
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
temp_paths: list[str] = []
|
||||
storage = StorageAdapter()
|
||||
|
||||
try:
|
||||
rules_path_resolved = rules_path or _resolve_rules_path(document_id, loop)
|
||||
run_id = _resolve_run_id(document_id, upload_info, loop)
|
||||
rules_resolution = _resolve_rules_runtime(document_id, run_id, rules_path, loop)
|
||||
loop.run_until_complete(_update_run_status_safe(run_id, "running"))
|
||||
rules_path_resolved = rules_resolution["rules_path"]
|
||||
|
||||
# For types with a known mapping (e.g. 行政处罚), pre-load rules_file.
|
||||
# For types that need content classification (e.g. 行政许可 sub-types),
|
||||
# rules_path will be None → adapter classifies after OCR → pipeline
|
||||
# loads rules from ocr_result.rules_file_path.
|
||||
rules_file = None
|
||||
if rules_path_resolved:
|
||||
loader = RulesLoader()
|
||||
rules_file = loader.load(rules_path_resolved)
|
||||
log.task.info(
|
||||
temp_rule_path = rules_resolution.get("temp_rule_path")
|
||||
if isinstance(temp_rule_path, str):
|
||||
temp_paths.append(temp_rule_path)
|
||||
log.info(
|
||||
f"[任务ID: {task_id}] RulesFile pre-loaded: {rules_path_resolved} "
|
||||
f"({len(rules_file.flat_rules)} rules, {len(rules_file.flat_extract)} fields)"
|
||||
)
|
||||
else:
|
||||
log.task.info(
|
||||
log.info(
|
||||
f"[任务ID: {task_id}] No fixed rules_path — "
|
||||
"will classify from document content after OCR"
|
||||
)
|
||||
@@ -75,47 +83,79 @@ def leaudit_process_document(
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp:
|
||||
temp.write(file_content)
|
||||
temp_path = temp.name
|
||||
temp_paths.append(temp_path)
|
||||
|
||||
pipeline = create_pipeline(rules_path=rules_path_resolved)
|
||||
runner = NativeRunner()
|
||||
|
||||
t0 = time.time()
|
||||
result = loop.run_until_complete(
|
||||
pipeline.run(
|
||||
document_id=document_id,
|
||||
file_path=temp_path,
|
||||
rules_file=rules_file,
|
||||
source_port=source_port or int(os.getenv("APP_PORT", "8000")),
|
||||
native_result = loop.run_until_complete(
|
||||
runner.run(
|
||||
NativeRunRequest(
|
||||
metadata=NativeAuditMetadata(
|
||||
run_id=run_id,
|
||||
document_id=document_id,
|
||||
rule_version_id=_optional_int(upload_info, "rule_version_id", "ruleVersionId"),
|
||||
extras={"taskId": task_id},
|
||||
),
|
||||
local_file_path=temp_path,
|
||||
rules_file=rules_file,
|
||||
rule_source_path=rules_path_resolved,
|
||||
)
|
||||
)
|
||||
)
|
||||
loop.run_until_complete(runner.persist_result(native_result))
|
||||
elapsed = round(time.time() - t0, 2)
|
||||
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
log.task.info(
|
||||
f"[任务ID: {task_id}] leaudit管线完成: phase={result.detected_phase}, "
|
||||
f"timing={result.timing}, 总耗时={elapsed:.1f}s"
|
||||
ctx = native_result.ctx
|
||||
loop.run_until_complete(_update_run_phase_safe(run_id, ctx.phase))
|
||||
loop.run_until_complete(_update_run_status_safe(run_id, "completed"))
|
||||
loop.run_until_complete(_update_status_safe(document_id, "completed"))
|
||||
log.info(
|
||||
f"[任务ID: {task_id}] leaudit管线完成: phase={ctx.phase}, "
|
||||
f"timing={dict(ctx.timing)}, 总耗时={elapsed:.1f}s"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"document_id": document_id,
|
||||
"phase": result.detected_phase,
|
||||
"timing": result.timing,
|
||||
"errors": result.errors,
|
||||
"run_id": run_id,
|
||||
"phase": ctx.phase,
|
||||
"timing": dict(ctx.timing),
|
||||
"errors": list(ctx.extraction.all_errors) if ctx.extraction is not None else list(ctx.extraction_errors),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
log.task.error(f"[任务ID: {task_id}] leaudit管线失败: {e}", exc_info=True)
|
||||
log.error(f"[任务ID: {task_id}] leaudit管线失败: {e}", exc_info=True)
|
||||
try:
|
||||
loop.run_until_complete(_update_status_safe(document_id, "Failed"))
|
||||
loop.run_until_complete(_update_status_safe(document_id, "failed"))
|
||||
if 'run_id' in locals():
|
||||
failed_phase = "persist"
|
||||
if "native_result" in locals():
|
||||
failed_phase = native_result.ctx.phase or failed_phase
|
||||
loop.run_until_complete(
|
||||
storage.fail_run(
|
||||
document_id,
|
||||
run_id=run_id,
|
||||
phase=failed_phase,
|
||||
message=str(e),
|
||||
detail_json={
|
||||
"taskId": task_id,
|
||||
"filename": filename,
|
||||
"errorType": type(e).__name__,
|
||||
},
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
finally:
|
||||
for temp_path in temp_paths:
|
||||
try:
|
||||
if Path(temp_path).exists():
|
||||
os.remove(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
loop.close()
|
||||
|
||||
|
||||
@@ -129,49 +169,158 @@ _TYPE_ID_RULES_MAP: dict[int, str] = {
|
||||
|
||||
def _resolve_rules_path(document_id: int, loop: asyncio.AbstractEventLoop) -> str | None:
|
||||
"""Resolve rules_path: config override → document metadata → type_id mapping."""
|
||||
from core.config import LEAUDIT_CONFIG
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from sqlalchemy import text as sa_text
|
||||
|
||||
# 1. Config override (when explicitly set)
|
||||
config_path = LEAUDIT_CONFIG.get("RULES_PATH", "")
|
||||
if config_path:
|
||||
return config_path
|
||||
# 1. Config override (when explicitly set in app.toml)
|
||||
if LEAUDIT_RULES_DIR:
|
||||
return LEAUDIT_RULES_DIR
|
||||
|
||||
try:
|
||||
client = get_postgrest_client()
|
||||
doc = loop.run_until_complete(
|
||||
client.select(
|
||||
table="documents",
|
||||
filters={"id": f"eq.{document_id}"},
|
||||
single=True,
|
||||
)
|
||||
)
|
||||
if not doc:
|
||||
return None
|
||||
async def _fetch():
|
||||
async with GetAsyncSession() as session:
|
||||
result = await session.execute(
|
||||
sa_text("SELECT type_id FROM leaudit_documents WHERE id = :did"),
|
||||
{"did": document_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if row and row[0] and row[0] in _TYPE_ID_RULES_MAP:
|
||||
return f"{_TYPE_ID_RULES_MAP[row[0]]}/rules.yaml"
|
||||
return None
|
||||
|
||||
# 2. Document-level override
|
||||
rfp = doc.get("rules_file_path")
|
||||
if rfp:
|
||||
return rfp
|
||||
|
||||
# 3. type_id mapping
|
||||
type_id = doc.get("type_id")
|
||||
if type_id and type_id in _TYPE_ID_RULES_MAP:
|
||||
return f"{_TYPE_ID_RULES_MAP[type_id]}/rules.yaml"
|
||||
return loop.run_until_complete(_fetch())
|
||||
except Exception as e:
|
||||
log.task.warning(f"Failed to resolve rules_path from document: {e}")
|
||||
log.warning(f"Failed to resolve rules_path from document: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _update_status_safe(document_id: int, status: str) -> None:
|
||||
"""Safely update document status, ignoring errors."""
|
||||
def _resolve_rules_runtime(
|
||||
document_id: int,
|
||||
run_id: int,
|
||||
explicit_rules_path: str | None,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
) -> dict[str, str | None]:
|
||||
"""解析本次执行使用的规则来源。"""
|
||||
if explicit_rules_path:
|
||||
return {
|
||||
"rules_path": explicit_rules_path,
|
||||
"temp_rule_path": None,
|
||||
"source_type": "explicit",
|
||||
"source_path": explicit_rules_path,
|
||||
}
|
||||
|
||||
resolver = RuleVersionResolver()
|
||||
try:
|
||||
client = get_postgrest_client()
|
||||
await client.update(
|
||||
table="documents",
|
||||
filters={"id": f"eq.{document_id}"},
|
||||
data={"status": status},
|
||||
)
|
||||
payload = loop.run_until_complete(resolver.ResolveForRun(run_id))
|
||||
if payload:
|
||||
log.info(
|
||||
f"run_id={run_id} 规则来源已解析: sourceType={payload.sourceType}, "
|
||||
f"sourcePath={payload.sourcePath}, localPath={payload.localPath}"
|
||||
)
|
||||
return {
|
||||
"rules_path": payload.localPath,
|
||||
"temp_rule_path": payload.localPath if payload.sourceType == "oss" else None,
|
||||
"source_type": payload.sourceType,
|
||||
"source_path": payload.sourcePath,
|
||||
}
|
||||
except Exception as e:
|
||||
log.warning(f"Failed to resolve rule version from run: run_id={run_id}, error={e}")
|
||||
|
||||
fallback_rules_path = _resolve_rules_path(document_id, loop)
|
||||
return {
|
||||
"rules_path": fallback_rules_path,
|
||||
"temp_rule_path": None,
|
||||
"source_type": "legacy_fallback",
|
||||
"source_path": fallback_rules_path,
|
||||
}
|
||||
|
||||
|
||||
def _resolve_run_id(
|
||||
document_id: int,
|
||||
upload_info: Dict[str, Any] | None,
|
||||
loop: asyncio.AbstractEventLoop,
|
||||
) -> int:
|
||||
"""解析本次任务对应的运行 ID。"""
|
||||
if upload_info:
|
||||
for key in ("run_id", "runId"):
|
||||
value = upload_info.get(key)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from sqlalchemy import text as sa_text
|
||||
|
||||
async def _fetch() -> int:
|
||||
async with GetAsyncSession() as session:
|
||||
result = await session.execute(
|
||||
sa_text("SELECT id FROM leaudit_audit_runs WHERE document_id = :did ORDER BY id DESC LIMIT 1"),
|
||||
{"did": document_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
if not row:
|
||||
raise ValueError(f"未找到 document_id={document_id} 对应的 run 记录")
|
||||
return int(row[0])
|
||||
|
||||
return loop.run_until_complete(_fetch())
|
||||
|
||||
|
||||
def _optional_int(payload: Dict[str, Any] | None, *keys: str) -> int | None:
|
||||
"""从字典中按顺序取可用整数。"""
|
||||
if not payload:
|
||||
return None
|
||||
|
||||
for key in keys:
|
||||
value = payload.get(key)
|
||||
if isinstance(value, int):
|
||||
return value
|
||||
return None
|
||||
|
||||
|
||||
async def _update_status_safe(document_id: int, status: str) -> None:
|
||||
"""Safely update document status via SQLAlchemy, ignoring errors."""
|
||||
try:
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from sqlalchemy import text as sa_text
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
sa_text("UPDATE leaudit_documents SET processing_status = :s, update_time = now() WHERE id = :did"),
|
||||
{"s": status, "did": document_id},
|
||||
)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _update_run_status_safe(run_id: int, status: str) -> None:
|
||||
"""安全更新运行状态。"""
|
||||
try:
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from sqlalchemy import text as sa_text
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
sa_text("UPDATE leaudit_audit_runs SET status = :s, update_time = now() WHERE id = :rid"),
|
||||
{"s": status, "rid": run_id},
|
||||
)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
async def _update_run_phase_safe(run_id: int, phase: str | None) -> None:
|
||||
"""安全更新运行阶段。"""
|
||||
try:
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from sqlalchemy import text as sa_text
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
sa_text("UPDATE leaudit_audit_runs SET phase = :p, update_time = now() WHERE id = :rid"),
|
||||
{"p": phase, "rid": run_id},
|
||||
)
|
||||
await session.commit()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -190,12 +339,16 @@ def dispatch_leaudit_task(
|
||||
source_port: Optional[int] = None,
|
||||
rules_path: Optional[str] = None,
|
||||
):
|
||||
"""Dispatch a leaudit processing task."""
|
||||
return leaudit_process_document.apply_async(
|
||||
args=[document_id, file_content, filename],
|
||||
kwargs={
|
||||
"upload_info": upload_info,
|
||||
"source_port": source_port or int(os.getenv("APP_PORT", "8000")),
|
||||
"rules_path": rules_path,
|
||||
},
|
||||
"""Dispatch a leaudit processing task.
|
||||
|
||||
P2: Celery 集成后改用 leaudit_process_document.apply_async(...)
|
||||
当前阶段直接同步调用。
|
||||
"""
|
||||
return leaudit_process_document(
|
||||
document_id=document_id,
|
||||
file_content=file_content,
|
||||
filename=filename,
|
||||
upload_info=upload_info,
|
||||
source_port=source_port or int(os.getenv("APP_PORT", "8000")),
|
||||
rules_path=rules_path,
|
||||
)
|
||||
|
||||
@@ -2,7 +2,8 @@
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.services.auditService import IAuditService
|
||||
from fastapi_modules.fastapi_leaudit.services.authService import IAuthService
|
||||
from fastapi_modules.fastapi_leaudit.services.ossService import IOssService
|
||||
from fastapi_modules.fastapi_leaudit.services.permissionService import IPermissionService
|
||||
from fastapi_modules.fastapi_leaudit.services.ruleService import IRuleService
|
||||
|
||||
__all__ = ["IAuditService", "IAuthService", "IPermissionService", "IRuleService"]
|
||||
__all__ = ["IAuditService", "IAuthService", "IOssService", "IPermissionService", "IRuleService"]
|
||||
|
||||
@@ -9,7 +9,7 @@ class IAuditService(ABC):
|
||||
"""评查服务接口。"""
|
||||
|
||||
@abstractmethod
|
||||
async def Run(self, DocumentId: int) -> AuditRunVO:
|
||||
async def Run(self, DocumentId: int, RuleType: str | None = None, Force: bool = False) -> AuditRunVO:
|
||||
"""触发文档评查。"""
|
||||
...
|
||||
|
||||
|
||||
@@ -4,13 +4,22 @@
|
||||
文档 → OCR → Extract → Evaluate → Rescue → Persist
|
||||
"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from fastapi_common.fastapi_common_logger import logger
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
|
||||
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
|
||||
from sqlalchemy import select, text
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.domian.vo.auditVo import AuditRunVO, AuditResultVO
|
||||
from fastapi_modules.fastapi_leaudit.models import LeauditAuditRun
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.fileSourceResolver import FileSourceResolver
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.tasks import dispatch_leaudit_task
|
||||
from fastapi_modules.fastapi_leaudit.models import (
|
||||
LeauditAuditRun,
|
||||
LeauditDocument,
|
||||
LeauditDocumentFile,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.services import IAuditService
|
||||
|
||||
|
||||
@@ -20,12 +29,116 @@ class AuditServiceImpl(IAuditService):
|
||||
async def Run(self, DocumentId: int, RuleType: str | None = None, Force: bool = False) -> AuditRunVO:
|
||||
"""触发文档评查。
|
||||
|
||||
实际执行流程由 Celery 任务异步处理。
|
||||
当前阶段同步触发 bridge 执行链,后续再切换为 Celery 异步分发。
|
||||
"""
|
||||
async with GetAsyncSession() as session:
|
||||
# TODO: 从 bridge 层获取 pipeline,提交 Celery 任务
|
||||
logger.info(f"触发评查: documentId={DocumentId}, ruleType={RuleType}")
|
||||
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "Celery 任务集成待实现")
|
||||
document = await session.get(LeauditDocument, DocumentId)
|
||||
if not document:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "评查文档不存在")
|
||||
|
||||
fileResult = await session.execute(
|
||||
select(LeauditDocumentFile)
|
||||
.where(
|
||||
LeauditDocumentFile.documentId == DocumentId,
|
||||
LeauditDocumentFile.isActive.is_(True),
|
||||
)
|
||||
.order_by(LeauditDocumentFile.Id.desc())
|
||||
.limit(1)
|
||||
)
|
||||
documentFile = fileResult.scalar_one_or_none()
|
||||
if not documentFile:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前文档没有可执行文件版本")
|
||||
|
||||
runNoResult = await session.execute(
|
||||
select(LeauditAuditRun.runNo)
|
||||
.where(LeauditAuditRun.documentId == DocumentId)
|
||||
.order_by(LeauditAuditRun.runNo.desc())
|
||||
.limit(1)
|
||||
)
|
||||
latestRunNo = runNoResult.scalar_one_or_none() or 0
|
||||
|
||||
bindingResult = await session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT
|
||||
rs.id AS rule_set_id,
|
||||
rs.current_version_id AS rule_version_id,
|
||||
rv.oss_url AS rule_source_oss_url,
|
||||
rv.file_sha256 AS rule_source_sha256,
|
||||
rv.metadata_type_id AS rule_type_id
|
||||
FROM leaudit_rule_type_bindings b
|
||||
JOIN leaudit_rule_sets rs ON rs.id = b.rule_set_id
|
||||
LEFT JOIN leaudit_rule_versions rv ON rv.id = rs.current_version_id
|
||||
WHERE b.doc_type_id = :doc_type_id
|
||||
AND b.is_active = true
|
||||
ORDER BY b.priority DESC, b.id DESC
|
||||
LIMIT 1
|
||||
"""
|
||||
),
|
||||
{"doc_type_id": document.typeId},
|
||||
)
|
||||
binding = bindingResult.mappings().first()
|
||||
if not binding or not binding["rule_set_id"] or not binding["rule_version_id"]:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前文档类型未绑定可用规则版本")
|
||||
|
||||
run = LeauditAuditRun(
|
||||
documentId=DocumentId,
|
||||
documentFileId=documentFile.Id,
|
||||
runNo=int(latestRunNo) + 1,
|
||||
triggerSource="manual" if not Force else "retry",
|
||||
status="pending",
|
||||
ruleSetId=int(binding["rule_set_id"]),
|
||||
ruleVersionId=int(binding["rule_version_id"]),
|
||||
ruleTypeId=binding["rule_type_id"],
|
||||
ruleSourceOssUrl=binding["rule_source_oss_url"],
|
||||
ruleSourceSha256=binding["rule_source_sha256"],
|
||||
startedAt=datetime.now(),
|
||||
)
|
||||
session.add(run)
|
||||
await session.flush()
|
||||
|
||||
document.currentRunId = run.Id
|
||||
document.processingStatus = "running"
|
||||
await session.commit()
|
||||
await session.refresh(run)
|
||||
|
||||
try:
|
||||
Resolver = FileSourceResolver()
|
||||
Payload = await Resolver.ResolvePayload(documentFile)
|
||||
except Exception as Error:
|
||||
raise LeauditException(
|
||||
StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
f"读取评查文件失败: {Error}",
|
||||
) from Error
|
||||
|
||||
dispatch_leaudit_task(
|
||||
document_id=DocumentId,
|
||||
file_content=Payload.fileContent,
|
||||
filename=Payload.fileName,
|
||||
upload_info={
|
||||
"run_id": run.Id,
|
||||
"rule_version_id": run.ruleVersionId,
|
||||
"rule_source_oss_url": run.ruleSourceOssUrl,
|
||||
"source_type": Payload.sourceType,
|
||||
"source_path": Payload.sourcePath,
|
||||
},
|
||||
rules_path=RuleType,
|
||||
)
|
||||
|
||||
await session.refresh(run)
|
||||
return AuditRunVO(
|
||||
runId=run.Id,
|
||||
documentId=run.documentId,
|
||||
runNo=run.runNo,
|
||||
status=run.status,
|
||||
phase=run.phase,
|
||||
totalScore=float(run.totalScore) if run.totalScore else None,
|
||||
passedCount=run.passedCount,
|
||||
failedCount=run.failedCount,
|
||||
startedAt=run.startedAt,
|
||||
finishedAt=run.finishedAt,
|
||||
)
|
||||
|
||||
async def GetRunStatus(self, RunId: int) -> AuditRunVO:
|
||||
"""查询评查运行状态。"""
|
||||
@@ -52,7 +165,33 @@ class AuditServiceImpl(IAuditService):
|
||||
run = await session.get(LeauditAuditRun, RunId)
|
||||
if not run:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "评查运行记录不存在")
|
||||
# TODO: 从 leaudit_rule_results 表查询规则级结果
|
||||
result = await session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT
|
||||
rule_id,
|
||||
rule_name,
|
||||
risk,
|
||||
score,
|
||||
passed,
|
||||
status,
|
||||
skip_reason,
|
||||
confidence,
|
||||
pass_message,
|
||||
fail_message,
|
||||
remediation,
|
||||
extracted_fields,
|
||||
field_positions,
|
||||
rescue_applied,
|
||||
rescue_passed
|
||||
FROM leaudit_rule_results
|
||||
WHERE run_id = :run_id
|
||||
ORDER BY id ASC
|
||||
"""
|
||||
),
|
||||
{"run_id": RunId},
|
||||
)
|
||||
rules = [dict(row) for row in result.mappings().all()]
|
||||
return AuditResultVO(
|
||||
runId=run.Id,
|
||||
totalScore=float(run.totalScore) if run.totalScore else None,
|
||||
@@ -61,5 +200,5 @@ class AuditServiceImpl(IAuditService):
|
||||
skippedCount=run.skippedCount or 0,
|
||||
phase=run.phase,
|
||||
rescueApplied=run.rescueApplied or False,
|
||||
rules=[],
|
||||
rules=rules,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,69 @@
|
||||
"""OSS 服务实现。"""
|
||||
|
||||
from fastapi_common.fastapi_common_storage.oss_client import OssClient
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.services.ossService import IOssService
|
||||
|
||||
|
||||
class OssServiceImpl(IOssService):
|
||||
"""OSS 服务实现。"""
|
||||
|
||||
def __init__(self, Client: OssClient | None = None) -> None:
|
||||
self.Client = Client or OssClient()
|
||||
|
||||
async def DownloadBytes(self, Source: str, Bucket: str | None = None) -> bytes:
|
||||
"""下载对象内容。"""
|
||||
return self.Client.DownloadBytes(Source=Source, Bucket=Bucket)
|
||||
|
||||
async def DownloadToTempFile(
|
||||
self,
|
||||
Source: str,
|
||||
Suffix: str = "",
|
||||
Prefix: str = "oss-",
|
||||
Bucket: str | None = None,
|
||||
) -> str:
|
||||
"""下载对象到本地临时文件。"""
|
||||
return self.Client.DownloadToTempFile(
|
||||
Source=Source,
|
||||
Suffix=Suffix,
|
||||
Prefix=Prefix,
|
||||
Bucket=Bucket,
|
||||
)
|
||||
|
||||
async def UploadBytes(
|
||||
self,
|
||||
ObjectKey: str,
|
||||
Content: bytes,
|
||||
ContentType: str = "application/octet-stream",
|
||||
Bucket: str | None = None,
|
||||
) -> str:
|
||||
"""上传二进制内容。"""
|
||||
return self.Client.UploadBytes(
|
||||
ObjectKey=ObjectKey,
|
||||
Content=Content,
|
||||
ContentType=ContentType,
|
||||
Bucket=Bucket,
|
||||
)
|
||||
|
||||
async def UploadText(
|
||||
self,
|
||||
ObjectKey: str,
|
||||
Content: str,
|
||||
ContentType: str = "text/plain; charset=utf-8",
|
||||
Bucket: str | None = None,
|
||||
) -> str:
|
||||
"""上传文本内容。"""
|
||||
return self.Client.UploadText(
|
||||
ObjectKey=ObjectKey,
|
||||
Content=Content,
|
||||
ContentType=ContentType,
|
||||
Bucket=Bucket,
|
||||
)
|
||||
|
||||
async def ObjectExists(self, Source: str, Bucket: str | None = None) -> bool:
|
||||
"""判断对象是否存在。"""
|
||||
return self.Client.ObjectExists(Source=Source, Bucket=Bucket)
|
||||
|
||||
async def PresignGetUrl(self, Source: str, Bucket: str | None = None) -> str:
|
||||
"""生成对象下载签名 URL。"""
|
||||
return self.Client.PresignGetUrl(Source=Source, Bucket=Bucket)
|
||||
@@ -0,0 +1,55 @@
|
||||
"""OSS 服务接口。"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class IOssService(ABC):
|
||||
"""OSS 服务接口。"""
|
||||
|
||||
@abstractmethod
|
||||
async def DownloadBytes(self, Source: str, Bucket: str | None = None) -> bytes:
|
||||
"""下载对象内容。"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def DownloadToTempFile(
|
||||
self,
|
||||
Source: str,
|
||||
Suffix: str = "",
|
||||
Prefix: str = "oss-",
|
||||
Bucket: str | None = None,
|
||||
) -> str:
|
||||
"""下载对象到本地临时文件。"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def UploadBytes(
|
||||
self,
|
||||
ObjectKey: str,
|
||||
Content: bytes,
|
||||
ContentType: str = "application/octet-stream",
|
||||
Bucket: str | None = None,
|
||||
) -> str:
|
||||
"""上传二进制内容。"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def UploadText(
|
||||
self,
|
||||
ObjectKey: str,
|
||||
Content: str,
|
||||
ContentType: str = "text/plain; charset=utf-8",
|
||||
Bucket: str | None = None,
|
||||
) -> str:
|
||||
"""上传文本内容。"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def ObjectExists(self, Source: str, Bucket: str | None = None) -> bool:
|
||||
"""判断对象是否存在。"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def PresignGetUrl(self, Source: str, Bucket: str | None = None) -> str:
|
||||
"""生成对象下载签名 URL。"""
|
||||
...
|
||||
Reference in New Issue
Block a user