feat(rag): add shared retriever for audit pipeline
This commit is contained in:
@@ -0,0 +1,128 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from leaudit.dsl.schema import Metadata, Rule, RuleAuthoringGroup, RulesFile, Stage
|
||||
from leaudit.engine.models import EvaluationResult
|
||||
from leaudit.extraction.bundle import bundle_from_single
|
||||
from leaudit.extraction.models import ExtractionResult, FieldValue
|
||||
from leaudit.ocr.models import OcrResult, Page
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline import LauditPipeline
|
||||
|
||||
|
||||
class FakeOcrClient:
|
||||
async def ocr(self, file_path):
|
||||
return OcrResult(
|
||||
pages=[Page(page_num=1, text="处罚决定书")],
|
||||
full_text="处罚决定书",
|
||||
)
|
||||
|
||||
|
||||
class FakeStorage:
|
||||
async def update_document_status(self, document_id, status):
|
||||
return None
|
||||
|
||||
async def save_ocr_result(self, document_id, ocr_result):
|
||||
return None
|
||||
|
||||
async def save_extraction_result(self, document_id, extraction_bundle):
|
||||
return None
|
||||
|
||||
async def save_evaluation_results(self, document_id, rules_file, evaluation_result, extraction_bundle):
|
||||
return None
|
||||
|
||||
|
||||
class FakeRetriever:
|
||||
pass
|
||||
|
||||
|
||||
def _rules_file() -> RulesFile:
|
||||
return RulesFile(
|
||||
metadata=Metadata(
|
||||
type_id="test.rag_bridge",
|
||||
name="RAG bridge test",
|
||||
version="1.0",
|
||||
last_updated="2026-05-21",
|
||||
),
|
||||
rules=[
|
||||
RuleAuthoringGroup(
|
||||
group="测试",
|
||||
rules=[
|
||||
Rule(
|
||||
rule_id="R-RAG",
|
||||
name="RAG 规则",
|
||||
risk="medium",
|
||||
score=1,
|
||||
stages=[Stage(id="1", check="required", field="处罚决定书.处罚依据")],
|
||||
logic="1",
|
||||
messages={"pass": "ok", "fail": "missing"},
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class LeauditRagBridgeTest(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_pipeline_passes_injected_rag_retriever_to_evaluation(self):
|
||||
captured = {}
|
||||
retriever = FakeRetriever()
|
||||
field_value = FieldValue(value="依据", confidence=0.9)
|
||||
object.__setattr__(field_value, "position", None)
|
||||
extraction = bundle_from_single(
|
||||
ExtractionResult(
|
||||
fields={"处罚决定书.处罚依据": field_value},
|
||||
source_text="处罚决定书",
|
||||
)
|
||||
)
|
||||
|
||||
async def fake_dispatch_extract(*args, **kwargs):
|
||||
return extraction
|
||||
|
||||
async def fake_determine_phase(*args, **kwargs):
|
||||
return "executed"
|
||||
|
||||
async def fake_evaluate_extraction(*args, **kwargs):
|
||||
captured["retriever"] = kwargs.get("retriever")
|
||||
return EvaluationResult()
|
||||
|
||||
class TestPipeline(LauditPipeline):
|
||||
async def _extract_and_save_case_number(self, document_id, ocr_result):
|
||||
return None
|
||||
|
||||
with tempfile.NamedTemporaryFile() as tmp:
|
||||
pipeline = TestPipeline(
|
||||
ocr_client=FakeOcrClient(),
|
||||
storage_adapter=FakeStorage(),
|
||||
rag_retriever=retriever,
|
||||
)
|
||||
with patch(
|
||||
"fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.dispatch_extract",
|
||||
fake_dispatch_extract,
|
||||
), patch(
|
||||
"fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.determine_phase",
|
||||
fake_determine_phase,
|
||||
), patch(
|
||||
"fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.evaluate_extraction",
|
||||
fake_evaluate_extraction,
|
||||
), patch.dict(
|
||||
sys.modules,
|
||||
{
|
||||
"leaudit.extraction.coordinate_resolver": types.SimpleNamespace(
|
||||
resolve_bundle_positions=lambda *args, **kwargs: None
|
||||
)
|
||||
},
|
||||
):
|
||||
await pipeline.run(1, Path(tmp.name), _rules_file())
|
||||
|
||||
self.assertIs(captured["retriever"], retriever)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.retriever import RagRetriever
|
||||
|
||||
|
||||
class FakeCollection:
|
||||
def query(self, **kwargs):
|
||||
return {
|
||||
"ids": [["seg-1", "seg-2"]],
|
||||
"documents": [["烟草处罚裁量标准内容", "其他来源内容"]],
|
||||
"metadatas": [[
|
||||
{
|
||||
"document_name": "广东省烟草专卖行政处罚裁量执行标准-rag.md",
|
||||
"chunk_index": 0,
|
||||
},
|
||||
{
|
||||
"document_name": "其他.md",
|
||||
"chunk_index": 1,
|
||||
},
|
||||
]],
|
||||
"distances": [[0.0, 0.2]],
|
||||
}
|
||||
|
||||
|
||||
class FallbackCollection:
|
||||
def query(self, **kwargs):
|
||||
raise RuntimeError("vector unavailable")
|
||||
|
||||
def get(self, **kwargs):
|
||||
return {
|
||||
"ids": ["seg-fallback"],
|
||||
"documents": ["未在当地烟草专卖批发企业进货,对应裁量档次内容"],
|
||||
"metadatas": [
|
||||
{
|
||||
"document_name": "案由_行政处罚与反走私管理治理办法.md",
|
||||
"chunk_index": 3,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class FakeChroma:
|
||||
def __init__(self):
|
||||
self.collection_name = None
|
||||
|
||||
def get_or_create_collection(self, name):
|
||||
self.collection_name = name
|
||||
return FakeCollection()
|
||||
|
||||
|
||||
class FallbackChroma:
|
||||
def get_or_create_collection(self, name):
|
||||
return FallbackCollection()
|
||||
|
||||
|
||||
class RagRetrieverTest(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_retrieve_from_collection_filters_sources_and_builds_resources(self):
|
||||
chroma = FakeChroma()
|
||||
retriever = RagRetriever(
|
||||
chroma_client=chroma,
|
||||
embed_texts=lambda texts, model_name="": [[0.1, 0.2] for _ in texts],
|
||||
hydrate_documents=False,
|
||||
)
|
||||
|
||||
result = await retriever.retrieve(
|
||||
query="罚款幅度",
|
||||
collection_name="general_legal_kb",
|
||||
top_k=5,
|
||||
source_names=["广东省烟草专卖行政处罚裁量执行标准-rag.md"],
|
||||
)
|
||||
|
||||
self.assertEqual(chroma.collection_name, "general_legal_kb")
|
||||
self.assertIn("烟草处罚裁量标准内容", result.rag_context)
|
||||
self.assertNotIn("其他来源内容", result.rag_context)
|
||||
self.assertEqual(len(result.rag_resources), 1)
|
||||
self.assertEqual(
|
||||
result.rag_resources[0]["document_name"],
|
||||
"广东省烟草专卖行政处罚裁量执行标准-rag.md",
|
||||
)
|
||||
|
||||
async def test_retrieve_uses_keyword_fallback_when_vector_search_fails(self):
|
||||
retriever = RagRetriever(
|
||||
chroma_client=FallbackChroma(),
|
||||
embed_texts=lambda texts, model_name="": [[0.1, 0.2] for _ in texts],
|
||||
hydrate_documents=False,
|
||||
)
|
||||
|
||||
result = await retriever.retrieve(
|
||||
query="未在当地烟草专卖批发企业进货",
|
||||
collection_name="general_legal_kb",
|
||||
source_names=["案由_行政处罚与反走私管理治理办法.md"],
|
||||
)
|
||||
|
||||
self.assertIn("对应裁量档次内容", result.rag_context)
|
||||
self.assertEqual(len(result.chunks), 1)
|
||||
self.assertEqual(result.rag_resources[0]["segment_id"], "seg-fallback")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user