Files
leaudit-platform-backend/fastapi_modules/fastapi_leaudit/leaudit_bridge/result_adapter.py
T
wren 535d97a70c chore: initial commit — leaudit-platform project skeleton
17-table PostgreSQL schema with full Chinese column comments,
FastAPI project structure (admin/common/modules),
DSL rule files, and schema migration scripts.
2026-04-27 16:48:22 +08:00

304 lines
10 KiB
Python

"""Adapt leaudit raw results into docauditai's standardized format.
Currently this logic is inlined in ``storage_adapter.py``
(``_rule_result_to_row``, ``_bundle_to_extracted``, ``_ocr_to_dict``).
This module extracts the conversion layer so storage_adapter focuses
on persistence (the "adapter" part of result_adapter).
"""
from __future__ import annotations
import logging
import re
from typing import Any
from leaudit.dsl.schema import Rule as DslRule
from leaudit.dsl.schema import RulesFile
from leaudit.engine.models import EvaluationResult, RuleResult
from leaudit.extraction.bundle import ExtractionBundle
from leaudit.extraction.models import FieldValue
from leaudit.ocr.models import OcrResult
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# OCR result → dict
# ---------------------------------------------------------------------------
def ocr_to_dict(ocr: OcrResult) -> dict[str, Any]:
"""Convert OcrResult to a JSON-safe dict for storage."""
result: dict[str, Any] = {
"numPages": len(ocr.pages),
"full_text": ocr.full_text,
"pages": [],
}
for page in ocr.pages:
page_dict: dict[str, Any] = {
"page_num": page.page_num,
"text": page.text,
"page_box": page.page_box,
}
if page.chunks:
page_dict["chunks"] = [
(
{"bbox": c["bbox"], "content": c["content"], "label": c.get("label")}
if isinstance(c, dict) and "bbox" in c and "content" in c
else {
"bbox": c.bbox if hasattr(c, "bbox") else None,
"content": c.content if hasattr(c, "content") else str(c),
"label": c.label if hasattr(c, "label") else None,
}
)
for c in page.chunks
]
if page.bboxes:
page_dict["bboxes"] = page.bboxes
result["pages"].append(page_dict)
if ocr.visual_manifest:
result["visual_manifest"] = ocr.visual_manifest.model_dump(mode="json")
if ocr.images:
result["images"] = ocr.images
return result
# ---------------------------------------------------------------------------
# ExtractionBundle → dict
# ---------------------------------------------------------------------------
def bundle_to_extracted(bundle: ExtractionBundle) -> dict[str, Any]:
"""Convert ExtractionBundle to docauditai's extracted_results format."""
fields: dict[str, Any] = {}
for name, fv in bundle.fields.items():
if isinstance(fv, FieldValue):
field_data = {
"value": fv.value,
"confidence": float(fv.confidence) if fv.confidence else 0.0,
}
if fv.position is not None:
field_data["position"] = fv.position.model_dump(mode="json")
fields[name] = field_data
else:
fields[name] = {"value": fv}
multi_entity: dict[str, Any] = {}
for name, rows in bundle.multi_entity.items():
multi_entity[name] = [
{
k: (v.value if isinstance(v, FieldValue) else v)
for k, v in row.items()
}
if isinstance(row, dict)
else {"value": row}
for row in rows
]
return {
"fields": fields,
"multi_entity": multi_entity,
"derived": dict(bundle.derived) if bundle.derived else {},
"is_case_file": bundle.is_case_file,
}
# ---------------------------------------------------------------------------
# EvaluationResult → per-rule rows
# ---------------------------------------------------------------------------
def rule_result_to_row(
document_id: int,
run_id: int,
rule_result: RuleResult,
rule: Any | None,
bundle: ExtractionBundle,
) -> dict[str, Any]:
"""Convert one RuleResult to a database row dict.
Args:
document_id: docauditai document ID.
run_id: ``leaudit_audit_runs.id`` for this execution.
rule_result: Single rule evaluation result from leaudit.
rule: DSL Rule definition (for metadata lookups).
bundle: The extraction bundle (for field position lookups).
"""
passed = rule_result.passed
pass_msg = ""
fail_msg = ""
if rule_result.messages:
pass_msg = rule_result.messages.get("pass", "")
fail_msg = rule_result.messages.get("fail", "")
elif isinstance(rule, DslRule) and rule.messages:
pass_msg = rule.messages.get("pass", "")
fail_msg = rule.messages.get("fail", "")
relevant_fields = _extract_relevant_fields(rule, bundle)
field_positions = _extract_relevant_field_positions(rule, bundle)
remediation = None
if rule_result.remediation:
remediation = rule_result.remediation.model_dump(mode="json")
rule_meta: dict[str, Any] = {}
if isinstance(rule, DslRule):
if rule.references_laws:
rule_meta["references_laws"] = rule.references_laws
if rule.desc:
rule_meta["desc"] = rule.desc
if rule.group:
rule_meta["group"] = rule.group
return {
"document_id": document_id,
"run_id": run_id,
"rule_id": rule_result.rule_id,
"rule_name": rule_result.name,
"risk": rule_result.risk or "medium",
"score": rule_result.score,
"passed": passed,
"status": rule_result.status,
"skip_reason": rule_result.skip_reason,
"confidence": rule_result.confidence,
"pass_message": pass_msg,
"fail_message": fail_msg,
"stages": [s.model_dump(mode="json") for s in (rule_result.stages or [])],
"extracted_fields": relevant_fields,
"field_positions": field_positions,
"remediation": remediation,
"rule_meta": rule_meta,
"rescue_applied": False,
"rescue_passed": None,
}
def evaluation_summary(eval_result: EvaluationResult) -> dict[str, Any]:
"""Extract summary fields from an EvaluationResult."""
return {
"total_score": eval_result.total_score,
"passed_count": eval_result.passed_count,
"failed_count": eval_result.failed_count,
"skipped_count": eval_result.skipped_count,
"result_status": _result_status(eval_result),
}
def _result_status(eval_result: EvaluationResult) -> str:
if eval_result.errors:
return "error"
if eval_result.failed_count == 0 and eval_result.skipped_count == 0:
return "pass"
if eval_result.failed_count > 0:
return "fail"
return "partial"
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _extract_relevant_fields(
rule: Any, bundle: ExtractionBundle,
) -> dict[str, Any]:
"""Extract field values referenced by a rule's stages."""
relevant: dict[str, Any] = {}
if not rule or not hasattr(rule, "stages") or not rule.stages:
return relevant
for stage in rule.stages:
stage_data = (
stage.model_dump(exclude_none=True)
if hasattr(stage, "model_dump")
else {}
)
extra = stage.extra if hasattr(stage, "extra") else {}
field_names: list[str] = []
for key in ("field", "field1", "field2", "fields"):
val = getattr(stage, key, None) or extra.get(key) or stage_data.get(key)
if isinstance(val, list):
field_names.extend(f for f in val if isinstance(f, str))
elif isinstance(val, str):
field_names.append(val)
pairs = stage_data.get("pairs")
if isinstance(pairs, list):
for pair in pairs:
if not isinstance(pair, dict):
continue
for key in ("source", "target", "a", "b"):
ref = pair.get(key)
if isinstance(ref, str):
field_names.append(ref)
prompt = stage_data.get("prompt")
if isinstance(prompt, str):
for m in re.finditer(r"\{\{\s*([^{}]+?)\s*\}\}", prompt):
field_names.append(m.group(1).strip())
for f in field_names:
if f in relevant:
continue
if f not in bundle.fields:
continue
fv = bundle.fields[f]
relevant[f] = fv.value if isinstance(fv, FieldValue) else fv
return relevant
def _extract_relevant_field_positions(
rule: Any, bundle: ExtractionBundle,
) -> dict[str, Any]:
"""Extract position data for fields referenced by a rule's stages."""
positions: dict[str, Any] = {}
if not rule or not hasattr(rule, "stages") or not rule.stages:
return positions
for stage in rule.stages:
stage_data = (
stage.model_dump(exclude_none=True)
if hasattr(stage, "model_dump")
else {}
)
extra = stage.extra if hasattr(stage, "extra") else {}
field_names: list[str] = []
for key in ("field", "field1", "field2", "fields"):
val = getattr(stage, key, None) or extra.get(key) or stage_data.get(key)
if isinstance(val, list):
field_names.extend(f for f in val if isinstance(f, str))
elif isinstance(val, str):
field_names.append(val)
pairs = stage_data.get("pairs")
if isinstance(pairs, list):
for pair in pairs:
if not isinstance(pair, dict):
continue
for key in ("source", "target", "a", "b"):
ref = pair.get(key)
if isinstance(ref, str):
field_names.append(ref)
prompt = stage_data.get("prompt")
if isinstance(prompt, str):
for m in re.finditer(r"\{\{\s*([^{}]+?)\s*\}\}", prompt):
field_names.append(m.group(1).strip())
for f in field_names:
if f in positions:
continue
fv = bundle.fields.get(f)
if fv is not None and isinstance(fv, FieldValue) and fv.position is not None:
positions[f] = fv.position.model_dump(mode="json")
return positions