128 lines
4.0 KiB
Python
128 lines
4.0 KiB
Python
"""RAG 知识库默认标记测试。"""
|
|
|
|
import asyncio
|
|
from unittest.mock import patch
|
|
|
|
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
|
|
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
|
|
|
|
from fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl import RagDatasetServiceImpl
|
|
|
|
|
|
class _ScalarResult:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
def scalar_one(self):
|
|
return self.value
|
|
|
|
|
|
class _ExecuteResult:
|
|
def mappings(self):
|
|
return self
|
|
|
|
def first(self):
|
|
return {"id": 1}
|
|
|
|
|
|
class _FakeSession:
|
|
def __init__(self, other_default_count: int):
|
|
self.other_default_count = other_default_count
|
|
self.executed_sql = []
|
|
self.executed_params = []
|
|
|
|
async def execute(self, sql, params=None):
|
|
sql_text = str(sql)
|
|
self.executed_sql.append(sql_text)
|
|
self.executed_params.append(params or {})
|
|
if "COUNT(1)" in sql_text and "is_default = TRUE" in sql_text:
|
|
return _ScalarResult(self.other_default_count)
|
|
return _ExecuteResult()
|
|
|
|
|
|
class _FakeSessionContext:
|
|
def __init__(self, session):
|
|
self.session = session
|
|
|
|
async def __aenter__(self):
|
|
return self.session
|
|
|
|
async def __aexit__(self, exc_type, exc, tb):
|
|
return None
|
|
|
|
|
|
async def _update_default_flag_with_other_count(other_default_count: int):
|
|
service = RagDatasetServiceImpl()
|
|
fake_session = _FakeSession(other_default_count=other_default_count)
|
|
existing_row = {
|
|
"id": 10,
|
|
"name": "旧默认知识库",
|
|
"description": "",
|
|
"area": "梅州",
|
|
"tenant_code": "MZ",
|
|
"tenant_name": "梅州",
|
|
"is_public": False,
|
|
"is_default": True,
|
|
"status": 1,
|
|
"document_count": 0,
|
|
"total_chunks": 0,
|
|
"chunk_max_size": 800,
|
|
"chunk_min_size": 20,
|
|
"sort_order": 0,
|
|
"retrieval_model": {},
|
|
"created_at": None,
|
|
"updated_at": None,
|
|
}
|
|
|
|
async def fake_get_dataset_row(dataset_id):
|
|
return existing_row
|
|
|
|
async def fake_resolve_tenant_context(**kwargs):
|
|
return {"tenant_code": "MZ", "tenant_name": "梅州", "tenant_type": "CUSTOM", "area": "梅州"}
|
|
|
|
async def fake_resolve_dataset_area_input(**kwargs):
|
|
return "梅州", "MZ", object()
|
|
|
|
async def noop_async(*args, **kwargs):
|
|
return None
|
|
|
|
service._get_dataset_row = fake_get_dataset_row
|
|
service._resolve_tenant_context = fake_resolve_tenant_context
|
|
service._resolve_dataset_area_input = fake_resolve_dataset_area_input
|
|
service._assert_manage_area_scope = noop_async
|
|
service._ensure_rag_tenant_schema = noop_async
|
|
service._ensure_linked_app = noop_async
|
|
|
|
with patch(
|
|
"fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl.GetAsyncSession",
|
|
return_value=_FakeSessionContext(fake_session),
|
|
):
|
|
result = await service.UpdateAdminDataset(
|
|
CurrentUserId=1,
|
|
UserArea="梅州",
|
|
UserRole="super_admin",
|
|
TenantCode="MZ",
|
|
TenantName="梅州",
|
|
DatasetId=10,
|
|
Body={"is_default": False},
|
|
)
|
|
|
|
return result, fake_session
|
|
|
|
|
|
def test_update_admin_dataset_allows_unsetting_default_when_same_tenant_has_another_default():
|
|
result, fake_session = asyncio.run(_update_default_flag_with_other_count(other_default_count=1))
|
|
|
|
assert result is not None
|
|
assert any("UPDATE rag_dataset" in sql_text for sql_text in fake_session.executed_sql)
|
|
assert any("tenant_code = :tenant_code" in sql_text for sql_text in fake_session.executed_sql)
|
|
assert any(params.get("tenant_code") == "MZ" for params in fake_session.executed_params)
|
|
|
|
|
|
def test_update_admin_dataset_rejects_unsetting_only_default_in_tenant():
|
|
try:
|
|
asyncio.run(_update_default_flag_with_other_count(other_default_count=0))
|
|
assert False, "expected LeauditException"
|
|
except LeauditException as exc:
|
|
assert exc.status == StatusCodeEnum.HTTP_400_BAD_REQUEST
|