Files
leaudit-platform-backend/scripts/migrate_legacy_users.py
T

406 lines
12 KiB
Python

#!/usr/bin/env python3
"""Migrate legacy users from docauditai into leaudit_platform.
Default mode is dry-run. Use --apply to write data.
What gets migrated:
- sso_users
- user_role
What gets reused:
- target roles already seeded in leaudit_platform
Rules:
- preserve legacy user id when inserting into the new database
- map missing legacy roles to `common`
- normalize area with trim and alias mapping
- never overwrite target area with empty value
"""
from __future__ import annotations
import argparse
import asyncio
import os
from collections import Counter, defaultdict
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import asyncpg
ROOT = Path(__file__).resolve().parents[1]
APP_TOML = ROOT / "app.toml"
ALLOWED_ROLE_KEYS = {"provincial_admin", "admin", "common", "super_admin"}
DEFAULT_ROLE_KEY = "common"
AREA_ALIASES = {
"": "省局",
"省厅": "省局",
"省公司": "省局",
}
@dataclass
class LegacyUser:
id: int
sub: str
username: str
nick_name: str
phone_number: str | None
email: str | None
ou_id: str
ou_name: str
status: int
is_leader: bool
created_at: Any
updated_at: Any
deleted_at: Any
password: str | None
try_count: int | None
try_login_time: Any
area: str | None
mq_person_uuid: str | None
mq_account_uuid: str | None
mq_synced_at: Any
tenant_name: str | None
dep_short_name: str | None
dep_name: str | None
def load_target_dsn() -> str:
try:
import tomllib
except ImportError: # pragma: no cover
import tomli as tomllib
with APP_TOML.open("rb") as fh:
config = tomllib.load(fh)
db = config["DB"]
return (
f"postgresql://{db['USER']}:{db['PASSWORD']}"
f"@{db['HOST']}:{db['PORT']}/{db['NAME']}"
)
def build_legacy_dsn(args: argparse.Namespace) -> str:
return (
f"postgresql://{args.legacy_user}:{args.legacy_password}"
f"@{args.legacy_host}:{args.legacy_port}/{args.legacy_db}"
)
def normalize_area(value: str | None) -> str | None:
if value is None:
return None
normalized = value.strip()
if not normalized:
return None
return AREA_ALIASES.get(normalized, normalized)
def choose_roles(role_keys: list[str]) -> list[str]:
cleaned: list[str] = []
for role_key in role_keys:
if role_key in ALLOWED_ROLE_KEYS:
cleaned.append(role_key)
if not cleaned:
cleaned = [DEFAULT_ROLE_KEY]
return sorted(set(cleaned))
async def fetch_legacy_users(conn: asyncpg.Connection) -> dict[int, LegacyUser]:
rows = await conn.fetch(
"""
SELECT id, sub, username, nick_name, phone_number, email,
ou_id, ou_name, status, is_leader, created_at, updated_at,
deleted_at, password, try_count, try_login_time, area,
mq_person_uuid, mq_account_uuid, mq_synced_at,
tenant_name, dep_short_name, dep_name
FROM sso_users
ORDER BY id
"""
)
return {
row["id"]: LegacyUser(**dict(row))
for row in rows
}
async def fetch_legacy_user_roles(conn: asyncpg.Connection) -> dict[int, list[str]]:
rows = await conn.fetch(
"""
SELECT ur.user_id, r.role_key
FROM user_role ur
JOIN roles r ON r.id = ur.role_id
ORDER BY ur.user_id, r.id
"""
)
result: dict[int, list[str]] = defaultdict(list)
for row in rows:
result[row["user_id"]].append(row["role_key"])
return result
async def fetch_target_roles(conn: asyncpg.Connection) -> dict[str, int]:
rows = await conn.fetch("SELECT id, role_key FROM roles ORDER BY id")
return {row["role_key"]: row["id"] for row in rows}
async def fetch_target_users(conn: asyncpg.Connection) -> tuple[dict[str, dict[str, Any]], dict[int, dict[str, Any]]]:
rows = await conn.fetch("SELECT id, sub, area FROM sso_users")
normalized = [dict(row) for row in rows]
by_sub = {row["sub"]: row for row in normalized}
by_id = {row["id"]: row for row in normalized}
return by_sub, by_id
async def ensure_target_ready(conn: asyncpg.Connection) -> None:
missing = []
for table_name in (
"sso_users",
"roles",
"user_role",
"permissions",
"role_permissions",
"sys_routes",
"role_route",
):
exists = await conn.fetchval(
"""
SELECT EXISTS (
SELECT 1
FROM information_schema.tables
WHERE table_schema = 'public' AND table_name = $1
)
""",
table_name,
)
if not exists:
missing.append(table_name)
if missing:
raise RuntimeError(f"target database is missing required tables: {', '.join(missing)}")
async def upsert_user(conn: asyncpg.Connection, user: LegacyUser, existing_sub_row: dict[str, Any] | None) -> int:
area = normalize_area(user.area)
if existing_sub_row:
await conn.execute(
"""
UPDATE sso_users
SET username = $2,
nick_name = $3,
phone_number = $4,
email = $5,
ou_id = $6,
ou_name = $7,
status = $8,
is_leader = $9,
updated_at = $10,
deleted_at = $11,
password = $12,
try_count = $13,
try_login_time = $14,
area = COALESCE($15, area),
mq_person_uuid = $16,
mq_account_uuid = $17,
mq_synced_at = $18,
tenant_name = $19,
dep_short_name = $20,
dep_name = $21
WHERE sub = $1
""",
user.sub,
user.username,
user.nick_name,
user.phone_number,
user.email,
user.ou_id,
user.ou_name,
user.status,
user.is_leader,
user.updated_at,
user.deleted_at,
user.password,
user.try_count,
user.try_login_time,
area,
user.mq_person_uuid,
user.mq_account_uuid,
user.mq_synced_at,
user.tenant_name,
user.dep_short_name,
user.dep_name,
)
return int(existing_sub_row["id"])
await conn.execute(
"""
INSERT INTO sso_users (
id, sub, username, nick_name, phone_number, email, ou_id, ou_name,
status, is_leader, created_at, updated_at, deleted_at, password,
try_count, try_login_time, area, mq_person_uuid, mq_account_uuid,
mq_synced_at, tenant_name, dep_short_name, dep_name
) VALUES (
$1, $2, $3, $4, $5, $6, $7, $8,
$9, $10, $11, $12, $13, $14,
$15, $16, $17, $18, $19,
$20, $21, $22, $23
)
""",
user.id,
user.sub,
user.username,
user.nick_name,
user.phone_number,
user.email,
user.ou_id,
user.ou_name,
user.status,
user.is_leader,
user.created_at,
user.updated_at,
user.deleted_at,
user.password,
user.try_count,
user.try_login_time,
area,
user.mq_person_uuid,
user.mq_account_uuid,
user.mq_synced_at,
user.tenant_name,
user.dep_short_name,
user.dep_name,
)
return user.id
async def assign_roles(
conn: asyncpg.Connection,
target_user_id: int,
role_keys: list[str],
role_key_to_id: dict[str, int],
) -> None:
for role_key in role_keys:
role_id = role_key_to_id.get(role_key)
if role_id is None:
raise RuntimeError(f"target role not found: {role_key}")
await conn.execute(
"""
INSERT INTO user_role (user_id, role_id, created_at, updated_at)
VALUES ($1, $2, NOW(), NOW())
ON CONFLICT (user_id, role_id) DO NOTHING
""",
target_user_id,
role_id,
)
async def sync_user_sequence(conn: asyncpg.Connection) -> None:
await conn.execute(
"""
SELECT setval(
pg_get_serial_sequence('sso_users', 'id'),
COALESCE((SELECT MAX(id) FROM sso_users), 1),
true
)
"""
)
async def main_async(args: argparse.Namespace) -> int:
legacy_conn = await asyncpg.connect(build_legacy_dsn(args))
target_conn = await asyncpg.connect(load_target_dsn())
try:
await ensure_target_ready(target_conn)
legacy_users = await fetch_legacy_users(legacy_conn)
legacy_user_roles = await fetch_legacy_user_roles(legacy_conn)
target_role_map = await fetch_target_roles(target_conn)
target_by_sub, target_by_id = await fetch_target_users(target_conn)
missing_target_roles = sorted(set(ALLOWED_ROLE_KEYS) - set(target_role_map))
if missing_target_roles:
raise RuntimeError(
"target database is missing seeded roles: " + ", ".join(missing_target_roles)
)
summary = Counter()
role_summary = Counter()
id_conflicts: list[tuple[int, str, int, str]] = []
async def process_all() -> None:
for user_id, user in legacy_users.items():
source_roles = legacy_user_roles.get(user_id, [])
desired_roles = choose_roles(source_roles)
role_summary.update(desired_roles)
if not source_roles:
summary["default_common_role"] += 1
existing_sub_row = target_by_sub.get(user.sub)
existing_id_row = target_by_id.get(user.id)
if existing_id_row and (existing_sub_row is None or existing_sub_row["sub"] != user.sub):
id_conflicts.append((user.id, user.sub, int(existing_id_row["id"]), existing_id_row["sub"]))
summary["id_conflict"] += 1
continue
if existing_sub_row:
summary["update_user"] += 1
else:
summary["insert_user"] += 1
if args.apply:
target_user_id = await upsert_user(target_conn, user, existing_sub_row)
await assign_roles(target_conn, target_user_id, desired_roles, target_role_map)
target_by_sub[user.sub] = {"id": target_user_id, "sub": user.sub, "area": normalize_area(user.area)}
target_by_id[user.id] = {"id": target_user_id, "sub": user.sub, "area": normalize_area(user.area)}
if args.apply:
async with target_conn.transaction():
await process_all()
await sync_user_sequence(target_conn)
else:
await process_all()
print("=== Migration Summary ===")
print(f"mode: {'APPLY' if args.apply else 'DRY_RUN'}")
print(f"legacy_users_total: {len(legacy_users)}")
for key in sorted(summary):
print(f"{key}: {summary[key]}")
print("role_assignment_plan:")
for role_key, count in sorted(role_summary.items()):
print(f" {role_key}: {count}")
print(f"id_conflicts: {len(id_conflicts)}")
if id_conflicts:
print("sample_id_conflicts:")
for conflict in id_conflicts[:20]:
print(f" legacy_id={conflict[0]} legacy_sub={conflict[1]} target_id={conflict[2]} target_sub={conflict[3]}")
if not args.apply:
print("dry-run complete; rerun with --apply to write data")
return 0
finally:
await legacy_conn.close()
await target_conn.close()
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(description="Migrate legacy users into leaudit_platform")
parser.add_argument("--legacy-host", default=os.getenv("LEGACY_DB_HOST", "172.16.0.81"))
parser.add_argument("--legacy-port", type=int, default=int(os.getenv("LEGACY_DB_PORT", "54302")))
parser.add_argument("--legacy-db", default=os.getenv("LEGACY_DB_NAME", "docauditai"))
parser.add_argument("--legacy-user", default=os.getenv("LEGACY_DB_USER", "docauditai_admin"))
parser.add_argument("--legacy-password", default=os.getenv("LEGACY_DB_PASSWORD", "zhfw*123*"))
parser.add_argument("--apply", action="store_true", help="write data into target database")
return parser.parse_args()
def main() -> int:
args = parse_args()
return asyncio.run(main_async(args))
if __name__ == "__main__":
raise SystemExit(main())