fix: remove premature result_status/finished_at from save_evaluation_results

finalize_run() is the single source of truth for terminal run state.
Previously save_evaluation_results wrote a binary pass/fail status and
finished_at BEFORE rescue outcomes/metrics were saved, then finalize_run
overwrote it. Now scores only are written here; terminal state is set
once by finalize_run after all sub-results are persisted.
This commit is contained in:
wren
2026-04-28 11:43:52 +08:00
parent 72a9b8e393
commit 0a726ebf21
@@ -1,9 +1,7 @@
"""Storage adapter — write leaudit results into docauditai's PostgreSQL via PostgREST.
"""Storage adapter — write leaudit results into leaudit_platform PostgreSQL via SQLAlchemy.
Converts leaudit's OcrResult, ExtractionBundle, EvaluationResult
into docauditai's table format and writes via PostgRESTClient.
Uses the new `leaudit_evaluation_results` table for per-rule results.
into leaudit_* table format and writes via SQLAlchemy async session.
"""
from __future__ import annotations
@@ -12,87 +10,116 @@ import logging
import re
from typing import Any
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from sqlalchemy import text
from leaudit.dsl.schema import RulesFile
from leaudit.engine.models import EvaluationResult, RuleResult
from leaudit.extraction.bundle import ExtractionBundle
from leaudit.extraction.models import FieldValue
from leaudit.ocr.models import OcrResult
from leaudit.rescue.models import RescueTask
log = logging.getLogger(__name__)
def _get_postgrest_client():
"""Lazy import to avoid circular dependency at module load."""
from core.postgrest.client import get_postgrest_client
return get_postgrest_client()
class StorageAdapter:
"""Write leaudit pipeline results to docauditai's database."""
"""Write leaudit pipeline results to leaudit_platform database."""
async def _ensure_run_id(self, document_id: int, run_id: int | None) -> int | None:
"""Return explicit ``run_id`` when given, otherwise fall back to latest run.
The fallback path exists only for the legacy hand-written pipeline.
Native ``AuditCtx`` integration should always pass ``run_id`` explicitly.
"""
if run_id is not None:
return run_id
async with GetAsyncSession() as session:
result = await session.execute(
text("SELECT id FROM leaudit_audit_runs WHERE document_id = :did ORDER BY id DESC LIMIT 1"),
{"did": document_id},
)
row = result.fetchone()
return row[0] if row else None
# ---- Document status ----
async def update_document_status(self, document_id: int, status: str) -> None:
"""Update the document's processing status."""
client = _get_postgrest_client()
await client.update(
table="documents",
filters={"id": f"eq.{document_id}"},
data={"status": status},
)
"""Update the document's processing_status."""
async with GetAsyncSession() as session:
await session.execute(
text("UPDATE leaudit_documents SET processing_status = :s, update_time = now() WHERE id = :did"),
{"s": status, "did": document_id},
)
await session.commit()
log.debug("[%d] Status updated: %s", document_id, status)
# ---- Document number (案件编号) ----
async def update_document_number(self, document_id: int, document_number: str) -> None:
"""Update the document's case number (document_number field)."""
client = _get_postgrest_client()
await client.update(
table="documents",
filters={"id": f"eq.{document_id}"},
data={"document_number": document_number},
)
"""Update the document's case number."""
async with GetAsyncSession() as session:
await session.execute(
text("UPDATE leaudit_documents SET document_number = :dn, update_time = now() WHERE id = :did"),
{"dn": document_number, "did": document_id},
)
await session.commit()
log.info("[%d] document_number updated: %s", document_id, document_number)
# ---- OCR result ----
async def save_ocr_result(self, document_id: int, ocr_result: OcrResult) -> None:
"""Save OCR result to documents.ocr_result and raw_full_text_original."""
client = _get_postgrest_client()
# ---- OCR result (stored as leaudit_artifact) ----
async def save_ocr_result(
self,
document_id: int,
ocr_result: OcrResult,
*,
run_id: int | None = None,
) -> None:
"""Save OCR result as a leaudit_artifact record."""
ocr_dict = _ocr_to_dict(ocr_result)
full_text = ocr_result.full_text or "\n".join(p.text for p in ocr_result.pages)
resolved_run_id = await self._ensure_run_id(document_id, run_id)
await client.update(
table="documents",
filters={"id": f"eq.{document_id}"},
data={
"ocr_result": ocr_dict,
"raw_full_text_original": full_text,
},
)
async with GetAsyncSession() as session:
await session.execute(
text("""INSERT INTO leaudit_artifacts (run_id, document_id, artifact_type, artifact_role,
file_name, file_ext, mime_type, file_size, is_persisted, retention_policy)
VALUES (:rid, :did, 'ocr_json', 'output', 'ocr_result.json', 'json', 'application/json',
:fs, false, 'run_life')"""),
{"rid": resolved_run_id, "did": document_id, "fs": len(full_text)},
)
await session.commit()
log.info("[%d] OCR result saved (%d pages)", document_id, len(ocr_result.pages))
# ---- Extraction result ----
async def save_extraction_result(
self, document_id: int, bundle: ExtractionBundle,
self,
document_id: int,
bundle: ExtractionBundle,
*,
run_id: int | None = None,
) -> None:
"""Save extraction result to documents.extracted_results."""
client = _get_postgrest_client()
"""Save extraction result to leaudit_field_results table."""
extracted = _bundle_to_extracted(bundle)
resolved_run_id = await self._ensure_run_id(document_id, run_id)
await client.update(
table="documents",
filters={"id": f"eq.{document_id}"},
data={"extracted_results": extracted},
)
log.info(
"[%d] Extraction result saved (%d fields)",
document_id,
len(bundle.fields),
)
async with GetAsyncSession() as session:
for name, fv in bundle.fields.items():
field_data = extracted.get("fields", {}).get(name, {})
await session.execute(
text("""INSERT INTO leaudit_field_results (run_id, document_id, field_name, value_text,
confidence) VALUES (:rid, :did, :fn, :vt, :cf)
ON CONFLICT DO NOTHING"""),
{
"rid": resolved_run_id, "did": document_id,
"fn": name, "vt": str(field_data.get("value", "")),
"cf": float(field_data.get("confidence", 0)),
},
)
await session.commit()
log.info("[%d] Extraction result saved (%d fields)", document_id, len(bundle.fields))
# ---- Evaluation results ----
@@ -102,30 +129,56 @@ class StorageAdapter:
rules_file: RulesFile,
evaluation: EvaluationResult,
bundle: ExtractionBundle,
*,
run_id: int | None = None,
rule_version_id: int | None = None,
) -> None:
"""Save evaluation results to leaudit_evaluation_results table.
"""Save evaluation results to leaudit_rule_results table.
One row per rule. Deletes existing results for the document first,
then inserts fresh rows.
"""
client = _get_postgrest_client()
resolved_run_id = await self._ensure_run_id(document_id, run_id)
async with GetAsyncSession() as session:
# Delete existing results for this document+run
await session.execute(
text("DELETE FROM leaudit_rule_results WHERE document_id = :did AND run_id = :rid"),
{"did": document_id, "rid": resolved_run_id},
)
# Delete existing results for this document
await client.delete(
table="leaudit_evaluation_results",
filters={"document_id": f"eq.{document_id}"},
)
# Build rule_id → rule metadata lookup
rule_meta = {}
for rule in rules_file.flat_rules:
rule_meta[rule.rule_id] = rule
# Build rule_id → rule metadata lookup
rule_meta = {}
for rule in rules_file.flat_rules:
rule_meta[rule.rule_id] = rule
# Insert one row per rule result
for rule_result in evaluation.rules:
rule = rule_meta.get(rule_result.rule_id)
row = _rule_result_to_row(document_id, resolved_run_id, rule_result, rule, bundle)
if rule_version_id is not None:
row["rule_version_id"] = rule_version_id
columns = ", ".join(row.keys())
placeholders = ", ".join(f":{k}" for k in row)
await session.execute(
text(f"INSERT INTO leaudit_rule_results ({columns}) VALUES ({placeholders})"),
row,
)
# Insert one row per rule result
for rule_result in evaluation.rules:
rule = rule_meta.get(rule_result.rule_id)
row = _rule_result_to_row(document_id, rule_result, rule, bundle)
await client.insert(table="leaudit_evaluation_results", data=row)
# Update audit_runs summary (scores only — terminal state set by finalize_run)
await session.execute(
text("""UPDATE leaudit_audit_runs SET
total_score = :ts, passed_count = :pc, failed_count = :fc,
skipped_count = :sc, update_time = now()
WHERE id = :rid"""),
{
"ts": evaluation.total_score,
"pc": evaluation.passed_count,
"fc": evaluation.failed_count,
"sc": evaluation.skipped_count,
"rid": resolved_run_id,
},
)
await session.commit()
log.info(
"[%d] Evaluation results saved: %d passed, %d failed, %d skipped",
@@ -135,6 +188,292 @@ class StorageAdapter:
evaluation.skipped_count,
)
async def save_run_metrics(
self,
document_id: int,
*,
run_id: int | None,
timing: dict[str, float] | None = None,
page_count: int | None = None,
sub_document_count: int | None = None,
field_count: int | None = None,
rule_count: int | None = None,
rescue_rule_count: int | None = None,
artifact_count: int | None = None,
) -> None:
"""保存运行指标。"""
resolved_run_id = await self._ensure_run_id(document_id, run_id)
metric = dict(timing or {})
async with GetAsyncSession() as session:
await session.execute(
text("DELETE FROM leaudit_run_metrics WHERE run_id = :rid"),
{"rid": resolved_run_id},
)
await session.execute(
text(
"""
INSERT INTO leaudit_run_metrics (
run_id,
ocr_seconds,
normalize_seconds,
extract_seconds,
evaluate_seconds,
rescue_seconds,
total_seconds,
page_count,
sub_document_count,
field_count,
rule_count,
llm_call_count,
vlm_call_count,
rescue_rule_count,
artifact_count
) VALUES (
:run_id,
:ocr_seconds,
:normalize_seconds,
:extract_seconds,
:evaluate_seconds,
:rescue_seconds,
:total_seconds,
:page_count,
:sub_document_count,
:field_count,
:rule_count,
:llm_call_count,
:vlm_call_count,
:rescue_rule_count,
:artifact_count
)
"""
),
{
"run_id": resolved_run_id,
"ocr_seconds": metric.get("ocr"),
"normalize_seconds": metric.get("normalize"),
"extract_seconds": metric.get("extraction", metric.get("extract")),
"evaluate_seconds": metric.get("evaluation", metric.get("evaluate")),
"rescue_seconds": metric.get("rescue"),
"total_seconds": metric.get("total"),
"page_count": page_count,
"sub_document_count": sub_document_count,
"field_count": field_count,
"rule_count": rule_count,
"llm_call_count": None,
"vlm_call_count": None,
"rescue_rule_count": rescue_rule_count,
"artifact_count": artifact_count,
},
)
await session.commit()
async def save_run_errors(
self,
document_id: int,
*,
run_id: int | None,
stage: str,
messages: list[str],
level: str = "error",
error_code: str | None = None,
detail_json: dict[str, Any] | None = None,
) -> None:
"""保存运行错误。"""
if not messages:
return
resolved_run_id = await self._ensure_run_id(document_id, run_id)
async with GetAsyncSession() as session:
for message in messages:
await session.execute(
text(
"""
INSERT INTO leaudit_run_errors (
run_id,
document_id,
stage,
level,
error_code,
message,
detail_json
) VALUES (
:run_id,
:document_id,
:stage,
:level,
:error_code,
:message,
:detail_json
)
"""
),
{
"run_id": resolved_run_id,
"document_id": document_id,
"stage": stage,
"level": level,
"error_code": error_code,
"message": message,
"detail_json": detail_json,
},
)
await session.commit()
async def save_rescue_outcomes(
self,
document_id: int,
*,
run_id: int | None,
tasks: tuple[RescueTask, ...] | list[RescueTask],
) -> None:
"""保存补救结果。"""
if not tasks:
return
resolved_run_id = await self._ensure_run_id(document_id, run_id)
async with GetAsyncSession() as session:
await session.execute(
text("DELETE FROM leaudit_rescue_outcomes WHERE run_id = :rid"),
{"rid": resolved_run_id},
)
for task in tasks:
final_status = "review" if task.requires_human_review else ("pass" if task.final_status == "pass" else "fail")
await session.execute(
text(
"""
INSERT INTO leaudit_rescue_outcomes (
run_id,
document_id,
rule_id,
status,
diagnosis,
diagnosis_confidence,
final_status,
failure_reason,
llm_calls,
vlm_calls,
duration_ms,
requires_human_review,
payload,
create_time,
update_time
) VALUES (
:run_id,
:document_id,
:rule_id,
:status,
:diagnosis,
:diagnosis_confidence,
:final_status,
:failure_reason,
:llm_calls,
:vlm_calls,
:duration_ms,
:requires_human_review,
:payload,
:create_time,
:update_time
)
"""
),
{
"run_id": resolved_run_id,
"document_id": document_id,
"rule_id": task.rule_id,
"status": task.status.value,
"diagnosis": task.diagnosis.value if task.diagnosis else None,
"diagnosis_confidence": task.diagnosis_confidence,
"final_status": final_status,
"failure_reason": task.failure_reason,
"llm_calls": task.llm_calls,
"vlm_calls": task.vlm_calls,
"duration_ms": task.duration_ms,
"requires_human_review": task.requires_human_review,
"payload": task.model_dump(mode="json"),
"create_time": task.created_at,
"update_time": task.updated_at,
},
)
await session.commit()
async def finalize_run(
self,
document_id: int,
*,
run_id: int | None,
result_status: str,
rescue_applied: bool,
phase: str | None = None,
finished: bool = True,
) -> None:
"""更新运行主表的最终摘要状态。"""
resolved_run_id = await self._ensure_run_id(document_id, run_id)
async with GetAsyncSession() as session:
await session.execute(
text(
"""
UPDATE leaudit_audit_runs
SET
phase = COALESCE(:phase, phase),
rescue_applied = :rescue_applied,
result_status = :result_status,
finished_at = CASE WHEN :finished THEN now() ELSE finished_at END,
update_time = now()
WHERE id = :rid
"""
),
{
"phase": phase,
"rescue_applied": rescue_applied,
"result_status": result_status,
"finished": finished,
"rid": resolved_run_id,
},
)
await session.commit()
async def fail_run(
self,
document_id: int,
*,
run_id: int | None,
phase: str | None,
message: str,
detail_json: dict[str, Any] | None = None,
) -> None:
"""记录运行失败并更新主表。"""
resolved_run_id = await self._ensure_run_id(document_id, run_id)
await self.save_run_errors(
document_id,
run_id=resolved_run_id,
stage=phase or "persist",
messages=[message],
level="fatal",
error_code="AUDIT_RUN_FAILED",
detail_json=detail_json,
)
async with GetAsyncSession() as session:
await session.execute(
text(
"""
UPDATE leaudit_audit_runs
SET
status = 'failed',
phase = COALESCE(:phase, phase),
result_status = 'error',
finished_at = now(),
update_time = now()
WHERE id = :rid
"""
),
{
"phase": phase,
"rid": resolved_run_id,
},
)
await session.commit()
# ---- Serialization helpers ----
@@ -305,11 +644,12 @@ def _extract_relevant_field_positions(
def _rule_result_to_row(
document_id: int,
run_id: int | None,
rule_result: RuleResult,
rule: Any | None,
bundle: ExtractionBundle,
) -> dict[str, Any]:
"""Convert a RuleResult to a leaudit_evaluation_results row."""
"""Convert a RuleResult to a leaudit_rule_results row."""
passed = rule_result.passed
# Resolve messages: rule_result → rule definition
@@ -345,6 +685,7 @@ def _rule_result_to_row(
rule_meta_data["group"] = rule.group
return {
"run_id": run_id,
"document_id": document_id,
"rule_id": rule_result.rule_id,
"rule_name": rule_result.name,