"""RAG 聊天流式引用来源测试。""" import asyncio import json from unittest.mock import patch from fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl import RagChatServiceImpl async def _fake_generate_stream(**kwargs): yield "data: " + json.dumps( { "event": "message", "task_id": kwargs["task_id"], "message_id": kwargs["message_id"], "conversation_id": kwargs["conversation_id"], "answer": "测试回答", }, ensure_ascii=False, ) + "\n\n" yield "data: " + json.dumps( { "event": "message_end", "task_id": kwargs["task_id"], "message_id": kwargs["message_id"], "conversation_id": kwargs["conversation_id"], "metadata": {}, }, ensure_ascii=False, ) + "\n\n" async def _run_streaming_task() -> list[dict]: service = RagChatServiceImpl() task_id = "task-test" service._task_events[task_id] = [] service._task_done[task_id] = False service._task_locks[task_id] = asyncio.Lock() async def fake_retrieve_context(dataset_id, query): return [ { "dataset_id": dataset_id, "document_id": "doc-1", "document_name": "引用文档.pdf", "source": "引用文档.pdf", "id": "segment-1", "score": 0.91, "hit_count": 3, "text": "引用片段内容", } ], "测试知识库" async def noop_async(*args, **kwargs): return None service._retrieve_context = fake_retrieve_context service._finalize_message_record = noop_async service._maybe_schedule_auto_title = noop_async with ( patch( "fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_stream", _fake_generate_stream, ), patch( "fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_followups", return_value=[], ), ): await service._run_message_task( task_id=task_id, conversation_id="conversation-test", message_id="message-test", query="测试问题", app={"dataset_id": 7}, ) return service._task_events[task_id] async def _run_streaming_task_with_attachment() -> tuple[list[dict], list[dict]]: service = RagChatServiceImpl() task_id = "task-attachment-test" service._task_events[task_id] = [] service._task_done[task_id] = False service._task_locks[task_id] = asyncio.Lock() captured_context_chunks: list[dict] = [] async def fake_retrieve_context(dataset_id, query): return [ { "dataset_id": dataset_id, "document_id": "law-doc-1", "document_name": "处罚依据.md", "source": "处罚依据.md", "id": "law-segment-1", "score": 0.88, "hit_count": 2, "text": "正式知识库中的处罚依据", } ], "正式法规知识库" async def fake_retrieve_attachment_context(**kwargs): return [ { "attachment_id": kwargs["AttachmentId"], "document_name": "用户上传.docx", "source": "用户上传.docx", "id": "attachment-segment-1", "score": 0.96, "text": "上传文档中的违法事实", "source_scope": "chat_attachment", "data_source_type": "chat_attachment", } ], "用户上传.docx" async def fake_generate_stream(**kwargs): captured_context_chunks.extend(kwargs["context_chunks"]) async for event in _fake_generate_stream(**kwargs): yield event async def noop_async(*args, **kwargs): return None service._retrieve_context = fake_retrieve_context service._retrieve_attachment_context = fake_retrieve_attachment_context service._finalize_message_record = noop_async service._maybe_schedule_auto_title = noop_async with ( patch( "fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_stream", fake_generate_stream, ), patch( "fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_followups", return_value=[], ), ): await service._run_message_task( task_id=task_id, conversation_id="conversation-test", message_id="message-test", query="这份材料的违法内容会受到什么处罚", app={"dataset_id": 7}, current_user_id=100, tenant_code="tenant-a", attachment_id="attachment-1", ) return service._task_events[task_id], captured_context_chunks async def _run_streaming_task_with_conversation_attachment_fallback() -> list[dict]: service = RagChatServiceImpl() task_id = "task-attachment-fallback-test" service._task_events[task_id] = [] service._task_done[task_id] = False service._task_locks[task_id] = asyncio.Lock() captured_context_chunks: list[dict] = [] async def fake_resolve_attachment_ids_for_conversation(**kwargs): assert kwargs["ConversationId"] == "conversation-test" assert kwargs["CurrentUserId"] == 100 return ["attachment-1"] async def fake_retrieve_context(dataset_id, query): return [ { "dataset_id": dataset_id, "document_id": "law-doc-1", "document_name": "处罚依据.md", "source": "处罚依据.md", "id": "law-segment-1", "score": 0.88, "hit_count": 2, "text": "正式知识库中的处罚依据", } ], "正式法规知识库" async def fake_retrieve_attachment_context(**kwargs): return [ { "attachment_id": kwargs["AttachmentId"], "document_name": "用户上传.docx", "source": "用户上传.docx", "id": "attachment-segment-1", "score": 0.96, "text": "上传文档中的违法事实", "source_scope": "chat_attachment", "data_source_type": "chat_attachment", } ], "用户上传.docx" async def fake_generate_stream(**kwargs): captured_context_chunks.extend(kwargs["context_chunks"]) async for event in _fake_generate_stream(**kwargs): yield event async def noop_async(*args, **kwargs): return None service._resolve_attachment_ids_for_conversation = fake_resolve_attachment_ids_for_conversation service._retrieve_context = fake_retrieve_context service._retrieve_attachment_context = fake_retrieve_attachment_context service._finalize_message_record = noop_async service._maybe_schedule_auto_title = noop_async with ( patch( "fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_stream", fake_generate_stream, ), patch( "fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_followups", return_value=[], ), ): await service._run_message_task( task_id=task_id, conversation_id="conversation-test", message_id="message-test", query="江小妹违法了什么", app={"dataset_id": 7}, current_user_id=100, tenant_code="tenant-a", attachment_id=None, ) return captured_context_chunks async def _run_streaming_task_with_multiple_attachments() -> tuple[list[dict], list[dict]]: service = RagChatServiceImpl() task_id = "task-multiple-attachments-test" service._task_events[task_id] = [] service._task_done[task_id] = False service._task_locks[task_id] = asyncio.Lock() captured_context_chunks: list[dict] = [] async def fake_retrieve_context(dataset_id, query): return [ { "dataset_id": dataset_id, "document_id": "law-doc-1", "document_name": "处罚依据.md", "source": "处罚依据.md", "id": "law-segment-1", "score": 0.88, "hit_count": 2, "text": "正式知识库中的处罚依据", } ], "正式法规知识库" async def fake_retrieve_attachment_context(**kwargs): attachment_id = kwargs["AttachmentId"] return [ { "attachment_id": attachment_id, "document_name": f"{attachment_id}.docx", "source": f"{attachment_id}.docx", "id": f"{attachment_id}-segment-1", "score": 0.96, "text": f"{attachment_id} 中的违法事实", "source_scope": "chat_attachment", "data_source_type": "chat_attachment", } ], f"{attachment_id}.docx" async def fake_generate_stream(**kwargs): captured_context_chunks.extend(kwargs["context_chunks"]) async for event in _fake_generate_stream(**kwargs): yield event async def noop_async(*args, **kwargs): return None service._retrieve_context = fake_retrieve_context service._retrieve_attachment_context = fake_retrieve_attachment_context service._finalize_message_record = noop_async service._maybe_schedule_auto_title = noop_async with ( patch( "fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_stream", fake_generate_stream, ), patch( "fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_followups", return_value=[], ), ): await service._run_message_task( task_id=task_id, conversation_id="conversation-test", message_id="message-test", query="这些材料的违法内容会受到什么处罚", app={"dataset_id": 7}, current_user_id=100, tenant_code="tenant-a", attachment_ids=["attachment-1", "attachment-2"], ) return service._task_events[task_id], captured_context_chunks def test_streaming_message_end_includes_retriever_resources(): events = asyncio.run(_run_streaming_task()) message_end = next(event for event in events if event.get("event") == "message_end") resources = message_end["metadata"].get("retriever_resources") assert resources assert resources[0]["dataset_id"] == "7" assert resources[0]["dataset_name"] == "测试知识库" assert resources[0]["document_name"] == "引用文档.pdf" def test_message_task_merges_attachment_facts_and_formal_kb_context(): events, context_chunks = asyncio.run(_run_streaming_task_with_attachment()) message_end = next(event for event in events if event.get("event") == "message_end") resources = message_end["metadata"].get("retriever_resources") scopes = [chunk.get("source_scope") for chunk in context_chunks] assert scopes == ["chat_attachment", "formal_kb"] assert context_chunks[0]["document_name"] == "用户上传.docx" assert context_chunks[1]["document_name"] == "处罚依据.md" assert resources[0]["data_source_type"] == "chat_attachment" assert resources[1]["data_source_type"] == "formal_kb" def test_message_task_uses_active_conversation_attachment_when_request_omits_attachment_id(): context_chunks = asyncio.run(_run_streaming_task_with_conversation_attachment_fallback()) scopes = [chunk.get("source_scope") for chunk in context_chunks] assert scopes == ["chat_attachment", "formal_kb"] assert context_chunks[0]["attachment_id"] == "attachment-1" def test_message_task_merges_multiple_attachment_contexts_before_formal_kb(): events, context_chunks = asyncio.run(_run_streaming_task_with_multiple_attachments()) message_end = next(event for event in events if event.get("event") == "message_end") resources = message_end["metadata"].get("retriever_resources") scopes = [chunk.get("source_scope") for chunk in context_chunks] assert scopes == ["chat_attachment", "chat_attachment", "formal_kb"] assert [chunk.get("attachment_id") for chunk in context_chunks[:2]] == ["attachment-1", "attachment-2"] assert resources[0]["data_source_type"] == "chat_attachment" assert resources[1]["data_source_type"] == "chat_attachment" assert resources[2]["data_source_type"] == "formal_kb" def test_message_attachment_metadata_is_stored_for_history_display(): service = RagChatServiceImpl() files = service._build_user_message_files( [ { "attachment_id": "att-doc", "original_name": "处罚材料.docx", "content_type": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "file_size": 1234, }, { "attachment_id": "att-img", "original_name": "现场照片.png", "content_type": "image/png", "file_size": 4567, }, ] ) assert files == [ { "id": "att-doc", "upload_file_id": "att-doc", "name": "处罚材料.docx", "fileName": "处罚材料.docx", "type": "file", "transfer_method": "local_file", "contentType": "application/vnd.openxmlformats-officedocument.wordprocessingml.document", "fileSize": 1234, "belongs_to": "user", "usage": "temporary_attachment", }, { "id": "att-img", "upload_file_id": "att-img", "name": "现场照片.png", "fileName": "现场照片.png", "type": "image", "transfer_method": "local_file", "contentType": "image/png", "fileSize": 4567, "belongs_to": "user", "usage": "temporary_attachment", }, ]