fix: restore rag chat permission feedback
This commit is contained in:
@@ -0,0 +1,91 @@
|
||||
"""RAG 聊天流式引用来源测试。"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl import RagChatServiceImpl
|
||||
|
||||
|
||||
async def _fake_generate_stream(**kwargs):
|
||||
yield "data: " + json.dumps(
|
||||
{
|
||||
"event": "message",
|
||||
"task_id": kwargs["task_id"],
|
||||
"message_id": kwargs["message_id"],
|
||||
"conversation_id": kwargs["conversation_id"],
|
||||
"answer": "测试回答",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
) + "\n\n"
|
||||
yield "data: " + json.dumps(
|
||||
{
|
||||
"event": "message_end",
|
||||
"task_id": kwargs["task_id"],
|
||||
"message_id": kwargs["message_id"],
|
||||
"conversation_id": kwargs["conversation_id"],
|
||||
"metadata": {},
|
||||
},
|
||||
ensure_ascii=False,
|
||||
) + "\n\n"
|
||||
|
||||
|
||||
async def _run_streaming_task() -> list[dict]:
|
||||
service = RagChatServiceImpl()
|
||||
task_id = "task-test"
|
||||
service._task_events[task_id] = []
|
||||
service._task_done[task_id] = False
|
||||
service._task_locks[task_id] = asyncio.Lock()
|
||||
|
||||
async def fake_retrieve_context(dataset_id, query):
|
||||
return [
|
||||
{
|
||||
"dataset_id": dataset_id,
|
||||
"document_id": "doc-1",
|
||||
"document_name": "引用文档.pdf",
|
||||
"source": "引用文档.pdf",
|
||||
"id": "segment-1",
|
||||
"score": 0.91,
|
||||
"hit_count": 3,
|
||||
"text": "引用片段内容",
|
||||
}
|
||||
], "测试知识库"
|
||||
|
||||
async def noop_async(*args, **kwargs):
|
||||
return None
|
||||
|
||||
service._retrieve_context = fake_retrieve_context
|
||||
service._finalize_message_record = noop_async
|
||||
service._maybe_schedule_auto_title = noop_async
|
||||
|
||||
with (
|
||||
patch(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_stream",
|
||||
_fake_generate_stream,
|
||||
),
|
||||
patch(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_followups",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
await service._run_message_task(
|
||||
task_id=task_id,
|
||||
conversation_id="conversation-test",
|
||||
message_id="message-test",
|
||||
query="测试问题",
|
||||
app={"dataset_id": 7},
|
||||
)
|
||||
|
||||
return service._task_events[task_id]
|
||||
|
||||
|
||||
def test_streaming_message_end_includes_retriever_resources():
|
||||
events = asyncio.run(_run_streaming_task())
|
||||
|
||||
message_end = next(event for event in events if event.get("event") == "message_end")
|
||||
|
||||
resources = message_end["metadata"].get("retriever_resources")
|
||||
assert resources
|
||||
assert resources[0]["dataset_id"] == "7"
|
||||
assert resources[0]["dataset_name"] == "测试知识库"
|
||||
assert resources[0]["document_name"] == "引用文档.pdf"
|
||||
Reference in New Issue
Block a user