129 lines
4.1 KiB
Python
129 lines
4.1 KiB
Python
from __future__ import annotations
|
|
|
|
import tempfile
|
|
import sys
|
|
import types
|
|
import unittest
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
|
|
from leaudit.dsl.schema import Metadata, Rule, RuleAuthoringGroup, RulesFile, Stage
|
|
from leaudit.engine.models import EvaluationResult
|
|
from leaudit.extraction.bundle import bundle_from_single
|
|
from leaudit.extraction.models import ExtractionResult, FieldValue
|
|
from leaudit.ocr.models import OcrResult, Page
|
|
|
|
from fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline import LauditPipeline
|
|
|
|
|
|
class FakeOcrClient:
|
|
async def ocr(self, file_path):
|
|
return OcrResult(
|
|
pages=[Page(page_num=1, text="处罚决定书")],
|
|
full_text="处罚决定书",
|
|
)
|
|
|
|
|
|
class FakeStorage:
|
|
async def update_document_status(self, document_id, status):
|
|
return None
|
|
|
|
async def save_ocr_result(self, document_id, ocr_result):
|
|
return None
|
|
|
|
async def save_extraction_result(self, document_id, extraction_bundle):
|
|
return None
|
|
|
|
async def save_evaluation_results(self, document_id, rules_file, evaluation_result, extraction_bundle):
|
|
return None
|
|
|
|
|
|
class FakeRetriever:
|
|
pass
|
|
|
|
|
|
def _rules_file() -> RulesFile:
|
|
return RulesFile(
|
|
metadata=Metadata(
|
|
type_id="test.rag_bridge",
|
|
name="RAG bridge test",
|
|
version="1.0",
|
|
last_updated="2026-05-21",
|
|
),
|
|
rules=[
|
|
RuleAuthoringGroup(
|
|
group="测试",
|
|
rules=[
|
|
Rule(
|
|
rule_id="R-RAG",
|
|
name="RAG 规则",
|
|
risk="medium",
|
|
score=1,
|
|
stages=[Stage(id="1", check="required", field="处罚决定书.处罚依据")],
|
|
logic="1",
|
|
messages={"pass": "ok", "fail": "missing"},
|
|
)
|
|
],
|
|
)
|
|
],
|
|
)
|
|
|
|
|
|
class LeauditRagBridgeTest(unittest.IsolatedAsyncioTestCase):
|
|
async def test_pipeline_passes_injected_rag_retriever_to_evaluation(self):
|
|
captured = {}
|
|
retriever = FakeRetriever()
|
|
field_value = FieldValue(value="依据", confidence=0.9)
|
|
object.__setattr__(field_value, "position", None)
|
|
extraction = bundle_from_single(
|
|
ExtractionResult(
|
|
fields={"处罚决定书.处罚依据": field_value},
|
|
source_text="处罚决定书",
|
|
)
|
|
)
|
|
|
|
async def fake_dispatch_extract(*args, **kwargs):
|
|
return extraction
|
|
|
|
async def fake_determine_phase(*args, **kwargs):
|
|
return "executed"
|
|
|
|
async def fake_evaluate_extraction(*args, **kwargs):
|
|
captured["retriever"] = kwargs.get("retriever")
|
|
return EvaluationResult()
|
|
|
|
class TestPipeline(LauditPipeline):
|
|
async def _extract_and_save_case_number(self, document_id, ocr_result):
|
|
return None
|
|
|
|
with tempfile.NamedTemporaryFile() as tmp:
|
|
pipeline = TestPipeline(
|
|
ocr_client=FakeOcrClient(),
|
|
storage_adapter=FakeStorage(),
|
|
rag_retriever=retriever,
|
|
)
|
|
with patch(
|
|
"fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.dispatch_extract",
|
|
fake_dispatch_extract,
|
|
), patch(
|
|
"fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.determine_phase",
|
|
fake_determine_phase,
|
|
), patch(
|
|
"fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.evaluate_extraction",
|
|
fake_evaluate_extraction,
|
|
), patch.dict(
|
|
sys.modules,
|
|
{
|
|
"leaudit.extraction.coordinate_resolver": types.SimpleNamespace(
|
|
resolve_bundle_positions=lambda *args, **kwargs: None
|
|
)
|
|
},
|
|
):
|
|
await pipeline.run(1, Path(tmp.name), _rules_file())
|
|
|
|
self.assertIs(captured["retriever"], retriever)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|