423 lines
17 KiB
Python
423 lines
17 KiB
Python
"""认证服务实现。"""
|
||
|
||
from __future__ import annotations
|
||
|
||
from datetime import datetime, timezone
|
||
from typing import Any
|
||
|
||
from fastapi_common.fastapi_common_logger import logger
|
||
from fastapi_common.fastapi_common_security.jwtService import JwtService
|
||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
|
||
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
|
||
|
||
from fastapi_modules.fastapi_leaudit.domian.vo.auth.loginTokenVo import LoginTokenVO
|
||
from fastapi_modules.fastapi_leaudit.services.authService import IAuthService
|
||
from fastapi_modules.fastapi_leaudit.services.impl.tenantResolver import TenantResolver
|
||
|
||
|
||
class AuthServiceImpl(IAuthService):
|
||
"""认证服务实现。"""
|
||
|
||
def __init__(self) -> None:
|
||
self.TenantResolver = TenantResolver()
|
||
self._sso_user_columns_cache: set[str] | None = None
|
||
|
||
@staticmethod
|
||
def _naive_utcnow() -> datetime:
|
||
"""返回适配 timestamp without time zone 的 UTC 时间。"""
|
||
return datetime.utcnow()
|
||
|
||
async def _get_sso_user_columns(self, session) -> set[str]:
|
||
"""读取 `sso_users` 实际列,兼容部分环境尚未完成租户字段迁移。"""
|
||
if self._sso_user_columns_cache is not None:
|
||
return self._sso_user_columns_cache
|
||
|
||
from sqlalchemy import text
|
||
|
||
rows = await session.execute(
|
||
text(
|
||
"""
|
||
SELECT column_name
|
||
FROM information_schema.columns
|
||
WHERE table_schema = current_schema()
|
||
AND table_name = 'sso_users'
|
||
"""
|
||
)
|
||
)
|
||
self._sso_user_columns_cache = {str(row[0]) for row in rows.fetchall()}
|
||
return self._sso_user_columns_cache
|
||
|
||
@staticmethod
|
||
def _optional_sso_user_column(columns: set[str], column: str, pg_type: str = "varchar") -> str:
|
||
"""列存在则直接查询,不存在时回退为同名 NULL 别名。"""
|
||
if column in columns:
|
||
return column
|
||
return f"NULL::{pg_type} AS {column}"
|
||
|
||
async def _build_sso_user_select_fields(
|
||
self,
|
||
session,
|
||
*,
|
||
include_password: bool = False,
|
||
include_status: bool = False,
|
||
include_deleted_at: bool = False,
|
||
include_try_fields: bool = False,
|
||
) -> str:
|
||
"""构造兼容旧库结构的 `sso_users` SELECT 字段列表。"""
|
||
columns = await self._get_sso_user_columns(session)
|
||
fields = [
|
||
"id",
|
||
"sub",
|
||
"username",
|
||
"nick_name",
|
||
"phone_number",
|
||
"email",
|
||
"ou_id",
|
||
"ou_name",
|
||
"is_leader",
|
||
]
|
||
if include_password:
|
||
fields.append(self._optional_sso_user_column(columns, "password"))
|
||
if include_status:
|
||
fields.append(self._optional_sso_user_column(columns, "status", "integer"))
|
||
if include_deleted_at:
|
||
fields.append(self._optional_sso_user_column(columns, "deleted_at", "timestamp"))
|
||
if include_try_fields:
|
||
fields.append(self._optional_sso_user_column(columns, "try_count", "integer"))
|
||
fields.append(self._optional_sso_user_column(columns, "try_login_time", "timestamp"))
|
||
fields.extend(
|
||
[
|
||
self._optional_sso_user_column(columns, "area"),
|
||
self._optional_sso_user_column(columns, "tenant_code"),
|
||
self._optional_sso_user_column(columns, "tenant_name"),
|
||
self._optional_sso_user_column(columns, "dep_name"),
|
||
self._optional_sso_user_column(columns, "dep_short_name"),
|
||
]
|
||
)
|
||
return ", ".join(fields)
|
||
|
||
async def PasswordLogin(self, Sub: str, Password: str) -> LoginTokenVO:
|
||
"""账密登录。
|
||
|
||
现阶段仍兼容旧库明文密码,后续应迁移到哈希校验。
|
||
登录标识同时兼容旧系统常见的 `sub` 与 `username`,
|
||
避免前端展示用户名为 `admin`、实际登录只能输入 `000`。
|
||
"""
|
||
async with GetAsyncSession() as session:
|
||
from sqlalchemy import text
|
||
|
||
select_fields = await self._build_sso_user_select_fields(
|
||
session,
|
||
include_password=True,
|
||
include_status=True,
|
||
include_deleted_at=True,
|
||
include_try_fields=True,
|
||
)
|
||
result = await session.execute(
|
||
text(
|
||
f"SELECT {select_fields} "
|
||
"FROM sso_users "
|
||
"WHERE deleted_at IS NULL AND (sub = :identifier OR username = :identifier) "
|
||
"ORDER BY CASE WHEN sub = :identifier THEN 0 ELSE 1 END, id ASC "
|
||
"LIMIT 1"
|
||
),
|
||
{"identifier": Sub},
|
||
)
|
||
row = result.fetchone()
|
||
|
||
if not row:
|
||
logger.warning("登录失败: 用户不存在 - identifier=%s", Sub)
|
||
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号或密码错误")
|
||
|
||
user = dict(row._mapping)
|
||
|
||
if user.get("deleted_at") is not None:
|
||
logger.warning("登录失败: 账号已删除 - sub=%s", Sub)
|
||
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号或密码错误")
|
||
|
||
if user.get("status") != 0:
|
||
logger.warning("登录失败: 账号已禁用 - sub=%s, status=%s", Sub, user.get("status"))
|
||
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号或密码错误")
|
||
|
||
if user.get("password") != Password:
|
||
logger.warning("登录失败: 密码错误 - sub=%s", Sub)
|
||
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号或密码错误")
|
||
|
||
await self._ensure_default_role(session, user["id"])
|
||
return await self._buildLoginResponse(user, session)
|
||
|
||
async def OAuthLogin(
|
||
self,
|
||
Sub: str,
|
||
Username: str | None,
|
||
Nickname: str | None,
|
||
Email: str | None,
|
||
PhoneNumber: str | None,
|
||
OuId: str | None,
|
||
OuName: str | None,
|
||
IsLeader: bool | None,
|
||
Area: str | None,
|
||
ExpiresIn: int,
|
||
) -> LoginTokenVO:
|
||
"""OAuth 登录。
|
||
|
||
当前阶段 area 不能被前端登录请求直接覆盖。
|
||
如果用户不存在,则仅创建基础账号信息,地区字段留待可信后台来源补齐。
|
||
"""
|
||
del Area, ExpiresIn
|
||
|
||
async with GetAsyncSession() as session:
|
||
from sqlalchemy import text
|
||
|
||
select_fields = await self._build_sso_user_select_fields(
|
||
session,
|
||
include_status=True,
|
||
include_deleted_at=True,
|
||
)
|
||
result = await session.execute(
|
||
text(
|
||
f"SELECT {select_fields} "
|
||
"FROM sso_users WHERE sub = :sub"
|
||
),
|
||
{"sub": Sub},
|
||
)
|
||
row = result.fetchone()
|
||
|
||
if row:
|
||
user = dict(row._mapping)
|
||
if user.get("deleted_at") is not None or user.get("status") != 0:
|
||
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号已被禁用或删除")
|
||
|
||
await session.execute(
|
||
text(
|
||
"UPDATE sso_users SET username = :username, nick_name = :nick, "
|
||
"email = :email, phone_number = :phone, ou_id = :ou_id, "
|
||
"ou_name = :ou_name, is_leader = :is_leader, "
|
||
"updated_at = :now WHERE id = :id"
|
||
),
|
||
{
|
||
"username": Username or user.get("username") or Sub,
|
||
"nick": Nickname or user.get("nick_name") or Username or Sub,
|
||
"email": Email,
|
||
"phone": PhoneNumber,
|
||
"ou_id": OuId or user.get("ou_id") or "",
|
||
"ou_name": OuName or user.get("ou_name") or "",
|
||
"is_leader": IsLeader if IsLeader is not None else user.get("is_leader"),
|
||
"now": self._naive_utcnow(),
|
||
"id": user["id"],
|
||
},
|
||
)
|
||
else:
|
||
created = await session.execute(
|
||
text(
|
||
"INSERT INTO sso_users (sub, username, nick_name, email, phone_number, "
|
||
"ou_id, ou_name, is_leader, status, created_at, updated_at) "
|
||
"VALUES (:sub, :username, :nick, :email, :phone, :ou_id, "
|
||
":ou_name, :is_leader, 0, :now, :now) RETURNING id"
|
||
),
|
||
{
|
||
"sub": Sub,
|
||
"username": Username or Sub,
|
||
"nick": Nickname or Username or Sub,
|
||
"email": Email,
|
||
"phone": PhoneNumber,
|
||
"ou_id": OuId or "",
|
||
"ou_name": OuName or "",
|
||
"is_leader": bool(IsLeader),
|
||
"now": self._naive_utcnow(),
|
||
},
|
||
)
|
||
user_id = created.scalar_one()
|
||
await self._ensure_default_role(session, user_id)
|
||
|
||
select_fields = await self._build_sso_user_select_fields(session)
|
||
result = await session.execute(
|
||
text(
|
||
f"SELECT {select_fields} "
|
||
"FROM sso_users WHERE sub = :sub"
|
||
),
|
||
{"sub": Sub},
|
||
)
|
||
row = result.fetchone()
|
||
user = dict(row._mapping) if row else {}
|
||
|
||
if not user:
|
||
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "用户创建失败")
|
||
|
||
await self._ensure_default_role(session, user["id"])
|
||
return await self._buildLoginResponse(user, session)
|
||
|
||
async def GetCurrentUser(self, UserId: int) -> dict:
|
||
"""获取当前登录用户信息。"""
|
||
async with GetAsyncSession() as session:
|
||
from sqlalchemy import text
|
||
|
||
select_fields = await self._build_sso_user_select_fields(
|
||
session,
|
||
include_status=True,
|
||
include_deleted_at=True,
|
||
)
|
||
result = await session.execute(
|
||
text(
|
||
f"SELECT {select_fields} "
|
||
"FROM sso_users WHERE id = :uid"
|
||
),
|
||
{"uid": UserId},
|
||
)
|
||
row = result.fetchone()
|
||
if not row:
|
||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "用户不存在")
|
||
|
||
user = dict(row._mapping)
|
||
if user.get("deleted_at") is not None or user.get("status") != 0:
|
||
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号已被禁用或删除")
|
||
|
||
await self._ensure_default_role(session, user["id"])
|
||
identity = await self._loadUserIdentity(session, user["id"])
|
||
user_info = self._buildUserInfo(user, identity)
|
||
tenant_resolution = await self.TenantResolver.ResolveUserContext(
|
||
Area=user.get("area"),
|
||
TenantCode=user.get("tenant_code"),
|
||
TenantName=user.get("tenant_name"),
|
||
Source="current_user",
|
||
)
|
||
user_info["tenant_code"] = tenant_resolution.tenant_code
|
||
user_info["tenant_name"] = tenant_resolution.tenant_name or user.get("tenant_name")
|
||
user_info["tenant_type"] = tenant_resolution.tenant_type
|
||
return user_info
|
||
|
||
async def _buildLoginResponse(self, user: dict[str, Any], session) -> LoginTokenVO:
|
||
"""组装登录响应:查询角色/权限 → 签发 JWT。"""
|
||
identity = await self._loadUserIdentity(session, user["id"])
|
||
user_info = self._buildUserInfo(user, identity)
|
||
tenant_resolution = await self.TenantResolver.ResolveUserContext(
|
||
Area=user.get("area"),
|
||
TenantCode=user.get("tenant_code"),
|
||
TenantName=user.get("tenant_name"),
|
||
)
|
||
user_info["tenant_code"] = tenant_resolution.tenant_code
|
||
user_info["tenant_type"] = tenant_resolution.tenant_type
|
||
user_info["tenant_name"] = tenant_resolution.tenant_name or user.get("tenant_name")
|
||
|
||
tokens = JwtService.generate(
|
||
userId=user["id"],
|
||
username=user.get("username") or user.get("sub", ""),
|
||
nickName=user.get("nick_name") or "",
|
||
ouId=user.get("ou_id") or "",
|
||
ouName=user.get("ou_name") or "",
|
||
roles=identity["roles"],
|
||
permissions=identity["permissions"],
|
||
area=user.get("area"),
|
||
tenantCode=tenant_resolution.tenant_code,
|
||
tenantName=tenant_resolution.tenant_name or user.get("tenant_name"),
|
||
tenantType=tenant_resolution.tenant_type,
|
||
userRole=identity["primary_role"],
|
||
)
|
||
|
||
return LoginTokenVO(
|
||
access_token=tokens["access_token"],
|
||
token_type="Bearer",
|
||
expires_in=tokens["expires_in"],
|
||
issued_time=tokens.get("issued_time", ""),
|
||
user_info=user_info,
|
||
)
|
||
|
||
async def _ensure_default_role(self, session, user_id: int) -> None:
|
||
"""确保用户至少拥有一个默认 common 角色。"""
|
||
from sqlalchemy import text
|
||
|
||
role_count = await session.execute(
|
||
text("SELECT COUNT(*) FROM user_role WHERE user_id = :uid"),
|
||
{"uid": user_id},
|
||
)
|
||
if (role_count.scalar_one() or 0) > 0:
|
||
return
|
||
|
||
common_role = await session.execute(
|
||
text("SELECT id FROM roles WHERE role_key = 'common' LIMIT 1")
|
||
)
|
||
common_role_id = common_role.scalar_one_or_none()
|
||
if common_role_id is None:
|
||
logger.warning("默认角色 common 不存在,无法为用户 %s 自动分配角色", user_id)
|
||
return
|
||
|
||
await session.execute(
|
||
text(
|
||
"INSERT INTO user_role (user_id, role_id, created_at, updated_at) "
|
||
"VALUES (:uid, :rid, :now, :now) "
|
||
"ON CONFLICT (user_id, role_id) DO NOTHING"
|
||
),
|
||
{"uid": user_id, "rid": common_role_id, "now": datetime.now(timezone.utc)},
|
||
)
|
||
logger.info("已为用户 %s 自动补默认角色 common", user_id)
|
||
|
||
async def _loadUserIdentity(self, session, user_id: int) -> dict[str, Any]:
|
||
"""加载用户角色和权限聚合结果。"""
|
||
from sqlalchemy import text
|
||
|
||
role_rows = await session.execute(
|
||
text(
|
||
"SELECT r.role_key, COALESCE(r.priority, 0) AS priority "
|
||
"FROM user_role ur "
|
||
"JOIN roles r ON ur.role_id = r.id "
|
||
"WHERE ur.user_id = :uid "
|
||
"ORDER BY COALESCE(r.priority, 0) DESC, r.id ASC"
|
||
),
|
||
{"uid": user_id},
|
||
)
|
||
roles = [row[0] for row in role_rows.fetchall()]
|
||
if not roles:
|
||
roles = ["common"]
|
||
|
||
perm_rows = await session.execute(
|
||
text(
|
||
"SELECT p.permission_key, rp.grant_type "
|
||
"FROM user_role ur "
|
||
"JOIN roles r ON ur.role_id = r.id "
|
||
"JOIN role_permissions rp ON r.id = rp.role_id "
|
||
"JOIN permissions p ON rp.permission_id = p.id "
|
||
"WHERE ur.user_id = :uid"
|
||
),
|
||
{"uid": user_id},
|
||
)
|
||
|
||
grants: set[str] = set()
|
||
denies: set[str] = set()
|
||
for permission_key, grant_type in perm_rows.fetchall():
|
||
if grant_type == "DENY":
|
||
denies.add(permission_key)
|
||
else:
|
||
grants.add(permission_key)
|
||
|
||
permissions = sorted(grants - denies)
|
||
return {
|
||
"roles": roles,
|
||
"primary_role": roles[0],
|
||
"permissions": permissions,
|
||
}
|
||
|
||
@staticmethod
|
||
def _buildUserInfo(user: dict[str, Any], identity: dict[str, Any]) -> dict[str, Any]:
|
||
"""组装统一用户信息。"""
|
||
return {
|
||
"user_id": user["id"],
|
||
"sub": user.get("sub"),
|
||
"username": user.get("username"),
|
||
"nick_name": user.get("nick_name"),
|
||
"email": user.get("email"),
|
||
"phone_number": user.get("phone_number"),
|
||
"ou_id": user.get("ou_id"),
|
||
"ou_name": user.get("ou_name"),
|
||
"is_leader": user.get("is_leader"),
|
||
"area": user.get("area"),
|
||
"tenant_code": user.get("tenant_code"),
|
||
"user_role": identity["primary_role"],
|
||
"roles": identity["roles"],
|
||
"permissions": identity["permissions"],
|
||
"tenant_name": user.get("tenant_name"),
|
||
"tenant_type": None,
|
||
"dep_name": user.get("dep_name"),
|
||
"dep_short_name": user.get("dep_short_name"),
|
||
}
|