Files
leaudit-platform-backend/fastapi_modules/fastapi_leaudit/leaudit_bridge/pipeline.py
T
wren 535d97a70c chore: initial commit — leaudit-platform project skeleton
17-table PostgreSQL schema with full Chinese column comments,
FastAPI project structure (admin/common/modules),
DSL rule files, and schema migration scripts.
2026-04-27 16:48:22 +08:00

269 lines
9.2 KiB
Python

"""Main leaudit pipeline orchestrator: OCR → Extract → Evaluate.
Uses leaudit's own pipeline directly (no conversion),
stores results into docauditai's database via StorageAdapter.
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from leaudit.dsl.schema import RulesFile
from leaudit.engine.case_file_evaluator import evaluate_extraction
from leaudit.engine.models import EvaluationResult
from leaudit.extraction.bundle import ExtractionBundle
from leaudit.extraction.dispatcher import dispatch_extract
from leaudit.extraction.phase_detection import determine_phase
from leaudit.llm.base import BaseLLMClient
from leaudit.ocr.base import BaseOCRClient
from leaudit.ocr.models import OcrResult
from leaudit_bridge.storage_adapter import StorageAdapter
log = logging.getLogger(__name__)
@dataclass
class PipelineResult:
"""Complete result from the leaudit pipeline."""
ocr_result: OcrResult
extraction_bundle: ExtractionBundle
evaluation_result: EvaluationResult
detected_phase: str
timing: dict[str, float] = field(default_factory=dict)
errors: list[str] = field(default_factory=list)
class LauditPipeline:
"""Run leaudit's full OCR → extraction → evaluation pipeline.
Does NOT use leaudit's own SQLAlchemy storage.
All results are written to docauditai's database via StorageAdapter.
"""
def __init__(
self,
ocr_client: BaseOCRClient,
llm_client: BaseLLMClient | None = None,
storage_adapter: StorageAdapter | None = None,
) -> None:
self.ocr_client = ocr_client
self.llm_client = llm_client
self.storage = storage_adapter or StorageAdapter()
async def run(
self,
document_id: int,
file_path: str | Path,
rules_file: RulesFile | None = None,
*,
source_port: int = 8000,
) -> PipelineResult:
"""Execute the full pipeline for one document.
Args:
document_id: docauditai document ID for DB writes.
file_path: Path to the document file (PDF/DOCX/etc).
rules_file: leaudit RulesFile (parsed from YAML). When None,
the pipeline attempts to load rules from the OCR result's
``rules_file_path`` (set by the classifier).
source_port: Instance port for context switching.
Returns:
PipelineResult with all intermediate and final outputs.
"""
file_path = Path(file_path)
errors: list[str] = []
timing: dict[str, float] = {}
# --- Phase 1: Update status to Cutting ---
await self.storage.update_document_status(document_id, "Cutting")
# --- Phase 2: OCR ---
t0 = time.time()
log.info("[%d] OCR starting: %s", document_id, file_path.name)
ocr_result = await self._run_ocr(file_path)
timing["ocr"] = round(time.time() - t0, 2)
log.info(
"[%d] OCR done: %d pages, %.1fs",
document_id,
len(ocr_result.pages),
timing["ocr"],
)
# --- Resolve rules_file after OCR if not provided ---
if rules_file is None:
rules_file = await self._resolve_rules_from_ocr(ocr_result, document_id)
if rules_file is None:
raise ValueError(
f"Cannot resolve rules_file for document {document_id}. "
"Neither passed explicitly nor classified from OCR content."
)
# --- Save OCR result ---
await self.storage.save_ocr_result(document_id, ocr_result)
# --- Extract & save case number (案件编号) ---
await self._extract_and_save_case_number(document_id, ocr_result)
# --- Phase 3: Extraction ---
t0 = time.time()
await self.storage.update_document_status(document_id, "Extractioning")
log.info("[%d] Extraction starting", document_id)
extraction_bundle = await dispatch_extract(
ocr_result,
rules_file,
llm_client=self.llm_client,
phase="executed",
)
timing["extraction"] = round(time.time() - t0, 2)
if extraction_bundle.all_errors:
errors.extend(extraction_bundle.all_errors)
log.warning(
"[%d] Extraction completed with %d errors",
document_id,
len(extraction_bundle.all_errors),
)
log.info(
"[%d] Extraction done: %d fields, %.1fs",
document_id,
len(extraction_bundle.fields),
timing["extraction"],
)
# --- Save extraction result ---
await self.storage.save_extraction_result(document_id, extraction_bundle)
# --- Resolve field positions from OCR chunks ---
from leaudit.extraction.coordinate_resolver import resolve_bundle_positions
resolve_bundle_positions(extraction_bundle, ocr_result)
positioned_count = sum(
1 for fv in extraction_bundle.fields.values() if fv.position is not None
)
log.info(
"[%d] Coordinate resolution: %d/%d fields positioned",
document_id,
positioned_count,
len(extraction_bundle.fields),
)
# --- Phase 4: Phase detection ---
visual_manifest = extraction_bundle.visual_manifest or ocr_result.visual_manifest
detected_phase = await determine_phase(
extraction_bundle.fields,
llm_client=self.llm_client,
visual_manifest=visual_manifest,
)
log.info("[%d] Detected phase: %s", document_id, detected_phase)
# --- Phase 5: Evaluation ---
t0 = time.time()
await self.storage.update_document_status(document_id, "Evaluationing")
log.info("[%d] Evaluation starting (phase=%s)", document_id, detected_phase)
external_mocks: dict[str, Any] = {}
if self.llm_client is not None:
external_mocks["llm_client"] = self.llm_client
external_mocks["rules_file"] = rules_file
evaluation_result = await evaluate_extraction(
rules_file,
extraction_bundle,
visual_manifest=visual_manifest,
phase=detected_phase,
external_mocks=external_mocks,
)
timing["evaluation"] = round(time.time() - t0, 2)
log.info(
"[%d] Evaluation done: %d passed, %d failed, %d skipped, %.1fs",
document_id,
evaluation_result.passed_count,
evaluation_result.failed_count,
evaluation_result.skipped_count,
timing["evaluation"],
)
# --- Save evaluation results ---
await self.storage.save_evaluation_results(
document_id, rules_file, evaluation_result, extraction_bundle,
)
# --- Phase 6: Finalize ---
timing["total"] = round(sum(timing.values()), 2)
await self.storage.update_document_status(document_id, "Processed")
log.info(
"[%d] Pipeline complete: phase=%s, timing=%s",
document_id,
detected_phase,
timing,
)
return PipelineResult(
ocr_result=ocr_result,
extraction_bundle=extraction_bundle,
evaluation_result=evaluation_result,
detected_phase=detected_phase,
timing=timing,
errors=errors,
)
async def _run_ocr(self, file_path: Path) -> OcrResult:
"""Run OCR with error handling."""
try:
return await self.ocr_client.ocr(file_path)
except Exception as e:
log.error("OCR failed for %s: %s", file_path.name, e)
raise
async def _extract_and_save_case_number(
self, document_id: int, ocr_result: OcrResult,
) -> None:
"""Extract case number from OCR and write to database."""
from leaudit_bridge.case_number_extractor import (
extract_case_number_with_llm,
)
case_number = await extract_case_number_with_llm(
ocr_result, llm_client=self.llm_client,
)
if case_number:
await self.storage.update_document_number(document_id, case_number)
log.info("[%d] Case number: %s", document_id, case_number)
else:
log.info("[%d] No case number found", document_id)
async def _resolve_rules_from_ocr(
self, ocr_result: OcrResult, document_id: int,
) -> RulesFile | None:
"""Load rules_file from OCR classification result."""
from leaudit.dsl.loader import load_rules_file
rfp = ocr_result.rules_file_path
if not rfp:
log.warning(
"[%d] No rules_file_path in OCR result, cannot resolve rules",
document_id,
)
return None
try:
rules_file = load_rules_file(rfp)
log.info(
"[%d] Resolved rules from classification: %s (%d rules)",
document_id, rfp, len(rules_file.flat_rules),
)
return rules_file
except Exception as e:
log.error("[%d] Failed to load rules from %s: %s", document_id, rfp, e)
return None