#!/usr/bin/env python3 """Migrate legacy users from docauditai into leaudit_platform. Default mode is dry-run. Use --apply to write data. What gets migrated: - sso_users - user_role What gets reused: - target roles already seeded in leaudit_platform Rules: - preserve legacy user id when inserting into the new database - map missing legacy roles to `common` - normalize area with trim and alias mapping - never overwrite target area with empty value """ from __future__ import annotations import argparse import asyncio import os from collections import Counter, defaultdict from dataclasses import dataclass from pathlib import Path from typing import Any import asyncpg ROOT = Path(__file__).resolve().parents[1] APP_TOML = ROOT / "app.toml" ALLOWED_ROLE_KEYS = {"provincial_admin", "admin", "common", "super_admin"} DEFAULT_ROLE_KEY = "common" AREA_ALIASES = { "省": "省局", "省厅": "省局", "省公司": "省局", } @dataclass class LegacyUser: id: int sub: str username: str nick_name: str phone_number: str | None email: str | None ou_id: str ou_name: str status: int is_leader: bool created_at: Any updated_at: Any deleted_at: Any password: str | None try_count: int | None try_login_time: Any area: str | None mq_person_uuid: str | None mq_account_uuid: str | None mq_synced_at: Any tenant_name: str | None dep_short_name: str | None dep_name: str | None def load_target_dsn() -> str: try: import tomllib except ImportError: # pragma: no cover import tomli as tomllib with APP_TOML.open("rb") as fh: config = tomllib.load(fh) db = config["DB"] return ( f"postgresql://{db['USER']}:{db['PASSWORD']}" f"@{db['HOST']}:{db['PORT']}/{db['NAME']}" ) def build_legacy_dsn(args: argparse.Namespace) -> str: return ( f"postgresql://{args.legacy_user}:{args.legacy_password}" f"@{args.legacy_host}:{args.legacy_port}/{args.legacy_db}" ) def normalize_area(value: str | None) -> str | None: if value is None: return None normalized = value.strip() if not normalized: return None return AREA_ALIASES.get(normalized, normalized) def choose_roles(role_keys: list[str]) -> list[str]: cleaned: list[str] = [] for role_key in role_keys: if role_key in ALLOWED_ROLE_KEYS: cleaned.append(role_key) if not cleaned: cleaned = [DEFAULT_ROLE_KEY] return sorted(set(cleaned)) async def fetch_legacy_users(conn: asyncpg.Connection) -> dict[int, LegacyUser]: rows = await conn.fetch( """ SELECT id, sub, username, nick_name, phone_number, email, ou_id, ou_name, status, is_leader, created_at, updated_at, deleted_at, password, try_count, try_login_time, area, mq_person_uuid, mq_account_uuid, mq_synced_at, tenant_name, dep_short_name, dep_name FROM sso_users ORDER BY id """ ) return { row["id"]: LegacyUser(**dict(row)) for row in rows } async def fetch_legacy_user_roles(conn: asyncpg.Connection) -> dict[int, list[str]]: rows = await conn.fetch( """ SELECT ur.user_id, r.role_key FROM user_role ur JOIN roles r ON r.id = ur.role_id ORDER BY ur.user_id, r.id """ ) result: dict[int, list[str]] = defaultdict(list) for row in rows: result[row["user_id"]].append(row["role_key"]) return result async def fetch_target_roles(conn: asyncpg.Connection) -> dict[str, int]: rows = await conn.fetch("SELECT id, role_key FROM roles ORDER BY id") return {row["role_key"]: row["id"] for row in rows} async def fetch_target_users(conn: asyncpg.Connection) -> tuple[dict[str, dict[str, Any]], dict[int, dict[str, Any]]]: rows = await conn.fetch("SELECT id, sub, area FROM sso_users") normalized = [dict(row) for row in rows] by_sub = {row["sub"]: row for row in normalized} by_id = {row["id"]: row for row in normalized} return by_sub, by_id async def ensure_target_ready(conn: asyncpg.Connection) -> None: missing = [] for table_name in ( "sso_users", "roles", "user_role", "permissions", "role_permissions", "sys_routes", "role_route", ): exists = await conn.fetchval( """ SELECT EXISTS ( SELECT 1 FROM information_schema.tables WHERE table_schema = 'public' AND table_name = $1 ) """, table_name, ) if not exists: missing.append(table_name) if missing: raise RuntimeError(f"target database is missing required tables: {', '.join(missing)}") async def upsert_user(conn: asyncpg.Connection, user: LegacyUser, existing_sub_row: dict[str, Any] | None) -> int: area = normalize_area(user.area) if existing_sub_row: await conn.execute( """ UPDATE sso_users SET username = $2, nick_name = $3, phone_number = $4, email = $5, ou_id = $6, ou_name = $7, status = $8, is_leader = $9, updated_at = $10, deleted_at = $11, password = $12, try_count = $13, try_login_time = $14, area = COALESCE($15, area), mq_person_uuid = $16, mq_account_uuid = $17, mq_synced_at = $18, tenant_name = $19, dep_short_name = $20, dep_name = $21 WHERE sub = $1 """, user.sub, user.username, user.nick_name, user.phone_number, user.email, user.ou_id, user.ou_name, user.status, user.is_leader, user.updated_at, user.deleted_at, user.password, user.try_count, user.try_login_time, area, user.mq_person_uuid, user.mq_account_uuid, user.mq_synced_at, user.tenant_name, user.dep_short_name, user.dep_name, ) return int(existing_sub_row["id"]) await conn.execute( """ INSERT INTO sso_users ( id, sub, username, nick_name, phone_number, email, ou_id, ou_name, status, is_leader, created_at, updated_at, deleted_at, password, try_count, try_login_time, area, mq_person_uuid, mq_account_uuid, mq_synced_at, tenant_name, dep_short_name, dep_name ) VALUES ( $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22, $23 ) """, user.id, user.sub, user.username, user.nick_name, user.phone_number, user.email, user.ou_id, user.ou_name, user.status, user.is_leader, user.created_at, user.updated_at, user.deleted_at, user.password, user.try_count, user.try_login_time, area, user.mq_person_uuid, user.mq_account_uuid, user.mq_synced_at, user.tenant_name, user.dep_short_name, user.dep_name, ) return user.id async def assign_roles( conn: asyncpg.Connection, target_user_id: int, role_keys: list[str], role_key_to_id: dict[str, int], ) -> None: for role_key in role_keys: role_id = role_key_to_id.get(role_key) if role_id is None: raise RuntimeError(f"target role not found: {role_key}") await conn.execute( """ INSERT INTO user_role (user_id, role_id, created_at, updated_at) VALUES ($1, $2, NOW(), NOW()) ON CONFLICT (user_id, role_id) DO NOTHING """, target_user_id, role_id, ) async def sync_user_sequence(conn: asyncpg.Connection) -> None: await conn.execute( """ SELECT setval( pg_get_serial_sequence('sso_users', 'id'), COALESCE((SELECT MAX(id) FROM sso_users), 1), true ) """ ) async def main_async(args: argparse.Namespace) -> int: legacy_conn = await asyncpg.connect(build_legacy_dsn(args)) target_conn = await asyncpg.connect(load_target_dsn()) try: await ensure_target_ready(target_conn) legacy_users = await fetch_legacy_users(legacy_conn) legacy_user_roles = await fetch_legacy_user_roles(legacy_conn) target_role_map = await fetch_target_roles(target_conn) target_by_sub, target_by_id = await fetch_target_users(target_conn) missing_target_roles = sorted(set(ALLOWED_ROLE_KEYS) - set(target_role_map)) if missing_target_roles: raise RuntimeError( "target database is missing seeded roles: " + ", ".join(missing_target_roles) ) summary = Counter() role_summary = Counter() id_conflicts: list[tuple[int, str, int, str]] = [] async def process_all() -> None: for user_id, user in legacy_users.items(): source_roles = legacy_user_roles.get(user_id, []) desired_roles = choose_roles(source_roles) role_summary.update(desired_roles) if not source_roles: summary["default_common_role"] += 1 existing_sub_row = target_by_sub.get(user.sub) existing_id_row = target_by_id.get(user.id) if existing_id_row and (existing_sub_row is None or existing_sub_row["sub"] != user.sub): id_conflicts.append((user.id, user.sub, int(existing_id_row["id"]), existing_id_row["sub"])) summary["id_conflict"] += 1 continue if existing_sub_row: summary["update_user"] += 1 else: summary["insert_user"] += 1 if args.apply: target_user_id = await upsert_user(target_conn, user, existing_sub_row) await assign_roles(target_conn, target_user_id, desired_roles, target_role_map) target_by_sub[user.sub] = {"id": target_user_id, "sub": user.sub, "area": normalize_area(user.area)} target_by_id[user.id] = {"id": target_user_id, "sub": user.sub, "area": normalize_area(user.area)} if args.apply: async with target_conn.transaction(): await process_all() await sync_user_sequence(target_conn) else: await process_all() print("=== Migration Summary ===") print(f"mode: {'APPLY' if args.apply else 'DRY_RUN'}") print(f"legacy_users_total: {len(legacy_users)}") for key in sorted(summary): print(f"{key}: {summary[key]}") print("role_assignment_plan:") for role_key, count in sorted(role_summary.items()): print(f" {role_key}: {count}") print(f"id_conflicts: {len(id_conflicts)}") if id_conflicts: print("sample_id_conflicts:") for conflict in id_conflicts[:20]: print(f" legacy_id={conflict[0]} legacy_sub={conflict[1]} target_id={conflict[2]} target_sub={conflict[3]}") if not args.apply: print("dry-run complete; rerun with --apply to write data") return 0 finally: await legacy_conn.close() await target_conn.close() def parse_args() -> argparse.Namespace: parser = argparse.ArgumentParser(description="Migrate legacy users into leaudit_platform") parser.add_argument("--legacy-host", default=os.getenv("LEGACY_DB_HOST", "172.16.0.81")) parser.add_argument("--legacy-port", type=int, default=int(os.getenv("LEGACY_DB_PORT", "54302"))) parser.add_argument("--legacy-db", default=os.getenv("LEGACY_DB_NAME", "docauditai")) parser.add_argument("--legacy-user", default=os.getenv("LEGACY_DB_USER", "docauditai_admin")) parser.add_argument("--legacy-password", default=os.getenv("LEGACY_DB_PASSWORD", "zhfw*123*")) parser.add_argument("--apply", action="store_true", help="write data into target database") return parser.parse_args() def main() -> int: args = parse_args() return asyncio.run(main_async(args)) if __name__ == "__main__": raise SystemExit(main())