Files
leaudit-platform-backend/tests/test_leaudit_rag_bridge.py
T

129 lines
3.6 KiB
Python

from __future__ import annotations
import tempfile
import sys
import types
import unittest
from pathlib import Path
from unittest.mock import patch
from leaudit.dsl.schema import Metadata, Rule, RuleAuthoringGroup, RulesFile, Stage
from leaudit.engine.models import EvaluationResult
from leaudit.extraction.bundle import bundle_from_single
from leaudit.extraction.models import ExtractionResult, FieldValue
from leaudit.ocr.models import OcrResult, Page
from fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline import LauditPipeline
class FakeOcrClient:
async def ocr(self, file_path):
return OcrResult(
pages=[Page(page_num=1, text="处罚决定书")],
full_text="处罚决定书",
)
class FakeStorage:
async def update_document_status(self, document_id, status):
return None
async def save_ocr_result(self, document_id, ocr_result):
return None
async def save_extraction_result(self, document_id, extraction_bundle):
return None
async def save_evaluation_results(self, document_id, rules_file, evaluation_result, extraction_bundle):
return None
class FakeRetriever:
pass
def _rules_file() -> RulesFile:
return RulesFile(
metadata=Metadata(
type_id="test.rag_bridge",
name="RAG bridge test",
version="1.0",
last_updated="2026-05-21",
),
rules=[
RuleAuthoringGroup(
group="测试",
rules=[
Rule(
rule_id="R-RAG",
name="RAG 规则",
risk="medium",
score=1,
stages=[Stage(id="1", check="required", field="处罚决定书.处罚依据")],
logic="1",
messages={"pass": "ok", "fail": "missing"},
)
],
)
],
)
class LeauditRagBridgeTest(unittest.IsolatedAsyncioTestCase):
async def test_pipeline_passes_injected_rag_retriever_to_evaluation(self):
captured = {}
retriever = FakeRetriever()
field_value = FieldValue(value="依据", confidence=0.9)
object.__setattr__(field_value, "position", None)
extraction = bundle_from_single(
ExtractionResult(
fields={"处罚决定书.处罚依据": field_value},
source_text="处罚决定书",
)
)
async def fake_dispatch_extract(*args, **kwargs):
return extraction
async def fake_determine_phase(*args, **kwargs):
return "executed"
async def fake_evaluate_extraction(*args, **kwargs):
captured["retriever"] = kwargs.get("retriever")
return EvaluationResult()
class TestPipeline(LauditPipeline):
async def _extract_and_save_case_number(self, document_id, ocr_result):
return None
with tempfile.NamedTemporaryFile() as tmp:
pipeline = TestPipeline(
ocr_client=FakeOcrClient(),
storage_adapter=FakeStorage(),
rag_retriever=retriever,
)
with patch(
"fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.dispatch_extract",
fake_dispatch_extract,
), patch(
"fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.determine_phase",
fake_determine_phase,
), patch(
"fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.evaluate_extraction",
fake_evaluate_extraction,
), patch.dict(
sys.modules,
{
"leaudit.extraction.coordinate_resolver": types.SimpleNamespace(
resolve_bundle_positions=lambda *args, **kwargs: None
)
},
):
await pipeline.run(1, Path(tmp.name), _rules_file())
self.assertIs(captured["retriever"], retriever)
if __name__ == "__main__":
unittest.main()