402 lines
14 KiB
Python
402 lines
14 KiB
Python
"""RAG 聊天流式引用来源测试。"""
|
|
|
|
import asyncio
|
|
import json
|
|
from unittest.mock import patch
|
|
|
|
from fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl import RagChatServiceImpl
|
|
|
|
|
|
async def _fake_generate_stream(**kwargs):
|
|
yield "data: " + json.dumps(
|
|
{
|
|
"event": "message",
|
|
"task_id": kwargs["task_id"],
|
|
"message_id": kwargs["message_id"],
|
|
"conversation_id": kwargs["conversation_id"],
|
|
"answer": "测试回答",
|
|
},
|
|
ensure_ascii=False,
|
|
) + "\n\n"
|
|
yield "data: " + json.dumps(
|
|
{
|
|
"event": "message_end",
|
|
"task_id": kwargs["task_id"],
|
|
"message_id": kwargs["message_id"],
|
|
"conversation_id": kwargs["conversation_id"],
|
|
"metadata": {},
|
|
},
|
|
ensure_ascii=False,
|
|
) + "\n\n"
|
|
|
|
|
|
async def _run_streaming_task() -> list[dict]:
|
|
service = RagChatServiceImpl()
|
|
task_id = "task-test"
|
|
service._task_events[task_id] = []
|
|
service._task_done[task_id] = False
|
|
service._task_locks[task_id] = asyncio.Lock()
|
|
|
|
async def fake_retrieve_context(dataset_id, query):
|
|
return [
|
|
{
|
|
"dataset_id": dataset_id,
|
|
"document_id": "doc-1",
|
|
"document_name": "引用文档.pdf",
|
|
"source": "引用文档.pdf",
|
|
"id": "segment-1",
|
|
"score": 0.91,
|
|
"hit_count": 3,
|
|
"text": "引用片段内容",
|
|
}
|
|
], "测试知识库"
|
|
|
|
async def noop_async(*args, **kwargs):
|
|
return None
|
|
|
|
service._retrieve_context = fake_retrieve_context
|
|
service._finalize_message_record = noop_async
|
|
service._maybe_schedule_auto_title = noop_async
|
|
|
|
with (
|
|
patch(
|
|
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_stream",
|
|
_fake_generate_stream,
|
|
),
|
|
patch(
|
|
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_followups",
|
|
return_value=[],
|
|
),
|
|
):
|
|
await service._run_message_task(
|
|
task_id=task_id,
|
|
conversation_id="conversation-test",
|
|
message_id="message-test",
|
|
query="测试问题",
|
|
app={"dataset_id": 7},
|
|
)
|
|
|
|
return service._task_events[task_id]
|
|
|
|
|
|
async def _run_streaming_task_with_attachment() -> tuple[list[dict], list[dict]]:
|
|
service = RagChatServiceImpl()
|
|
task_id = "task-attachment-test"
|
|
service._task_events[task_id] = []
|
|
service._task_done[task_id] = False
|
|
service._task_locks[task_id] = asyncio.Lock()
|
|
captured_context_chunks: list[dict] = []
|
|
|
|
async def fake_retrieve_context(dataset_id, query):
|
|
return [
|
|
{
|
|
"dataset_id": dataset_id,
|
|
"document_id": "law-doc-1",
|
|
"document_name": "处罚依据.md",
|
|
"source": "处罚依据.md",
|
|
"id": "law-segment-1",
|
|
"score": 0.88,
|
|
"hit_count": 2,
|
|
"text": "正式知识库中的处罚依据",
|
|
}
|
|
], "正式法规知识库"
|
|
|
|
async def fake_retrieve_attachment_context(**kwargs):
|
|
return [
|
|
{
|
|
"attachment_id": kwargs["AttachmentId"],
|
|
"document_name": "用户上传.docx",
|
|
"source": "用户上传.docx",
|
|
"id": "attachment-segment-1",
|
|
"score": 0.96,
|
|
"text": "上传文档中的违法事实",
|
|
"source_scope": "chat_attachment",
|
|
"data_source_type": "chat_attachment",
|
|
}
|
|
], "用户上传.docx"
|
|
|
|
async def fake_generate_stream(**kwargs):
|
|
captured_context_chunks.extend(kwargs["context_chunks"])
|
|
async for event in _fake_generate_stream(**kwargs):
|
|
yield event
|
|
|
|
async def noop_async(*args, **kwargs):
|
|
return None
|
|
|
|
service._retrieve_context = fake_retrieve_context
|
|
service._retrieve_attachment_context = fake_retrieve_attachment_context
|
|
service._finalize_message_record = noop_async
|
|
service._maybe_schedule_auto_title = noop_async
|
|
|
|
with (
|
|
patch(
|
|
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_stream",
|
|
fake_generate_stream,
|
|
),
|
|
patch(
|
|
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_followups",
|
|
return_value=[],
|
|
),
|
|
):
|
|
await service._run_message_task(
|
|
task_id=task_id,
|
|
conversation_id="conversation-test",
|
|
message_id="message-test",
|
|
query="这份材料的违法内容会受到什么处罚",
|
|
app={"dataset_id": 7},
|
|
current_user_id=100,
|
|
tenant_code="tenant-a",
|
|
attachment_id="attachment-1",
|
|
)
|
|
|
|
return service._task_events[task_id], captured_context_chunks
|
|
|
|
|
|
async def _run_streaming_task_with_conversation_attachment_fallback() -> list[dict]:
|
|
service = RagChatServiceImpl()
|
|
task_id = "task-attachment-fallback-test"
|
|
service._task_events[task_id] = []
|
|
service._task_done[task_id] = False
|
|
service._task_locks[task_id] = asyncio.Lock()
|
|
captured_context_chunks: list[dict] = []
|
|
|
|
async def fake_resolve_attachment_ids_for_conversation(**kwargs):
|
|
assert kwargs["ConversationId"] == "conversation-test"
|
|
assert kwargs["CurrentUserId"] == 100
|
|
return ["attachment-1"]
|
|
|
|
async def fake_retrieve_context(dataset_id, query):
|
|
return [
|
|
{
|
|
"dataset_id": dataset_id,
|
|
"document_id": "law-doc-1",
|
|
"document_name": "处罚依据.md",
|
|
"source": "处罚依据.md",
|
|
"id": "law-segment-1",
|
|
"score": 0.88,
|
|
"hit_count": 2,
|
|
"text": "正式知识库中的处罚依据",
|
|
}
|
|
], "正式法规知识库"
|
|
|
|
async def fake_retrieve_attachment_context(**kwargs):
|
|
return [
|
|
{
|
|
"attachment_id": kwargs["AttachmentId"],
|
|
"document_name": "用户上传.docx",
|
|
"source": "用户上传.docx",
|
|
"id": "attachment-segment-1",
|
|
"score": 0.96,
|
|
"text": "上传文档中的违法事实",
|
|
"source_scope": "chat_attachment",
|
|
"data_source_type": "chat_attachment",
|
|
}
|
|
], "用户上传.docx"
|
|
|
|
async def fake_generate_stream(**kwargs):
|
|
captured_context_chunks.extend(kwargs["context_chunks"])
|
|
async for event in _fake_generate_stream(**kwargs):
|
|
yield event
|
|
|
|
async def noop_async(*args, **kwargs):
|
|
return None
|
|
|
|
service._resolve_attachment_ids_for_conversation = fake_resolve_attachment_ids_for_conversation
|
|
service._retrieve_context = fake_retrieve_context
|
|
service._retrieve_attachment_context = fake_retrieve_attachment_context
|
|
service._finalize_message_record = noop_async
|
|
service._maybe_schedule_auto_title = noop_async
|
|
|
|
with (
|
|
patch(
|
|
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_stream",
|
|
fake_generate_stream,
|
|
),
|
|
patch(
|
|
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_followups",
|
|
return_value=[],
|
|
),
|
|
):
|
|
await service._run_message_task(
|
|
task_id=task_id,
|
|
conversation_id="conversation-test",
|
|
message_id="message-test",
|
|
query="江小妹违法了什么",
|
|
app={"dataset_id": 7},
|
|
current_user_id=100,
|
|
tenant_code="tenant-a",
|
|
attachment_id=None,
|
|
)
|
|
|
|
return captured_context_chunks
|
|
|
|
|
|
async def _run_streaming_task_with_multiple_attachments() -> tuple[list[dict], list[dict]]:
|
|
service = RagChatServiceImpl()
|
|
task_id = "task-multiple-attachments-test"
|
|
service._task_events[task_id] = []
|
|
service._task_done[task_id] = False
|
|
service._task_locks[task_id] = asyncio.Lock()
|
|
captured_context_chunks: list[dict] = []
|
|
|
|
async def fake_retrieve_context(dataset_id, query):
|
|
return [
|
|
{
|
|
"dataset_id": dataset_id,
|
|
"document_id": "law-doc-1",
|
|
"document_name": "处罚依据.md",
|
|
"source": "处罚依据.md",
|
|
"id": "law-segment-1",
|
|
"score": 0.88,
|
|
"hit_count": 2,
|
|
"text": "正式知识库中的处罚依据",
|
|
}
|
|
], "正式法规知识库"
|
|
|
|
async def fake_retrieve_attachment_context(**kwargs):
|
|
attachment_id = kwargs["AttachmentId"]
|
|
return [
|
|
{
|
|
"attachment_id": attachment_id,
|
|
"document_name": f"{attachment_id}.docx",
|
|
"source": f"{attachment_id}.docx",
|
|
"id": f"{attachment_id}-segment-1",
|
|
"score": 0.96,
|
|
"text": f"{attachment_id} 中的违法事实",
|
|
"source_scope": "chat_attachment",
|
|
"data_source_type": "chat_attachment",
|
|
}
|
|
], f"{attachment_id}.docx"
|
|
|
|
async def fake_generate_stream(**kwargs):
|
|
captured_context_chunks.extend(kwargs["context_chunks"])
|
|
async for event in _fake_generate_stream(**kwargs):
|
|
yield event
|
|
|
|
async def noop_async(*args, **kwargs):
|
|
return None
|
|
|
|
service._retrieve_context = fake_retrieve_context
|
|
service._retrieve_attachment_context = fake_retrieve_attachment_context
|
|
service._finalize_message_record = noop_async
|
|
service._maybe_schedule_auto_title = noop_async
|
|
|
|
with (
|
|
patch(
|
|
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_stream",
|
|
fake_generate_stream,
|
|
),
|
|
patch(
|
|
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_followups",
|
|
return_value=[],
|
|
),
|
|
):
|
|
await service._run_message_task(
|
|
task_id=task_id,
|
|
conversation_id="conversation-test",
|
|
message_id="message-test",
|
|
query="这些材料的违法内容会受到什么处罚",
|
|
app={"dataset_id": 7},
|
|
current_user_id=100,
|
|
tenant_code="tenant-a",
|
|
attachment_ids=["attachment-1", "attachment-2"],
|
|
)
|
|
|
|
return service._task_events[task_id], captured_context_chunks
|
|
|
|
|
|
def test_streaming_message_end_includes_retriever_resources():
|
|
events = asyncio.run(_run_streaming_task())
|
|
|
|
message_end = next(event for event in events if event.get("event") == "message_end")
|
|
|
|
resources = message_end["metadata"].get("retriever_resources")
|
|
assert resources
|
|
assert resources[0]["dataset_id"] == "7"
|
|
assert resources[0]["dataset_name"] == "测试知识库"
|
|
assert resources[0]["document_name"] == "引用文档.pdf"
|
|
|
|
|
|
def test_message_task_merges_attachment_facts_and_formal_kb_context():
|
|
events, context_chunks = asyncio.run(_run_streaming_task_with_attachment())
|
|
|
|
message_end = next(event for event in events if event.get("event") == "message_end")
|
|
resources = message_end["metadata"].get("retriever_resources")
|
|
|
|
scopes = [chunk.get("source_scope") for chunk in context_chunks]
|
|
assert scopes == ["chat_attachment", "formal_kb"]
|
|
assert context_chunks[0]["document_name"] == "用户上传.docx"
|
|
assert context_chunks[1]["document_name"] == "处罚依据.md"
|
|
assert resources[0]["data_source_type"] == "chat_attachment"
|
|
assert resources[1]["data_source_type"] == "formal_kb"
|
|
|
|
|
|
def test_message_task_uses_active_conversation_attachment_when_request_omits_attachment_id():
|
|
context_chunks = asyncio.run(_run_streaming_task_with_conversation_attachment_fallback())
|
|
|
|
scopes = [chunk.get("source_scope") for chunk in context_chunks]
|
|
assert scopes == ["chat_attachment", "formal_kb"]
|
|
assert context_chunks[0]["attachment_id"] == "attachment-1"
|
|
|
|
|
|
def test_message_task_merges_multiple_attachment_contexts_before_formal_kb():
|
|
events, context_chunks = asyncio.run(_run_streaming_task_with_multiple_attachments())
|
|
|
|
message_end = next(event for event in events if event.get("event") == "message_end")
|
|
resources = message_end["metadata"].get("retriever_resources")
|
|
|
|
scopes = [chunk.get("source_scope") for chunk in context_chunks]
|
|
assert scopes == ["chat_attachment", "chat_attachment", "formal_kb"]
|
|
assert [chunk.get("attachment_id") for chunk in context_chunks[:2]] == ["attachment-1", "attachment-2"]
|
|
assert resources[0]["data_source_type"] == "chat_attachment"
|
|
assert resources[1]["data_source_type"] == "chat_attachment"
|
|
assert resources[2]["data_source_type"] == "formal_kb"
|
|
|
|
|
|
def test_message_attachment_metadata_is_stored_for_history_display():
|
|
service = RagChatServiceImpl()
|
|
|
|
files = service._build_user_message_files(
|
|
[
|
|
{
|
|
"attachment_id": "att-doc",
|
|
"original_name": "处罚材料.docx",
|
|
"content_type": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
"file_size": 1234,
|
|
},
|
|
{
|
|
"attachment_id": "att-img",
|
|
"original_name": "现场照片.png",
|
|
"content_type": "image/png",
|
|
"file_size": 4567,
|
|
},
|
|
]
|
|
)
|
|
|
|
assert files == [
|
|
{
|
|
"id": "att-doc",
|
|
"upload_file_id": "att-doc",
|
|
"name": "处罚材料.docx",
|
|
"fileName": "处罚材料.docx",
|
|
"type": "file",
|
|
"transfer_method": "local_file",
|
|
"contentType": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
|
"fileSize": 1234,
|
|
"belongs_to": "user",
|
|
"usage": "temporary_attachment",
|
|
},
|
|
{
|
|
"id": "att-img",
|
|
"upload_file_id": "att-img",
|
|
"name": "现场照片.png",
|
|
"fileName": "现场照片.png",
|
|
"type": "image",
|
|
"transfer_method": "local_file",
|
|
"contentType": "image/png",
|
|
"fileSize": 4567,
|
|
"belongs_to": "user",
|
|
"usage": "temporary_attachment",
|
|
},
|
|
]
|