"""RAG 知识库默认标记测试。""" import asyncio from unittest.mock import patch from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException from fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl import RagDatasetServiceImpl class _ScalarResult: def __init__(self, value): self.value = value def scalar_one(self): return self.value class _ExecuteResult: def mappings(self): return self def first(self): return {"id": 1} class _FakeSession: def __init__(self, other_default_count: int): self.other_default_count = other_default_count self.executed_sql = [] self.executed_params = [] async def execute(self, sql, params=None): sql_text = str(sql) self.executed_sql.append(sql_text) self.executed_params.append(params or {}) if "COUNT(1)" in sql_text and "is_default = TRUE" in sql_text: return _ScalarResult(self.other_default_count) return _ExecuteResult() class _FakeSessionContext: def __init__(self, session): self.session = session async def __aenter__(self): return self.session async def __aexit__(self, exc_type, exc, tb): return None async def _update_default_flag_with_other_count(other_default_count: int): service = RagDatasetServiceImpl() fake_session = _FakeSession(other_default_count=other_default_count) existing_row = { "id": 10, "name": "旧默认知识库", "description": "", "area": "梅州", "tenant_code": "MZ", "tenant_name": "梅州", "is_public": False, "is_default": True, "status": 1, "document_count": 0, "total_chunks": 0, "chunk_max_size": 800, "chunk_min_size": 20, "sort_order": 0, "retrieval_model": {}, "created_at": None, "updated_at": None, } async def fake_get_dataset_row(dataset_id): return existing_row async def fake_resolve_tenant_context(**kwargs): return {"tenant_code": "MZ", "tenant_name": "梅州", "tenant_type": "CUSTOM", "area": "梅州"} async def fake_resolve_dataset_area_input(**kwargs): return "梅州", "MZ", object() async def noop_async(*args, **kwargs): return None service._get_dataset_row = fake_get_dataset_row service._resolve_tenant_context = fake_resolve_tenant_context service._resolve_dataset_area_input = fake_resolve_dataset_area_input service._assert_manage_area_scope = noop_async service._ensure_rag_tenant_schema = noop_async service._ensure_linked_app = noop_async with patch( "fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl.GetAsyncSession", return_value=_FakeSessionContext(fake_session), ): result = await service.UpdateAdminDataset( CurrentUserId=1, UserArea="梅州", UserRole="super_admin", TenantCode="MZ", TenantName="梅州", DatasetId=10, Body={"is_default": False}, ) return result, fake_session def test_update_admin_dataset_allows_unsetting_default_when_same_tenant_has_another_default(): result, fake_session = asyncio.run(_update_default_flag_with_other_count(other_default_count=1)) assert result is not None assert any("UPDATE rag_dataset" in sql_text for sql_text in fake_session.executed_sql) assert any("tenant_code = :tenant_code" in sql_text for sql_text in fake_session.executed_sql) assert any(params.get("tenant_code") == "MZ" for params in fake_session.executed_params) def test_update_admin_dataset_rejects_unsetting_only_default_in_tenant(): try: asyncio.run(_update_default_flag_with_other_count(other_default_count=0)) assert False, "expected LeauditException" except LeauditException as exc: assert exc.status == StatusCodeEnum.HTTP_400_BAD_REQUEST