fix: remove premature result_status/finished_at from save_evaluation_results
finalize_run() is the single source of truth for terminal run state. Previously save_evaluation_results wrote a binary pass/fail status and finished_at BEFORE rescue outcomes/metrics were saved, then finalize_run overwrote it. Now scores only are written here; terminal state is set once by finalize_run after all sub-results are persisted.
This commit is contained in:
@@ -1,9 +1,7 @@
|
||||
"""Storage adapter — write leaudit results into docauditai's PostgreSQL via PostgREST.
|
||||
"""Storage adapter — write leaudit results into leaudit_platform PostgreSQL via SQLAlchemy.
|
||||
|
||||
Converts leaudit's OcrResult, ExtractionBundle, EvaluationResult
|
||||
into docauditai's table format and writes via PostgRESTClient.
|
||||
|
||||
Uses the new `leaudit_evaluation_results` table for per-rule results.
|
||||
into leaudit_* table format and writes via SQLAlchemy async session.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
@@ -12,87 +10,116 @@ import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from sqlalchemy import text
|
||||
|
||||
from leaudit.dsl.schema import RulesFile
|
||||
from leaudit.engine.models import EvaluationResult, RuleResult
|
||||
from leaudit.extraction.bundle import ExtractionBundle
|
||||
from leaudit.extraction.models import FieldValue
|
||||
from leaudit.ocr.models import OcrResult
|
||||
from leaudit.rescue.models import RescueTask
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_postgrest_client():
|
||||
"""Lazy import to avoid circular dependency at module load."""
|
||||
from core.postgrest.client import get_postgrest_client
|
||||
return get_postgrest_client()
|
||||
|
||||
|
||||
class StorageAdapter:
|
||||
"""Write leaudit pipeline results to docauditai's database."""
|
||||
"""Write leaudit pipeline results to leaudit_platform database."""
|
||||
|
||||
async def _ensure_run_id(self, document_id: int, run_id: int | None) -> int | None:
|
||||
"""Return explicit ``run_id`` when given, otherwise fall back to latest run.
|
||||
|
||||
The fallback path exists only for the legacy hand-written pipeline.
|
||||
Native ``AuditCtx`` integration should always pass ``run_id`` explicitly.
|
||||
"""
|
||||
if run_id is not None:
|
||||
return run_id
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
result = await session.execute(
|
||||
text("SELECT id FROM leaudit_audit_runs WHERE document_id = :did ORDER BY id DESC LIMIT 1"),
|
||||
{"did": document_id},
|
||||
)
|
||||
row = result.fetchone()
|
||||
return row[0] if row else None
|
||||
|
||||
# ---- Document status ----
|
||||
|
||||
async def update_document_status(self, document_id: int, status: str) -> None:
|
||||
"""Update the document's processing status."""
|
||||
client = _get_postgrest_client()
|
||||
await client.update(
|
||||
table="documents",
|
||||
filters={"id": f"eq.{document_id}"},
|
||||
data={"status": status},
|
||||
)
|
||||
"""Update the document's processing_status."""
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
text("UPDATE leaudit_documents SET processing_status = :s, update_time = now() WHERE id = :did"),
|
||||
{"s": status, "did": document_id},
|
||||
)
|
||||
await session.commit()
|
||||
log.debug("[%d] Status updated: %s", document_id, status)
|
||||
|
||||
# ---- Document number (案件编号) ----
|
||||
|
||||
async def update_document_number(self, document_id: int, document_number: str) -> None:
|
||||
"""Update the document's case number (document_number field)."""
|
||||
client = _get_postgrest_client()
|
||||
await client.update(
|
||||
table="documents",
|
||||
filters={"id": f"eq.{document_id}"},
|
||||
data={"document_number": document_number},
|
||||
)
|
||||
"""Update the document's case number."""
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
text("UPDATE leaudit_documents SET document_number = :dn, update_time = now() WHERE id = :did"),
|
||||
{"dn": document_number, "did": document_id},
|
||||
)
|
||||
await session.commit()
|
||||
log.info("[%d] document_number updated: %s", document_id, document_number)
|
||||
|
||||
# ---- OCR result ----
|
||||
|
||||
async def save_ocr_result(self, document_id: int, ocr_result: OcrResult) -> None:
|
||||
"""Save OCR result to documents.ocr_result and raw_full_text_original."""
|
||||
client = _get_postgrest_client()
|
||||
# ---- OCR result (stored as leaudit_artifact) ----
|
||||
|
||||
async def save_ocr_result(
|
||||
self,
|
||||
document_id: int,
|
||||
ocr_result: OcrResult,
|
||||
*,
|
||||
run_id: int | None = None,
|
||||
) -> None:
|
||||
"""Save OCR result as a leaudit_artifact record."""
|
||||
ocr_dict = _ocr_to_dict(ocr_result)
|
||||
full_text = ocr_result.full_text or "\n".join(p.text for p in ocr_result.pages)
|
||||
resolved_run_id = await self._ensure_run_id(document_id, run_id)
|
||||
|
||||
await client.update(
|
||||
table="documents",
|
||||
filters={"id": f"eq.{document_id}"},
|
||||
data={
|
||||
"ocr_result": ocr_dict,
|
||||
"raw_full_text_original": full_text,
|
||||
},
|
||||
)
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
text("""INSERT INTO leaudit_artifacts (run_id, document_id, artifact_type, artifact_role,
|
||||
file_name, file_ext, mime_type, file_size, is_persisted, retention_policy)
|
||||
VALUES (:rid, :did, 'ocr_json', 'output', 'ocr_result.json', 'json', 'application/json',
|
||||
:fs, false, 'run_life')"""),
|
||||
{"rid": resolved_run_id, "did": document_id, "fs": len(full_text)},
|
||||
)
|
||||
await session.commit()
|
||||
log.info("[%d] OCR result saved (%d pages)", document_id, len(ocr_result.pages))
|
||||
|
||||
# ---- Extraction result ----
|
||||
|
||||
async def save_extraction_result(
|
||||
self, document_id: int, bundle: ExtractionBundle,
|
||||
self,
|
||||
document_id: int,
|
||||
bundle: ExtractionBundle,
|
||||
*,
|
||||
run_id: int | None = None,
|
||||
) -> None:
|
||||
"""Save extraction result to documents.extracted_results."""
|
||||
client = _get_postgrest_client()
|
||||
|
||||
"""Save extraction result to leaudit_field_results table."""
|
||||
extracted = _bundle_to_extracted(bundle)
|
||||
resolved_run_id = await self._ensure_run_id(document_id, run_id)
|
||||
|
||||
await client.update(
|
||||
table="documents",
|
||||
filters={"id": f"eq.{document_id}"},
|
||||
data={"extracted_results": extracted},
|
||||
)
|
||||
log.info(
|
||||
"[%d] Extraction result saved (%d fields)",
|
||||
document_id,
|
||||
len(bundle.fields),
|
||||
)
|
||||
async with GetAsyncSession() as session:
|
||||
for name, fv in bundle.fields.items():
|
||||
field_data = extracted.get("fields", {}).get(name, {})
|
||||
await session.execute(
|
||||
text("""INSERT INTO leaudit_field_results (run_id, document_id, field_name, value_text,
|
||||
confidence) VALUES (:rid, :did, :fn, :vt, :cf)
|
||||
ON CONFLICT DO NOTHING"""),
|
||||
{
|
||||
"rid": resolved_run_id, "did": document_id,
|
||||
"fn": name, "vt": str(field_data.get("value", "")),
|
||||
"cf": float(field_data.get("confidence", 0)),
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
log.info("[%d] Extraction result saved (%d fields)", document_id, len(bundle.fields))
|
||||
|
||||
# ---- Evaluation results ----
|
||||
|
||||
@@ -102,30 +129,56 @@ class StorageAdapter:
|
||||
rules_file: RulesFile,
|
||||
evaluation: EvaluationResult,
|
||||
bundle: ExtractionBundle,
|
||||
*,
|
||||
run_id: int | None = None,
|
||||
rule_version_id: int | None = None,
|
||||
) -> None:
|
||||
"""Save evaluation results to leaudit_evaluation_results table.
|
||||
"""Save evaluation results to leaudit_rule_results table.
|
||||
|
||||
One row per rule. Deletes existing results for the document first,
|
||||
then inserts fresh rows.
|
||||
"""
|
||||
client = _get_postgrest_client()
|
||||
resolved_run_id = await self._ensure_run_id(document_id, run_id)
|
||||
async with GetAsyncSession() as session:
|
||||
# Delete existing results for this document+run
|
||||
await session.execute(
|
||||
text("DELETE FROM leaudit_rule_results WHERE document_id = :did AND run_id = :rid"),
|
||||
{"did": document_id, "rid": resolved_run_id},
|
||||
)
|
||||
|
||||
# Delete existing results for this document
|
||||
await client.delete(
|
||||
table="leaudit_evaluation_results",
|
||||
filters={"document_id": f"eq.{document_id}"},
|
||||
)
|
||||
# Build rule_id → rule metadata lookup
|
||||
rule_meta = {}
|
||||
for rule in rules_file.flat_rules:
|
||||
rule_meta[rule.rule_id] = rule
|
||||
|
||||
# Build rule_id → rule metadata lookup
|
||||
rule_meta = {}
|
||||
for rule in rules_file.flat_rules:
|
||||
rule_meta[rule.rule_id] = rule
|
||||
# Insert one row per rule result
|
||||
for rule_result in evaluation.rules:
|
||||
rule = rule_meta.get(rule_result.rule_id)
|
||||
row = _rule_result_to_row(document_id, resolved_run_id, rule_result, rule, bundle)
|
||||
if rule_version_id is not None:
|
||||
row["rule_version_id"] = rule_version_id
|
||||
columns = ", ".join(row.keys())
|
||||
placeholders = ", ".join(f":{k}" for k in row)
|
||||
await session.execute(
|
||||
text(f"INSERT INTO leaudit_rule_results ({columns}) VALUES ({placeholders})"),
|
||||
row,
|
||||
)
|
||||
|
||||
# Insert one row per rule result
|
||||
for rule_result in evaluation.rules:
|
||||
rule = rule_meta.get(rule_result.rule_id)
|
||||
row = _rule_result_to_row(document_id, rule_result, rule, bundle)
|
||||
await client.insert(table="leaudit_evaluation_results", data=row)
|
||||
# Update audit_runs summary (scores only — terminal state set by finalize_run)
|
||||
await session.execute(
|
||||
text("""UPDATE leaudit_audit_runs SET
|
||||
total_score = :ts, passed_count = :pc, failed_count = :fc,
|
||||
skipped_count = :sc, update_time = now()
|
||||
WHERE id = :rid"""),
|
||||
{
|
||||
"ts": evaluation.total_score,
|
||||
"pc": evaluation.passed_count,
|
||||
"fc": evaluation.failed_count,
|
||||
"sc": evaluation.skipped_count,
|
||||
"rid": resolved_run_id,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
log.info(
|
||||
"[%d] Evaluation results saved: %d passed, %d failed, %d skipped",
|
||||
@@ -135,6 +188,292 @@ class StorageAdapter:
|
||||
evaluation.skipped_count,
|
||||
)
|
||||
|
||||
async def save_run_metrics(
|
||||
self,
|
||||
document_id: int,
|
||||
*,
|
||||
run_id: int | None,
|
||||
timing: dict[str, float] | None = None,
|
||||
page_count: int | None = None,
|
||||
sub_document_count: int | None = None,
|
||||
field_count: int | None = None,
|
||||
rule_count: int | None = None,
|
||||
rescue_rule_count: int | None = None,
|
||||
artifact_count: int | None = None,
|
||||
) -> None:
|
||||
"""保存运行指标。"""
|
||||
resolved_run_id = await self._ensure_run_id(document_id, run_id)
|
||||
metric = dict(timing or {})
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
text("DELETE FROM leaudit_run_metrics WHERE run_id = :rid"),
|
||||
{"rid": resolved_run_id},
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO leaudit_run_metrics (
|
||||
run_id,
|
||||
ocr_seconds,
|
||||
normalize_seconds,
|
||||
extract_seconds,
|
||||
evaluate_seconds,
|
||||
rescue_seconds,
|
||||
total_seconds,
|
||||
page_count,
|
||||
sub_document_count,
|
||||
field_count,
|
||||
rule_count,
|
||||
llm_call_count,
|
||||
vlm_call_count,
|
||||
rescue_rule_count,
|
||||
artifact_count
|
||||
) VALUES (
|
||||
:run_id,
|
||||
:ocr_seconds,
|
||||
:normalize_seconds,
|
||||
:extract_seconds,
|
||||
:evaluate_seconds,
|
||||
:rescue_seconds,
|
||||
:total_seconds,
|
||||
:page_count,
|
||||
:sub_document_count,
|
||||
:field_count,
|
||||
:rule_count,
|
||||
:llm_call_count,
|
||||
:vlm_call_count,
|
||||
:rescue_rule_count,
|
||||
:artifact_count
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"run_id": resolved_run_id,
|
||||
"ocr_seconds": metric.get("ocr"),
|
||||
"normalize_seconds": metric.get("normalize"),
|
||||
"extract_seconds": metric.get("extraction", metric.get("extract")),
|
||||
"evaluate_seconds": metric.get("evaluation", metric.get("evaluate")),
|
||||
"rescue_seconds": metric.get("rescue"),
|
||||
"total_seconds": metric.get("total"),
|
||||
"page_count": page_count,
|
||||
"sub_document_count": sub_document_count,
|
||||
"field_count": field_count,
|
||||
"rule_count": rule_count,
|
||||
"llm_call_count": None,
|
||||
"vlm_call_count": None,
|
||||
"rescue_rule_count": rescue_rule_count,
|
||||
"artifact_count": artifact_count,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def save_run_errors(
|
||||
self,
|
||||
document_id: int,
|
||||
*,
|
||||
run_id: int | None,
|
||||
stage: str,
|
||||
messages: list[str],
|
||||
level: str = "error",
|
||||
error_code: str | None = None,
|
||||
detail_json: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""保存运行错误。"""
|
||||
if not messages:
|
||||
return
|
||||
|
||||
resolved_run_id = await self._ensure_run_id(document_id, run_id)
|
||||
async with GetAsyncSession() as session:
|
||||
for message in messages:
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO leaudit_run_errors (
|
||||
run_id,
|
||||
document_id,
|
||||
stage,
|
||||
level,
|
||||
error_code,
|
||||
message,
|
||||
detail_json
|
||||
) VALUES (
|
||||
:run_id,
|
||||
:document_id,
|
||||
:stage,
|
||||
:level,
|
||||
:error_code,
|
||||
:message,
|
||||
:detail_json
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"run_id": resolved_run_id,
|
||||
"document_id": document_id,
|
||||
"stage": stage,
|
||||
"level": level,
|
||||
"error_code": error_code,
|
||||
"message": message,
|
||||
"detail_json": detail_json,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def save_rescue_outcomes(
|
||||
self,
|
||||
document_id: int,
|
||||
*,
|
||||
run_id: int | None,
|
||||
tasks: tuple[RescueTask, ...] | list[RescueTask],
|
||||
) -> None:
|
||||
"""保存补救结果。"""
|
||||
if not tasks:
|
||||
return
|
||||
|
||||
resolved_run_id = await self._ensure_run_id(document_id, run_id)
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
text("DELETE FROM leaudit_rescue_outcomes WHERE run_id = :rid"),
|
||||
{"rid": resolved_run_id},
|
||||
)
|
||||
for task in tasks:
|
||||
final_status = "review" if task.requires_human_review else ("pass" if task.final_status == "pass" else "fail")
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
INSERT INTO leaudit_rescue_outcomes (
|
||||
run_id,
|
||||
document_id,
|
||||
rule_id,
|
||||
status,
|
||||
diagnosis,
|
||||
diagnosis_confidence,
|
||||
final_status,
|
||||
failure_reason,
|
||||
llm_calls,
|
||||
vlm_calls,
|
||||
duration_ms,
|
||||
requires_human_review,
|
||||
payload,
|
||||
create_time,
|
||||
update_time
|
||||
) VALUES (
|
||||
:run_id,
|
||||
:document_id,
|
||||
:rule_id,
|
||||
:status,
|
||||
:diagnosis,
|
||||
:diagnosis_confidence,
|
||||
:final_status,
|
||||
:failure_reason,
|
||||
:llm_calls,
|
||||
:vlm_calls,
|
||||
:duration_ms,
|
||||
:requires_human_review,
|
||||
:payload,
|
||||
:create_time,
|
||||
:update_time
|
||||
)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"run_id": resolved_run_id,
|
||||
"document_id": document_id,
|
||||
"rule_id": task.rule_id,
|
||||
"status": task.status.value,
|
||||
"diagnosis": task.diagnosis.value if task.diagnosis else None,
|
||||
"diagnosis_confidence": task.diagnosis_confidence,
|
||||
"final_status": final_status,
|
||||
"failure_reason": task.failure_reason,
|
||||
"llm_calls": task.llm_calls,
|
||||
"vlm_calls": task.vlm_calls,
|
||||
"duration_ms": task.duration_ms,
|
||||
"requires_human_review": task.requires_human_review,
|
||||
"payload": task.model_dump(mode="json"),
|
||||
"create_time": task.created_at,
|
||||
"update_time": task.updated_at,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def finalize_run(
|
||||
self,
|
||||
document_id: int,
|
||||
*,
|
||||
run_id: int | None,
|
||||
result_status: str,
|
||||
rescue_applied: bool,
|
||||
phase: str | None = None,
|
||||
finished: bool = True,
|
||||
) -> None:
|
||||
"""更新运行主表的最终摘要状态。"""
|
||||
resolved_run_id = await self._ensure_run_id(document_id, run_id)
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE leaudit_audit_runs
|
||||
SET
|
||||
phase = COALESCE(:phase, phase),
|
||||
rescue_applied = :rescue_applied,
|
||||
result_status = :result_status,
|
||||
finished_at = CASE WHEN :finished THEN now() ELSE finished_at END,
|
||||
update_time = now()
|
||||
WHERE id = :rid
|
||||
"""
|
||||
),
|
||||
{
|
||||
"phase": phase,
|
||||
"rescue_applied": rescue_applied,
|
||||
"result_status": result_status,
|
||||
"finished": finished,
|
||||
"rid": resolved_run_id,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
async def fail_run(
|
||||
self,
|
||||
document_id: int,
|
||||
*,
|
||||
run_id: int | None,
|
||||
phase: str | None,
|
||||
message: str,
|
||||
detail_json: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""记录运行失败并更新主表。"""
|
||||
resolved_run_id = await self._ensure_run_id(document_id, run_id)
|
||||
await self.save_run_errors(
|
||||
document_id,
|
||||
run_id=resolved_run_id,
|
||||
stage=phase or "persist",
|
||||
messages=[message],
|
||||
level="fatal",
|
||||
error_code="AUDIT_RUN_FAILED",
|
||||
detail_json=detail_json,
|
||||
)
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE leaudit_audit_runs
|
||||
SET
|
||||
status = 'failed',
|
||||
phase = COALESCE(:phase, phase),
|
||||
result_status = 'error',
|
||||
finished_at = now(),
|
||||
update_time = now()
|
||||
WHERE id = :rid
|
||||
"""
|
||||
),
|
||||
{
|
||||
"phase": phase,
|
||||
"rid": resolved_run_id,
|
||||
},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
|
||||
# ---- Serialization helpers ----
|
||||
|
||||
@@ -305,11 +644,12 @@ def _extract_relevant_field_positions(
|
||||
|
||||
def _rule_result_to_row(
|
||||
document_id: int,
|
||||
run_id: int | None,
|
||||
rule_result: RuleResult,
|
||||
rule: Any | None,
|
||||
bundle: ExtractionBundle,
|
||||
) -> dict[str, Any]:
|
||||
"""Convert a RuleResult to a leaudit_evaluation_results row."""
|
||||
"""Convert a RuleResult to a leaudit_rule_results row."""
|
||||
passed = rule_result.passed
|
||||
|
||||
# Resolve messages: rule_result → rule definition
|
||||
@@ -345,6 +685,7 @@ def _rule_result_to_row(
|
||||
rule_meta_data["group"] = rule.group
|
||||
|
||||
return {
|
||||
"run_id": run_id,
|
||||
"document_id": document_id,
|
||||
"rule_id": rule_result.rule_id,
|
||||
"rule_name": rule_result.name,
|
||||
|
||||
Reference in New Issue
Block a user