chore: initial commit — leaudit-platform project skeleton

17-table PostgreSQL schema with full Chinese column comments,
FastAPI project structure (admin/common/modules),
DSL rule files, and schema migration scripts.
This commit is contained in:
wren
2026-04-27 16:48:22 +08:00
commit 535d97a70c
142 changed files with 25219 additions and 0 deletions
@@ -0,0 +1,82 @@
"""leaudit bridge — use leaudit's full pipeline with docauditai's database storage.
Directly calls leaudit's OCR → extraction → evaluation pipeline
and persists results into docauditai's PostgreSQL via PostgREST.
Configuration switch (in env.{port}):
PIPELINE_MODE=leaudit → use leaudit pipeline
"""
from leaudit_bridge.client_factory import (
create_ocr_client,
create_llm_client,
create_vlm_client,
)
from leaudit_bridge.ocr_bridge import BridgeOCRClient
from leaudit_bridge.pipeline import LauditPipeline, PipelineResult
from leaudit_bridge.rules_loader import RulesLoader
from leaudit_bridge.storage_adapter import StorageAdapter
def is_leaudit_mode() -> bool:
"""Check if the system is configured to use the leaudit pipeline."""
from core.config import PIPELINE_MODE
return PIPELINE_MODE == "leaudit"
def create_pipeline(rules_path: str | None = None) -> LauditPipeline:
"""Create a fully configured LauditPipeline from current config.
Wraps the raw OCR client with DocNormalizationAdapter so that a single
``.ocr()`` call produces a fully enriched OcrResult with:
- Document classification (type_id + rules_file_path)
- Dossier segmentation (sub-document page mapping)
- Seal/signature enrichment (text, seal_id, party_id)
- Normalized markdown (seal blocks + page separators)
Args:
rules_path: If provided, forces the adapter to use this rules file
for classification and segmentation. When None, the adapter
uses the RulesFileRegistry to classify from document content,
enabling auto-detection of sub-types (e.g. 行政许可 variants).
"""
from pathlib import Path
from leaudit.doc_normalization.adapter import DocNormalizationAdapter
from leaudit.doc_normalization.doc_classifier import RulesFileRegistry
raw_ocr = create_ocr_client()
llm_client = create_llm_client()
vlm_client = create_vlm_client()
# Build registry from rules/ directory for content-based classification
registry = None
if rules_path is None:
rules_dir = Path(__file__).resolve().parents[1] / "rules"
if rules_dir.is_dir():
registry = RulesFileRegistry.from_directory(rules_dir)
ocr_client = DocNormalizationAdapter(
ocr_client=raw_ocr,
registry=registry,
llm_client=llm_client,
vlm_client=vlm_client,
force_rules_path=rules_path,
)
ocr_client = BridgeOCRClient(ocr_client, vlm_client=vlm_client)
return LauditPipeline(
ocr_client=ocr_client,
llm_client=llm_client,
)
__all__ = [
"LauditPipeline",
"PipelineResult",
"StorageAdapter",
"RulesLoader",
"create_ocr_client",
"create_llm_client",
"create_pipeline",
"is_leaudit_mode",
]
@@ -0,0 +1,128 @@
"""Extract case number (案件编号) from leaudit OcrResult.
Port of docauditai's ``extract_case_number_for_regex`` adapted for leaudit's
OcrResult model. Uses regex patterns first; falls back to LLM when available.
"""
from __future__ import annotations
import logging
import re
from typing import Any
from leaudit.ocr.models import OcrResult
log = logging.getLogger(__name__)
# Regex patterns for case number extraction
_PATTERNS: list[tuple[str, str]] = [
# Direct match: "案件编号:梅烟专罚〔2024〕第XX号"
(r"案件编号[:]\s*(.*?)(?:\n|$)", "direct"),
# From 卷宗/卷 宗 header: extract content between 卷宗 and 案由
(r"\s*宗([\s\S]*?)案\s*由", "file_content"),
# Standalone pattern: e.g. 梅烟专罚〔2024〕第001号
(r"[\u4e00-\u9fa5]{2,8}[专罚处决][〔(\(\[]\d{4}[〕)\)\]][\u4e00-\u9fa5]*\d+号", "standalone"),
# Dossier number: e.g. "2024 年度 郁烟 第 71 号"
(r"\d{4}\s*年度\s*[\u4e00-\u9fa5]{1,6}\s*第\s*\d+\s*号", "dossier"),
]
# Chinese number format within 卷宗 extracted text
_YEAR_NUMBER_RE = re.compile(r"\d{4}[\u4e00-\u9fa5]+\d+号")
def extract_case_number(ocr_result: OcrResult) -> str | None:
"""Extract case number from OCR result using regex patterns.
Searches across all pages but prioritizes early pages where case numbers
typically appear (封面, 卷宗封面).
Args:
ocr_result: OCR result with pages containing text.
Returns:
Extracted case number string, or None if not found.
"""
# Build text from first few pages (case numbers appear early)
pages_to_check = ocr_result.pages[:5] if len(ocr_result.pages) > 5 else ocr_result.pages
text = "\n".join(p.text for p in pages_to_check)
if not text.strip():
return None
for pattern, ptype in _PATTERNS:
match = re.search(pattern, text)
if not match:
continue
if ptype == "direct":
return match.group(1).strip()
if ptype == "file_content":
content = match.group(1).strip()
num_match = _YEAR_NUMBER_RE.search(content)
if num_match:
return num_match.group()
if ptype == "standalone":
return match.group()
if ptype == "dossier":
return match.group()
return None
async def extract_case_number_with_llm(
ocr_result: OcrResult,
llm_client: Any = None,
) -> str | None:
"""Extract case number using regex first, then LLM fallback.
Args:
ocr_result: OCR result with pages containing text.
llm_client: Optional LLM client for fallback extraction.
Returns:
Extracted case number string, or None if not found.
"""
# Try regex first (fast, no API call)
result = extract_case_number(ocr_result)
if result:
return result
# LLM fallback
if llm_client is None:
return None
text = "\n".join(p.text for p in ocr_result.pages[:5])
if not text.strip():
return None
try:
from leaudit.llm.base import BaseLLMClient, LLMRequest, LLMMessage
if not isinstance(llm_client, BaseLLMClient):
return None
prompt = (
"请从以下法律文书文本中提取案件编号。"
"案件编号通常格式如:梅烟专罚〔2024〕第001号。\n"
"只返回JSON: {\"case_number\": \"案件编号\"}\n\n"
f"文本:\n{text[:2000]}"
)
request = LLMRequest(
messages=[LLMMessage(role="user", content=prompt)],
response_format={"type": "json_object"},
max_tokens=512,
)
response = await llm_client.complete(request)
import json
parsed = json.loads(response.content)
case_number = parsed.get("case_number")
if case_number and case_number != "未找到":
if re.search(r"[\u4e00-\u9fa5]|[\d()]", case_number):
return case_number
except Exception as e:
log.warning("LLM case number extraction failed: %s", e)
return None
@@ -0,0 +1,80 @@
"""Create leaudit OCR/LLM/VLM clients from docauditai's env.{port} config."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from core.config import (
OCR_CONFIG,
DEFAULT_BASE_URL,
DEFAULT_LLM_MODEL,
DEFAULT_API_KEY,
DEFAULT_VLM_BASE_URL,
DEFAULT_VLM_MODEL,
)
if TYPE_CHECKING:
from leaudit.llm.base import BaseLLMClient
from leaudit.llm.vlm_base import BaseVLMClient
from leaudit.ocr.base import BaseOCRClient
log = logging.getLogger(__name__)
def create_ocr_client() -> BaseOCRClient:
"""Create a leaudit ChandraOCRClient from LEAUDIT_OCR_URL config."""
import os
from leaudit.ocr.chandra_client import ChandraOCRClient
base_url = os.getenv("LEAUDIT_OCR_URL", "").rstrip("/")
if not base_url:
base_url = OCR_CONFIG["API_URL"].rsplit("/api/v1/ocr", 1)[0]
timeout = float(OCR_CONFIG["TIMEOUT"])
client = ChandraOCRClient(
base_url=base_url,
timeout=timeout,
include_images=True,
)
log.info("leaudit OCR client created: %s (timeout=%ss)", base_url, timeout)
return client
def create_llm_client() -> BaseLLMClient:
"""Create a leaudit OpenAICompatibleClient from docauditai's LLM config."""
from leaudit.llm.openai_client import OpenAICompatibleClient
base_url = DEFAULT_BASE_URL
model = DEFAULT_LLM_MODEL
api_key = DEFAULT_API_KEY or "no-key"
client = OpenAICompatibleClient(
api_key=api_key,
base_url=base_url,
default_model=model,
timeout=120.0,
)
log.info("leaudit LLM client created: %s (model=%s)", base_url, model)
return client
def create_vlm_client() -> BaseVLMClient | None:
"""Create a leaudit QwenVLMClient from docauditai's VLM config."""
from leaudit.llm.qwen_vlm_client import QwenVLMClient
base_url = DEFAULT_VLM_BASE_URL
model = DEFAULT_VLM_MODEL
api_key = DEFAULT_API_KEY or "no-key"
if not base_url or not model:
log.info("leaudit VLM client skipped: no VLM config")
return None
client = QwenVLMClient(
base_url=base_url,
api_key=api_key,
model=model,
)
log.info("leaudit VLM client created: %s (model=%s)", base_url, model)
return client
@@ -0,0 +1,132 @@
"""Build leaudit execution context from docauditai document data.
Currently leaudit's pipeline in docauditai bypasses leaudit's own
``AuditCtx`` / ``AuditService`` and calls engine modules directly.
This module encapsulates the pre-execution setup that currently lives
inlined in ``pipeline.py`` and ``tasks.py``:
- Resolve local file path (download from OSS to temp if needed)
- Determine RulesFile (from document metadata, type binding, or
content classification)
- Prepare OCR/LLM/VLM client references
"""
from __future__ import annotations
import logging
import os
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
from leaudit.dsl.schema import RulesFile
from leaudit.llm.base import BaseLLMClient
from leaudit.ocr.base import BaseOCRClient
if TYPE_CHECKING:
from leaudit.llm.vlm_base import BaseVLMClient
log = logging.getLogger(__name__)
@dataclass
class ExecutionContext:
"""Everything leaudit needs to run for one document."""
document_id: int
file_path: Path
rules_file: RulesFile
ocr_client: BaseOCRClient
llm_client: BaseLLMClient | None = None
vlm_client: object | None = None
source_port: int = 8000
tmp_path: Path | None = None
metadata: dict = field(default_factory=dict)
def cleanup(self) -> None:
"""Remove temporary file if one was created."""
if self.tmp_path is not None:
try:
os.remove(self.tmp_path)
except OSError:
pass
class CtxBuilder:
"""Build :class:`ExecutionContext` from docauditai document data.
Handles the glue between docauditai's document model and leaudit's
execution expectations — primarily file-path resolution and rules
selection.
"""
def __init__(
self,
ocr_client: BaseOCRClient | None = None,
llm_client: BaseLLMClient | None = None,
vlm_client: object | None = None,
) -> None:
self.ocr_client = ocr_client
self.llm_client = llm_client
self.vlm_client = vlm_client
async def build(
self,
document_id: int,
file_path: str | Path | None = None,
file_content: bytes | None = None,
filename: str | None = None,
rules_file: RulesFile | None = None,
*,
source_port: int = 8000,
) -> ExecutionContext:
"""Build a ready-to-use execution context.
At least one of *file_path* or (*file_content* + *filename*)
must be provided.
Args:
document_id: docauditai document ID.
file_path: Existing local path to the document file.
file_content: Raw bytes (from DB or OSS) — a temp file is
created.
filename: Required when *file_content* is given.
rules_file: Pre-loaded RulesFile. When None, the caller
must resolve after OCR classification.
source_port: Instance port.
Returns:
ExecutionContext ready for pipeline.run().
"""
tmp_path: Path | None = None
if file_path is not None:
resolved = Path(file_path)
elif file_content is not None and filename is not None:
suffix = self._suffix(filename)
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
tmp.write(file_content)
tmp.close()
resolved = Path(tmp.name)
tmp_path = resolved
else:
raise ValueError(
"Either file_path or (file_content + filename) is required"
)
return ExecutionContext(
document_id=document_id,
file_path=resolved,
rules_file=rules_file, # type: ignore[arg-type]
ocr_client=self.ocr_client, # type: ignore[arg-type]
llm_client=self.llm_client,
vlm_client=self.vlm_client,
source_port=source_port,
tmp_path=tmp_path,
)
@staticmethod
def _suffix(filename: str) -> str:
_, ext = os.path.splitext(filename)
return ext if ext else ".pdf"
@@ -0,0 +1,272 @@
"""Bridge-side OCR post-processing for leaudit integration.
Keeps docauditai-specific fixes outside ``services/leaudit/**``:
- DOCX embedded-image visuals can be refined once more with the VLM after
the merged ``OcrResult`` is built.
- Cross-page seals with missing completeness flags are normalized so the
legacy compatibility checks have a stable shape to consume.
"""
from __future__ import annotations
import logging
from io import BytesIO
from pathlib import Path
from leaudit.ocr.base import BaseOCRClient
from leaudit.ocr.models import OcrResult, VisualManifestItem
log = logging.getLogger(__name__)
class BridgeOCRClient(BaseOCRClient):
"""Wrap an OCR client and apply integration-side post-processing."""
def __init__(
self,
inner: BaseOCRClient,
*,
vlm_client: object | None = None,
vlm_concurrency: int = 6,
) -> None:
self.inner = inner
self.vlm_client = vlm_client
self.vlm_concurrency = vlm_concurrency
async def ocr(self, file_path: Path | str) -> OcrResult:
path = Path(file_path)
result = await self.inner.ocr(path)
await postprocess_ocr_result(
result,
file_path=path,
vlm_client=self.vlm_client,
vlm_concurrency=self.vlm_concurrency,
)
return result
async def postprocess_ocr_result(
ocr_result: OcrResult,
*,
file_path: Path,
vlm_client: object | None = None,
vlm_concurrency: int = 6,
) -> OcrResult:
"""Apply bridge-side visual repairs without touching leaudit core."""
suffix = file_path.suffix.lower()
if suffix not in {".docx", ".doc", ".wps"}:
return ocr_result
await _maybe_refine_docx_visuals(
ocr_result,
vlm_client=vlm_client,
concurrency=vlm_concurrency,
)
await _inject_docx_signature_candidates(
ocr_result,
vlm_client=vlm_client,
)
_normalize_cross_page_seals(ocr_result)
return ocr_result
async def _maybe_refine_docx_visuals(
ocr_result: OcrResult,
*,
vlm_client: object | None,
concurrency: int,
) -> None:
vm = ocr_result.visual_manifest
if vlm_client is None or vm is None:
return
if not (vm.seals or vm.signatures or vm.cross_page_seals):
return
try:
from leaudit.ocr.visual_classifier import refine_visual_manifest
await refine_visual_manifest(
ocr_result,
vlm_client,
concurrency=concurrency,
)
except Exception as exc:
log.warning("bridge visual refinement skipped: %s", exc)
async def _inject_docx_signature_candidates(
ocr_result: OcrResult,
*,
vlm_client: object | None,
) -> None:
"""Probe likely handwritten-signature zones on DOCX parent images."""
if vlm_client is None:
return
try:
from PIL import Image
except ImportError:
log.warning("Pillow unavailable, skip DOCX signature candidate probing")
return
parent_to_items: dict[str, list[VisualManifestItem]] = {}
for bucket in (
ocr_result.visual_manifest.seals or [],
ocr_result.visual_manifest.signatures or [],
ocr_result.visual_manifest.cross_page_seals or [],
):
for item in bucket:
parent_key = getattr(item, "parent_image_key", None)
if parent_key:
parent_to_items.setdefault(parent_key, []).append(item)
for parent_key, items in parent_to_items.items():
if any((it.label or "") == "signature" for it in items):
continue
parent_bytes = ocr_result.get_image_bytes(parent_key)
if not parent_bytes:
continue
try:
image = Image.open(BytesIO(parent_bytes))
except Exception as exc:
log.warning("failed to open parent image %s: %s", parent_key, exc)
continue
width, height = image.size
for candidate_bbox in _signature_candidate_boxes(items, width, height):
try:
crop = image.crop(tuple(candidate_bbox))
buf = BytesIO()
crop.save(buf, format="PNG")
result = await _classify_signature_candidate(
vlm_client,
buf.getvalue(),
"这是合同签章页里疑似法人签名的候选区域,请优先判断是否为手写签名。",
)
except Exception as exc:
log.warning("signature probe failed for %s: %s", parent_key, exc)
continue
if getattr(result, "kind", None) != "signature":
continue
page_num = _infer_parent_page_num(items)
ocr_result.visual_manifest.signatures.append(
VisualManifestItem(
page_num=page_num,
bbox=candidate_bbox,
label="signature",
confidence=getattr(result, "confidence", 0.9) or 0.9,
text_match=(getattr(result, "text", None) or "").strip() or None,
alt_text="docx_signature_candidate",
image_key=parent_key,
parent_image_key=parent_key,
)
)
break
async def _classify_signature_candidate(
vlm_client: object,
image_bytes: bytes,
user_hint: str,
) -> object:
"""Classify with one retry using a fresh VLM client when needed."""
try:
return await vlm_client.classify_visual(image_bytes, user_hint=user_hint)
except Exception as exc:
log.warning("signature probe primary VLM failed, retrying fresh client: %s", exc)
try:
from leaudit.llm.qwen_vlm_client import QwenVLMClient
fresh = QwenVLMClient(
base_url=getattr(vlm_client, "base_url"),
api_key=getattr(vlm_client, "api_key", ""),
model=getattr(vlm_client, "model"),
timeout=getattr(vlm_client, "timeout", 90.0),
)
try:
return await fresh.classify_visual(image_bytes, user_hint=user_hint)
finally:
await fresh.close()
except Exception as exc:
raise RuntimeError(exc) from exc
def _signature_candidate_boxes(
items: list[VisualManifestItem],
width: int,
height: int,
) -> list[list[int]]:
candidates: list[list[int]] = []
seen: set[tuple[int, int, int, int]] = set()
for item in items:
seal_type = getattr(item, "seal_type", None)
label = getattr(item, "label", None)
bbox = getattr(item, "bbox", None) or []
if len(bbox) != 4:
continue
x1, y1, x2, y2 = bbox
box_w = max(1, x2 - x1)
box_h = max(1, y2 - y1)
ratio = box_w / box_h
if seal_type == "法人章" or label == "法人章":
continue
if not (0.75 <= ratio <= 1.35):
continue
if box_w < width * 0.10 or box_h < height * 0.10:
continue
cand = [
max(0, int(x1 - box_w * 0.25)),
max(0, int(y1 + box_h * 0.50)),
min(width, int(x2 + box_w * 0.25)),
min(height, int(y2 + box_h * 0.95)),
]
if cand[2] - cand[0] < 24 or cand[3] - cand[1] < 24:
continue
key = tuple(cand)
if key not in seen:
seen.add(key)
candidates.append(cand)
return candidates
def _infer_parent_page_num(items: list[VisualManifestItem]) -> int:
for item in items:
page_num = getattr(item, "page_num", None)
if isinstance(page_num, int):
return page_num
return 0
def _normalize_cross_page_seals(ocr_result: OcrResult) -> None:
"""Fill obvious completeness defaults for bridge-side checks."""
for item in ocr_result.visual_manifest.cross_page_seals or []:
if item.pages and len(item.pages) >= 2:
item.is_complete = True
continue
bbox = item.bbox or []
if len(bbox) == 4:
width = max(1, bbox[2] - bbox[0])
height = max(1, bbox[3] - bbox[1])
ratio = width / height
# DOCX embedded images often contain a complete round seal near the
# page edge; Chandra may still classify it as a seam-seal half by
# geometry. A near-square crop is a strong signal that the visible
# stamp is already complete.
if 0.65 <= ratio <= 1.35:
item.is_complete = True
continue
if item.is_complete is not None:
continue
if item.pages and len(item.pages) == 1:
item.is_complete = False
@@ -0,0 +1,268 @@
"""Main leaudit pipeline orchestrator: OCR → Extract → Evaluate.
Uses leaudit's own pipeline directly (no conversion),
stores results into docauditai's database via StorageAdapter.
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from leaudit.dsl.schema import RulesFile
from leaudit.engine.case_file_evaluator import evaluate_extraction
from leaudit.engine.models import EvaluationResult
from leaudit.extraction.bundle import ExtractionBundle
from leaudit.extraction.dispatcher import dispatch_extract
from leaudit.extraction.phase_detection import determine_phase
from leaudit.llm.base import BaseLLMClient
from leaudit.ocr.base import BaseOCRClient
from leaudit.ocr.models import OcrResult
from leaudit_bridge.storage_adapter import StorageAdapter
log = logging.getLogger(__name__)
@dataclass
class PipelineResult:
"""Complete result from the leaudit pipeline."""
ocr_result: OcrResult
extraction_bundle: ExtractionBundle
evaluation_result: EvaluationResult
detected_phase: str
timing: dict[str, float] = field(default_factory=dict)
errors: list[str] = field(default_factory=list)
class LauditPipeline:
"""Run leaudit's full OCR → extraction → evaluation pipeline.
Does NOT use leaudit's own SQLAlchemy storage.
All results are written to docauditai's database via StorageAdapter.
"""
def __init__(
self,
ocr_client: BaseOCRClient,
llm_client: BaseLLMClient | None = None,
storage_adapter: StorageAdapter | None = None,
) -> None:
self.ocr_client = ocr_client
self.llm_client = llm_client
self.storage = storage_adapter or StorageAdapter()
async def run(
self,
document_id: int,
file_path: str | Path,
rules_file: RulesFile | None = None,
*,
source_port: int = 8000,
) -> PipelineResult:
"""Execute the full pipeline for one document.
Args:
document_id: docauditai document ID for DB writes.
file_path: Path to the document file (PDF/DOCX/etc).
rules_file: leaudit RulesFile (parsed from YAML). When None,
the pipeline attempts to load rules from the OCR result's
``rules_file_path`` (set by the classifier).
source_port: Instance port for context switching.
Returns:
PipelineResult with all intermediate and final outputs.
"""
file_path = Path(file_path)
errors: list[str] = []
timing: dict[str, float] = {}
# --- Phase 1: Update status to Cutting ---
await self.storage.update_document_status(document_id, "Cutting")
# --- Phase 2: OCR ---
t0 = time.time()
log.info("[%d] OCR starting: %s", document_id, file_path.name)
ocr_result = await self._run_ocr(file_path)
timing["ocr"] = round(time.time() - t0, 2)
log.info(
"[%d] OCR done: %d pages, %.1fs",
document_id,
len(ocr_result.pages),
timing["ocr"],
)
# --- Resolve rules_file after OCR if not provided ---
if rules_file is None:
rules_file = await self._resolve_rules_from_ocr(ocr_result, document_id)
if rules_file is None:
raise ValueError(
f"Cannot resolve rules_file for document {document_id}. "
"Neither passed explicitly nor classified from OCR content."
)
# --- Save OCR result ---
await self.storage.save_ocr_result(document_id, ocr_result)
# --- Extract & save case number (案件编号) ---
await self._extract_and_save_case_number(document_id, ocr_result)
# --- Phase 3: Extraction ---
t0 = time.time()
await self.storage.update_document_status(document_id, "Extractioning")
log.info("[%d] Extraction starting", document_id)
extraction_bundle = await dispatch_extract(
ocr_result,
rules_file,
llm_client=self.llm_client,
phase="executed",
)
timing["extraction"] = round(time.time() - t0, 2)
if extraction_bundle.all_errors:
errors.extend(extraction_bundle.all_errors)
log.warning(
"[%d] Extraction completed with %d errors",
document_id,
len(extraction_bundle.all_errors),
)
log.info(
"[%d] Extraction done: %d fields, %.1fs",
document_id,
len(extraction_bundle.fields),
timing["extraction"],
)
# --- Save extraction result ---
await self.storage.save_extraction_result(document_id, extraction_bundle)
# --- Resolve field positions from OCR chunks ---
from leaudit.extraction.coordinate_resolver import resolve_bundle_positions
resolve_bundle_positions(extraction_bundle, ocr_result)
positioned_count = sum(
1 for fv in extraction_bundle.fields.values() if fv.position is not None
)
log.info(
"[%d] Coordinate resolution: %d/%d fields positioned",
document_id,
positioned_count,
len(extraction_bundle.fields),
)
# --- Phase 4: Phase detection ---
visual_manifest = extraction_bundle.visual_manifest or ocr_result.visual_manifest
detected_phase = await determine_phase(
extraction_bundle.fields,
llm_client=self.llm_client,
visual_manifest=visual_manifest,
)
log.info("[%d] Detected phase: %s", document_id, detected_phase)
# --- Phase 5: Evaluation ---
t0 = time.time()
await self.storage.update_document_status(document_id, "Evaluationing")
log.info("[%d] Evaluation starting (phase=%s)", document_id, detected_phase)
external_mocks: dict[str, Any] = {}
if self.llm_client is not None:
external_mocks["llm_client"] = self.llm_client
external_mocks["rules_file"] = rules_file
evaluation_result = await evaluate_extraction(
rules_file,
extraction_bundle,
visual_manifest=visual_manifest,
phase=detected_phase,
external_mocks=external_mocks,
)
timing["evaluation"] = round(time.time() - t0, 2)
log.info(
"[%d] Evaluation done: %d passed, %d failed, %d skipped, %.1fs",
document_id,
evaluation_result.passed_count,
evaluation_result.failed_count,
evaluation_result.skipped_count,
timing["evaluation"],
)
# --- Save evaluation results ---
await self.storage.save_evaluation_results(
document_id, rules_file, evaluation_result, extraction_bundle,
)
# --- Phase 6: Finalize ---
timing["total"] = round(sum(timing.values()), 2)
await self.storage.update_document_status(document_id, "Processed")
log.info(
"[%d] Pipeline complete: phase=%s, timing=%s",
document_id,
detected_phase,
timing,
)
return PipelineResult(
ocr_result=ocr_result,
extraction_bundle=extraction_bundle,
evaluation_result=evaluation_result,
detected_phase=detected_phase,
timing=timing,
errors=errors,
)
async def _run_ocr(self, file_path: Path) -> OcrResult:
"""Run OCR with error handling."""
try:
return await self.ocr_client.ocr(file_path)
except Exception as e:
log.error("OCR failed for %s: %s", file_path.name, e)
raise
async def _extract_and_save_case_number(
self, document_id: int, ocr_result: OcrResult,
) -> None:
"""Extract case number from OCR and write to database."""
from leaudit_bridge.case_number_extractor import (
extract_case_number_with_llm,
)
case_number = await extract_case_number_with_llm(
ocr_result, llm_client=self.llm_client,
)
if case_number:
await self.storage.update_document_number(document_id, case_number)
log.info("[%d] Case number: %s", document_id, case_number)
else:
log.info("[%d] No case number found", document_id)
async def _resolve_rules_from_ocr(
self, ocr_result: OcrResult, document_id: int,
) -> RulesFile | None:
"""Load rules_file from OCR classification result."""
from leaudit.dsl.loader import load_rules_file
rfp = ocr_result.rules_file_path
if not rfp:
log.warning(
"[%d] No rules_file_path in OCR result, cannot resolve rules",
document_id,
)
return None
try:
rules_file = load_rules_file(rfp)
log.info(
"[%d] Resolved rules from classification: %s (%d rules)",
document_id, rfp, len(rules_file.flat_rules),
)
return rules_file
except Exception as e:
log.error("[%d] Failed to load rules from %s: %s", document_id, rfp, e)
return None
@@ -0,0 +1,303 @@
"""Adapt leaudit raw results into docauditai's standardized format.
Currently this logic is inlined in ``storage_adapter.py``
(``_rule_result_to_row``, ``_bundle_to_extracted``, ``_ocr_to_dict``).
This module extracts the conversion layer so storage_adapter focuses
on persistence (the "adapter" part of result_adapter).
"""
from __future__ import annotations
import logging
import re
from typing import Any
from leaudit.dsl.schema import Rule as DslRule
from leaudit.dsl.schema import RulesFile
from leaudit.engine.models import EvaluationResult, RuleResult
from leaudit.extraction.bundle import ExtractionBundle
from leaudit.extraction.models import FieldValue
from leaudit.ocr.models import OcrResult
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# OCR result → dict
# ---------------------------------------------------------------------------
def ocr_to_dict(ocr: OcrResult) -> dict[str, Any]:
"""Convert OcrResult to a JSON-safe dict for storage."""
result: dict[str, Any] = {
"numPages": len(ocr.pages),
"full_text": ocr.full_text,
"pages": [],
}
for page in ocr.pages:
page_dict: dict[str, Any] = {
"page_num": page.page_num,
"text": page.text,
"page_box": page.page_box,
}
if page.chunks:
page_dict["chunks"] = [
(
{"bbox": c["bbox"], "content": c["content"], "label": c.get("label")}
if isinstance(c, dict) and "bbox" in c and "content" in c
else {
"bbox": c.bbox if hasattr(c, "bbox") else None,
"content": c.content if hasattr(c, "content") else str(c),
"label": c.label if hasattr(c, "label") else None,
}
)
for c in page.chunks
]
if page.bboxes:
page_dict["bboxes"] = page.bboxes
result["pages"].append(page_dict)
if ocr.visual_manifest:
result["visual_manifest"] = ocr.visual_manifest.model_dump(mode="json")
if ocr.images:
result["images"] = ocr.images
return result
# ---------------------------------------------------------------------------
# ExtractionBundle → dict
# ---------------------------------------------------------------------------
def bundle_to_extracted(bundle: ExtractionBundle) -> dict[str, Any]:
"""Convert ExtractionBundle to docauditai's extracted_results format."""
fields: dict[str, Any] = {}
for name, fv in bundle.fields.items():
if isinstance(fv, FieldValue):
field_data = {
"value": fv.value,
"confidence": float(fv.confidence) if fv.confidence else 0.0,
}
if fv.position is not None:
field_data["position"] = fv.position.model_dump(mode="json")
fields[name] = field_data
else:
fields[name] = {"value": fv}
multi_entity: dict[str, Any] = {}
for name, rows in bundle.multi_entity.items():
multi_entity[name] = [
{
k: (v.value if isinstance(v, FieldValue) else v)
for k, v in row.items()
}
if isinstance(row, dict)
else {"value": row}
for row in rows
]
return {
"fields": fields,
"multi_entity": multi_entity,
"derived": dict(bundle.derived) if bundle.derived else {},
"is_case_file": bundle.is_case_file,
}
# ---------------------------------------------------------------------------
# EvaluationResult → per-rule rows
# ---------------------------------------------------------------------------
def rule_result_to_row(
document_id: int,
run_id: int,
rule_result: RuleResult,
rule: Any | None,
bundle: ExtractionBundle,
) -> dict[str, Any]:
"""Convert one RuleResult to a database row dict.
Args:
document_id: docauditai document ID.
run_id: ``leaudit_audit_runs.id`` for this execution.
rule_result: Single rule evaluation result from leaudit.
rule: DSL Rule definition (for metadata lookups).
bundle: The extraction bundle (for field position lookups).
"""
passed = rule_result.passed
pass_msg = ""
fail_msg = ""
if rule_result.messages:
pass_msg = rule_result.messages.get("pass", "")
fail_msg = rule_result.messages.get("fail", "")
elif isinstance(rule, DslRule) and rule.messages:
pass_msg = rule.messages.get("pass", "")
fail_msg = rule.messages.get("fail", "")
relevant_fields = _extract_relevant_fields(rule, bundle)
field_positions = _extract_relevant_field_positions(rule, bundle)
remediation = None
if rule_result.remediation:
remediation = rule_result.remediation.model_dump(mode="json")
rule_meta: dict[str, Any] = {}
if isinstance(rule, DslRule):
if rule.references_laws:
rule_meta["references_laws"] = rule.references_laws
if rule.desc:
rule_meta["desc"] = rule.desc
if rule.group:
rule_meta["group"] = rule.group
return {
"document_id": document_id,
"run_id": run_id,
"rule_id": rule_result.rule_id,
"rule_name": rule_result.name,
"risk": rule_result.risk or "medium",
"score": rule_result.score,
"passed": passed,
"status": rule_result.status,
"skip_reason": rule_result.skip_reason,
"confidence": rule_result.confidence,
"pass_message": pass_msg,
"fail_message": fail_msg,
"stages": [s.model_dump(mode="json") for s in (rule_result.stages or [])],
"extracted_fields": relevant_fields,
"field_positions": field_positions,
"remediation": remediation,
"rule_meta": rule_meta,
"rescue_applied": False,
"rescue_passed": None,
}
def evaluation_summary(eval_result: EvaluationResult) -> dict[str, Any]:
"""Extract summary fields from an EvaluationResult."""
return {
"total_score": eval_result.total_score,
"passed_count": eval_result.passed_count,
"failed_count": eval_result.failed_count,
"skipped_count": eval_result.skipped_count,
"result_status": _result_status(eval_result),
}
def _result_status(eval_result: EvaluationResult) -> str:
if eval_result.errors:
return "error"
if eval_result.failed_count == 0 and eval_result.skipped_count == 0:
return "pass"
if eval_result.failed_count > 0:
return "fail"
return "partial"
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _extract_relevant_fields(
rule: Any, bundle: ExtractionBundle,
) -> dict[str, Any]:
"""Extract field values referenced by a rule's stages."""
relevant: dict[str, Any] = {}
if not rule or not hasattr(rule, "stages") or not rule.stages:
return relevant
for stage in rule.stages:
stage_data = (
stage.model_dump(exclude_none=True)
if hasattr(stage, "model_dump")
else {}
)
extra = stage.extra if hasattr(stage, "extra") else {}
field_names: list[str] = []
for key in ("field", "field1", "field2", "fields"):
val = getattr(stage, key, None) or extra.get(key) or stage_data.get(key)
if isinstance(val, list):
field_names.extend(f for f in val if isinstance(f, str))
elif isinstance(val, str):
field_names.append(val)
pairs = stage_data.get("pairs")
if isinstance(pairs, list):
for pair in pairs:
if not isinstance(pair, dict):
continue
for key in ("source", "target", "a", "b"):
ref = pair.get(key)
if isinstance(ref, str):
field_names.append(ref)
prompt = stage_data.get("prompt")
if isinstance(prompt, str):
for m in re.finditer(r"\{\{\s*([^{}]+?)\s*\}\}", prompt):
field_names.append(m.group(1).strip())
for f in field_names:
if f in relevant:
continue
if f not in bundle.fields:
continue
fv = bundle.fields[f]
relevant[f] = fv.value if isinstance(fv, FieldValue) else fv
return relevant
def _extract_relevant_field_positions(
rule: Any, bundle: ExtractionBundle,
) -> dict[str, Any]:
"""Extract position data for fields referenced by a rule's stages."""
positions: dict[str, Any] = {}
if not rule or not hasattr(rule, "stages") or not rule.stages:
return positions
for stage in rule.stages:
stage_data = (
stage.model_dump(exclude_none=True)
if hasattr(stage, "model_dump")
else {}
)
extra = stage.extra if hasattr(stage, "extra") else {}
field_names: list[str] = []
for key in ("field", "field1", "field2", "fields"):
val = getattr(stage, key, None) or extra.get(key) or stage_data.get(key)
if isinstance(val, list):
field_names.extend(f for f in val if isinstance(f, str))
elif isinstance(val, str):
field_names.append(val)
pairs = stage_data.get("pairs")
if isinstance(pairs, list):
for pair in pairs:
if not isinstance(pair, dict):
continue
for key in ("source", "target", "a", "b"):
ref = pair.get(key)
if isinstance(ref, str):
field_names.append(ref)
prompt = stage_data.get("prompt")
if isinstance(prompt, str):
for m in re.finditer(r"\{\{\s*([^{}]+?)\s*\}\}", prompt):
field_names.append(m.group(1).strip())
for f in field_names:
if f in positions:
continue
fv = bundle.fields.get(f)
if fv is not None and isinstance(fv, FieldValue) and fv.position is not None:
positions[f] = fv.position.model_dump(mode="json")
return positions
@@ -0,0 +1,55 @@
"""Load leaudit YAML RulesFile from filesystem or MinIO."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from leaudit.dsl.schema import RulesFile
log = logging.getLogger(__name__)
_DEFAULT_RULES_DIR = Path(__file__).resolve().parents[2] / "rules"
class RulesLoader:
"""Load and cache leaudit RulesFile from YAML files."""
def __init__(self, rules_dir: str | Path | None = None) -> None:
self._rules_dir = Path(rules_dir) if rules_dir else _DEFAULT_RULES_DIR
self._cache: dict[str, RulesFile] = {}
def load(self, rules_path: str) -> RulesFile:
"""Load a RulesFile by relative path under rules_dir, or absolute path."""
from leaudit.dsl.loader import load_rules_file
if rules_path in self._cache:
return self._cache[rules_path]
p = Path(rules_path)
if not p.is_absolute():
p = self._rules_dir / p
log.info("Loading RulesFile: %s", p)
rules_file = load_rules_file(p)
self._cache[rules_path] = rules_file
return rules_file
def load_from_yaml_text(self, yaml_text: str, cache_key: str | None = None) -> RulesFile:
"""Parse a RulesFile from raw YAML string."""
from leaudit.dsl.loader import parse_rules_yaml_text
if cache_key and cache_key in self._cache:
return self._cache[cache_key]
rules_file = parse_rules_yaml_text(yaml_text)
if cache_key:
self._cache[cache_key] = rules_file
return rules_file
def clear_cache(self) -> None:
self._cache.clear()
@@ -0,0 +1,364 @@
"""Storage adapter — write leaudit results into docauditai's PostgreSQL via PostgREST.
Converts leaudit's OcrResult, ExtractionBundle, EvaluationResult
into docauditai's table format and writes via PostgRESTClient.
Uses the new `leaudit_evaluation_results` table for per-rule results.
"""
from __future__ import annotations
import logging
import re
from typing import Any
from leaudit.dsl.schema import RulesFile
from leaudit.engine.models import EvaluationResult, RuleResult
from leaudit.extraction.bundle import ExtractionBundle
from leaudit.extraction.models import FieldValue
from leaudit.ocr.models import OcrResult
log = logging.getLogger(__name__)
def _get_postgrest_client():
"""Lazy import to avoid circular dependency at module load."""
from core.postgrest.client import get_postgrest_client
return get_postgrest_client()
class StorageAdapter:
"""Write leaudit pipeline results to docauditai's database."""
# ---- Document status ----
async def update_document_status(self, document_id: int, status: str) -> None:
"""Update the document's processing status."""
client = _get_postgrest_client()
await client.update(
table="documents",
filters={"id": f"eq.{document_id}"},
data={"status": status},
)
log.debug("[%d] Status updated: %s", document_id, status)
# ---- Document number (案件编号) ----
async def update_document_number(self, document_id: int, document_number: str) -> None:
"""Update the document's case number (document_number field)."""
client = _get_postgrest_client()
await client.update(
table="documents",
filters={"id": f"eq.{document_id}"},
data={"document_number": document_number},
)
log.info("[%d] document_number updated: %s", document_id, document_number)
# ---- OCR result ----
async def save_ocr_result(self, document_id: int, ocr_result: OcrResult) -> None:
"""Save OCR result to documents.ocr_result and raw_full_text_original."""
client = _get_postgrest_client()
ocr_dict = _ocr_to_dict(ocr_result)
full_text = ocr_result.full_text or "\n".join(p.text for p in ocr_result.pages)
await client.update(
table="documents",
filters={"id": f"eq.{document_id}"},
data={
"ocr_result": ocr_dict,
"raw_full_text_original": full_text,
},
)
log.info("[%d] OCR result saved (%d pages)", document_id, len(ocr_result.pages))
# ---- Extraction result ----
async def save_extraction_result(
self, document_id: int, bundle: ExtractionBundle,
) -> None:
"""Save extraction result to documents.extracted_results."""
client = _get_postgrest_client()
extracted = _bundle_to_extracted(bundle)
await client.update(
table="documents",
filters={"id": f"eq.{document_id}"},
data={"extracted_results": extracted},
)
log.info(
"[%d] Extraction result saved (%d fields)",
document_id,
len(bundle.fields),
)
# ---- Evaluation results ----
async def save_evaluation_results(
self,
document_id: int,
rules_file: RulesFile,
evaluation: EvaluationResult,
bundle: ExtractionBundle,
) -> None:
"""Save evaluation results to leaudit_evaluation_results table.
One row per rule. Deletes existing results for the document first,
then inserts fresh rows.
"""
client = _get_postgrest_client()
# Delete existing results for this document
await client.delete(
table="leaudit_evaluation_results",
filters={"document_id": f"eq.{document_id}"},
)
# Build rule_id → rule metadata lookup
rule_meta = {}
for rule in rules_file.flat_rules:
rule_meta[rule.rule_id] = rule
# Insert one row per rule result
for rule_result in evaluation.rules:
rule = rule_meta.get(rule_result.rule_id)
row = _rule_result_to_row(document_id, rule_result, rule, bundle)
await client.insert(table="leaudit_evaluation_results", data=row)
log.info(
"[%d] Evaluation results saved: %d passed, %d failed, %d skipped",
document_id,
evaluation.passed_count,
evaluation.failed_count,
evaluation.skipped_count,
)
# ---- Serialization helpers ----
def _ocr_to_dict(ocr: OcrResult) -> dict[str, Any]:
"""Convert OcrResult to a JSON-safe dict for PostgREST storage."""
result: dict[str, Any] = {
"numPages": len(ocr.pages),
"full_text": ocr.full_text,
"pages": [],
}
for page in ocr.pages:
page_dict: dict[str, Any] = {
"page_num": page.page_num,
"text": page.text,
"page_box": page.page_box,
}
if page.chunks:
page_dict["chunks"] = [
(
{"bbox": c["bbox"], "content": c["content"], "label": c.get("label")}
if isinstance(c, dict) and "bbox" in c and "content" in c
else {
"bbox": c.bbox if hasattr(c, "bbox") else None,
"content": c.content if hasattr(c, "content") else str(c),
"label": c.label if hasattr(c, "label") else None,
}
)
for c in page.chunks
]
if page.bboxes:
page_dict["bboxes"] = page.bboxes
result["pages"].append(page_dict)
if ocr.visual_manifest:
result["visual_manifest"] = ocr.visual_manifest.model_dump(mode="json")
if ocr.images:
result["images"] = ocr.images
return result
def _bundle_to_extracted(bundle: ExtractionBundle) -> dict[str, Any]:
"""Convert ExtractionBundle to docauditai's extracted_results format."""
fields: dict[str, Any] = {}
for name, fv in bundle.fields.items():
if isinstance(fv, FieldValue):
field_data = {
"value": fv.value,
"confidence": float(fv.confidence) if fv.confidence else 0.0,
}
if fv.position is not None:
field_data["position"] = fv.position.model_dump(mode="json")
fields[name] = field_data
else:
fields[name] = {"value": fv}
multi_entity: dict[str, Any] = {}
for name, rows in bundle.multi_entity.items():
multi_entity[name] = [
{
k: (v.value if isinstance(v, FieldValue) else v)
for k, v in row.items()
}
if isinstance(row, dict) else {"value": row}
for row in rows
]
return {
"fields": fields,
"multi_entity": multi_entity,
"derived": dict(bundle.derived) if bundle.derived else {},
"is_case_file": bundle.is_case_file,
}
def _extract_relevant_fields(rule: Any, bundle: ExtractionBundle) -> dict[str, Any]:
"""Extract field values referenced by a rule's stages."""
relevant: dict[str, Any] = {}
if not rule or not hasattr(rule, "stages") or not rule.stages:
return relevant
for stage in rule.stages:
stage_data = stage.model_dump(exclude_none=True) if hasattr(stage, "model_dump") else {}
extra = stage.extra if hasattr(stage, "extra") else {}
field_names: list[str] = []
for key in ("field", "field1", "field2", "fields"):
val = getattr(stage, key, None) or extra.get(key) or stage_data.get(key)
if isinstance(val, list):
field_names.extend(f for f in val if isinstance(f, str))
elif isinstance(val, str):
field_names.append(val)
pairs = stage_data.get("pairs")
if isinstance(pairs, list):
for pair in pairs:
if not isinstance(pair, dict):
continue
for key in ("source", "target", "a", "b"):
ref = pair.get(key)
if isinstance(ref, str):
field_names.append(ref)
prompt = stage_data.get("prompt")
if isinstance(prompt, str):
for m in re.finditer(r"\{\{\s*([^{}]+?)\s*\}\}", prompt):
field_names.append(m.group(1).strip())
for f in field_names:
if f in relevant:
continue
if f not in bundle.fields:
continue
fv = bundle.fields[f]
relevant[f] = fv.value if isinstance(fv, FieldValue) else fv
return relevant
def _extract_relevant_field_positions(
rule: Any,
bundle: ExtractionBundle,
) -> dict[str, Any]:
"""Extract position data for fields referenced by a rule's stages."""
positions: dict[str, Any] = {}
if not rule or not hasattr(rule, "stages") or not rule.stages:
return positions
for stage in rule.stages:
stage_data = stage.model_dump(exclude_none=True) if hasattr(stage, "model_dump") else {}
extra = stage.extra if hasattr(stage, "extra") else {}
field_names: list[str] = []
for key in ("field", "field1", "field2", "fields"):
val = getattr(stage, key, None) or extra.get(key) or stage_data.get(key)
if isinstance(val, list):
field_names.extend(f for f in val if isinstance(f, str))
elif isinstance(val, str):
field_names.append(val)
pairs = stage_data.get("pairs")
if isinstance(pairs, list):
for pair in pairs:
if not isinstance(pair, dict):
continue
for key in ("source", "target", "a", "b"):
ref = pair.get(key)
if isinstance(ref, str):
field_names.append(ref)
prompt = stage_data.get("prompt")
if isinstance(prompt, str):
for m in re.finditer(r"\{\{\s*([^{}]+?)\s*\}\}", prompt):
field_names.append(m.group(1).strip())
for f in field_names:
if f in positions:
continue
fv = bundle.fields.get(f)
if fv is not None and isinstance(fv, FieldValue) and fv.position is not None:
positions[f] = fv.position.model_dump(mode="json")
return positions
def _rule_result_to_row(
document_id: int,
rule_result: RuleResult,
rule: Any | None,
bundle: ExtractionBundle,
) -> dict[str, Any]:
"""Convert a RuleResult to a leaudit_evaluation_results row."""
passed = rule_result.passed
# Resolve messages: rule_result → rule definition
pass_msg = ""
fail_msg = ""
if rule_result.messages:
pass_msg = rule_result.messages.get("pass", "")
fail_msg = rule_result.messages.get("fail", "")
elif rule:
from leaudit.dsl.schema import Rule as DslRule
if isinstance(rule, DslRule) and rule.messages:
pass_msg = rule.messages.get("pass", "")
fail_msg = rule.messages.get("fail", "")
# Extract relevant fields
relevant_fields = _extract_relevant_fields(rule, bundle)
# Remediation (if present)
remediation = None
if rule_result.remediation:
remediation = rule_result.remediation.model_dump(mode="json")
# Rule metadata (references_laws, etc.)
rule_meta_data: dict[str, Any] = {}
if rule:
from leaudit.dsl.schema import Rule as DslRule
if isinstance(rule, DslRule):
if rule.references_laws:
rule_meta_data["references_laws"] = rule.references_laws
if rule.desc:
rule_meta_data["desc"] = rule.desc
if rule.group:
rule_meta_data["group"] = rule.group
return {
"document_id": document_id,
"rule_id": rule_result.rule_id,
"rule_name": rule_result.name,
"risk": rule_result.risk or "medium",
"score": rule_result.score,
"passed": passed,
"status": rule_result.status,
"skip_reason": rule_result.skip_reason,
"confidence": rule_result.confidence,
"pass_message": pass_msg,
"fail_message": fail_msg,
"stages": [s.model_dump(mode="json") for s in (rule_result.stages or [])],
"extracted_fields": relevant_fields,
"field_positions": _extract_relevant_field_positions(rule, bundle),
"remediation": remediation,
"rule_meta": rule_meta_data,
}
@@ -0,0 +1,201 @@
"""Celery task for leaudit pipeline processing.
Activated when PIPELINE_MODE=leaudit in env.{port} config.
Replaces the legacy OCR → extraction → evaluation pipeline with
leaudit's YAML-rules-driven approach.
"""
from __future__ import annotations
import asyncio
import os
import tempfile
import time
from typing import Any, Dict, Optional
from core.celery_app_limited import celery_app
from core.postgrest.client import get_postgrest_client
from core.logger import log
from leaudit_bridge import create_pipeline, RulesLoader
@celery_app.task(bind=True, name="leaudit.process_document")
def leaudit_process_document(
self,
document_id: int,
file_content: bytes,
filename: str,
upload_info: Optional[Dict[str, Any]] = None,
source_port: Optional[int] = None,
rules_path: Optional[str] = None,
):
"""Process a document using leaudit's full pipeline.
Steps: OCR → Extraction → Evaluation → Store in docauditai DB.
"""
task_id = self.request.id
log.task.info(f"[任务ID: {task_id}] leaudit管线开始处理: {filename}")
if source_port:
from core.utils.instance_context import set_instance_environment
instance_name = set_instance_environment(source_port)
log.task.info(
f"[任务ID: {task_id}] 实例环境: {instance_name} (端口: {source_port})"
)
if upload_info is None:
upload_info = {}
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
rules_path_resolved = rules_path or _resolve_rules_path(document_id, loop)
# For types with a known mapping (e.g. 行政处罚), pre-load rules_file.
# For types that need content classification (e.g. 行政许可 sub-types),
# rules_path will be None → adapter classifies after OCR → pipeline
# loads rules from ocr_result.rules_file_path.
rules_file = None
if rules_path_resolved:
loader = RulesLoader()
rules_file = loader.load(rules_path_resolved)
log.task.info(
f"[任务ID: {task_id}] RulesFile pre-loaded: {rules_path_resolved} "
f"({len(rules_file.flat_rules)} rules, {len(rules_file.flat_extract)} fields)"
)
else:
log.task.info(
f"[任务ID: {task_id}] No fixed rules_path — "
"will classify from document content after OCR"
)
suffix = _get_suffix(filename)
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp:
temp.write(file_content)
temp_path = temp.name
pipeline = create_pipeline(rules_path=rules_path_resolved)
t0 = time.time()
result = loop.run_until_complete(
pipeline.run(
document_id=document_id,
file_path=temp_path,
rules_file=rules_file,
source_port=source_port or int(os.getenv("APP_PORT", "8000")),
)
)
elapsed = round(time.time() - t0, 2)
try:
os.remove(temp_path)
except OSError:
pass
log.task.info(
f"[任务ID: {task_id}] leaudit管线完成: phase={result.detected_phase}, "
f"timing={result.timing}, 总耗时={elapsed:.1f}s"
)
return {
"status": "success",
"document_id": document_id,
"phase": result.detected_phase,
"timing": result.timing,
"errors": result.errors,
}
except Exception as e:
log.task.error(f"[任务ID: {task_id}] leaudit管线失败: {e}", exc_info=True)
try:
loop.run_until_complete(_update_status_safe(document_id, "Failed"))
except Exception:
pass
raise
finally:
loop.close()
# type_id → rules directory mapping (only fixed-mapping types)
# 行政许可 (type_id=2) has 9 sub-types, NOT mapped here —
# must come from document metadata (rules_file_path) or content classification.
_TYPE_ID_RULES_MAP: dict[int, str] = {
3: "行政处罚",
}
def _resolve_rules_path(document_id: int, loop: asyncio.AbstractEventLoop) -> str | None:
"""Resolve rules_path: config override → document metadata → type_id mapping."""
from core.config import LEAUDIT_CONFIG
# 1. Config override (when explicitly set)
config_path = LEAUDIT_CONFIG.get("RULES_PATH", "")
if config_path:
return config_path
try:
client = get_postgrest_client()
doc = loop.run_until_complete(
client.select(
table="documents",
filters={"id": f"eq.{document_id}"},
single=True,
)
)
if not doc:
return None
# 2. Document-level override
rfp = doc.get("rules_file_path")
if rfp:
return rfp
# 3. type_id mapping
type_id = doc.get("type_id")
if type_id and type_id in _TYPE_ID_RULES_MAP:
return f"{_TYPE_ID_RULES_MAP[type_id]}/rules.yaml"
except Exception as e:
log.task.warning(f"Failed to resolve rules_path from document: {e}")
return None
async def _update_status_safe(document_id: int, status: str) -> None:
"""Safely update document status, ignoring errors."""
try:
client = get_postgrest_client()
await client.update(
table="documents",
filters={"id": f"eq.{document_id}"},
data={"status": status},
)
except Exception:
pass
def _get_suffix(filename: str) -> str:
"""Extract file suffix from filename."""
_, ext = os.path.splitext(filename)
return ext if ext else ".pdf"
def dispatch_leaudit_task(
document_id: int,
file_content: bytes,
filename: str,
upload_info: Optional[Dict[str, Any]] = None,
source_port: Optional[int] = None,
rules_path: Optional[str] = None,
):
"""Dispatch a leaudit processing task."""
return leaudit_process_document.apply_async(
args=[document_id, file_content, filename],
kwargs={
"upload_info": upload_info,
"source_port": source_port or int(os.getenv("APP_PORT", "8000")),
"rules_path": rules_path,
},
)