diff --git a/fastapi_modules/fastapi_leaudit/leaudit_bridge/storage_adapter.py b/fastapi_modules/fastapi_leaudit/leaudit_bridge/storage_adapter.py index 391d29d..0a3a414 100644 --- a/fastapi_modules/fastapi_leaudit/leaudit_bridge/storage_adapter.py +++ b/fastapi_modules/fastapi_leaudit/leaudit_bridge/storage_adapter.py @@ -1,9 +1,7 @@ -"""Storage adapter — write leaudit results into docauditai's PostgreSQL via PostgREST. +"""Storage adapter — write leaudit results into leaudit_platform PostgreSQL via SQLAlchemy. Converts leaudit's OcrResult, ExtractionBundle, EvaluationResult -into docauditai's table format and writes via PostgRESTClient. - -Uses the new `leaudit_evaluation_results` table for per-rule results. +into leaudit_* table format and writes via SQLAlchemy async session. """ from __future__ import annotations @@ -12,87 +10,116 @@ import logging import re from typing import Any +from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession +from sqlalchemy import text + from leaudit.dsl.schema import RulesFile from leaudit.engine.models import EvaluationResult, RuleResult from leaudit.extraction.bundle import ExtractionBundle from leaudit.extraction.models import FieldValue from leaudit.ocr.models import OcrResult +from leaudit.rescue.models import RescueTask log = logging.getLogger(__name__) -def _get_postgrest_client(): - """Lazy import to avoid circular dependency at module load.""" - from core.postgrest.client import get_postgrest_client - return get_postgrest_client() - - class StorageAdapter: - """Write leaudit pipeline results to docauditai's database.""" + """Write leaudit pipeline results to leaudit_platform database.""" + + async def _ensure_run_id(self, document_id: int, run_id: int | None) -> int | None: + """Return explicit ``run_id`` when given, otherwise fall back to latest run. + + The fallback path exists only for the legacy hand-written pipeline. + Native ``AuditCtx`` integration should always pass ``run_id`` explicitly. + """ + if run_id is not None: + return run_id + + async with GetAsyncSession() as session: + result = await session.execute( + text("SELECT id FROM leaudit_audit_runs WHERE document_id = :did ORDER BY id DESC LIMIT 1"), + {"did": document_id}, + ) + row = result.fetchone() + return row[0] if row else None # ---- Document status ---- async def update_document_status(self, document_id: int, status: str) -> None: - """Update the document's processing status.""" - client = _get_postgrest_client() - await client.update( - table="documents", - filters={"id": f"eq.{document_id}"}, - data={"status": status}, - ) + """Update the document's processing_status.""" + async with GetAsyncSession() as session: + await session.execute( + text("UPDATE leaudit_documents SET processing_status = :s, update_time = now() WHERE id = :did"), + {"s": status, "did": document_id}, + ) + await session.commit() log.debug("[%d] Status updated: %s", document_id, status) # ---- Document number (案件编号) ---- async def update_document_number(self, document_id: int, document_number: str) -> None: - """Update the document's case number (document_number field).""" - client = _get_postgrest_client() - await client.update( - table="documents", - filters={"id": f"eq.{document_id}"}, - data={"document_number": document_number}, - ) + """Update the document's case number.""" + async with GetAsyncSession() as session: + await session.execute( + text("UPDATE leaudit_documents SET document_number = :dn, update_time = now() WHERE id = :did"), + {"dn": document_number, "did": document_id}, + ) + await session.commit() log.info("[%d] document_number updated: %s", document_id, document_number) - # ---- OCR result ---- - - async def save_ocr_result(self, document_id: int, ocr_result: OcrResult) -> None: - """Save OCR result to documents.ocr_result and raw_full_text_original.""" - client = _get_postgrest_client() + # ---- OCR result (stored as leaudit_artifact) ---- + async def save_ocr_result( + self, + document_id: int, + ocr_result: OcrResult, + *, + run_id: int | None = None, + ) -> None: + """Save OCR result as a leaudit_artifact record.""" ocr_dict = _ocr_to_dict(ocr_result) full_text = ocr_result.full_text or "\n".join(p.text for p in ocr_result.pages) + resolved_run_id = await self._ensure_run_id(document_id, run_id) - await client.update( - table="documents", - filters={"id": f"eq.{document_id}"}, - data={ - "ocr_result": ocr_dict, - "raw_full_text_original": full_text, - }, - ) + async with GetAsyncSession() as session: + await session.execute( + text("""INSERT INTO leaudit_artifacts (run_id, document_id, artifact_type, artifact_role, + file_name, file_ext, mime_type, file_size, is_persisted, retention_policy) + VALUES (:rid, :did, 'ocr_json', 'output', 'ocr_result.json', 'json', 'application/json', + :fs, false, 'run_life')"""), + {"rid": resolved_run_id, "did": document_id, "fs": len(full_text)}, + ) + await session.commit() log.info("[%d] OCR result saved (%d pages)", document_id, len(ocr_result.pages)) # ---- Extraction result ---- async def save_extraction_result( - self, document_id: int, bundle: ExtractionBundle, + self, + document_id: int, + bundle: ExtractionBundle, + *, + run_id: int | None = None, ) -> None: - """Save extraction result to documents.extracted_results.""" - client = _get_postgrest_client() - + """Save extraction result to leaudit_field_results table.""" extracted = _bundle_to_extracted(bundle) + resolved_run_id = await self._ensure_run_id(document_id, run_id) - await client.update( - table="documents", - filters={"id": f"eq.{document_id}"}, - data={"extracted_results": extracted}, - ) - log.info( - "[%d] Extraction result saved (%d fields)", - document_id, - len(bundle.fields), - ) + async with GetAsyncSession() as session: + for name, fv in bundle.fields.items(): + field_data = extracted.get("fields", {}).get(name, {}) + await session.execute( + text("""INSERT INTO leaudit_field_results (run_id, document_id, field_name, value_text, + confidence) VALUES (:rid, :did, :fn, :vt, :cf) + ON CONFLICT DO NOTHING"""), + { + "rid": resolved_run_id, "did": document_id, + "fn": name, "vt": str(field_data.get("value", "")), + "cf": float(field_data.get("confidence", 0)), + }, + ) + await session.commit() + log.info("[%d] Extraction result saved (%d fields)", document_id, len(bundle.fields)) # ---- Evaluation results ---- @@ -102,30 +129,56 @@ class StorageAdapter: rules_file: RulesFile, evaluation: EvaluationResult, bundle: ExtractionBundle, + *, + run_id: int | None = None, + rule_version_id: int | None = None, ) -> None: - """Save evaluation results to leaudit_evaluation_results table. + """Save evaluation results to leaudit_rule_results table. One row per rule. Deletes existing results for the document first, then inserts fresh rows. """ - client = _get_postgrest_client() + resolved_run_id = await self._ensure_run_id(document_id, run_id) + async with GetAsyncSession() as session: + # Delete existing results for this document+run + await session.execute( + text("DELETE FROM leaudit_rule_results WHERE document_id = :did AND run_id = :rid"), + {"did": document_id, "rid": resolved_run_id}, + ) - # Delete existing results for this document - await client.delete( - table="leaudit_evaluation_results", - filters={"document_id": f"eq.{document_id}"}, - ) + # Build rule_id → rule metadata lookup + rule_meta = {} + for rule in rules_file.flat_rules: + rule_meta[rule.rule_id] = rule - # Build rule_id → rule metadata lookup - rule_meta = {} - for rule in rules_file.flat_rules: - rule_meta[rule.rule_id] = rule + # Insert one row per rule result + for rule_result in evaluation.rules: + rule = rule_meta.get(rule_result.rule_id) + row = _rule_result_to_row(document_id, resolved_run_id, rule_result, rule, bundle) + if rule_version_id is not None: + row["rule_version_id"] = rule_version_id + columns = ", ".join(row.keys()) + placeholders = ", ".join(f":{k}" for k in row) + await session.execute( + text(f"INSERT INTO leaudit_rule_results ({columns}) VALUES ({placeholders})"), + row, + ) - # Insert one row per rule result - for rule_result in evaluation.rules: - rule = rule_meta.get(rule_result.rule_id) - row = _rule_result_to_row(document_id, rule_result, rule, bundle) - await client.insert(table="leaudit_evaluation_results", data=row) + # Update audit_runs summary (scores only — terminal state set by finalize_run) + await session.execute( + text("""UPDATE leaudit_audit_runs SET + total_score = :ts, passed_count = :pc, failed_count = :fc, + skipped_count = :sc, update_time = now() + WHERE id = :rid"""), + { + "ts": evaluation.total_score, + "pc": evaluation.passed_count, + "fc": evaluation.failed_count, + "sc": evaluation.skipped_count, + "rid": resolved_run_id, + }, + ) + await session.commit() log.info( "[%d] Evaluation results saved: %d passed, %d failed, %d skipped", @@ -135,6 +188,292 @@ class StorageAdapter: evaluation.skipped_count, ) + async def save_run_metrics( + self, + document_id: int, + *, + run_id: int | None, + timing: dict[str, float] | None = None, + page_count: int | None = None, + sub_document_count: int | None = None, + field_count: int | None = None, + rule_count: int | None = None, + rescue_rule_count: int | None = None, + artifact_count: int | None = None, + ) -> None: + """保存运行指标。""" + resolved_run_id = await self._ensure_run_id(document_id, run_id) + metric = dict(timing or {}) + + async with GetAsyncSession() as session: + await session.execute( + text("DELETE FROM leaudit_run_metrics WHERE run_id = :rid"), + {"rid": resolved_run_id}, + ) + await session.execute( + text( + """ + INSERT INTO leaudit_run_metrics ( + run_id, + ocr_seconds, + normalize_seconds, + extract_seconds, + evaluate_seconds, + rescue_seconds, + total_seconds, + page_count, + sub_document_count, + field_count, + rule_count, + llm_call_count, + vlm_call_count, + rescue_rule_count, + artifact_count + ) VALUES ( + :run_id, + :ocr_seconds, + :normalize_seconds, + :extract_seconds, + :evaluate_seconds, + :rescue_seconds, + :total_seconds, + :page_count, + :sub_document_count, + :field_count, + :rule_count, + :llm_call_count, + :vlm_call_count, + :rescue_rule_count, + :artifact_count + ) + """ + ), + { + "run_id": resolved_run_id, + "ocr_seconds": metric.get("ocr"), + "normalize_seconds": metric.get("normalize"), + "extract_seconds": metric.get("extraction", metric.get("extract")), + "evaluate_seconds": metric.get("evaluation", metric.get("evaluate")), + "rescue_seconds": metric.get("rescue"), + "total_seconds": metric.get("total"), + "page_count": page_count, + "sub_document_count": sub_document_count, + "field_count": field_count, + "rule_count": rule_count, + "llm_call_count": None, + "vlm_call_count": None, + "rescue_rule_count": rescue_rule_count, + "artifact_count": artifact_count, + }, + ) + await session.commit() + + async def save_run_errors( + self, + document_id: int, + *, + run_id: int | None, + stage: str, + messages: list[str], + level: str = "error", + error_code: str | None = None, + detail_json: dict[str, Any] | None = None, + ) -> None: + """保存运行错误。""" + if not messages: + return + + resolved_run_id = await self._ensure_run_id(document_id, run_id) + async with GetAsyncSession() as session: + for message in messages: + await session.execute( + text( + """ + INSERT INTO leaudit_run_errors ( + run_id, + document_id, + stage, + level, + error_code, + message, + detail_json + ) VALUES ( + :run_id, + :document_id, + :stage, + :level, + :error_code, + :message, + :detail_json + ) + """ + ), + { + "run_id": resolved_run_id, + "document_id": document_id, + "stage": stage, + "level": level, + "error_code": error_code, + "message": message, + "detail_json": detail_json, + }, + ) + await session.commit() + + async def save_rescue_outcomes( + self, + document_id: int, + *, + run_id: int | None, + tasks: tuple[RescueTask, ...] | list[RescueTask], + ) -> None: + """保存补救结果。""" + if not tasks: + return + + resolved_run_id = await self._ensure_run_id(document_id, run_id) + async with GetAsyncSession() as session: + await session.execute( + text("DELETE FROM leaudit_rescue_outcomes WHERE run_id = :rid"), + {"rid": resolved_run_id}, + ) + for task in tasks: + final_status = "review" if task.requires_human_review else ("pass" if task.final_status == "pass" else "fail") + await session.execute( + text( + """ + INSERT INTO leaudit_rescue_outcomes ( + run_id, + document_id, + rule_id, + status, + diagnosis, + diagnosis_confidence, + final_status, + failure_reason, + llm_calls, + vlm_calls, + duration_ms, + requires_human_review, + payload, + create_time, + update_time + ) VALUES ( + :run_id, + :document_id, + :rule_id, + :status, + :diagnosis, + :diagnosis_confidence, + :final_status, + :failure_reason, + :llm_calls, + :vlm_calls, + :duration_ms, + :requires_human_review, + :payload, + :create_time, + :update_time + ) + """ + ), + { + "run_id": resolved_run_id, + "document_id": document_id, + "rule_id": task.rule_id, + "status": task.status.value, + "diagnosis": task.diagnosis.value if task.diagnosis else None, + "diagnosis_confidence": task.diagnosis_confidence, + "final_status": final_status, + "failure_reason": task.failure_reason, + "llm_calls": task.llm_calls, + "vlm_calls": task.vlm_calls, + "duration_ms": task.duration_ms, + "requires_human_review": task.requires_human_review, + "payload": task.model_dump(mode="json"), + "create_time": task.created_at, + "update_time": task.updated_at, + }, + ) + await session.commit() + + async def finalize_run( + self, + document_id: int, + *, + run_id: int | None, + result_status: str, + rescue_applied: bool, + phase: str | None = None, + finished: bool = True, + ) -> None: + """更新运行主表的最终摘要状态。""" + resolved_run_id = await self._ensure_run_id(document_id, run_id) + async with GetAsyncSession() as session: + await session.execute( + text( + """ + UPDATE leaudit_audit_runs + SET + phase = COALESCE(:phase, phase), + rescue_applied = :rescue_applied, + result_status = :result_status, + finished_at = CASE WHEN :finished THEN now() ELSE finished_at END, + update_time = now() + WHERE id = :rid + """ + ), + { + "phase": phase, + "rescue_applied": rescue_applied, + "result_status": result_status, + "finished": finished, + "rid": resolved_run_id, + }, + ) + await session.commit() + + async def fail_run( + self, + document_id: int, + *, + run_id: int | None, + phase: str | None, + message: str, + detail_json: dict[str, Any] | None = None, + ) -> None: + """记录运行失败并更新主表。""" + resolved_run_id = await self._ensure_run_id(document_id, run_id) + await self.save_run_errors( + document_id, + run_id=resolved_run_id, + stage=phase or "persist", + messages=[message], + level="fatal", + error_code="AUDIT_RUN_FAILED", + detail_json=detail_json, + ) + async with GetAsyncSession() as session: + await session.execute( + text( + """ + UPDATE leaudit_audit_runs + SET + status = 'failed', + phase = COALESCE(:phase, phase), + result_status = 'error', + finished_at = now(), + update_time = now() + WHERE id = :rid + """ + ), + { + "phase": phase, + "rid": resolved_run_id, + }, + ) + await session.commit() + # ---- Serialization helpers ---- @@ -305,11 +644,12 @@ def _extract_relevant_field_positions( def _rule_result_to_row( document_id: int, + run_id: int | None, rule_result: RuleResult, rule: Any | None, bundle: ExtractionBundle, ) -> dict[str, Any]: - """Convert a RuleResult to a leaudit_evaluation_results row.""" + """Convert a RuleResult to a leaudit_rule_results row.""" passed = rule_result.passed # Resolve messages: rule_result → rule definition @@ -345,6 +685,7 @@ def _rule_result_to_row( rule_meta_data["group"] = rule.group return { + "run_id": run_id, "document_id": document_id, "rule_id": rule_result.rule_id, "rule_name": rule_result.name,