307 lines
10 KiB
Python
307 lines
10 KiB
Python
"""Main leaudit pipeline orchestrator: OCR → Extract → Evaluate.
|
|
|
|
Uses leaudit's own pipeline directly (no conversion),
|
|
stores results into docauditai's database via StorageAdapter.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
|
|
from leaudit.dsl.schema import RulesFile
|
|
from leaudit.engine.case_file_evaluator import evaluate_extraction
|
|
from leaudit.engine.models import EvaluationResult
|
|
from leaudit.extraction.bundle import ExtractionBundle
|
|
from leaudit.extraction.dispatcher import dispatch_extract
|
|
from leaudit.extraction.phase_detection import determine_phase
|
|
from leaudit.llm.base import BaseLLMClient
|
|
from leaudit.ocr.base import BaseOCRClient
|
|
from leaudit.ocr.models import OcrResult
|
|
|
|
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
def _ensure_text_page_chunks(ocr_result: OcrResult) -> None:
|
|
"""Backfill pseudo chunks for text-native pages that have no OCR chunks.
|
|
|
|
DOCX/legacy-doc normalization currently produces page text but often no
|
|
geometric chunks, which causes ``resolve_bundle_positions`` to return zero
|
|
positions for every extracted field. We synthesize coarse text chunks so at
|
|
least page-level定位 can be recovered on the review page.
|
|
"""
|
|
for page in ocr_result.pages:
|
|
if page.chunks:
|
|
continue
|
|
|
|
raw_text = page.text or ""
|
|
normalized_text = (
|
|
raw_text
|
|
.replace("\r\n", "\n")
|
|
.replace("\r", "\n")
|
|
)
|
|
blocks = [
|
|
block.strip()
|
|
for block in normalized_text.split("\n\n")
|
|
if block.strip() and not block.strip().startswith("<!-- PAGE ")
|
|
]
|
|
if not blocks:
|
|
continue
|
|
|
|
page.chunks = [
|
|
{
|
|
"content": block,
|
|
"bbox": [0, 0, 0, 0],
|
|
"label": "text",
|
|
}
|
|
for block in blocks
|
|
]
|
|
page.bboxes = [chunk["bbox"] for chunk in page.chunks]
|
|
|
|
|
|
@dataclass
|
|
class PipelineResult:
|
|
"""Complete result from the leaudit pipeline."""
|
|
|
|
ocr_result: OcrResult
|
|
extraction_bundle: ExtractionBundle
|
|
evaluation_result: EvaluationResult
|
|
detected_phase: str
|
|
timing: dict[str, float] = field(default_factory=dict)
|
|
errors: list[str] = field(default_factory=list)
|
|
|
|
|
|
class LauditPipeline:
|
|
"""Run leaudit's full OCR → extraction → evaluation pipeline.
|
|
|
|
Does NOT use leaudit's own SQLAlchemy storage.
|
|
All results are written to docauditai's database via StorageAdapter.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
ocr_client: BaseOCRClient,
|
|
llm_client: BaseLLMClient | None = None,
|
|
storage_adapter: StorageAdapter | None = None,
|
|
) -> None:
|
|
self.ocr_client = ocr_client
|
|
self.llm_client = llm_client
|
|
self.storage = storage_adapter or StorageAdapter()
|
|
|
|
async def run(
|
|
self,
|
|
document_id: int,
|
|
file_path: str | Path,
|
|
rules_file: RulesFile | None = None,
|
|
*,
|
|
source_port: int = 8000,
|
|
) -> PipelineResult:
|
|
"""Execute the full pipeline for one document.
|
|
|
|
Args:
|
|
document_id: docauditai document ID for DB writes.
|
|
file_path: Path to the document file (PDF/DOCX/etc).
|
|
rules_file: leaudit RulesFile (parsed from YAML). When None,
|
|
the pipeline attempts to load rules from the OCR result's
|
|
``rules_file_path`` (set by the classifier).
|
|
source_port: Instance port for context switching.
|
|
|
|
Returns:
|
|
PipelineResult with all intermediate and final outputs.
|
|
"""
|
|
file_path = Path(file_path)
|
|
errors: list[str] = []
|
|
timing: dict[str, float] = {}
|
|
|
|
# --- Phase 1: Update status to Cutting ---
|
|
await self.storage.update_document_status(document_id, "Cutting")
|
|
|
|
# --- Phase 2: OCR ---
|
|
t0 = time.time()
|
|
log.info("[%d] OCR starting: %s", document_id, file_path.name)
|
|
ocr_result = await self._run_ocr(file_path)
|
|
_ensure_text_page_chunks(ocr_result)
|
|
timing["ocr"] = round(time.time() - t0, 2)
|
|
log.info(
|
|
"[%d] OCR done: %d pages, %.1fs",
|
|
document_id,
|
|
len(ocr_result.pages),
|
|
timing["ocr"],
|
|
)
|
|
|
|
# --- Resolve rules_file after OCR if not provided ---
|
|
if rules_file is None:
|
|
rules_file = await self._resolve_rules_from_ocr(ocr_result, document_id)
|
|
if rules_file is None:
|
|
raise ValueError(
|
|
f"Cannot resolve rules_file for document {document_id}. "
|
|
"Neither passed explicitly nor classified from OCR content."
|
|
)
|
|
|
|
# --- Save OCR result ---
|
|
await self.storage.save_ocr_result(document_id, ocr_result)
|
|
|
|
# --- Extract & save case number (案件编号) ---
|
|
await self._extract_and_save_case_number(document_id, ocr_result)
|
|
|
|
# --- Phase 3: Extraction ---
|
|
t0 = time.time()
|
|
await self.storage.update_document_status(document_id, "Extractioning")
|
|
log.info("[%d] Extraction starting", document_id)
|
|
|
|
extraction_bundle = await dispatch_extract(
|
|
ocr_result,
|
|
rules_file,
|
|
llm_client=self.llm_client,
|
|
phase="executed",
|
|
)
|
|
timing["extraction"] = round(time.time() - t0, 2)
|
|
|
|
if extraction_bundle.all_errors:
|
|
errors.extend(extraction_bundle.all_errors)
|
|
log.warning(
|
|
"[%d] Extraction completed with %d errors",
|
|
document_id,
|
|
len(extraction_bundle.all_errors),
|
|
)
|
|
|
|
log.info(
|
|
"[%d] Extraction done: %d fields, %.1fs",
|
|
document_id,
|
|
len(extraction_bundle.fields),
|
|
timing["extraction"],
|
|
)
|
|
|
|
# --- Resolve field positions from OCR chunks ---
|
|
from leaudit.extraction.coordinate_resolver import resolve_bundle_positions
|
|
resolve_bundle_positions(extraction_bundle, ocr_result)
|
|
positioned_count = sum(
|
|
1 for fv in extraction_bundle.fields.values() if fv.position is not None
|
|
)
|
|
log.info(
|
|
"[%d] Coordinate resolution: %d/%d fields positioned",
|
|
document_id,
|
|
positioned_count,
|
|
len(extraction_bundle.fields),
|
|
)
|
|
|
|
# --- Save extraction result ---
|
|
await self.storage.save_extraction_result(document_id, extraction_bundle)
|
|
|
|
# --- Phase 4: Phase detection ---
|
|
visual_manifest = extraction_bundle.visual_manifest or ocr_result.visual_manifest
|
|
detected_phase = await determine_phase(
|
|
extraction_bundle.fields,
|
|
llm_client=self.llm_client,
|
|
visual_manifest=visual_manifest,
|
|
)
|
|
log.info("[%d] Detected phase: %s", document_id, detected_phase)
|
|
|
|
# --- Phase 5: Evaluation ---
|
|
t0 = time.time()
|
|
await self.storage.update_document_status(document_id, "Evaluationing")
|
|
log.info("[%d] Evaluation starting (phase=%s)", document_id, detected_phase)
|
|
|
|
external_mocks: dict[str, Any] = {}
|
|
if self.llm_client is not None:
|
|
external_mocks["llm_client"] = self.llm_client
|
|
external_mocks["rules_file"] = rules_file
|
|
|
|
evaluation_result = await evaluate_extraction(
|
|
rules_file,
|
|
extraction_bundle,
|
|
visual_manifest=visual_manifest,
|
|
phase=detected_phase,
|
|
external_mocks=external_mocks,
|
|
)
|
|
timing["evaluation"] = round(time.time() - t0, 2)
|
|
log.info(
|
|
"[%d] Evaluation done: %d passed, %d failed, %d skipped, %.1fs",
|
|
document_id,
|
|
evaluation_result.passed_count,
|
|
evaluation_result.failed_count,
|
|
evaluation_result.skipped_count,
|
|
timing["evaluation"],
|
|
)
|
|
|
|
# --- Save evaluation results ---
|
|
await self.storage.save_evaluation_results(
|
|
document_id, rules_file, evaluation_result, extraction_bundle,
|
|
)
|
|
|
|
# --- Phase 6: Finalize ---
|
|
timing["total"] = round(sum(timing.values()), 2)
|
|
await self.storage.update_document_status(document_id, "Processed")
|
|
|
|
log.info(
|
|
"[%d] Pipeline complete: phase=%s, timing=%s",
|
|
document_id,
|
|
detected_phase,
|
|
timing,
|
|
)
|
|
|
|
return PipelineResult(
|
|
ocr_result=ocr_result,
|
|
extraction_bundle=extraction_bundle,
|
|
evaluation_result=evaluation_result,
|
|
detected_phase=detected_phase,
|
|
timing=timing,
|
|
errors=errors,
|
|
)
|
|
|
|
async def _run_ocr(self, file_path: Path) -> OcrResult:
|
|
"""Run OCR with error handling."""
|
|
try:
|
|
return await self.ocr_client.ocr(file_path)
|
|
except Exception as e:
|
|
log.error("OCR failed for %s: %s", file_path.name, e)
|
|
raise
|
|
|
|
async def _extract_and_save_case_number(
|
|
self, document_id: int, ocr_result: OcrResult,
|
|
) -> None:
|
|
"""Extract case number from OCR and write to database."""
|
|
from fastapi_modules.fastapi_leaudit.leaudit_bridge.case_number_extractor import (
|
|
extract_case_number_with_llm,
|
|
)
|
|
|
|
case_number = await extract_case_number_with_llm(
|
|
ocr_result, llm_client=self.llm_client,
|
|
)
|
|
if case_number:
|
|
await self.storage.update_document_number(document_id, case_number)
|
|
log.info("[%d] Case number: %s", document_id, case_number)
|
|
else:
|
|
log.info("[%d] No case number found", document_id)
|
|
|
|
async def _resolve_rules_from_ocr(
|
|
self, ocr_result: OcrResult, document_id: int,
|
|
) -> RulesFile | None:
|
|
"""Load rules_file from OCR classification result."""
|
|
from leaudit.dsl.loader import load_rules_file
|
|
|
|
rfp = ocr_result.rules_file_path
|
|
if not rfp:
|
|
log.warning(
|
|
"[%d] No rules_file_path in OCR result, cannot resolve rules",
|
|
document_id,
|
|
)
|
|
return None
|
|
|
|
try:
|
|
rules_file = load_rules_file(rfp)
|
|
log.info(
|
|
"[%d] Resolved rules from classification: %s (%d rules)",
|
|
document_id, rfp, len(rules_file.flat_rules),
|
|
)
|
|
return rules_file
|
|
except Exception as e:
|
|
log.error("[%d] Failed to load rules from %s: %s", document_id, rfp, e)
|
|
return None
|