feat(rag): add shared retriever for audit pipeline

This commit is contained in:
2026-05-21 15:54:47 +08:00
parent 4847bccdec
commit 6ce1a290ab
6 changed files with 795 additions and 316 deletions
+128
View File
@@ -0,0 +1,128 @@
from __future__ import annotations
import tempfile
import sys
import types
import unittest
from pathlib import Path
from unittest.mock import patch
from leaudit.dsl.schema import Metadata, Rule, RuleAuthoringGroup, RulesFile, Stage
from leaudit.engine.models import EvaluationResult
from leaudit.extraction.bundle import bundle_from_single
from leaudit.extraction.models import ExtractionResult, FieldValue
from leaudit.ocr.models import OcrResult, Page
from fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline import LauditPipeline
class FakeOcrClient:
async def ocr(self, file_path):
return OcrResult(
pages=[Page(page_num=1, text="处罚决定书")],
full_text="处罚决定书",
)
class FakeStorage:
async def update_document_status(self, document_id, status):
return None
async def save_ocr_result(self, document_id, ocr_result):
return None
async def save_extraction_result(self, document_id, extraction_bundle):
return None
async def save_evaluation_results(self, document_id, rules_file, evaluation_result, extraction_bundle):
return None
class FakeRetriever:
pass
def _rules_file() -> RulesFile:
return RulesFile(
metadata=Metadata(
type_id="test.rag_bridge",
name="RAG bridge test",
version="1.0",
last_updated="2026-05-21",
),
rules=[
RuleAuthoringGroup(
group="测试",
rules=[
Rule(
rule_id="R-RAG",
name="RAG 规则",
risk="medium",
score=1,
stages=[Stage(id="1", check="required", field="处罚决定书.处罚依据")],
logic="1",
messages={"pass": "ok", "fail": "missing"},
)
],
)
],
)
class LeauditRagBridgeTest(unittest.IsolatedAsyncioTestCase):
async def test_pipeline_passes_injected_rag_retriever_to_evaluation(self):
captured = {}
retriever = FakeRetriever()
field_value = FieldValue(value="依据", confidence=0.9)
object.__setattr__(field_value, "position", None)
extraction = bundle_from_single(
ExtractionResult(
fields={"处罚决定书.处罚依据": field_value},
source_text="处罚决定书",
)
)
async def fake_dispatch_extract(*args, **kwargs):
return extraction
async def fake_determine_phase(*args, **kwargs):
return "executed"
async def fake_evaluate_extraction(*args, **kwargs):
captured["retriever"] = kwargs.get("retriever")
return EvaluationResult()
class TestPipeline(LauditPipeline):
async def _extract_and_save_case_number(self, document_id, ocr_result):
return None
with tempfile.NamedTemporaryFile() as tmp:
pipeline = TestPipeline(
ocr_client=FakeOcrClient(),
storage_adapter=FakeStorage(),
rag_retriever=retriever,
)
with patch(
"fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.dispatch_extract",
fake_dispatch_extract,
), patch(
"fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.determine_phase",
fake_determine_phase,
), patch(
"fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.evaluate_extraction",
fake_evaluate_extraction,
), patch.dict(
sys.modules,
{
"leaudit.extraction.coordinate_resolver": types.SimpleNamespace(
resolve_bundle_positions=lambda *args, **kwargs: None
)
},
):
await pipeline.run(1, Path(tmp.name), _rules_file())
self.assertIs(captured["retriever"], retriever)
if __name__ == "__main__":
unittest.main()
+102
View File
@@ -0,0 +1,102 @@
from __future__ import annotations
import unittest
from fastapi_modules.fastapi_leaudit.rag_engine.retriever import RagRetriever
class FakeCollection:
def query(self, **kwargs):
return {
"ids": [["seg-1", "seg-2"]],
"documents": [["烟草处罚裁量标准内容", "其他来源内容"]],
"metadatas": [[
{
"document_name": "广东省烟草专卖行政处罚裁量执行标准-rag.md",
"chunk_index": 0,
},
{
"document_name": "其他.md",
"chunk_index": 1,
},
]],
"distances": [[0.0, 0.2]],
}
class FallbackCollection:
def query(self, **kwargs):
raise RuntimeError("vector unavailable")
def get(self, **kwargs):
return {
"ids": ["seg-fallback"],
"documents": ["未在当地烟草专卖批发企业进货,对应裁量档次内容"],
"metadatas": [
{
"document_name": "案由_行政处罚与反走私管理治理办法.md",
"chunk_index": 3,
}
],
}
class FakeChroma:
def __init__(self):
self.collection_name = None
def get_or_create_collection(self, name):
self.collection_name = name
return FakeCollection()
class FallbackChroma:
def get_or_create_collection(self, name):
return FallbackCollection()
class RagRetrieverTest(unittest.IsolatedAsyncioTestCase):
async def test_retrieve_from_collection_filters_sources_and_builds_resources(self):
chroma = FakeChroma()
retriever = RagRetriever(
chroma_client=chroma,
embed_texts=lambda texts, model_name="": [[0.1, 0.2] for _ in texts],
hydrate_documents=False,
)
result = await retriever.retrieve(
query="罚款幅度",
collection_name="general_legal_kb",
top_k=5,
source_names=["广东省烟草专卖行政处罚裁量执行标准-rag.md"],
)
self.assertEqual(chroma.collection_name, "general_legal_kb")
self.assertIn("烟草处罚裁量标准内容", result.rag_context)
self.assertNotIn("其他来源内容", result.rag_context)
self.assertEqual(len(result.rag_resources), 1)
self.assertEqual(
result.rag_resources[0]["document_name"],
"广东省烟草专卖行政处罚裁量执行标准-rag.md",
)
async def test_retrieve_uses_keyword_fallback_when_vector_search_fails(self):
retriever = RagRetriever(
chroma_client=FallbackChroma(),
embed_texts=lambda texts, model_name="": [[0.1, 0.2] for _ in texts],
hydrate_documents=False,
)
result = await retriever.retrieve(
query="未在当地烟草专卖批发企业进货",
collection_name="general_legal_kb",
source_names=["案由_行政处罚与反走私管理治理办法.md"],
)
self.assertIn("对应裁量档次内容", result.rag_context)
self.assertEqual(len(result.chunks), 1)
self.assertEqual(result.rag_resources[0]["segment_id"], "seg-fallback")
if __name__ == "__main__":
unittest.main()