fix: restore rag chat permission feedback
This commit is contained in:
@@ -1,3 +1,5 @@
|
||||
"""RAG 聊天服务实现。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
@@ -49,6 +51,8 @@ MANUAL_TITLE_SOURCE = "manual"
|
||||
|
||||
|
||||
class RagChatServiceImpl(IRagChatService):
|
||||
"""RAG 聊天服务实现。"""
|
||||
|
||||
_message_tasks: dict[str, asyncio.Task] = {}
|
||||
_task_events: dict[str, list[dict]] = {}
|
||||
_task_done: dict[str, bool] = {}
|
||||
@@ -1143,14 +1147,17 @@ class RagChatServiceImpl(IRagChatService):
|
||||
except Exception:
|
||||
followups = []
|
||||
|
||||
sources = self._build_sources(context_chunks, dataset_name)
|
||||
if message_end_payload:
|
||||
message_end_payload.setdefault("metadata", {})["suggested_questions"] = followups
|
||||
message_end_metadata = message_end_payload.setdefault("metadata", {})
|
||||
message_end_metadata["suggested_questions"] = followups
|
||||
message_end_metadata["retriever_resources"] = sources
|
||||
await self._append_task_event(task_id, message_end_payload)
|
||||
await self._finalize_message_record(
|
||||
conversation_id=conversation_id,
|
||||
message_id=message_id,
|
||||
content=collected_answer,
|
||||
sources=self._build_sources(context_chunks, dataset_name),
|
||||
sources=sources,
|
||||
metadata={"suggested_questions": followups, "status": "completed", "task_id": task_id},
|
||||
)
|
||||
await self._maybe_schedule_auto_title(
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
"""RAG 知识库服务实现。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
@@ -42,6 +44,8 @@ from fastapi_modules.fastapi_leaudit.services.impl.tenantResolver import TenantR
|
||||
|
||||
|
||||
class RagDatasetServiceImpl(IRagDatasetService):
|
||||
"""RAG 知识库服务实现。"""
|
||||
|
||||
_ACTIVE_INDEXING_STATUSES = {"waiting", "parsing", "cleaning", "splitting", "indexing"}
|
||||
_DELETABLE_DOCUMENT_STATUSES = {"completed", "error", "paused"}
|
||||
_APP_LINK_SQL = """
|
||||
@@ -295,7 +299,32 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
||||
if target_is_default:
|
||||
await self._clear_default_flags(session, tenant_code=resolved_tenant_code)
|
||||
elif existing.get("is_default") and Body.get("is_default") is False:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "默认知识库不能直接取消,请先将其他知识库设为默认")
|
||||
normalized_tenant_code = str(resolved_tenant_code or "").strip()
|
||||
default_filters = [
|
||||
"deleted_at IS NULL",
|
||||
"is_default = TRUE",
|
||||
"id <> :dataset_id",
|
||||
]
|
||||
default_params = {"dataset_id": DatasetId}
|
||||
if normalized_tenant_code:
|
||||
default_filters.append("tenant_code = :tenant_code")
|
||||
default_params["tenant_code"] = normalized_tenant_code
|
||||
else:
|
||||
default_filters.append("(tenant_code IS NULL OR BTRIM(tenant_code) = '')")
|
||||
other_default_count = (
|
||||
await session.execute(
|
||||
text(
|
||||
f"""
|
||||
SELECT COUNT(1)
|
||||
FROM rag_dataset
|
||||
WHERE {" AND ".join(default_filters)}
|
||||
"""
|
||||
),
|
||||
default_params,
|
||||
)
|
||||
).scalar_one()
|
||||
if int(other_default_count or 0) <= 0:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "默认知识库不能直接取消,请先将其他知识库设为默认")
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
|
||||
+1
-1
Submodule legal-platform-frontend updated: 7932689d08...20c9d4872a
@@ -0,0 +1,91 @@
|
||||
"""RAG 聊天流式引用来源测试。"""
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl import RagChatServiceImpl
|
||||
|
||||
|
||||
async def _fake_generate_stream(**kwargs):
|
||||
yield "data: " + json.dumps(
|
||||
{
|
||||
"event": "message",
|
||||
"task_id": kwargs["task_id"],
|
||||
"message_id": kwargs["message_id"],
|
||||
"conversation_id": kwargs["conversation_id"],
|
||||
"answer": "测试回答",
|
||||
},
|
||||
ensure_ascii=False,
|
||||
) + "\n\n"
|
||||
yield "data: " + json.dumps(
|
||||
{
|
||||
"event": "message_end",
|
||||
"task_id": kwargs["task_id"],
|
||||
"message_id": kwargs["message_id"],
|
||||
"conversation_id": kwargs["conversation_id"],
|
||||
"metadata": {},
|
||||
},
|
||||
ensure_ascii=False,
|
||||
) + "\n\n"
|
||||
|
||||
|
||||
async def _run_streaming_task() -> list[dict]:
|
||||
service = RagChatServiceImpl()
|
||||
task_id = "task-test"
|
||||
service._task_events[task_id] = []
|
||||
service._task_done[task_id] = False
|
||||
service._task_locks[task_id] = asyncio.Lock()
|
||||
|
||||
async def fake_retrieve_context(dataset_id, query):
|
||||
return [
|
||||
{
|
||||
"dataset_id": dataset_id,
|
||||
"document_id": "doc-1",
|
||||
"document_name": "引用文档.pdf",
|
||||
"source": "引用文档.pdf",
|
||||
"id": "segment-1",
|
||||
"score": 0.91,
|
||||
"hit_count": 3,
|
||||
"text": "引用片段内容",
|
||||
}
|
||||
], "测试知识库"
|
||||
|
||||
async def noop_async(*args, **kwargs):
|
||||
return None
|
||||
|
||||
service._retrieve_context = fake_retrieve_context
|
||||
service._finalize_message_record = noop_async
|
||||
service._maybe_schedule_auto_title = noop_async
|
||||
|
||||
with (
|
||||
patch(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_stream",
|
||||
_fake_generate_stream,
|
||||
),
|
||||
patch(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl.generate_followups",
|
||||
return_value=[],
|
||||
),
|
||||
):
|
||||
await service._run_message_task(
|
||||
task_id=task_id,
|
||||
conversation_id="conversation-test",
|
||||
message_id="message-test",
|
||||
query="测试问题",
|
||||
app={"dataset_id": 7},
|
||||
)
|
||||
|
||||
return service._task_events[task_id]
|
||||
|
||||
|
||||
def test_streaming_message_end_includes_retriever_resources():
|
||||
events = asyncio.run(_run_streaming_task())
|
||||
|
||||
message_end = next(event for event in events if event.get("event") == "message_end")
|
||||
|
||||
resources = message_end["metadata"].get("retriever_resources")
|
||||
assert resources
|
||||
assert resources[0]["dataset_id"] == "7"
|
||||
assert resources[0]["dataset_name"] == "测试知识库"
|
||||
assert resources[0]["document_name"] == "引用文档.pdf"
|
||||
@@ -0,0 +1,127 @@
|
||||
"""RAG 知识库默认标记测试。"""
|
||||
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
|
||||
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl import RagDatasetServiceImpl
|
||||
|
||||
|
||||
class _ScalarResult:
|
||||
def __init__(self, value):
|
||||
self.value = value
|
||||
|
||||
def scalar_one(self):
|
||||
return self.value
|
||||
|
||||
|
||||
class _ExecuteResult:
|
||||
def mappings(self):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return {"id": 1}
|
||||
|
||||
|
||||
class _FakeSession:
|
||||
def __init__(self, other_default_count: int):
|
||||
self.other_default_count = other_default_count
|
||||
self.executed_sql = []
|
||||
self.executed_params = []
|
||||
|
||||
async def execute(self, sql, params=None):
|
||||
sql_text = str(sql)
|
||||
self.executed_sql.append(sql_text)
|
||||
self.executed_params.append(params or {})
|
||||
if "COUNT(1)" in sql_text and "is_default = TRUE" in sql_text:
|
||||
return _ScalarResult(self.other_default_count)
|
||||
return _ExecuteResult()
|
||||
|
||||
|
||||
class _FakeSessionContext:
|
||||
def __init__(self, session):
|
||||
self.session = session
|
||||
|
||||
async def __aenter__(self):
|
||||
return self.session
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
|
||||
async def _update_default_flag_with_other_count(other_default_count: int):
|
||||
service = RagDatasetServiceImpl()
|
||||
fake_session = _FakeSession(other_default_count=other_default_count)
|
||||
existing_row = {
|
||||
"id": 10,
|
||||
"name": "旧默认知识库",
|
||||
"description": "",
|
||||
"area": "梅州",
|
||||
"tenant_code": "MZ",
|
||||
"tenant_name": "梅州",
|
||||
"is_public": False,
|
||||
"is_default": True,
|
||||
"status": 1,
|
||||
"document_count": 0,
|
||||
"total_chunks": 0,
|
||||
"chunk_max_size": 800,
|
||||
"chunk_min_size": 20,
|
||||
"sort_order": 0,
|
||||
"retrieval_model": {},
|
||||
"created_at": None,
|
||||
"updated_at": None,
|
||||
}
|
||||
|
||||
async def fake_get_dataset_row(dataset_id):
|
||||
return existing_row
|
||||
|
||||
async def fake_resolve_tenant_context(**kwargs):
|
||||
return {"tenant_code": "MZ", "tenant_name": "梅州", "tenant_type": "CUSTOM", "area": "梅州"}
|
||||
|
||||
async def fake_resolve_dataset_area_input(**kwargs):
|
||||
return "梅州", "MZ", object()
|
||||
|
||||
async def noop_async(*args, **kwargs):
|
||||
return None
|
||||
|
||||
service._get_dataset_row = fake_get_dataset_row
|
||||
service._resolve_tenant_context = fake_resolve_tenant_context
|
||||
service._resolve_dataset_area_input = fake_resolve_dataset_area_input
|
||||
service._assert_manage_area_scope = noop_async
|
||||
service._ensure_rag_tenant_schema = noop_async
|
||||
service._ensure_linked_app = noop_async
|
||||
|
||||
with patch(
|
||||
"fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl.GetAsyncSession",
|
||||
return_value=_FakeSessionContext(fake_session),
|
||||
):
|
||||
result = await service.UpdateAdminDataset(
|
||||
CurrentUserId=1,
|
||||
UserArea="梅州",
|
||||
UserRole="super_admin",
|
||||
TenantCode="MZ",
|
||||
TenantName="梅州",
|
||||
DatasetId=10,
|
||||
Body={"is_default": False},
|
||||
)
|
||||
|
||||
return result, fake_session
|
||||
|
||||
|
||||
def test_update_admin_dataset_allows_unsetting_default_when_same_tenant_has_another_default():
|
||||
result, fake_session = asyncio.run(_update_default_flag_with_other_count(other_default_count=1))
|
||||
|
||||
assert result is not None
|
||||
assert any("UPDATE rag_dataset" in sql_text for sql_text in fake_session.executed_sql)
|
||||
assert any("tenant_code = :tenant_code" in sql_text for sql_text in fake_session.executed_sql)
|
||||
assert any(params.get("tenant_code") == "MZ" for params in fake_session.executed_params)
|
||||
|
||||
|
||||
def test_update_admin_dataset_rejects_unsetting_only_default_in_tenant():
|
||||
try:
|
||||
asyncio.run(_update_default_flag_with_other_count(other_default_count=0))
|
||||
assert False, "expected LeauditException"
|
||||
except LeauditException as exc:
|
||||
assert exc.status == StatusCodeEnum.HTTP_400_BAD_REQUEST
|
||||
Reference in New Issue
Block a user