feat(rag): add temporary chat attachments
This commit is contained in:
@@ -0,0 +1,372 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
|
||||
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
|
||||
|
||||
|
||||
def _service() -> RagChatAttachmentServiceImpl:
|
||||
return RagChatAttachmentServiceImpl(chroma_client=None, embed_texts=lambda texts, model_name="": [[0.1] for _ in texts])
|
||||
|
||||
|
||||
def test_default_expiry_is_seven_days_from_now():
|
||||
service = _service()
|
||||
now = datetime(2026, 5, 25, 12, 0, 0, tzinfo=timezone.utc)
|
||||
|
||||
expires_at = service._default_expires_at(now)
|
||||
|
||||
assert expires_at == now + timedelta(days=7)
|
||||
|
||||
|
||||
def test_collection_name_contains_isolation_components():
|
||||
service = _service()
|
||||
|
||||
collection_name = service.BuildCollectionName(
|
||||
TenantCode="gd-tobacco",
|
||||
UserId=42,
|
||||
ConversationId="conversation-abc-123",
|
||||
AttachmentId="attach-xyz",
|
||||
)
|
||||
|
||||
assert collection_name.startswith("chat_attachment_gd_tobacco_42_")
|
||||
assert collection_name.endswith("_attach_xyz")
|
||||
assert "conversation-abc-123" not in collection_name
|
||||
assert len(collection_name) <= 120
|
||||
|
||||
|
||||
def test_validate_attachment_scope_rejects_other_user_and_conversation():
|
||||
service = _service()
|
||||
record = {
|
||||
"attachment_id": "att-1",
|
||||
"tenant_code": "tenant-a",
|
||||
"user_id": 100,
|
||||
"conversation_id": "conv-a",
|
||||
"indexing_status": "completed",
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(days=1),
|
||||
"deleted_at": None,
|
||||
}
|
||||
|
||||
with pytest.raises(LeauditException) as user_exc:
|
||||
service._assert_attachment_scope(
|
||||
record,
|
||||
CurrentUserId=101,
|
||||
TenantCode="tenant-a",
|
||||
ConversationId="conv-a",
|
||||
RequireCompleted=True,
|
||||
Now=datetime.now(timezone.utc),
|
||||
)
|
||||
assert user_exc.value.status == StatusCodeEnum.HTTP_403_FORBIDDEN
|
||||
|
||||
with pytest.raises(LeauditException) as conversation_exc:
|
||||
service._assert_attachment_scope(
|
||||
record,
|
||||
CurrentUserId=100,
|
||||
TenantCode="tenant-a",
|
||||
ConversationId="conv-b",
|
||||
RequireCompleted=True,
|
||||
Now=datetime.now(timezone.utc),
|
||||
)
|
||||
assert conversation_exc.value.status == StatusCodeEnum.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def test_get_attachment_requires_request_conversation_when_provided():
|
||||
service = _service()
|
||||
record = {
|
||||
"attachment_id": "att-1",
|
||||
"tenant_code": "tenant-a",
|
||||
"user_id": 100,
|
||||
"conversation_id": "conv-a",
|
||||
"indexing_status": "completed",
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(days=1),
|
||||
"deleted_at": None,
|
||||
}
|
||||
|
||||
with pytest.raises(LeauditException) as exc:
|
||||
service._assert_attachment_scope(
|
||||
record,
|
||||
CurrentUserId=100,
|
||||
TenantCode="tenant-a",
|
||||
ConversationId="conv-b",
|
||||
RequireCompleted=False,
|
||||
Now=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
assert exc.value.status == StatusCodeEnum.HTTP_403_FORBIDDEN
|
||||
assert "当前会话" in exc.value.message
|
||||
|
||||
|
||||
def test_get_attachment_rejects_same_user_attachment_from_another_conversation(monkeypatch):
|
||||
service = _service()
|
||||
record = {
|
||||
"attachment_id": "att-1",
|
||||
"tenant_code": "tenant-a",
|
||||
"user_id": 100,
|
||||
"conversation_id": "conv-a",
|
||||
"indexing_status": "completed",
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(days=1),
|
||||
"deleted_at": None,
|
||||
}
|
||||
|
||||
async def fake_get_attachment_record(_attachment_id):
|
||||
return record
|
||||
|
||||
monkeypatch.setattr(service, "_get_attachment_record", fake_get_attachment_record)
|
||||
|
||||
with pytest.raises(LeauditException) as exc:
|
||||
asyncio.run(
|
||||
service.GetAttachment(
|
||||
CurrentUserId=100,
|
||||
UserArea=None,
|
||||
UserRole=None,
|
||||
TenantCode="tenant-a",
|
||||
TenantName=None,
|
||||
ConversationId="conv-b",
|
||||
AttachmentId="att-1",
|
||||
)
|
||||
)
|
||||
|
||||
assert exc.value.status == StatusCodeEnum.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def test_delete_attachment_rejects_same_user_attachment_from_another_conversation(monkeypatch):
|
||||
service = _service()
|
||||
record = {
|
||||
"attachment_id": "att-1",
|
||||
"tenant_code": "tenant-a",
|
||||
"user_id": 100,
|
||||
"conversation_id": "conv-a",
|
||||
"indexing_status": "completed",
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(days=1),
|
||||
"deleted_at": None,
|
||||
}
|
||||
|
||||
async def fake_get_attachment_record(_attachment_id):
|
||||
return record
|
||||
|
||||
monkeypatch.setattr(service, "_get_attachment_record", fake_get_attachment_record)
|
||||
|
||||
with pytest.raises(LeauditException) as exc:
|
||||
asyncio.run(
|
||||
service.DeleteAttachment(
|
||||
CurrentUserId=100,
|
||||
UserArea=None,
|
||||
UserRole=None,
|
||||
TenantCode="tenant-a",
|
||||
TenantName=None,
|
||||
ConversationId="conv-b",
|
||||
AttachmentId="att-1",
|
||||
)
|
||||
)
|
||||
|
||||
assert exc.value.status == StatusCodeEnum.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def test_validate_attachment_scope_rejects_other_tenant():
|
||||
service = _service()
|
||||
record = {
|
||||
"attachment_id": "att-1",
|
||||
"tenant_code": "tenant-a",
|
||||
"user_id": 100,
|
||||
"conversation_id": "conv-a",
|
||||
"indexing_status": "completed",
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(days=1),
|
||||
"deleted_at": None,
|
||||
}
|
||||
|
||||
with pytest.raises(LeauditException) as exc:
|
||||
service._assert_attachment_scope(
|
||||
record,
|
||||
CurrentUserId=100,
|
||||
TenantCode="tenant-b",
|
||||
ConversationId="conv-a",
|
||||
RequireCompleted=True,
|
||||
Now=datetime.now(timezone.utc),
|
||||
)
|
||||
|
||||
assert exc.value.status == StatusCodeEnum.HTTP_403_FORBIDDEN
|
||||
|
||||
|
||||
def test_validate_attachment_scope_rejects_expired_or_incomplete_attachment():
|
||||
service = _service()
|
||||
expired = {
|
||||
"attachment_id": "att-1",
|
||||
"tenant_code": "tenant-a",
|
||||
"user_id": 100,
|
||||
"conversation_id": "conv-a",
|
||||
"indexing_status": "completed",
|
||||
"expires_at": datetime.now(timezone.utc) - timedelta(seconds=1),
|
||||
"deleted_at": None,
|
||||
}
|
||||
waiting = {
|
||||
**expired,
|
||||
"indexing_status": "indexing",
|
||||
"expires_at": datetime.now(timezone.utc) + timedelta(days=1),
|
||||
}
|
||||
|
||||
with pytest.raises(LeauditException) as expired_exc:
|
||||
service._assert_attachment_scope(
|
||||
expired,
|
||||
CurrentUserId=100,
|
||||
TenantCode="tenant-a",
|
||||
ConversationId="conv-a",
|
||||
RequireCompleted=True,
|
||||
Now=datetime.now(timezone.utc),
|
||||
)
|
||||
assert expired_exc.value.status == StatusCodeEnum.HTTP_400_BAD_REQUEST
|
||||
assert "已过期" in expired_exc.value.message
|
||||
|
||||
with pytest.raises(LeauditException) as waiting_exc:
|
||||
service._assert_attachment_scope(
|
||||
waiting,
|
||||
CurrentUserId=100,
|
||||
TenantCode="tenant-a",
|
||||
ConversationId="conv-a",
|
||||
RequireCompleted=True,
|
||||
Now=datetime.now(timezone.utc),
|
||||
)
|
||||
assert waiting_exc.value.status == StatusCodeEnum.HTTP_400_BAD_REQUEST
|
||||
assert "尚未完成" in waiting_exc.value.message
|
||||
|
||||
|
||||
def test_build_chunks_includes_isolation_metadata():
|
||||
service = _service()
|
||||
|
||||
chunks = service.BuildChunks(
|
||||
AttachmentId="att-1",
|
||||
TenantCode="tenant-a",
|
||||
UserId=100,
|
||||
ConversationId="conv-a",
|
||||
FileName="处罚材料.docx",
|
||||
PageTexts=[(1, "第一段违法事实。\n\n第二段处罚线索。")],
|
||||
ChunkMaxSize=20,
|
||||
ChunkOverlap=0,
|
||||
)
|
||||
|
||||
assert chunks
|
||||
metadata = chunks[0]["metadata"]
|
||||
assert metadata["tenant_code"] == "tenant-a"
|
||||
assert metadata["user_id"] == "100"
|
||||
assert metadata["conversation_id"] == "conv-a"
|
||||
assert metadata["attachment_id"] == "att-1"
|
||||
assert metadata["source_scope"] == "chat_attachment"
|
||||
assert metadata["document_name"] == "处罚材料.docx"
|
||||
assert metadata["page"] == 1
|
||||
|
||||
|
||||
def test_resolve_active_attachment_id_uses_user_conversation_tenant_and_completed_state(monkeypatch):
|
||||
service = _service()
|
||||
captured_sql = {}
|
||||
captured_params = {}
|
||||
|
||||
class FakeResult:
|
||||
def mappings(self):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return {"attachment_id": "att-active"}
|
||||
|
||||
def all(self):
|
||||
return [{"attachment_id": "att-active"}]
|
||||
|
||||
class FakeSession:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def execute(self, statement, params=None):
|
||||
captured_sql["value"] = str(statement)
|
||||
captured_params.update(params or {})
|
||||
return FakeResult()
|
||||
|
||||
class FakeSessionFactory:
|
||||
def __call__(self):
|
||||
return FakeSession()
|
||||
|
||||
service.__class__._attachment_schema_checked = True
|
||||
monkeypatch.setattr(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl.GetAsyncSession",
|
||||
FakeSessionFactory(),
|
||||
)
|
||||
|
||||
attachment_id = asyncio.run(
|
||||
service.ResolveActiveAttachmentIdForConversation(
|
||||
CurrentUserId=100,
|
||||
TenantCode="tenant-a",
|
||||
UserArea="云浮",
|
||||
ConversationId="conv-a",
|
||||
)
|
||||
)
|
||||
|
||||
assert attachment_id == "att-active"
|
||||
assert "user_id = :user_id" in captured_sql["value"]
|
||||
assert "conversation_id = :conversation_id" in captured_sql["value"]
|
||||
assert "indexing_status = 'completed'" in captured_sql["value"]
|
||||
assert "expires_at > NOW()" in captured_sql["value"]
|
||||
assert "deleted_at IS NULL" in captured_sql["value"]
|
||||
assert captured_params == {
|
||||
"user_id": 100,
|
||||
"conversation_id": "conv-a",
|
||||
"tenant_code": "tenant-a",
|
||||
"user_area": "云浮",
|
||||
}
|
||||
|
||||
|
||||
def test_resolve_active_attachment_ids_returns_all_completed_conversation_attachments(monkeypatch):
|
||||
service = _service()
|
||||
captured_sql = {}
|
||||
captured_params = {}
|
||||
|
||||
class FakeResult:
|
||||
def mappings(self):
|
||||
return self
|
||||
|
||||
def all(self):
|
||||
return [{"attachment_id": "att-1"}, {"attachment_id": "att-2"}]
|
||||
|
||||
class FakeSession:
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
async def execute(self, statement, params=None):
|
||||
captured_sql["value"] = str(statement)
|
||||
captured_params.update(params or {})
|
||||
return FakeResult()
|
||||
|
||||
class FakeSessionFactory:
|
||||
def __call__(self):
|
||||
return FakeSession()
|
||||
|
||||
service.__class__._attachment_schema_checked = True
|
||||
monkeypatch.setattr(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl.GetAsyncSession",
|
||||
FakeSessionFactory(),
|
||||
)
|
||||
|
||||
attachment_ids = asyncio.run(
|
||||
service.ResolveActiveAttachmentIdsForConversation(
|
||||
CurrentUserId=100,
|
||||
TenantCode="tenant-a",
|
||||
UserArea="云浮",
|
||||
ConversationId="conv-a",
|
||||
)
|
||||
)
|
||||
|
||||
assert attachment_ids == ["att-1", "att-2"]
|
||||
assert "ORDER BY indexing_completed_at DESC NULLS LAST, created_at DESC" in captured_sql["value"]
|
||||
assert "LIMIT 1" not in captured_sql["value"]
|
||||
assert captured_params == {
|
||||
"user_id": 100,
|
||||
"conversation_id": "conv-a",
|
||||
"tenant_code": "tenant-a",
|
||||
"user_area": "云浮",
|
||||
}
|
||||
@@ -79,6 +79,232 @@ async def _run_streaming_task() -> list[dict]:
|
||||
return service._task_events[task_id]
|
||||
|
||||
|
||||
async def _run_streaming_task_with_attachment() -> tuple[list[dict], list[dict]]:
|
||||
service = RagChatServiceImpl()
|
||||
task_id = "task-attachment-test"
|
||||
service._task_events[task_id] = []
|
||||
service._task_done[task_id] = False
|
||||
service._task_locks[task_id] = asyncio.Lock()
|
||||
captured_context_chunks: list[dict] = []
|
||||
|
||||
async def fake_retrieve_context(dataset_id, query):
|
||||
return [
|
||||
{
|
||||
"dataset_id": dataset_id,
|
||||
"document_id": "law-doc-1",
|
||||
"document_name": "处罚依据.md",
|
||||
"source": "处罚依据.md",
|
||||
"id": "law-segment-1",
|
||||
"score": 0.88,
|
||||
"hit_count": 2,
|
||||
"text": "正式知识库中的处罚依据",
|
||||
}
|
||||
], "正式法规知识库"
|
||||
|
||||
async def fake_retrieve_attachment_context(**kwargs):
|
||||
return [
|
||||
{
|
||||
"attachment_id": kwargs["AttachmentId"],
|
||||
"document_name": "用户上传.docx",
|
||||
"source": "用户上传.docx",
|
||||
"id": "attachment-segment-1",
|
||||
"score": 0.96,
|
||||
"text": "上传文档中的违法事实",
|
||||
"source_scope": "chat_attachment",
|
||||
"data_source_type": "chat_attachment",
|
||||
}
|
||||
], "用户上传.docx"
|
||||
|
||||
async def fake_generate_stream(**kwargs):
|
||||
captured_context_chunks.extend(kwargs["context_chunks"])
|
||||
async for event in _fake_generate_stream(**kwargs):
|
||||
yield event
|
||||
|
||||
async def noop_async(*args, **kwargs):
|
||||
return None
|
||||
|
||||
service._retrieve_context = fake_retrieve_context
|
||||
service._retrieve_attachment_context = fake_retrieve_attachment_context
|
||||
service._finalize_message_record = noop_async
|
||||
service._maybe_schedule_auto_title = noop_async
|
||||
|
||||
with (
|
||||
patch(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_stream",
|
||||
fake_generate_stream,
|
||||
),
|
||||
patch(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_followups",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
await service._run_message_task(
|
||||
task_id=task_id,
|
||||
conversation_id="conversation-test",
|
||||
message_id="message-test",
|
||||
query="这份材料的违法内容会受到什么处罚",
|
||||
app={"dataset_id": 7},
|
||||
current_user_id=100,
|
||||
tenant_code="tenant-a",
|
||||
attachment_id="attachment-1",
|
||||
)
|
||||
|
||||
return service._task_events[task_id], captured_context_chunks
|
||||
|
||||
|
||||
async def _run_streaming_task_with_conversation_attachment_fallback() -> list[dict]:
|
||||
service = RagChatServiceImpl()
|
||||
task_id = "task-attachment-fallback-test"
|
||||
service._task_events[task_id] = []
|
||||
service._task_done[task_id] = False
|
||||
service._task_locks[task_id] = asyncio.Lock()
|
||||
captured_context_chunks: list[dict] = []
|
||||
|
||||
async def fake_resolve_attachment_ids_for_conversation(**kwargs):
|
||||
assert kwargs["ConversationId"] == "conversation-test"
|
||||
assert kwargs["CurrentUserId"] == 100
|
||||
return ["attachment-1"]
|
||||
|
||||
async def fake_retrieve_context(dataset_id, query):
|
||||
return [
|
||||
{
|
||||
"dataset_id": dataset_id,
|
||||
"document_id": "law-doc-1",
|
||||
"document_name": "处罚依据.md",
|
||||
"source": "处罚依据.md",
|
||||
"id": "law-segment-1",
|
||||
"score": 0.88,
|
||||
"hit_count": 2,
|
||||
"text": "正式知识库中的处罚依据",
|
||||
}
|
||||
], "正式法规知识库"
|
||||
|
||||
async def fake_retrieve_attachment_context(**kwargs):
|
||||
return [
|
||||
{
|
||||
"attachment_id": kwargs["AttachmentId"],
|
||||
"document_name": "用户上传.docx",
|
||||
"source": "用户上传.docx",
|
||||
"id": "attachment-segment-1",
|
||||
"score": 0.96,
|
||||
"text": "上传文档中的违法事实",
|
||||
"source_scope": "chat_attachment",
|
||||
"data_source_type": "chat_attachment",
|
||||
}
|
||||
], "用户上传.docx"
|
||||
|
||||
async def fake_generate_stream(**kwargs):
|
||||
captured_context_chunks.extend(kwargs["context_chunks"])
|
||||
async for event in _fake_generate_stream(**kwargs):
|
||||
yield event
|
||||
|
||||
async def noop_async(*args, **kwargs):
|
||||
return None
|
||||
|
||||
service._resolve_attachment_ids_for_conversation = fake_resolve_attachment_ids_for_conversation
|
||||
service._retrieve_context = fake_retrieve_context
|
||||
service._retrieve_attachment_context = fake_retrieve_attachment_context
|
||||
service._finalize_message_record = noop_async
|
||||
service._maybe_schedule_auto_title = noop_async
|
||||
|
||||
with (
|
||||
patch(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_stream",
|
||||
fake_generate_stream,
|
||||
),
|
||||
patch(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_followups",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
await service._run_message_task(
|
||||
task_id=task_id,
|
||||
conversation_id="conversation-test",
|
||||
message_id="message-test",
|
||||
query="江小妹违法了什么",
|
||||
app={"dataset_id": 7},
|
||||
current_user_id=100,
|
||||
tenant_code="tenant-a",
|
||||
attachment_id=None,
|
||||
)
|
||||
|
||||
return captured_context_chunks
|
||||
|
||||
|
||||
async def _run_streaming_task_with_multiple_attachments() -> tuple[list[dict], list[dict]]:
|
||||
service = RagChatServiceImpl()
|
||||
task_id = "task-multiple-attachments-test"
|
||||
service._task_events[task_id] = []
|
||||
service._task_done[task_id] = False
|
||||
service._task_locks[task_id] = asyncio.Lock()
|
||||
captured_context_chunks: list[dict] = []
|
||||
|
||||
async def fake_retrieve_context(dataset_id, query):
|
||||
return [
|
||||
{
|
||||
"dataset_id": dataset_id,
|
||||
"document_id": "law-doc-1",
|
||||
"document_name": "处罚依据.md",
|
||||
"source": "处罚依据.md",
|
||||
"id": "law-segment-1",
|
||||
"score": 0.88,
|
||||
"hit_count": 2,
|
||||
"text": "正式知识库中的处罚依据",
|
||||
}
|
||||
], "正式法规知识库"
|
||||
|
||||
async def fake_retrieve_attachment_context(**kwargs):
|
||||
attachment_id = kwargs["AttachmentId"]
|
||||
return [
|
||||
{
|
||||
"attachment_id": attachment_id,
|
||||
"document_name": f"{attachment_id}.docx",
|
||||
"source": f"{attachment_id}.docx",
|
||||
"id": f"{attachment_id}-segment-1",
|
||||
"score": 0.96,
|
||||
"text": f"{attachment_id} 中的违法事实",
|
||||
"source_scope": "chat_attachment",
|
||||
"data_source_type": "chat_attachment",
|
||||
}
|
||||
], f"{attachment_id}.docx"
|
||||
|
||||
async def fake_generate_stream(**kwargs):
|
||||
captured_context_chunks.extend(kwargs["context_chunks"])
|
||||
async for event in _fake_generate_stream(**kwargs):
|
||||
yield event
|
||||
|
||||
async def noop_async(*args, **kwargs):
|
||||
return None
|
||||
|
||||
service._retrieve_context = fake_retrieve_context
|
||||
service._retrieve_attachment_context = fake_retrieve_attachment_context
|
||||
service._finalize_message_record = noop_async
|
||||
service._maybe_schedule_auto_title = noop_async
|
||||
|
||||
with (
|
||||
patch(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_stream",
|
||||
fake_generate_stream,
|
||||
),
|
||||
patch(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_followups",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
await service._run_message_task(
|
||||
task_id=task_id,
|
||||
conversation_id="conversation-test",
|
||||
message_id="message-test",
|
||||
query="这些材料的违法内容会受到什么处罚",
|
||||
app={"dataset_id": 7},
|
||||
current_user_id=100,
|
||||
tenant_code="tenant-a",
|
||||
attachment_ids=["attachment-1", "attachment-2"],
|
||||
)
|
||||
|
||||
return service._task_events[task_id], captured_context_chunks
|
||||
|
||||
|
||||
def test_streaming_message_end_includes_retriever_resources():
|
||||
events = asyncio.run(_run_streaming_task())
|
||||
|
||||
@@ -89,3 +315,87 @@ def test_streaming_message_end_includes_retriever_resources():
|
||||
assert resources[0]["dataset_id"] == "7"
|
||||
assert resources[0]["dataset_name"] == "测试知识库"
|
||||
assert resources[0]["document_name"] == "引用文档.pdf"
|
||||
|
||||
|
||||
def test_message_task_merges_attachment_facts_and_formal_kb_context():
|
||||
events, context_chunks = asyncio.run(_run_streaming_task_with_attachment())
|
||||
|
||||
message_end = next(event for event in events if event.get("event") == "message_end")
|
||||
resources = message_end["metadata"].get("retriever_resources")
|
||||
|
||||
scopes = [chunk.get("source_scope") for chunk in context_chunks]
|
||||
assert scopes == ["chat_attachment", "formal_kb"]
|
||||
assert context_chunks[0]["document_name"] == "用户上传.docx"
|
||||
assert context_chunks[1]["document_name"] == "处罚依据.md"
|
||||
assert resources[0]["data_source_type"] == "chat_attachment"
|
||||
assert resources[1]["data_source_type"] == "formal_kb"
|
||||
|
||||
|
||||
def test_message_task_uses_active_conversation_attachment_when_request_omits_attachment_id():
|
||||
context_chunks = asyncio.run(_run_streaming_task_with_conversation_attachment_fallback())
|
||||
|
||||
scopes = [chunk.get("source_scope") for chunk in context_chunks]
|
||||
assert scopes == ["chat_attachment", "formal_kb"]
|
||||
assert context_chunks[0]["attachment_id"] == "attachment-1"
|
||||
|
||||
|
||||
def test_message_task_merges_multiple_attachment_contexts_before_formal_kb():
|
||||
events, context_chunks = asyncio.run(_run_streaming_task_with_multiple_attachments())
|
||||
|
||||
message_end = next(event for event in events if event.get("event") == "message_end")
|
||||
resources = message_end["metadata"].get("retriever_resources")
|
||||
|
||||
scopes = [chunk.get("source_scope") for chunk in context_chunks]
|
||||
assert scopes == ["chat_attachment", "chat_attachment", "formal_kb"]
|
||||
assert [chunk.get("attachment_id") for chunk in context_chunks[:2]] == ["attachment-1", "attachment-2"]
|
||||
assert resources[0]["data_source_type"] == "chat_attachment"
|
||||
assert resources[1]["data_source_type"] == "chat_attachment"
|
||||
assert resources[2]["data_source_type"] == "formal_kb"
|
||||
|
||||
|
||||
def test_message_attachment_metadata_is_stored_for_history_display():
|
||||
service = RagChatServiceImpl()
|
||||
|
||||
files = service._build_user_message_files(
|
||||
[
|
||||
{
|
||||
"attachment_id": "att-doc",
|
||||
"original_name": "处罚材料.docx",
|
||||
"content_type": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"file_size": 1234,
|
||||
},
|
||||
{
|
||||
"attachment_id": "att-img",
|
||||
"original_name": "现场照片.png",
|
||||
"content_type": "image/png",
|
||||
"file_size": 4567,
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
assert files == [
|
||||
{
|
||||
"id": "att-doc",
|
||||
"upload_file_id": "att-doc",
|
||||
"name": "处罚材料.docx",
|
||||
"fileName": "处罚材料.docx",
|
||||
"type": "file",
|
||||
"transfer_method": "local_file",
|
||||
"contentType": "application/vnd.openxmlformats-officedocument.wordprocessingml.document",
|
||||
"fileSize": 1234,
|
||||
"belongs_to": "user",
|
||||
"usage": "temporary_attachment",
|
||||
},
|
||||
{
|
||||
"id": "att-img",
|
||||
"upload_file_id": "att-img",
|
||||
"name": "现场照片.png",
|
||||
"fileName": "现场照片.png",
|
||||
"type": "image",
|
||||
"transfer_method": "local_file",
|
||||
"contentType": "image/png",
|
||||
"fileSize": 4567,
|
||||
"belongs_to": "user",
|
||||
"usage": "temporary_attachment",
|
||||
},
|
||||
]
|
||||
|
||||
Reference in New Issue
Block a user