Files
leaudit-platform-backend/tests/test_rag_dataset_defaults.py
T
2026-05-22 15:36:13 +08:00

128 lines
4.0 KiB
Python

"""RAG 知识库默认标记测试。"""
import asyncio
from unittest.mock import patch
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
from fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl import RagDatasetServiceImpl
class _ScalarResult:
def __init__(self, value):
self.value = value
def scalar_one(self):
return self.value
class _ExecuteResult:
def mappings(self):
return self
def first(self):
return {"id": 1}
class _FakeSession:
def __init__(self, other_default_count: int):
self.other_default_count = other_default_count
self.executed_sql = []
self.executed_params = []
async def execute(self, sql, params=None):
sql_text = str(sql)
self.executed_sql.append(sql_text)
self.executed_params.append(params or {})
if "COUNT(1)" in sql_text and "is_default = TRUE" in sql_text:
return _ScalarResult(self.other_default_count)
return _ExecuteResult()
class _FakeSessionContext:
def __init__(self, session):
self.session = session
async def __aenter__(self):
return self.session
async def __aexit__(self, exc_type, exc, tb):
return None
async def _update_default_flag_with_other_count(other_default_count: int):
service = RagDatasetServiceImpl()
fake_session = _FakeSession(other_default_count=other_default_count)
existing_row = {
"id": 10,
"name": "旧默认知识库",
"description": "",
"area": "梅州",
"tenant_code": "MZ",
"tenant_name": "梅州",
"is_public": False,
"is_default": True,
"status": 1,
"document_count": 0,
"total_chunks": 0,
"chunk_max_size": 800,
"chunk_min_size": 20,
"sort_order": 0,
"retrieval_model": {},
"created_at": None,
"updated_at": None,
}
async def fake_get_dataset_row(dataset_id):
return existing_row
async def fake_resolve_tenant_context(**kwargs):
return {"tenant_code": "MZ", "tenant_name": "梅州", "tenant_type": "CUSTOM", "area": "梅州"}
async def fake_resolve_dataset_area_input(**kwargs):
return "梅州", "MZ", object()
async def noop_async(*args, **kwargs):
return None
service._get_dataset_row = fake_get_dataset_row
service._resolve_tenant_context = fake_resolve_tenant_context
service._resolve_dataset_area_input = fake_resolve_dataset_area_input
service._assert_manage_area_scope = noop_async
service._ensure_rag_tenant_schema = noop_async
service._ensure_linked_app = noop_async
with patch(
"fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl.GetAsyncSession",
return_value=_FakeSessionContext(fake_session),
):
result = await service.UpdateAdminDataset(
CurrentUserId=1,
UserArea="梅州",
UserRole="super_admin",
TenantCode="MZ",
TenantName="梅州",
DatasetId=10,
Body={"is_default": False},
)
return result, fake_session
def test_update_admin_dataset_allows_unsetting_default_when_same_tenant_has_another_default():
result, fake_session = asyncio.run(_update_default_flag_with_other_count(other_default_count=1))
assert result is not None
assert any("UPDATE rag_dataset" in sql_text for sql_text in fake_session.executed_sql)
assert any("tenant_code = :tenant_code" in sql_text for sql_text in fake_session.executed_sql)
assert any(params.get("tenant_code") == "MZ" for params in fake_session.executed_params)
def test_update_admin_dataset_rejects_unsetting_only_default_in_tenant():
try:
asyncio.run(_update_default_flag_with_other_count(other_default_count=0))
assert False, "expected LeauditException"
except LeauditException as exc:
assert exc.status == StatusCodeEnum.HTTP_400_BAD_REQUEST