Files
leaudit-platform-backend/tests/test_rag_retriever.py
T
2026-05-22 18:49:57 +08:00

103 lines
3.3 KiB
Python

from __future__ import annotations
import unittest
from fastapi_modules.fastapi_leaudit.rag_engine.retriever import RagRetriever
class FakeCollection:
def query(self, **kwargs):
return {
"ids": [["seg-1", "seg-2"]],
"documents": [["烟草处罚裁量标准内容", "其他来源内容"]],
"metadatas": [[
{
"document_name": "广东省烟草专卖行政处罚裁量执行标准-rag.md",
"chunk_index": 0,
},
{
"document_name": "其他.md",
"chunk_index": 1,
},
]],
"distances": [[0.0, 0.2]],
}
class FallbackCollection:
def query(self, **kwargs):
raise RuntimeError("vector unavailable")
def get(self, **kwargs):
return {
"ids": ["seg-fallback"],
"documents": ["未在当地烟草专卖批发企业进货,对应裁量档次内容"],
"metadatas": [
{
"document_name": "案由_行政处罚与反走私管理治理办法.md",
"chunk_index": 3,
}
],
}
class FakeChroma:
def __init__(self):
self.collection_name = None
def get_or_create_collection(self, name):
self.collection_name = name
return FakeCollection()
class FallbackChroma:
def get_or_create_collection(self, name):
return FallbackCollection()
class RagRetrieverTest(unittest.IsolatedAsyncioTestCase):
async def test_retrieve_from_collection_filters_sources_and_builds_resources(self):
chroma = FakeChroma()
retriever = RagRetriever(
chroma_client=chroma,
embed_texts=lambda texts, model_name="": [[0.1, 0.2] for _ in texts],
hydrate_documents=False,
)
result = await retriever.retrieve(
query="罚款幅度",
collection_name="general_legal_kb",
top_k=5,
source_names=["广东省烟草专卖行政处罚裁量执行标准-rag.md"],
)
self.assertEqual(chroma.collection_name, "general_legal_kb")
self.assertIn("烟草处罚裁量标准内容", result.rag_context)
self.assertNotIn("其他来源内容", result.rag_context)
self.assertEqual(len(result.rag_resources), 1)
self.assertEqual(
result.rag_resources[0]["document_name"],
"广东省烟草专卖行政处罚裁量执行标准-rag.md",
)
async def test_retrieve_uses_keyword_fallback_when_vector_search_fails(self):
retriever = RagRetriever(
chroma_client=FallbackChroma(),
embed_texts=lambda texts, model_name="": [[0.1, 0.2] for _ in texts],
hydrate_documents=False,
)
result = await retriever.retrieve(
query="未在当地烟草专卖批发企业进货",
collection_name="general_legal_kb",
source_names=["案由_行政处罚与反走私管理治理办法.md"],
)
self.assertIn("对应裁量档次内容", result.rag_context)
self.assertEqual(len(result.chunks), 1)
self.assertEqual(result.rag_resources[0]["segment_id"], "seg-fallback")
if __name__ == "__main__":
unittest.main()