"""租户主数据服务实现。""" from __future__ import annotations from typing import Any from sqlalchemy import text from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException from fastapi_modules.fastapi_leaudit.domian.Dto.tenantDto import ( TenantCreateDTO, TenantStatusUpdateDTO, TenantUpdateDTO, ) from fastapi_modules.fastapi_leaudit.services.impl.ruleTenantMaterializer import ( GetRuleTenantMaterializerSingleton, RuleTenantMaterializer, ) from fastapi_modules.fastapi_leaudit.services.tenantService import ITenantService class TenantServiceImpl(ITenantService): """租户主数据服务实现。""" _BUILTIN_TENANT_CODES: tuple[str, ...] = ("PUBLIC", "PROVINCIAL") _SUPPORTED_FEATURE_KEYS: tuple[str, ...] = ( "home.entry_module", "documents.upload", "rag.dataset", ) def __init__(self, RuleTenantMaterializer: RuleTenantMaterializer | None = None) -> None: self._table_exists_cache: dict[str, bool] = {} self.RuleTenantMaterializer = RuleTenantMaterializer or GetRuleTenantMaterializerSingleton() async def ListTenants(self, IncludeDisabled: bool = False) -> list[dict[str, Any]]: if not await self._table_exists("sys_tenants"): return await self._list_legacy_tenants() filters = ["t.deleted_at IS NULL"] params: dict[str, Any] = {} if not IncludeDisabled: filters.append("t.is_enabled = TRUE") where_sql = " AND ".join(filters) async with GetAsyncSession() as session: rows = ( await session.execute( text( f""" SELECT t.tenant_code, t.tenant_name, t.tenant_short_name, t.tenant_type, t.parent_tenant_code, t.display_order, t.is_enabled, t.is_builtin, t.is_public, t.can_host_entry_module, t.can_host_documents, t.can_host_rag, t.can_host_templates, t.ext, COALESCE( ARRAY( SELECT f.feature_key FROM sys_tenant_feature_flags f WHERE f.tenant_code = t.tenant_code AND f.deleted_at IS NULL AND f.is_enabled = TRUE ORDER BY f.feature_key ASC ), ARRAY[]::VARCHAR[] ) AS feature_keys, COALESCE( ARRAY( SELECT a.alias_value FROM sys_tenant_aliases a WHERE a.tenant_code = t.tenant_code AND a.deleted_at IS NULL AND a.is_enabled = TRUE ORDER BY CASE a.alias_type WHEN 'DISPLAY' THEN 1 WHEN 'SHORT_NAME' THEN 2 WHEN 'LEGACY_AREA' THEN 3 ELSE 9 END ASC, a.id ASC ), ARRAY[]::VARCHAR[] ) AS alias_values FROM sys_tenants t WHERE {where_sql} ORDER BY t.display_order ASC, t.id ASC """ ), params, ) ).mappings().all() return [dict(row) for row in rows] async def ListTenantOptions(self, FeatureKey: str | None = None) -> list[dict[str, Any]]: if not await self._table_exists("sys_tenants"): items = await self._list_legacy_tenant_options() return items filters = ["t.deleted_at IS NULL", "t.is_enabled = TRUE"] params: dict[str, Any] = {} join_sql = "" if FeatureKey and FeatureKey.strip() and await self._table_exists("sys_tenant_feature_flags"): join_sql = """ JOIN sys_tenant_feature_flags f ON f.tenant_code = t.tenant_code AND f.deleted_at IS NULL AND f.is_enabled = TRUE AND f.feature_key = :feature_key """ params["feature_key"] = FeatureKey.strip() where_sql = " AND ".join(filters) async with GetAsyncSession() as session: rows = ( await session.execute( text( f""" SELECT t.tenant_code, t.tenant_name, t.tenant_short_name, t.tenant_type, t.is_public, t.display_order FROM sys_tenants t {join_sql} WHERE {where_sql} ORDER BY t.display_order ASC, t.id ASC """ ), params, ) ).mappings().all() return [dict(row) for row in rows] async def GetTenant(self, TenantCode: str) -> dict[str, Any] | None: tenant_code = str(TenantCode or "").strip() if not tenant_code: return None if not await self._table_exists("sys_tenants"): for item in await self._list_legacy_tenants(): if str(item.get("tenant_code") or "").strip() == tenant_code: return item return None async with GetAsyncSession() as session: row = ( await session.execute( text( """ SELECT t.tenant_code, t.tenant_name, t.tenant_short_name, t.tenant_type, t.parent_tenant_code, t.display_order, t.is_enabled, t.is_builtin, t.is_public, t.can_host_entry_module, t.can_host_documents, t.can_host_rag, t.can_host_templates, t.ext FROM sys_tenants t WHERE t.tenant_code = :tenant_code AND t.deleted_at IS NULL LIMIT 1 """ ), {"tenant_code": tenant_code}, ) ).mappings().first() return dict(row) if row else None async def GetTenantFeatures(self, TenantCode: str) -> list[str]: tenant_code = str(TenantCode or "").strip() if not tenant_code: return [] if not await self._table_exists("sys_tenant_feature_flags"): return [] async with GetAsyncSession() as session: rows = ( await session.execute( text( """ SELECT feature_key FROM sys_tenant_feature_flags WHERE tenant_code = :tenant_code AND deleted_at IS NULL AND is_enabled = TRUE ORDER BY feature_key ASC """ ), {"tenant_code": tenant_code}, ) ).all() return [str(row[0]) for row in rows] async def GetTenantAliases(self, TenantCode: str) -> list[str]: tenant_code = str(TenantCode or "").strip() if not tenant_code: return [] return await self._getTenantAliases(tenant_code) async def CreateTenant(self, CurrentUserId: int, Body: TenantCreateDTO) -> dict[str, Any]: del CurrentUserId await self._ensureWritableTenantFoundation() tenant_code = self._normalizeTenantCode(Body.tenant_code) tenant_name = self._normalizeRequiredText(Body.tenant_name, "租户名称") tenant_short_name = self._normalizeOptionalText(Body.tenant_short_name) or tenant_name tenant_type = self._normalizeTenantType(Body.tenant_type) parent_tenant_code = self._normalizeOptionalCode(Body.parent_tenant_code) feature_keys = self._normalizeFeatureKeys(Body.feature_keys) alias_values = self._normalizeAliasValues(Body.alias_values, tenant_name=tenant_name, tenant_short_name=tenant_short_name) ext = Body.ext or {} async with GetAsyncSession() as session: exists = ( await session.execute( text( """ SELECT 1 FROM sys_tenants WHERE tenant_code = :tenant_code AND deleted_at IS NULL LIMIT 1 """ ), {"tenant_code": tenant_code}, ) ).scalar_one_or_none() if exists: raise LeauditException(StatusCodeEnum.HTTP_409_CONFLICT, f"租户编码已存在: {tenant_code}") if parent_tenant_code: await self._assertTenantExists(session, parent_tenant_code, "父级租户不存在") await session.execute( text( """ INSERT INTO sys_tenants ( tenant_code, tenant_name, tenant_short_name, tenant_type, parent_tenant_code, display_order, is_enabled, is_builtin, is_public, can_host_entry_module, can_host_documents, can_host_rag, can_host_templates, ext, created_at, updated_at, deleted_at ) VALUES ( :tenant_code, :tenant_name, :tenant_short_name, :tenant_type, :parent_tenant_code, :display_order, :is_enabled, FALSE, :is_public, :can_host_entry_module, :can_host_documents, :can_host_rag, :can_host_templates, CAST(:ext AS jsonb), NOW(), NOW(), NULL ) """ ), { "tenant_code": tenant_code, "tenant_name": tenant_name, "tenant_short_name": tenant_short_name, "tenant_type": tenant_type, "parent_tenant_code": parent_tenant_code, "display_order": int(Body.display_order or 0), "is_enabled": bool(Body.is_enabled), "is_public": bool(Body.is_public), "can_host_entry_module": bool(Body.can_host_entry_module), "can_host_documents": bool(Body.can_host_documents), "can_host_rag": bool(Body.can_host_rag), "can_host_templates": bool(Body.can_host_templates), "ext": self._dumpJson(ext), }, ) await self._replaceTenantAliases(session, tenant_code, alias_values) await self._replaceTenantFeatures(session, tenant_code, feature_keys) await session.commit() await self.RuleTenantMaterializer.MaterializeTenant(tenant_code) return await self._getTenantDetailOrFail(tenant_code) async def UpdateTenant(self, CurrentUserId: int, TenantCode: str, Body: TenantUpdateDTO) -> dict[str, Any]: del CurrentUserId await self._ensureWritableTenantFoundation() tenant_code = self._normalizeTenantCode(TenantCode) async with GetAsyncSession() as session: current = await self._loadTenantRow(session, tenant_code) if not current: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "租户不存在") is_builtin = bool(current.get("is_builtin")) tenant_name = self._normalizeOptionalText(Body.tenant_name) if Body.tenant_name is not None else str(current.get("tenant_name") or "") if not tenant_name: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "租户名称不能为空") tenant_short_name = ( self._normalizeOptionalText(Body.tenant_short_name) if Body.tenant_short_name is not None else self._normalizeOptionalText(current.get("tenant_short_name")) ) or tenant_name tenant_type = self._normalizeTenantType(Body.tenant_type) if Body.tenant_type is not None else str(current.get("tenant_type") or "CUSTOM") parent_tenant_code = self._normalizeOptionalCode(Body.parent_tenant_code) if Body.parent_tenant_code is not None else current.get("parent_tenant_code") if parent_tenant_code == tenant_code: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "父级租户不能指向自身") if parent_tenant_code: await self._assertTenantExists(session, str(parent_tenant_code), "父级租户不存在", exclude_tenant_code=tenant_code) if is_builtin and tenant_code in self._BUILTIN_TENANT_CODES: tenant_type = str(current.get("tenant_type") or tenant_type) parent_tenant_code = current.get("parent_tenant_code") await session.execute( text( """ UPDATE sys_tenants SET tenant_name = :tenant_name, tenant_short_name = :tenant_short_name, tenant_type = :tenant_type, parent_tenant_code = :parent_tenant_code, display_order = :display_order, is_public = :is_public, can_host_entry_module = :can_host_entry_module, can_host_documents = :can_host_documents, can_host_rag = :can_host_rag, can_host_templates = :can_host_templates, ext = CAST(:ext AS jsonb), updated_at = NOW() WHERE tenant_code = :tenant_code AND deleted_at IS NULL """ ), { "tenant_code": tenant_code, "tenant_name": tenant_name, "tenant_short_name": tenant_short_name, "tenant_type": tenant_type, "parent_tenant_code": parent_tenant_code, "display_order": int(Body.display_order if Body.display_order is not None else current.get("display_order") or 0), "is_public": bool(Body.is_public if Body.is_public is not None else current.get("is_public")), "can_host_entry_module": bool( Body.can_host_entry_module if Body.can_host_entry_module is not None else current.get("can_host_entry_module") ), "can_host_documents": bool( Body.can_host_documents if Body.can_host_documents is not None else current.get("can_host_documents") ), "can_host_rag": bool(Body.can_host_rag if Body.can_host_rag is not None else current.get("can_host_rag")), "can_host_templates": bool( Body.can_host_templates if Body.can_host_templates is not None else current.get("can_host_templates") ), "ext": self._dumpJson(Body.ext if Body.ext is not None else (current.get("ext") or {})), }, ) if Body.alias_values is not None: alias_values = self._normalizeAliasValues(Body.alias_values, tenant_name=tenant_name, tenant_short_name=tenant_short_name) await self._replaceTenantAliases(session, tenant_code, alias_values) if Body.feature_keys is not None: feature_keys = self._normalizeFeatureKeys(Body.feature_keys) await self._replaceTenantFeatures(session, tenant_code, feature_keys) await session.commit() return await self._getTenantDetailOrFail(tenant_code) async def UpdateTenantStatus(self, CurrentUserId: int, TenantCode: str, Body: TenantStatusUpdateDTO) -> dict[str, Any]: del CurrentUserId await self._ensureWritableTenantFoundation() tenant_code = self._normalizeTenantCode(TenantCode) is_enabled = bool(Body.is_enabled) async with GetAsyncSession() as session: current = await self._loadTenantRow(session, tenant_code) if not current: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "租户不存在") if bool(current.get("is_builtin")) and tenant_code in self._BUILTIN_TENANT_CODES and not is_enabled: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "内建核心租户不允许禁用") if not is_enabled: await self._assertTenantCanBeDisabled(session, tenant_code) await session.execute( text( """ UPDATE sys_tenants SET is_enabled = :is_enabled, updated_at = NOW() WHERE tenant_code = :tenant_code AND deleted_at IS NULL """ ), {"tenant_code": tenant_code, "is_enabled": is_enabled}, ) await session.commit() return await self._getTenantDetailOrFail(tenant_code) async def _table_exists(self, table_name: str) -> bool: cached = self._table_exists_cache.get(table_name) if cached is not None: return cached async with GetAsyncSession() as session: exists = bool( ( await session.execute( text( """ SELECT EXISTS ( SELECT 1 FROM information_schema.tables WHERE table_schema = current_schema() AND table_name = :table_name ) """ ), {"table_name": table_name}, ) ).scalar_one() ) self._table_exists_cache[table_name] = exists return exists async def _list_legacy_tenants(self) -> list[dict[str, Any]]: async with GetAsyncSession() as session: user_rows = ( await session.execute( text( """ SELECT DISTINCT COALESCE(NULLIF(area, ''), '') AS area FROM sso_users WHERE deleted_at IS NULL AND COALESCE(NULLIF(area, ''), '') <> '' ORDER BY COALESCE(NULLIF(area, ''), '') ASC """ ) ) ).all() items: list[dict[str, Any]] = [ { "tenant_code": "PUBLIC", "tenant_name": "公共", "tenant_short_name": "公共", "tenant_type": "PUBLIC", "parent_tenant_code": None, "display_order": 0, "is_enabled": True, "is_builtin": True, "is_public": True, "can_host_entry_module": True, "can_host_documents": True, "can_host_rag": True, "can_host_templates": True, "ext": {}, "feature_keys": [], } ] for index, row in enumerate(user_rows, start=1): area = str(row[0] or "").strip() if not area or area == "公共": continue items.append( { "tenant_code": area, "tenant_name": area, "tenant_short_name": area, "tenant_type": "LEGACY_AREA", "parent_tenant_code": None, "display_order": index * 10, "is_enabled": True, "is_builtin": False, "is_public": False, "can_host_entry_module": True, "can_host_documents": True, "can_host_rag": True, "can_host_templates": True, "ext": {}, "feature_keys": [], } ) return items async def _list_legacy_tenant_options(self) -> list[dict[str, Any]]: return [ { "tenant_code": str(item.get("tenant_code") or ""), "tenant_name": str(item.get("tenant_name") or ""), "tenant_short_name": item.get("tenant_short_name"), "tenant_type": item.get("tenant_type"), "is_public": bool(item.get("is_public")), "display_order": item.get("display_order"), } for item in await self._list_legacy_tenants() ] async def _ensureWritableTenantFoundation(self) -> None: required_tables = ("sys_tenants", "sys_tenant_aliases", "sys_tenant_feature_flags") missing = [table_name for table_name in required_tables if not await self._table_exists(table_name)] if missing: joined = ", ".join(missing) raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, f"租户主数据底座未初始化,缺少表: {joined}") async def _getTenantDetailOrFail(self, tenant_code: str) -> dict[str, Any]: item = await self.GetTenant(tenant_code) if not item: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "租户不存在") item["feature_keys"] = await self.GetTenantFeatures(tenant_code) item["alias_values"] = await self.GetTenantAliases(tenant_code) return item async def _getTenantAliases(self, tenant_code: str) -> list[str]: if not await self._table_exists("sys_tenant_aliases"): return [] async with GetAsyncSession() as session: rows = ( await session.execute( text( """ SELECT alias_value FROM sys_tenant_aliases WHERE tenant_code = :tenant_code AND deleted_at IS NULL AND is_enabled = TRUE ORDER BY CASE alias_type WHEN 'DISPLAY' THEN 1 WHEN 'SHORT_NAME' THEN 2 WHEN 'LEGACY_AREA' THEN 3 ELSE 9 END ASC, id ASC """ ), {"tenant_code": tenant_code}, ) ).all() return [str(row[0]).strip() for row in rows if str(row[0] or "").strip()] async def _loadTenantRow(self, session: Any, tenant_code: str) -> dict[str, Any] | None: row = ( await session.execute( text( """ SELECT tenant_code, tenant_name, tenant_short_name, tenant_type, parent_tenant_code, display_order, is_enabled, is_builtin, is_public, can_host_entry_module, can_host_documents, can_host_rag, can_host_templates, ext FROM sys_tenants WHERE tenant_code = :tenant_code AND deleted_at IS NULL LIMIT 1 """ ), {"tenant_code": tenant_code}, ) ).mappings().first() return dict(row) if row else None async def _assertTenantExists( self, session: Any, tenant_code: str, error_message: str, exclude_tenant_code: str | None = None, ) -> None: params: dict[str, Any] = {"tenant_code": tenant_code} sql = """ SELECT 1 FROM sys_tenants WHERE tenant_code = :tenant_code AND deleted_at IS NULL """ if exclude_tenant_code: sql += " AND tenant_code <> :exclude_tenant_code" params["exclude_tenant_code"] = exclude_tenant_code sql += " LIMIT 1" exists = (await session.execute(text(sql), params)).scalar_one_or_none() if not exists: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, error_message) async def _assertTenantCanBeDisabled(self, session: Any, tenant_code: str) -> None: references = await self._collectDisableReferences(session, tenant_code) if not references: return joined = ";".join(references) raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, f"当前租户仍被引用,不能禁用:{joined}") async def _collectDisableReferences(self, session: Any, tenant_code: str) -> list[str]: references: list[str] = [] if await self._table_exists("sys_tenants"): child_count = int( ( await session.execute( text( """ SELECT COUNT(*) FROM sys_tenants WHERE parent_tenant_code = :tenant_code AND deleted_at IS NULL AND is_enabled = TRUE """ ), {"tenant_code": tenant_code}, ) ).scalar_one() ) if child_count > 0: references.append(f"存在 {child_count} 个启用中的子租户") if await self._table_exists("leaudit_entry_module_tenants"): entry_module_count = int( ( await session.execute( text( """ SELECT COUNT(DISTINCT emt.entry_module_id) FROM leaudit_entry_module_tenants emt WHERE emt.tenant_code = :tenant_code AND emt.deleted_at IS NULL AND COALESCE(emt.is_enabled, TRUE) = TRUE """ ), {"tenant_code": tenant_code}, ) ).scalar_one() ) if entry_module_count > 0: references.append(f"仍绑定 {entry_module_count} 个入口模块") if await self._table_exists("sso_users") and await self._column_exists("sso_users", "tenant_code"): user_count = int( ( await session.execute( text( """ SELECT COUNT(*) FROM sso_users WHERE tenant_code = :tenant_code AND deleted_at IS NULL AND status = 0 """ ), {"tenant_code": tenant_code}, ) ).scalar_one() ) if user_count > 0: references.append(f"仍有 {user_count} 个启用用户归属该租户") return references async def _replaceTenantAliases(self, session: Any, tenant_code: str, alias_values: list[str]) -> None: await session.execute( text( """ DELETE FROM sys_tenant_aliases WHERE tenant_code = :tenant_code """ ), {"tenant_code": tenant_code}, ) for index, alias in enumerate(alias_values): alias_type = "DISPLAY" if index == 0 else "SHORT_NAME" await session.execute( text( """ INSERT INTO sys_tenant_aliases ( tenant_code, alias_type, alias_value, is_enabled, created_at, updated_at, deleted_at ) VALUES ( :tenant_code, :alias_type, :alias_value, TRUE, NOW(), NOW(), NULL ) """ ), { "tenant_code": tenant_code, "alias_type": alias_type, "alias_value": alias, }, ) async def _replaceTenantFeatures(self, session: Any, tenant_code: str, feature_keys: list[str]) -> None: await session.execute( text( """ DELETE FROM sys_tenant_feature_flags WHERE tenant_code = :tenant_code """ ), {"tenant_code": tenant_code}, ) for feature_key in feature_keys: await session.execute( text( """ INSERT INTO sys_tenant_feature_flags ( tenant_code, feature_key, is_enabled, created_at, updated_at, deleted_at ) VALUES ( :tenant_code, :feature_key, TRUE, NOW(), NOW(), NULL ) """ ), { "tenant_code": tenant_code, "feature_key": feature_key, }, ) def _normalizeTenantCode(self, value: str | None) -> str: tenant_code = str(value or "").strip() if not tenant_code: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "租户编码不能为空") if len(tenant_code) > 64: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "租户编码长度不能超过 64") return tenant_code def _normalizeOptionalCode(self, value: str | None) -> str | None: normalized = self._normalizeOptionalText(value) return normalized or None def _normalizeRequiredText(self, value: str | None, field_name: str) -> str: normalized = self._normalizeOptionalText(value) if not normalized: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, f"{field_name}不能为空") return normalized @staticmethod def _normalizeOptionalText(value: Any) -> str | None: if value is None: return None text_value = str(value).strip() return text_value or None def _normalizeTenantType(self, value: str | None) -> str: tenant_type = self._normalizeOptionalText(value) or "CUSTOM" if len(tenant_type) > 32: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "租户类型长度不能超过 32") return tenant_type.upper() def _normalizeFeatureKeys(self, values: list[str] | None) -> list[str]: normalized: list[str] = [] seen: set[str] = set() for item in values or []: feature_key = self._normalizeOptionalText(item) if not feature_key: continue if feature_key not in self._SUPPORTED_FEATURE_KEYS: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, f"不支持的功能标识: {feature_key}") if feature_key in seen: continue seen.add(feature_key) normalized.append(feature_key) return normalized def _normalizeAliasValues(self, values: list[str] | None, *, tenant_name: str, tenant_short_name: str) -> list[str]: normalized: list[str] = [] seen: set[str] = set() seeds = [tenant_name, tenant_short_name, *(values or [])] for item in seeds: alias = self._normalizeOptionalText(item) if not alias or alias in seen: continue if len(alias) > 100: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, f"租户别名过长: {alias}") seen.add(alias) normalized.append(alias) return normalized async def _column_exists(self, table_name: str, column_name: str) -> bool: cache_key = f"{table_name}.{column_name}" cached = self._table_exists_cache.get(cache_key) if cached is not None: return cached async with GetAsyncSession() as session: exists = bool( ( await session.execute( text( """ SELECT EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_schema = current_schema() AND table_name = :table_name AND column_name = :column_name ) """ ), {"table_name": table_name, "column_name": column_name}, ) ).scalar_one() ) self._table_exists_cache[cache_key] = exists return exists @staticmethod def _dumpJson(value: dict[str, Any]) -> str: import json return json.dumps(value, ensure_ascii=False)