feat(rag): add shared retriever for audit pipeline
This commit is contained in:
@@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.retriever import RagRetriever
|
||||
|
||||
|
||||
class FakeCollection:
|
||||
def query(self, **kwargs):
|
||||
return {
|
||||
"ids": [["seg-1", "seg-2"]],
|
||||
"documents": [["烟草处罚裁量标准内容", "其他来源内容"]],
|
||||
"metadatas": [[
|
||||
{
|
||||
"document_name": "广东省烟草专卖行政处罚裁量执行标准-rag.md",
|
||||
"chunk_index": 0,
|
||||
},
|
||||
{
|
||||
"document_name": "其他.md",
|
||||
"chunk_index": 1,
|
||||
},
|
||||
]],
|
||||
"distances": [[0.0, 0.2]],
|
||||
}
|
||||
|
||||
|
||||
class FallbackCollection:
|
||||
def query(self, **kwargs):
|
||||
raise RuntimeError("vector unavailable")
|
||||
|
||||
def get(self, **kwargs):
|
||||
return {
|
||||
"ids": ["seg-fallback"],
|
||||
"documents": ["未在当地烟草专卖批发企业进货,对应裁量档次内容"],
|
||||
"metadatas": [
|
||||
{
|
||||
"document_name": "案由_行政处罚与反走私管理治理办法.md",
|
||||
"chunk_index": 3,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class FakeChroma:
|
||||
def __init__(self):
|
||||
self.collection_name = None
|
||||
|
||||
def get_or_create_collection(self, name):
|
||||
self.collection_name = name
|
||||
return FakeCollection()
|
||||
|
||||
|
||||
class FallbackChroma:
|
||||
def get_or_create_collection(self, name):
|
||||
return FallbackCollection()
|
||||
|
||||
|
||||
class RagRetrieverTest(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_retrieve_from_collection_filters_sources_and_builds_resources(self):
|
||||
chroma = FakeChroma()
|
||||
retriever = RagRetriever(
|
||||
chroma_client=chroma,
|
||||
embed_texts=lambda texts, model_name="": [[0.1, 0.2] for _ in texts],
|
||||
hydrate_documents=False,
|
||||
)
|
||||
|
||||
result = await retriever.retrieve(
|
||||
query="罚款幅度",
|
||||
collection_name="general_legal_kb",
|
||||
top_k=5,
|
||||
source_names=["广东省烟草专卖行政处罚裁量执行标准-rag.md"],
|
||||
)
|
||||
|
||||
self.assertEqual(chroma.collection_name, "general_legal_kb")
|
||||
self.assertIn("烟草处罚裁量标准内容", result.rag_context)
|
||||
self.assertNotIn("其他来源内容", result.rag_context)
|
||||
self.assertEqual(len(result.rag_resources), 1)
|
||||
self.assertEqual(
|
||||
result.rag_resources[0]["document_name"],
|
||||
"广东省烟草专卖行政处罚裁量执行标准-rag.md",
|
||||
)
|
||||
|
||||
async def test_retrieve_uses_keyword_fallback_when_vector_search_fails(self):
|
||||
retriever = RagRetriever(
|
||||
chroma_client=FallbackChroma(),
|
||||
embed_texts=lambda texts, model_name="": [[0.1, 0.2] for _ in texts],
|
||||
hydrate_documents=False,
|
||||
)
|
||||
|
||||
result = await retriever.retrieve(
|
||||
query="未在当地烟草专卖批发企业进货",
|
||||
collection_name="general_legal_kb",
|
||||
source_names=["案由_行政处罚与反走私管理治理办法.md"],
|
||||
)
|
||||
|
||||
self.assertIn("对应裁量档次内容", result.rag_context)
|
||||
self.assertEqual(len(result.chunks), 1)
|
||||
self.assertEqual(result.rag_resources[0]["segment_id"], "seg-fallback")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user