103 lines
2.9 KiB
Python
103 lines
2.9 KiB
Python
from __future__ import annotations
|
|
|
|
import unittest
|
|
|
|
from fastapi_modules.fastapi_leaudit.rag_engine.retriever import RagRetriever
|
|
|
|
|
|
class FakeCollection:
|
|
def query(self, **kwargs):
|
|
return {
|
|
"ids": [["seg-1", "seg-2"]],
|
|
"documents": [["烟草处罚裁量标准内容", "其他来源内容"]],
|
|
"metadatas": [[
|
|
{
|
|
"document_name": "广东省烟草专卖行政处罚裁量执行标准-rag.md",
|
|
"chunk_index": 0,
|
|
},
|
|
{
|
|
"document_name": "其他.md",
|
|
"chunk_index": 1,
|
|
},
|
|
]],
|
|
"distances": [[0.0, 0.2]],
|
|
}
|
|
|
|
|
|
class FallbackCollection:
|
|
def query(self, **kwargs):
|
|
raise RuntimeError("vector unavailable")
|
|
|
|
def get(self, **kwargs):
|
|
return {
|
|
"ids": ["seg-fallback"],
|
|
"documents": ["未在当地烟草专卖批发企业进货,对应裁量档次内容"],
|
|
"metadatas": [
|
|
{
|
|
"document_name": "案由_行政处罚与反走私管理治理办法.md",
|
|
"chunk_index": 3,
|
|
}
|
|
],
|
|
}
|
|
|
|
|
|
class FakeChroma:
|
|
def __init__(self):
|
|
self.collection_name = None
|
|
|
|
def get_or_create_collection(self, name):
|
|
self.collection_name = name
|
|
return FakeCollection()
|
|
|
|
|
|
class FallbackChroma:
|
|
def get_or_create_collection(self, name):
|
|
return FallbackCollection()
|
|
|
|
|
|
class RagRetrieverTest(unittest.IsolatedAsyncioTestCase):
|
|
async def test_retrieve_from_collection_filters_sources_and_builds_resources(self):
|
|
chroma = FakeChroma()
|
|
retriever = RagRetriever(
|
|
chroma_client=chroma,
|
|
embed_texts=lambda texts, model_name="": [[0.1, 0.2] for _ in texts],
|
|
hydrate_documents=False,
|
|
)
|
|
|
|
result = await retriever.retrieve(
|
|
query="罚款幅度",
|
|
collection_name="general_legal_kb",
|
|
top_k=5,
|
|
source_names=["广东省烟草专卖行政处罚裁量执行标准-rag.md"],
|
|
)
|
|
|
|
self.assertEqual(chroma.collection_name, "general_legal_kb")
|
|
self.assertIn("烟草处罚裁量标准内容", result.rag_context)
|
|
self.assertNotIn("其他来源内容", result.rag_context)
|
|
self.assertEqual(len(result.rag_resources), 1)
|
|
self.assertEqual(
|
|
result.rag_resources[0]["document_name"],
|
|
"广东省烟草专卖行政处罚裁量执行标准-rag.md",
|
|
)
|
|
|
|
async def test_retrieve_uses_keyword_fallback_when_vector_search_fails(self):
|
|
retriever = RagRetriever(
|
|
chroma_client=FallbackChroma(),
|
|
embed_texts=lambda texts, model_name="": [[0.1, 0.2] for _ in texts],
|
|
hydrate_documents=False,
|
|
)
|
|
|
|
result = await retriever.retrieve(
|
|
query="未在当地烟草专卖批发企业进货",
|
|
collection_name="general_legal_kb",
|
|
source_names=["案由_行政处罚与反走私管理治理办法.md"],
|
|
)
|
|
|
|
self.assertIn("对应裁量档次内容", result.rag_context)
|
|
self.assertEqual(len(result.chunks), 1)
|
|
self.assertEqual(result.rag_resources[0]["segment_id"], "seg-fallback")
|
|
|
|
|
|
if __name__ == "__main__":
|
|
unittest.main()
|