chore: initial commit — leaudit-platform project skeleton
17-table PostgreSQL schema with full Chinese column comments, FastAPI project structure (admin/common/modules), DSL rule files, and schema migration scripts.
This commit is contained in:
@@ -0,0 +1,268 @@
|
||||
"""Main leaudit pipeline orchestrator: OCR → Extract → Evaluate.
|
||||
|
||||
Uses leaudit's own pipeline directly (no conversion),
|
||||
stores results into docauditai's database via StorageAdapter.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
from leaudit.dsl.schema import RulesFile
|
||||
from leaudit.engine.case_file_evaluator import evaluate_extraction
|
||||
from leaudit.engine.models import EvaluationResult
|
||||
from leaudit.extraction.bundle import ExtractionBundle
|
||||
from leaudit.extraction.dispatcher import dispatch_extract
|
||||
from leaudit.extraction.phase_detection import determine_phase
|
||||
from leaudit.llm.base import BaseLLMClient
|
||||
from leaudit.ocr.base import BaseOCRClient
|
||||
from leaudit.ocr.models import OcrResult
|
||||
|
||||
from leaudit_bridge.storage_adapter import StorageAdapter
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineResult:
|
||||
"""Complete result from the leaudit pipeline."""
|
||||
|
||||
ocr_result: OcrResult
|
||||
extraction_bundle: ExtractionBundle
|
||||
evaluation_result: EvaluationResult
|
||||
detected_phase: str
|
||||
timing: dict[str, float] = field(default_factory=dict)
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class LauditPipeline:
|
||||
"""Run leaudit's full OCR → extraction → evaluation pipeline.
|
||||
|
||||
Does NOT use leaudit's own SQLAlchemy storage.
|
||||
All results are written to docauditai's database via StorageAdapter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ocr_client: BaseOCRClient,
|
||||
llm_client: BaseLLMClient | None = None,
|
||||
storage_adapter: StorageAdapter | None = None,
|
||||
) -> None:
|
||||
self.ocr_client = ocr_client
|
||||
self.llm_client = llm_client
|
||||
self.storage = storage_adapter or StorageAdapter()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
document_id: int,
|
||||
file_path: str | Path,
|
||||
rules_file: RulesFile | None = None,
|
||||
*,
|
||||
source_port: int = 8000,
|
||||
) -> PipelineResult:
|
||||
"""Execute the full pipeline for one document.
|
||||
|
||||
Args:
|
||||
document_id: docauditai document ID for DB writes.
|
||||
file_path: Path to the document file (PDF/DOCX/etc).
|
||||
rules_file: leaudit RulesFile (parsed from YAML). When None,
|
||||
the pipeline attempts to load rules from the OCR result's
|
||||
``rules_file_path`` (set by the classifier).
|
||||
source_port: Instance port for context switching.
|
||||
|
||||
Returns:
|
||||
PipelineResult with all intermediate and final outputs.
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
errors: list[str] = []
|
||||
timing: dict[str, float] = {}
|
||||
|
||||
# --- Phase 1: Update status to Cutting ---
|
||||
await self.storage.update_document_status(document_id, "Cutting")
|
||||
|
||||
# --- Phase 2: OCR ---
|
||||
t0 = time.time()
|
||||
log.info("[%d] OCR starting: %s", document_id, file_path.name)
|
||||
ocr_result = await self._run_ocr(file_path)
|
||||
timing["ocr"] = round(time.time() - t0, 2)
|
||||
log.info(
|
||||
"[%d] OCR done: %d pages, %.1fs",
|
||||
document_id,
|
||||
len(ocr_result.pages),
|
||||
timing["ocr"],
|
||||
)
|
||||
|
||||
# --- Resolve rules_file after OCR if not provided ---
|
||||
if rules_file is None:
|
||||
rules_file = await self._resolve_rules_from_ocr(ocr_result, document_id)
|
||||
if rules_file is None:
|
||||
raise ValueError(
|
||||
f"Cannot resolve rules_file for document {document_id}. "
|
||||
"Neither passed explicitly nor classified from OCR content."
|
||||
)
|
||||
|
||||
# --- Save OCR result ---
|
||||
await self.storage.save_ocr_result(document_id, ocr_result)
|
||||
|
||||
# --- Extract & save case number (案件编号) ---
|
||||
await self._extract_and_save_case_number(document_id, ocr_result)
|
||||
|
||||
# --- Phase 3: Extraction ---
|
||||
t0 = time.time()
|
||||
await self.storage.update_document_status(document_id, "Extractioning")
|
||||
log.info("[%d] Extraction starting", document_id)
|
||||
|
||||
extraction_bundle = await dispatch_extract(
|
||||
ocr_result,
|
||||
rules_file,
|
||||
llm_client=self.llm_client,
|
||||
phase="executed",
|
||||
)
|
||||
timing["extraction"] = round(time.time() - t0, 2)
|
||||
|
||||
if extraction_bundle.all_errors:
|
||||
errors.extend(extraction_bundle.all_errors)
|
||||
log.warning(
|
||||
"[%d] Extraction completed with %d errors",
|
||||
document_id,
|
||||
len(extraction_bundle.all_errors),
|
||||
)
|
||||
|
||||
log.info(
|
||||
"[%d] Extraction done: %d fields, %.1fs",
|
||||
document_id,
|
||||
len(extraction_bundle.fields),
|
||||
timing["extraction"],
|
||||
)
|
||||
|
||||
# --- Save extraction result ---
|
||||
await self.storage.save_extraction_result(document_id, extraction_bundle)
|
||||
|
||||
# --- Resolve field positions from OCR chunks ---
|
||||
from leaudit.extraction.coordinate_resolver import resolve_bundle_positions
|
||||
resolve_bundle_positions(extraction_bundle, ocr_result)
|
||||
positioned_count = sum(
|
||||
1 for fv in extraction_bundle.fields.values() if fv.position is not None
|
||||
)
|
||||
log.info(
|
||||
"[%d] Coordinate resolution: %d/%d fields positioned",
|
||||
document_id,
|
||||
positioned_count,
|
||||
len(extraction_bundle.fields),
|
||||
)
|
||||
|
||||
# --- Phase 4: Phase detection ---
|
||||
visual_manifest = extraction_bundle.visual_manifest or ocr_result.visual_manifest
|
||||
detected_phase = await determine_phase(
|
||||
extraction_bundle.fields,
|
||||
llm_client=self.llm_client,
|
||||
visual_manifest=visual_manifest,
|
||||
)
|
||||
log.info("[%d] Detected phase: %s", document_id, detected_phase)
|
||||
|
||||
# --- Phase 5: Evaluation ---
|
||||
t0 = time.time()
|
||||
await self.storage.update_document_status(document_id, "Evaluationing")
|
||||
log.info("[%d] Evaluation starting (phase=%s)", document_id, detected_phase)
|
||||
|
||||
external_mocks: dict[str, Any] = {}
|
||||
if self.llm_client is not None:
|
||||
external_mocks["llm_client"] = self.llm_client
|
||||
external_mocks["rules_file"] = rules_file
|
||||
|
||||
evaluation_result = await evaluate_extraction(
|
||||
rules_file,
|
||||
extraction_bundle,
|
||||
visual_manifest=visual_manifest,
|
||||
phase=detected_phase,
|
||||
external_mocks=external_mocks,
|
||||
)
|
||||
timing["evaluation"] = round(time.time() - t0, 2)
|
||||
log.info(
|
||||
"[%d] Evaluation done: %d passed, %d failed, %d skipped, %.1fs",
|
||||
document_id,
|
||||
evaluation_result.passed_count,
|
||||
evaluation_result.failed_count,
|
||||
evaluation_result.skipped_count,
|
||||
timing["evaluation"],
|
||||
)
|
||||
|
||||
# --- Save evaluation results ---
|
||||
await self.storage.save_evaluation_results(
|
||||
document_id, rules_file, evaluation_result, extraction_bundle,
|
||||
)
|
||||
|
||||
# --- Phase 6: Finalize ---
|
||||
timing["total"] = round(sum(timing.values()), 2)
|
||||
await self.storage.update_document_status(document_id, "Processed")
|
||||
|
||||
log.info(
|
||||
"[%d] Pipeline complete: phase=%s, timing=%s",
|
||||
document_id,
|
||||
detected_phase,
|
||||
timing,
|
||||
)
|
||||
|
||||
return PipelineResult(
|
||||
ocr_result=ocr_result,
|
||||
extraction_bundle=extraction_bundle,
|
||||
evaluation_result=evaluation_result,
|
||||
detected_phase=detected_phase,
|
||||
timing=timing,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def _run_ocr(self, file_path: Path) -> OcrResult:
|
||||
"""Run OCR with error handling."""
|
||||
try:
|
||||
return await self.ocr_client.ocr(file_path)
|
||||
except Exception as e:
|
||||
log.error("OCR failed for %s: %s", file_path.name, e)
|
||||
raise
|
||||
|
||||
async def _extract_and_save_case_number(
|
||||
self, document_id: int, ocr_result: OcrResult,
|
||||
) -> None:
|
||||
"""Extract case number from OCR and write to database."""
|
||||
from leaudit_bridge.case_number_extractor import (
|
||||
extract_case_number_with_llm,
|
||||
)
|
||||
|
||||
case_number = await extract_case_number_with_llm(
|
||||
ocr_result, llm_client=self.llm_client,
|
||||
)
|
||||
if case_number:
|
||||
await self.storage.update_document_number(document_id, case_number)
|
||||
log.info("[%d] Case number: %s", document_id, case_number)
|
||||
else:
|
||||
log.info("[%d] No case number found", document_id)
|
||||
|
||||
async def _resolve_rules_from_ocr(
|
||||
self, ocr_result: OcrResult, document_id: int,
|
||||
) -> RulesFile | None:
|
||||
"""Load rules_file from OCR classification result."""
|
||||
from leaudit.dsl.loader import load_rules_file
|
||||
|
||||
rfp = ocr_result.rules_file_path
|
||||
if not rfp:
|
||||
log.warning(
|
||||
"[%d] No rules_file_path in OCR result, cannot resolve rules",
|
||||
document_id,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
rules_file = load_rules_file(rfp)
|
||||
log.info(
|
||||
"[%d] Resolved rules from classification: %s (%d rules)",
|
||||
document_id, rfp, len(rules_file.flat_rules),
|
||||
)
|
||||
return rules_file
|
||||
except Exception as e:
|
||||
log.error("[%d] Failed to load rules from %s: %s", document_id, rfp, e)
|
||||
return None
|
||||
Reference in New Issue
Block a user