from __future__ import annotations import unittest from fastapi_modules.fastapi_leaudit.rag_engine.retriever import RagRetriever class FakeCollection: def query(self, **kwargs): return { "ids": [["seg-1", "seg-2"]], "documents": [["烟草处罚裁量标准内容", "其他来源内容"]], "metadatas": [[ { "document_name": "广东省烟草专卖行政处罚裁量执行标准-rag.md", "chunk_index": 0, }, { "document_name": "其他.md", "chunk_index": 1, }, ]], "distances": [[0.0, 0.2]], } class FallbackCollection: def query(self, **kwargs): raise RuntimeError("vector unavailable") def get(self, **kwargs): return { "ids": ["seg-fallback"], "documents": ["未在当地烟草专卖批发企业进货,对应裁量档次内容"], "metadatas": [ { "document_name": "案由_行政处罚与反走私管理治理办法.md", "chunk_index": 3, } ], } class FakeChroma: def __init__(self): self.collection_name = None def get_or_create_collection(self, name): self.collection_name = name return FakeCollection() class FallbackChroma: def get_or_create_collection(self, name): return FallbackCollection() class RagRetrieverTest(unittest.IsolatedAsyncioTestCase): async def test_retrieve_from_collection_filters_sources_and_builds_resources(self): chroma = FakeChroma() retriever = RagRetriever( chroma_client=chroma, embed_texts=lambda texts, model_name="": [[0.1, 0.2] for _ in texts], hydrate_documents=False, ) result = await retriever.retrieve( query="罚款幅度", collection_name="general_legal_kb", top_k=5, source_names=["广东省烟草专卖行政处罚裁量执行标准-rag.md"], ) self.assertEqual(chroma.collection_name, "general_legal_kb") self.assertIn("烟草处罚裁量标准内容", result.rag_context) self.assertNotIn("其他来源内容", result.rag_context) self.assertEqual(len(result.rag_resources), 1) self.assertEqual( result.rag_resources[0]["document_name"], "广东省烟草专卖行政处罚裁量执行标准-rag.md", ) async def test_retrieve_uses_keyword_fallback_when_vector_search_fails(self): retriever = RagRetriever( chroma_client=FallbackChroma(), embed_texts=lambda texts, model_name="": [[0.1, 0.2] for _ in texts], hydrate_documents=False, ) result = await retriever.retrieve( query="未在当地烟草专卖批发企业进货", collection_name="general_legal_kb", source_names=["案由_行政处罚与反走私管理治理办法.md"], ) self.assertIn("对应裁量档次内容", result.rag_context) self.assertEqual(len(result.chunks), 1) self.assertEqual(result.rag_resources[0]["segment_id"], "seg-fallback") if __name__ == "__main__": unittest.main()