535d97a70c
17-table PostgreSQL schema with full Chinese column comments, FastAPI project structure (admin/common/modules), DSL rule files, and schema migration scripts.
304 lines
10 KiB
Python
304 lines
10 KiB
Python
"""Adapt leaudit raw results into docauditai's standardized format.
|
|
|
|
Currently this logic is inlined in ``storage_adapter.py``
|
|
(``_rule_result_to_row``, ``_bundle_to_extracted``, ``_ocr_to_dict``).
|
|
This module extracts the conversion layer so storage_adapter focuses
|
|
on persistence (the "adapter" part of result_adapter).
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import re
|
|
from typing import Any
|
|
|
|
from leaudit.dsl.schema import Rule as DslRule
|
|
from leaudit.dsl.schema import RulesFile
|
|
from leaudit.engine.models import EvaluationResult, RuleResult
|
|
from leaudit.extraction.bundle import ExtractionBundle
|
|
from leaudit.extraction.models import FieldValue
|
|
from leaudit.ocr.models import OcrResult
|
|
|
|
log = logging.getLogger(__name__)
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# OCR result → dict
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def ocr_to_dict(ocr: OcrResult) -> dict[str, Any]:
|
|
"""Convert OcrResult to a JSON-safe dict for storage."""
|
|
result: dict[str, Any] = {
|
|
"numPages": len(ocr.pages),
|
|
"full_text": ocr.full_text,
|
|
"pages": [],
|
|
}
|
|
|
|
for page in ocr.pages:
|
|
page_dict: dict[str, Any] = {
|
|
"page_num": page.page_num,
|
|
"text": page.text,
|
|
"page_box": page.page_box,
|
|
}
|
|
if page.chunks:
|
|
page_dict["chunks"] = [
|
|
(
|
|
{"bbox": c["bbox"], "content": c["content"], "label": c.get("label")}
|
|
if isinstance(c, dict) and "bbox" in c and "content" in c
|
|
else {
|
|
"bbox": c.bbox if hasattr(c, "bbox") else None,
|
|
"content": c.content if hasattr(c, "content") else str(c),
|
|
"label": c.label if hasattr(c, "label") else None,
|
|
}
|
|
)
|
|
for c in page.chunks
|
|
]
|
|
if page.bboxes:
|
|
page_dict["bboxes"] = page.bboxes
|
|
result["pages"].append(page_dict)
|
|
|
|
if ocr.visual_manifest:
|
|
result["visual_manifest"] = ocr.visual_manifest.model_dump(mode="json")
|
|
|
|
if ocr.images:
|
|
result["images"] = ocr.images
|
|
|
|
return result
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# ExtractionBundle → dict
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def bundle_to_extracted(bundle: ExtractionBundle) -> dict[str, Any]:
|
|
"""Convert ExtractionBundle to docauditai's extracted_results format."""
|
|
fields: dict[str, Any] = {}
|
|
for name, fv in bundle.fields.items():
|
|
if isinstance(fv, FieldValue):
|
|
field_data = {
|
|
"value": fv.value,
|
|
"confidence": float(fv.confidence) if fv.confidence else 0.0,
|
|
}
|
|
if fv.position is not None:
|
|
field_data["position"] = fv.position.model_dump(mode="json")
|
|
fields[name] = field_data
|
|
else:
|
|
fields[name] = {"value": fv}
|
|
|
|
multi_entity: dict[str, Any] = {}
|
|
for name, rows in bundle.multi_entity.items():
|
|
multi_entity[name] = [
|
|
{
|
|
k: (v.value if isinstance(v, FieldValue) else v)
|
|
for k, v in row.items()
|
|
}
|
|
if isinstance(row, dict)
|
|
else {"value": row}
|
|
for row in rows
|
|
]
|
|
|
|
return {
|
|
"fields": fields,
|
|
"multi_entity": multi_entity,
|
|
"derived": dict(bundle.derived) if bundle.derived else {},
|
|
"is_case_file": bundle.is_case_file,
|
|
}
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# EvaluationResult → per-rule rows
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def rule_result_to_row(
|
|
document_id: int,
|
|
run_id: int,
|
|
rule_result: RuleResult,
|
|
rule: Any | None,
|
|
bundle: ExtractionBundle,
|
|
) -> dict[str, Any]:
|
|
"""Convert one RuleResult to a database row dict.
|
|
|
|
Args:
|
|
document_id: docauditai document ID.
|
|
run_id: ``leaudit_audit_runs.id`` for this execution.
|
|
rule_result: Single rule evaluation result from leaudit.
|
|
rule: DSL Rule definition (for metadata lookups).
|
|
bundle: The extraction bundle (for field position lookups).
|
|
"""
|
|
passed = rule_result.passed
|
|
|
|
pass_msg = ""
|
|
fail_msg = ""
|
|
if rule_result.messages:
|
|
pass_msg = rule_result.messages.get("pass", "")
|
|
fail_msg = rule_result.messages.get("fail", "")
|
|
elif isinstance(rule, DslRule) and rule.messages:
|
|
pass_msg = rule.messages.get("pass", "")
|
|
fail_msg = rule.messages.get("fail", "")
|
|
|
|
relevant_fields = _extract_relevant_fields(rule, bundle)
|
|
field_positions = _extract_relevant_field_positions(rule, bundle)
|
|
|
|
remediation = None
|
|
if rule_result.remediation:
|
|
remediation = rule_result.remediation.model_dump(mode="json")
|
|
|
|
rule_meta: dict[str, Any] = {}
|
|
if isinstance(rule, DslRule):
|
|
if rule.references_laws:
|
|
rule_meta["references_laws"] = rule.references_laws
|
|
if rule.desc:
|
|
rule_meta["desc"] = rule.desc
|
|
if rule.group:
|
|
rule_meta["group"] = rule.group
|
|
|
|
return {
|
|
"document_id": document_id,
|
|
"run_id": run_id,
|
|
"rule_id": rule_result.rule_id,
|
|
"rule_name": rule_result.name,
|
|
"risk": rule_result.risk or "medium",
|
|
"score": rule_result.score,
|
|
"passed": passed,
|
|
"status": rule_result.status,
|
|
"skip_reason": rule_result.skip_reason,
|
|
"confidence": rule_result.confidence,
|
|
"pass_message": pass_msg,
|
|
"fail_message": fail_msg,
|
|
"stages": [s.model_dump(mode="json") for s in (rule_result.stages or [])],
|
|
"extracted_fields": relevant_fields,
|
|
"field_positions": field_positions,
|
|
"remediation": remediation,
|
|
"rule_meta": rule_meta,
|
|
"rescue_applied": False,
|
|
"rescue_passed": None,
|
|
}
|
|
|
|
|
|
def evaluation_summary(eval_result: EvaluationResult) -> dict[str, Any]:
|
|
"""Extract summary fields from an EvaluationResult."""
|
|
return {
|
|
"total_score": eval_result.total_score,
|
|
"passed_count": eval_result.passed_count,
|
|
"failed_count": eval_result.failed_count,
|
|
"skipped_count": eval_result.skipped_count,
|
|
"result_status": _result_status(eval_result),
|
|
}
|
|
|
|
|
|
def _result_status(eval_result: EvaluationResult) -> str:
|
|
if eval_result.errors:
|
|
return "error"
|
|
if eval_result.failed_count == 0 and eval_result.skipped_count == 0:
|
|
return "pass"
|
|
if eval_result.failed_count > 0:
|
|
return "fail"
|
|
return "partial"
|
|
|
|
|
|
# ---------------------------------------------------------------------------
|
|
# Internal helpers
|
|
# ---------------------------------------------------------------------------
|
|
|
|
|
|
def _extract_relevant_fields(
|
|
rule: Any, bundle: ExtractionBundle,
|
|
) -> dict[str, Any]:
|
|
"""Extract field values referenced by a rule's stages."""
|
|
relevant: dict[str, Any] = {}
|
|
if not rule or not hasattr(rule, "stages") or not rule.stages:
|
|
return relevant
|
|
|
|
for stage in rule.stages:
|
|
stage_data = (
|
|
stage.model_dump(exclude_none=True)
|
|
if hasattr(stage, "model_dump")
|
|
else {}
|
|
)
|
|
extra = stage.extra if hasattr(stage, "extra") else {}
|
|
field_names: list[str] = []
|
|
|
|
for key in ("field", "field1", "field2", "fields"):
|
|
val = getattr(stage, key, None) or extra.get(key) or stage_data.get(key)
|
|
if isinstance(val, list):
|
|
field_names.extend(f for f in val if isinstance(f, str))
|
|
elif isinstance(val, str):
|
|
field_names.append(val)
|
|
|
|
pairs = stage_data.get("pairs")
|
|
if isinstance(pairs, list):
|
|
for pair in pairs:
|
|
if not isinstance(pair, dict):
|
|
continue
|
|
for key in ("source", "target", "a", "b"):
|
|
ref = pair.get(key)
|
|
if isinstance(ref, str):
|
|
field_names.append(ref)
|
|
|
|
prompt = stage_data.get("prompt")
|
|
if isinstance(prompt, str):
|
|
for m in re.finditer(r"\{\{\s*([^{}]+?)\s*\}\}", prompt):
|
|
field_names.append(m.group(1).strip())
|
|
|
|
for f in field_names:
|
|
if f in relevant:
|
|
continue
|
|
if f not in bundle.fields:
|
|
continue
|
|
fv = bundle.fields[f]
|
|
relevant[f] = fv.value if isinstance(fv, FieldValue) else fv
|
|
|
|
return relevant
|
|
|
|
|
|
def _extract_relevant_field_positions(
|
|
rule: Any, bundle: ExtractionBundle,
|
|
) -> dict[str, Any]:
|
|
"""Extract position data for fields referenced by a rule's stages."""
|
|
positions: dict[str, Any] = {}
|
|
if not rule or not hasattr(rule, "stages") or not rule.stages:
|
|
return positions
|
|
|
|
for stage in rule.stages:
|
|
stage_data = (
|
|
stage.model_dump(exclude_none=True)
|
|
if hasattr(stage, "model_dump")
|
|
else {}
|
|
)
|
|
extra = stage.extra if hasattr(stage, "extra") else {}
|
|
field_names: list[str] = []
|
|
|
|
for key in ("field", "field1", "field2", "fields"):
|
|
val = getattr(stage, key, None) or extra.get(key) or stage_data.get(key)
|
|
if isinstance(val, list):
|
|
field_names.extend(f for f in val if isinstance(f, str))
|
|
elif isinstance(val, str):
|
|
field_names.append(val)
|
|
|
|
pairs = stage_data.get("pairs")
|
|
if isinstance(pairs, list):
|
|
for pair in pairs:
|
|
if not isinstance(pair, dict):
|
|
continue
|
|
for key in ("source", "target", "a", "b"):
|
|
ref = pair.get(key)
|
|
if isinstance(ref, str):
|
|
field_names.append(ref)
|
|
|
|
prompt = stage_data.get("prompt")
|
|
if isinstance(prompt, str):
|
|
for m in re.finditer(r"\{\{\s*([^{}]+?)\s*\}\}", prompt):
|
|
field_names.append(m.group(1).strip())
|
|
|
|
for f in field_names:
|
|
if f in positions:
|
|
continue
|
|
fv = bundle.fields.get(f)
|
|
if fv is not None and isinstance(fv, FieldValue) and fv.position is not None:
|
|
positions[f] = fv.position.model_dump(mode="json")
|
|
|
|
return positions
|