from __future__ import annotations import tempfile import sys import types import unittest from pathlib import Path from unittest.mock import patch from leaudit.dsl.schema import Metadata, Rule, RuleAuthoringGroup, RulesFile, Stage from leaudit.engine.models import EvaluationResult from leaudit.extraction.bundle import bundle_from_single from leaudit.extraction.models import ExtractionResult, FieldValue from leaudit.ocr.models import OcrResult, Page from fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline import LauditPipeline class FakeOcrClient: async def ocr(self, file_path): return OcrResult( pages=[Page(page_num=1, text="处罚决定书")], full_text="处罚决定书", ) class FakeStorage: async def update_document_status(self, document_id, status): return None async def save_ocr_result(self, document_id, ocr_result): return None async def save_extraction_result(self, document_id, extraction_bundle): return None async def save_evaluation_results(self, document_id, rules_file, evaluation_result, extraction_bundle): return None class FakeRetriever: pass def _rules_file() -> RulesFile: return RulesFile( metadata=Metadata( type_id="test.rag_bridge", name="RAG bridge test", version="1.0", last_updated="2026-05-21", ), rules=[ RuleAuthoringGroup( group="测试", rules=[ Rule( rule_id="R-RAG", name="RAG 规则", risk="medium", score=1, stages=[Stage(id="1", check="required", field="处罚决定书.处罚依据")], logic="1", messages={"pass": "ok", "fail": "missing"}, ) ], ) ], ) class LeauditRagBridgeTest(unittest.IsolatedAsyncioTestCase): async def test_pipeline_passes_injected_rag_retriever_to_evaluation(self): captured = {} retriever = FakeRetriever() field_value = FieldValue(value="依据", confidence=0.9) object.__setattr__(field_value, "position", None) extraction = bundle_from_single( ExtractionResult( fields={"处罚决定书.处罚依据": field_value}, source_text="处罚决定书", ) ) async def fake_dispatch_extract(*args, **kwargs): return extraction async def fake_determine_phase(*args, **kwargs): return "executed" async def fake_evaluate_extraction(*args, **kwargs): captured["retriever"] = kwargs.get("retriever") return EvaluationResult() class TestPipeline(LauditPipeline): async def _extract_and_save_case_number(self, document_id, ocr_result): return None with tempfile.NamedTemporaryFile() as tmp: pipeline = TestPipeline( ocr_client=FakeOcrClient(), storage_adapter=FakeStorage(), rag_retriever=retriever, ) with patch( "fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.dispatch_extract", fake_dispatch_extract, ), patch( "fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.determine_phase", fake_determine_phase, ), patch( "fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.evaluate_extraction", fake_evaluate_extraction, ), patch.dict( sys.modules, { "leaudit.extraction.coordinate_resolver": types.SimpleNamespace( resolve_bundle_positions=lambda *args, **kwargs: None ) }, ): await pipeline.run(1, Path(tmp.name), _rules_file()) self.assertIs(captured["retriever"], retriever) if __name__ == "__main__": unittest.main()