Files

423 lines
17 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""认证服务实现。"""
from __future__ import annotations
from datetime import datetime, timezone
from typing import Any
from fastapi_common.fastapi_common_logger import logger
from fastapi_common.fastapi_common_security.jwtService import JwtService
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
from fastapi_modules.fastapi_leaudit.domian.vo.auth.loginTokenVo import LoginTokenVO
from fastapi_modules.fastapi_leaudit.services.authService import IAuthService
from fastapi_modules.fastapi_leaudit.services.impl.tenantResolver import TenantResolver
class AuthServiceImpl(IAuthService):
"""认证服务实现。"""
def __init__(self) -> None:
self.TenantResolver = TenantResolver()
self._sso_user_columns_cache: set[str] | None = None
@staticmethod
def _naive_utcnow() -> datetime:
"""返回适配 timestamp without time zone 的 UTC 时间。"""
return datetime.utcnow()
async def _get_sso_user_columns(self, session) -> set[str]:
"""读取 `sso_users` 实际列,兼容部分环境尚未完成租户字段迁移。"""
if self._sso_user_columns_cache is not None:
return self._sso_user_columns_cache
from sqlalchemy import text
rows = await session.execute(
text(
"""
SELECT column_name
FROM information_schema.columns
WHERE table_schema = current_schema()
AND table_name = 'sso_users'
"""
)
)
self._sso_user_columns_cache = {str(row[0]) for row in rows.fetchall()}
return self._sso_user_columns_cache
@staticmethod
def _optional_sso_user_column(columns: set[str], column: str, pg_type: str = "varchar") -> str:
"""列存在则直接查询,不存在时回退为同名 NULL 别名。"""
if column in columns:
return column
return f"NULL::{pg_type} AS {column}"
async def _build_sso_user_select_fields(
self,
session,
*,
include_password: bool = False,
include_status: bool = False,
include_deleted_at: bool = False,
include_try_fields: bool = False,
) -> str:
"""构造兼容旧库结构的 `sso_users` SELECT 字段列表。"""
columns = await self._get_sso_user_columns(session)
fields = [
"id",
"sub",
"username",
"nick_name",
"phone_number",
"email",
"ou_id",
"ou_name",
"is_leader",
]
if include_password:
fields.append(self._optional_sso_user_column(columns, "password"))
if include_status:
fields.append(self._optional_sso_user_column(columns, "status", "integer"))
if include_deleted_at:
fields.append(self._optional_sso_user_column(columns, "deleted_at", "timestamp"))
if include_try_fields:
fields.append(self._optional_sso_user_column(columns, "try_count", "integer"))
fields.append(self._optional_sso_user_column(columns, "try_login_time", "timestamp"))
fields.extend(
[
self._optional_sso_user_column(columns, "area"),
self._optional_sso_user_column(columns, "tenant_code"),
self._optional_sso_user_column(columns, "tenant_name"),
self._optional_sso_user_column(columns, "dep_name"),
self._optional_sso_user_column(columns, "dep_short_name"),
]
)
return ", ".join(fields)
async def PasswordLogin(self, Sub: str, Password: str) -> LoginTokenVO:
"""账密登录。
现阶段仍兼容旧库明文密码,后续应迁移到哈希校验。
登录标识同时兼容旧系统常见的 `sub` 与 `username`
避免前端展示用户名为 `admin`、实际登录只能输入 `000`。
"""
async with GetAsyncSession() as session:
from sqlalchemy import text
select_fields = await self._build_sso_user_select_fields(
session,
include_password=True,
include_status=True,
include_deleted_at=True,
include_try_fields=True,
)
result = await session.execute(
text(
f"SELECT {select_fields} "
"FROM sso_users "
"WHERE deleted_at IS NULL AND (sub = :identifier OR username = :identifier) "
"ORDER BY CASE WHEN sub = :identifier THEN 0 ELSE 1 END, id ASC "
"LIMIT 1"
),
{"identifier": Sub},
)
row = result.fetchone()
if not row:
logger.warning("登录失败: 用户不存在 - identifier=%s", Sub)
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号或密码错误")
user = dict(row._mapping)
if user.get("deleted_at") is not None:
logger.warning("登录失败: 账号已删除 - sub=%s", Sub)
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号或密码错误")
if user.get("status") != 0:
logger.warning("登录失败: 账号已禁用 - sub=%s, status=%s", Sub, user.get("status"))
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号或密码错误")
if user.get("password") != Password:
logger.warning("登录失败: 密码错误 - sub=%s", Sub)
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号或密码错误")
await self._ensure_default_role(session, user["id"])
return await self._buildLoginResponse(user, session)
async def OAuthLogin(
self,
Sub: str,
Username: str | None,
Nickname: str | None,
Email: str | None,
PhoneNumber: str | None,
OuId: str | None,
OuName: str | None,
IsLeader: bool | None,
Area: str | None,
ExpiresIn: int,
) -> LoginTokenVO:
"""OAuth 登录。
当前阶段 area 不能被前端登录请求直接覆盖。
如果用户不存在,则仅创建基础账号信息,地区字段留待可信后台来源补齐。
"""
del Area, ExpiresIn
async with GetAsyncSession() as session:
from sqlalchemy import text
select_fields = await self._build_sso_user_select_fields(
session,
include_status=True,
include_deleted_at=True,
)
result = await session.execute(
text(
f"SELECT {select_fields} "
"FROM sso_users WHERE sub = :sub"
),
{"sub": Sub},
)
row = result.fetchone()
if row:
user = dict(row._mapping)
if user.get("deleted_at") is not None or user.get("status") != 0:
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号已被禁用或删除")
await session.execute(
text(
"UPDATE sso_users SET username = :username, nick_name = :nick, "
"email = :email, phone_number = :phone, ou_id = :ou_id, "
"ou_name = :ou_name, is_leader = :is_leader, "
"updated_at = :now WHERE id = :id"
),
{
"username": Username or user.get("username") or Sub,
"nick": Nickname or user.get("nick_name") or Username or Sub,
"email": Email,
"phone": PhoneNumber,
"ou_id": OuId or user.get("ou_id") or "",
"ou_name": OuName or user.get("ou_name") or "",
"is_leader": IsLeader if IsLeader is not None else user.get("is_leader"),
"now": self._naive_utcnow(),
"id": user["id"],
},
)
else:
created = await session.execute(
text(
"INSERT INTO sso_users (sub, username, nick_name, email, phone_number, "
"ou_id, ou_name, is_leader, status, created_at, updated_at) "
"VALUES (:sub, :username, :nick, :email, :phone, :ou_id, "
":ou_name, :is_leader, 0, :now, :now) RETURNING id"
),
{
"sub": Sub,
"username": Username or Sub,
"nick": Nickname or Username or Sub,
"email": Email,
"phone": PhoneNumber,
"ou_id": OuId or "",
"ou_name": OuName or "",
"is_leader": bool(IsLeader),
"now": self._naive_utcnow(),
},
)
user_id = created.scalar_one()
await self._ensure_default_role(session, user_id)
select_fields = await self._build_sso_user_select_fields(session)
result = await session.execute(
text(
f"SELECT {select_fields} "
"FROM sso_users WHERE sub = :sub"
),
{"sub": Sub},
)
row = result.fetchone()
user = dict(row._mapping) if row else {}
if not user:
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "用户创建失败")
await self._ensure_default_role(session, user["id"])
return await self._buildLoginResponse(user, session)
async def GetCurrentUser(self, UserId: int) -> dict:
"""获取当前登录用户信息。"""
async with GetAsyncSession() as session:
from sqlalchemy import text
select_fields = await self._build_sso_user_select_fields(
session,
include_status=True,
include_deleted_at=True,
)
result = await session.execute(
text(
f"SELECT {select_fields} "
"FROM sso_users WHERE id = :uid"
),
{"uid": UserId},
)
row = result.fetchone()
if not row:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "用户不存在")
user = dict(row._mapping)
if user.get("deleted_at") is not None or user.get("status") != 0:
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号已被禁用或删除")
await self._ensure_default_role(session, user["id"])
identity = await self._loadUserIdentity(session, user["id"])
user_info = self._buildUserInfo(user, identity)
tenant_resolution = await self.TenantResolver.ResolveUserContext(
Area=user.get("area"),
TenantCode=user.get("tenant_code"),
TenantName=user.get("tenant_name"),
Source="current_user",
)
user_info["tenant_code"] = tenant_resolution.tenant_code
user_info["tenant_name"] = tenant_resolution.tenant_name or user.get("tenant_name")
user_info["tenant_type"] = tenant_resolution.tenant_type
return user_info
async def _buildLoginResponse(self, user: dict[str, Any], session) -> LoginTokenVO:
"""组装登录响应:查询角色/权限 → 签发 JWT。"""
identity = await self._loadUserIdentity(session, user["id"])
user_info = self._buildUserInfo(user, identity)
tenant_resolution = await self.TenantResolver.ResolveUserContext(
Area=user.get("area"),
TenantCode=user.get("tenant_code"),
TenantName=user.get("tenant_name"),
)
user_info["tenant_code"] = tenant_resolution.tenant_code
user_info["tenant_type"] = tenant_resolution.tenant_type
user_info["tenant_name"] = tenant_resolution.tenant_name or user.get("tenant_name")
tokens = JwtService.generate(
userId=user["id"],
username=user.get("username") or user.get("sub", ""),
nickName=user.get("nick_name") or "",
ouId=user.get("ou_id") or "",
ouName=user.get("ou_name") or "",
roles=identity["roles"],
permissions=identity["permissions"],
area=user.get("area"),
tenantCode=tenant_resolution.tenant_code,
tenantName=tenant_resolution.tenant_name or user.get("tenant_name"),
tenantType=tenant_resolution.tenant_type,
userRole=identity["primary_role"],
)
return LoginTokenVO(
access_token=tokens["access_token"],
token_type="Bearer",
expires_in=tokens["expires_in"],
issued_time=tokens.get("issued_time", ""),
user_info=user_info,
)
async def _ensure_default_role(self, session, user_id: int) -> None:
"""确保用户至少拥有一个默认 common 角色。"""
from sqlalchemy import text
role_count = await session.execute(
text("SELECT COUNT(*) FROM user_role WHERE user_id = :uid"),
{"uid": user_id},
)
if (role_count.scalar_one() or 0) > 0:
return
common_role = await session.execute(
text("SELECT id FROM roles WHERE role_key = 'common' LIMIT 1")
)
common_role_id = common_role.scalar_one_or_none()
if common_role_id is None:
logger.warning("默认角色 common 不存在,无法为用户 %s 自动分配角色", user_id)
return
await session.execute(
text(
"INSERT INTO user_role (user_id, role_id, created_at, updated_at) "
"VALUES (:uid, :rid, :now, :now) "
"ON CONFLICT (user_id, role_id) DO NOTHING"
),
{"uid": user_id, "rid": common_role_id, "now": datetime.now(timezone.utc)},
)
logger.info("已为用户 %s 自动补默认角色 common", user_id)
async def _loadUserIdentity(self, session, user_id: int) -> dict[str, Any]:
"""加载用户角色和权限聚合结果。"""
from sqlalchemy import text
role_rows = await session.execute(
text(
"SELECT r.role_key, COALESCE(r.priority, 0) AS priority "
"FROM user_role ur "
"JOIN roles r ON ur.role_id = r.id "
"WHERE ur.user_id = :uid "
"ORDER BY COALESCE(r.priority, 0) DESC, r.id ASC"
),
{"uid": user_id},
)
roles = [row[0] for row in role_rows.fetchall()]
if not roles:
roles = ["common"]
perm_rows = await session.execute(
text(
"SELECT p.permission_key, rp.grant_type "
"FROM user_role ur "
"JOIN roles r ON ur.role_id = r.id "
"JOIN role_permissions rp ON r.id = rp.role_id "
"JOIN permissions p ON rp.permission_id = p.id "
"WHERE ur.user_id = :uid"
),
{"uid": user_id},
)
grants: set[str] = set()
denies: set[str] = set()
for permission_key, grant_type in perm_rows.fetchall():
if grant_type == "DENY":
denies.add(permission_key)
else:
grants.add(permission_key)
permissions = sorted(grants - denies)
return {
"roles": roles,
"primary_role": roles[0],
"permissions": permissions,
}
@staticmethod
def _buildUserInfo(user: dict[str, Any], identity: dict[str, Any]) -> dict[str, Any]:
"""组装统一用户信息。"""
return {
"user_id": user["id"],
"sub": user.get("sub"),
"username": user.get("username"),
"nick_name": user.get("nick_name"),
"email": user.get("email"),
"phone_number": user.get("phone_number"),
"ou_id": user.get("ou_id"),
"ou_name": user.get("ou_name"),
"is_leader": user.get("is_leader"),
"area": user.get("area"),
"tenant_code": user.get("tenant_code"),
"user_role": identity["primary_role"],
"roles": identity["roles"],
"permissions": identity["permissions"],
"tenant_name": user.get("tenant_name"),
"tenant_type": None,
"dep_name": user.get("dep_name"),
"dep_short_name": user.get("dep_short_name"),
}