406 lines
12 KiB
Python
406 lines
12 KiB
Python
#!/usr/bin/env python3
|
|
"""Migrate legacy users from docauditai into leaudit_platform.
|
|
|
|
Default mode is dry-run. Use --apply to write data.
|
|
|
|
What gets migrated:
|
|
- sso_users
|
|
- user_role
|
|
|
|
What gets reused:
|
|
- target roles already seeded in leaudit_platform
|
|
|
|
Rules:
|
|
- preserve legacy user id when inserting into the new database
|
|
- map missing legacy roles to `common`
|
|
- normalize area with trim and alias mapping
|
|
- never overwrite target area with empty value
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import asyncio
|
|
import os
|
|
from collections import Counter, defaultdict
|
|
from dataclasses import dataclass
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
import asyncpg
|
|
|
|
|
|
ROOT = Path(__file__).resolve().parents[1]
|
|
APP_TOML = ROOT / "app.toml"
|
|
|
|
ALLOWED_ROLE_KEYS = {"provincial_admin", "admin", "common", "super_admin"}
|
|
DEFAULT_ROLE_KEY = "common"
|
|
AREA_ALIASES = {
|
|
"省": "省局",
|
|
"省厅": "省局",
|
|
"省公司": "省局",
|
|
}
|
|
|
|
|
|
@dataclass
|
|
class LegacyUser:
|
|
id: int
|
|
sub: str
|
|
username: str
|
|
nick_name: str
|
|
phone_number: str | None
|
|
email: str | None
|
|
ou_id: str
|
|
ou_name: str
|
|
status: int
|
|
is_leader: bool
|
|
created_at: Any
|
|
updated_at: Any
|
|
deleted_at: Any
|
|
password: str | None
|
|
try_count: int | None
|
|
try_login_time: Any
|
|
area: str | None
|
|
mq_person_uuid: str | None
|
|
mq_account_uuid: str | None
|
|
mq_synced_at: Any
|
|
tenant_name: str | None
|
|
dep_short_name: str | None
|
|
dep_name: str | None
|
|
|
|
|
|
def load_target_dsn() -> str:
|
|
try:
|
|
import tomllib
|
|
except ImportError: # pragma: no cover
|
|
import tomli as tomllib
|
|
|
|
with APP_TOML.open("rb") as fh:
|
|
config = tomllib.load(fh)
|
|
db = config["DB"]
|
|
return (
|
|
f"postgresql://{db['USER']}:{db['PASSWORD']}"
|
|
f"@{db['HOST']}:{db['PORT']}/{db['NAME']}"
|
|
)
|
|
|
|
|
|
def build_legacy_dsn(args: argparse.Namespace) -> str:
|
|
return (
|
|
f"postgresql://{args.legacy_user}:{args.legacy_password}"
|
|
f"@{args.legacy_host}:{args.legacy_port}/{args.legacy_db}"
|
|
)
|
|
|
|
|
|
def normalize_area(value: str | None) -> str | None:
|
|
if value is None:
|
|
return None
|
|
normalized = value.strip()
|
|
if not normalized:
|
|
return None
|
|
return AREA_ALIASES.get(normalized, normalized)
|
|
|
|
|
|
def choose_roles(role_keys: list[str]) -> list[str]:
|
|
cleaned: list[str] = []
|
|
for role_key in role_keys:
|
|
if role_key in ALLOWED_ROLE_KEYS:
|
|
cleaned.append(role_key)
|
|
if not cleaned:
|
|
cleaned = [DEFAULT_ROLE_KEY]
|
|
return sorted(set(cleaned))
|
|
|
|
|
|
async def fetch_legacy_users(conn: asyncpg.Connection) -> dict[int, LegacyUser]:
|
|
rows = await conn.fetch(
|
|
"""
|
|
SELECT id, sub, username, nick_name, phone_number, email,
|
|
ou_id, ou_name, status, is_leader, created_at, updated_at,
|
|
deleted_at, password, try_count, try_login_time, area,
|
|
mq_person_uuid, mq_account_uuid, mq_synced_at,
|
|
tenant_name, dep_short_name, dep_name
|
|
FROM sso_users
|
|
ORDER BY id
|
|
"""
|
|
)
|
|
return {
|
|
row["id"]: LegacyUser(**dict(row))
|
|
for row in rows
|
|
}
|
|
|
|
|
|
async def fetch_legacy_user_roles(conn: asyncpg.Connection) -> dict[int, list[str]]:
|
|
rows = await conn.fetch(
|
|
"""
|
|
SELECT ur.user_id, r.role_key
|
|
FROM user_role ur
|
|
JOIN roles r ON r.id = ur.role_id
|
|
ORDER BY ur.user_id, r.id
|
|
"""
|
|
)
|
|
result: dict[int, list[str]] = defaultdict(list)
|
|
for row in rows:
|
|
result[row["user_id"]].append(row["role_key"])
|
|
return result
|
|
|
|
|
|
async def fetch_target_roles(conn: asyncpg.Connection) -> dict[str, int]:
|
|
rows = await conn.fetch("SELECT id, role_key FROM roles ORDER BY id")
|
|
return {row["role_key"]: row["id"] for row in rows}
|
|
|
|
|
|
async def fetch_target_users(conn: asyncpg.Connection) -> tuple[dict[str, dict[str, Any]], dict[int, dict[str, Any]]]:
|
|
rows = await conn.fetch("SELECT id, sub, area FROM sso_users")
|
|
normalized = [dict(row) for row in rows]
|
|
by_sub = {row["sub"]: row for row in normalized}
|
|
by_id = {row["id"]: row for row in normalized}
|
|
return by_sub, by_id
|
|
|
|
|
|
async def ensure_target_ready(conn: asyncpg.Connection) -> None:
|
|
missing = []
|
|
for table_name in (
|
|
"sso_users",
|
|
"roles",
|
|
"user_role",
|
|
"permissions",
|
|
"role_permissions",
|
|
"sys_routes",
|
|
"role_route",
|
|
):
|
|
exists = await conn.fetchval(
|
|
"""
|
|
SELECT EXISTS (
|
|
SELECT 1
|
|
FROM information_schema.tables
|
|
WHERE table_schema = 'public' AND table_name = $1
|
|
)
|
|
""",
|
|
table_name,
|
|
)
|
|
if not exists:
|
|
missing.append(table_name)
|
|
if missing:
|
|
raise RuntimeError(f"target database is missing required tables: {', '.join(missing)}")
|
|
|
|
|
|
async def upsert_user(conn: asyncpg.Connection, user: LegacyUser, existing_sub_row: dict[str, Any] | None) -> int:
|
|
area = normalize_area(user.area)
|
|
if existing_sub_row:
|
|
await conn.execute(
|
|
"""
|
|
UPDATE sso_users
|
|
SET username = $2,
|
|
nick_name = $3,
|
|
phone_number = $4,
|
|
email = $5,
|
|
ou_id = $6,
|
|
ou_name = $7,
|
|
status = $8,
|
|
is_leader = $9,
|
|
updated_at = $10,
|
|
deleted_at = $11,
|
|
password = $12,
|
|
try_count = $13,
|
|
try_login_time = $14,
|
|
area = COALESCE($15, area),
|
|
mq_person_uuid = $16,
|
|
mq_account_uuid = $17,
|
|
mq_synced_at = $18,
|
|
tenant_name = $19,
|
|
dep_short_name = $20,
|
|
dep_name = $21
|
|
WHERE sub = $1
|
|
""",
|
|
user.sub,
|
|
user.username,
|
|
user.nick_name,
|
|
user.phone_number,
|
|
user.email,
|
|
user.ou_id,
|
|
user.ou_name,
|
|
user.status,
|
|
user.is_leader,
|
|
user.updated_at,
|
|
user.deleted_at,
|
|
user.password,
|
|
user.try_count,
|
|
user.try_login_time,
|
|
area,
|
|
user.mq_person_uuid,
|
|
user.mq_account_uuid,
|
|
user.mq_synced_at,
|
|
user.tenant_name,
|
|
user.dep_short_name,
|
|
user.dep_name,
|
|
)
|
|
return int(existing_sub_row["id"])
|
|
|
|
await conn.execute(
|
|
"""
|
|
INSERT INTO sso_users (
|
|
id, sub, username, nick_name, phone_number, email, ou_id, ou_name,
|
|
status, is_leader, created_at, updated_at, deleted_at, password,
|
|
try_count, try_login_time, area, mq_person_uuid, mq_account_uuid,
|
|
mq_synced_at, tenant_name, dep_short_name, dep_name
|
|
) VALUES (
|
|
$1, $2, $3, $4, $5, $6, $7, $8,
|
|
$9, $10, $11, $12, $13, $14,
|
|
$15, $16, $17, $18, $19,
|
|
$20, $21, $22, $23
|
|
)
|
|
""",
|
|
user.id,
|
|
user.sub,
|
|
user.username,
|
|
user.nick_name,
|
|
user.phone_number,
|
|
user.email,
|
|
user.ou_id,
|
|
user.ou_name,
|
|
user.status,
|
|
user.is_leader,
|
|
user.created_at,
|
|
user.updated_at,
|
|
user.deleted_at,
|
|
user.password,
|
|
user.try_count,
|
|
user.try_login_time,
|
|
area,
|
|
user.mq_person_uuid,
|
|
user.mq_account_uuid,
|
|
user.mq_synced_at,
|
|
user.tenant_name,
|
|
user.dep_short_name,
|
|
user.dep_name,
|
|
)
|
|
return user.id
|
|
|
|
|
|
async def assign_roles(
|
|
conn: asyncpg.Connection,
|
|
target_user_id: int,
|
|
role_keys: list[str],
|
|
role_key_to_id: dict[str, int],
|
|
) -> None:
|
|
for role_key in role_keys:
|
|
role_id = role_key_to_id.get(role_key)
|
|
if role_id is None:
|
|
raise RuntimeError(f"target role not found: {role_key}")
|
|
await conn.execute(
|
|
"""
|
|
INSERT INTO user_role (user_id, role_id, created_at, updated_at)
|
|
VALUES ($1, $2, NOW(), NOW())
|
|
ON CONFLICT (user_id, role_id) DO NOTHING
|
|
""",
|
|
target_user_id,
|
|
role_id,
|
|
)
|
|
|
|
|
|
async def sync_user_sequence(conn: asyncpg.Connection) -> None:
|
|
await conn.execute(
|
|
"""
|
|
SELECT setval(
|
|
pg_get_serial_sequence('sso_users', 'id'),
|
|
COALESCE((SELECT MAX(id) FROM sso_users), 1),
|
|
true
|
|
)
|
|
"""
|
|
)
|
|
|
|
|
|
async def main_async(args: argparse.Namespace) -> int:
|
|
legacy_conn = await asyncpg.connect(build_legacy_dsn(args))
|
|
target_conn = await asyncpg.connect(load_target_dsn())
|
|
try:
|
|
await ensure_target_ready(target_conn)
|
|
|
|
legacy_users = await fetch_legacy_users(legacy_conn)
|
|
legacy_user_roles = await fetch_legacy_user_roles(legacy_conn)
|
|
target_role_map = await fetch_target_roles(target_conn)
|
|
target_by_sub, target_by_id = await fetch_target_users(target_conn)
|
|
|
|
missing_target_roles = sorted(set(ALLOWED_ROLE_KEYS) - set(target_role_map))
|
|
if missing_target_roles:
|
|
raise RuntimeError(
|
|
"target database is missing seeded roles: " + ", ".join(missing_target_roles)
|
|
)
|
|
|
|
summary = Counter()
|
|
role_summary = Counter()
|
|
id_conflicts: list[tuple[int, str, int, str]] = []
|
|
|
|
async def process_all() -> None:
|
|
for user_id, user in legacy_users.items():
|
|
source_roles = legacy_user_roles.get(user_id, [])
|
|
desired_roles = choose_roles(source_roles)
|
|
role_summary.update(desired_roles)
|
|
if not source_roles:
|
|
summary["default_common_role"] += 1
|
|
|
|
existing_sub_row = target_by_sub.get(user.sub)
|
|
existing_id_row = target_by_id.get(user.id)
|
|
if existing_id_row and (existing_sub_row is None or existing_sub_row["sub"] != user.sub):
|
|
id_conflicts.append((user.id, user.sub, int(existing_id_row["id"]), existing_id_row["sub"]))
|
|
summary["id_conflict"] += 1
|
|
continue
|
|
|
|
if existing_sub_row:
|
|
summary["update_user"] += 1
|
|
else:
|
|
summary["insert_user"] += 1
|
|
|
|
if args.apply:
|
|
target_user_id = await upsert_user(target_conn, user, existing_sub_row)
|
|
await assign_roles(target_conn, target_user_id, desired_roles, target_role_map)
|
|
target_by_sub[user.sub] = {"id": target_user_id, "sub": user.sub, "area": normalize_area(user.area)}
|
|
target_by_id[user.id] = {"id": target_user_id, "sub": user.sub, "area": normalize_area(user.area)}
|
|
|
|
if args.apply:
|
|
async with target_conn.transaction():
|
|
await process_all()
|
|
await sync_user_sequence(target_conn)
|
|
else:
|
|
await process_all()
|
|
|
|
print("=== Migration Summary ===")
|
|
print(f"mode: {'APPLY' if args.apply else 'DRY_RUN'}")
|
|
print(f"legacy_users_total: {len(legacy_users)}")
|
|
for key in sorted(summary):
|
|
print(f"{key}: {summary[key]}")
|
|
print("role_assignment_plan:")
|
|
for role_key, count in sorted(role_summary.items()):
|
|
print(f" {role_key}: {count}")
|
|
print(f"id_conflicts: {len(id_conflicts)}")
|
|
if id_conflicts:
|
|
print("sample_id_conflicts:")
|
|
for conflict in id_conflicts[:20]:
|
|
print(f" legacy_id={conflict[0]} legacy_sub={conflict[1]} target_id={conflict[2]} target_sub={conflict[3]}")
|
|
|
|
if not args.apply:
|
|
print("dry-run complete; rerun with --apply to write data")
|
|
return 0
|
|
finally:
|
|
await legacy_conn.close()
|
|
await target_conn.close()
|
|
|
|
|
|
def parse_args() -> argparse.Namespace:
|
|
parser = argparse.ArgumentParser(description="Migrate legacy users into leaudit_platform")
|
|
parser.add_argument("--legacy-host", default=os.getenv("LEGACY_DB_HOST", "172.16.0.81"))
|
|
parser.add_argument("--legacy-port", type=int, default=int(os.getenv("LEGACY_DB_PORT", "54302")))
|
|
parser.add_argument("--legacy-db", default=os.getenv("LEGACY_DB_NAME", "docauditai"))
|
|
parser.add_argument("--legacy-user", default=os.getenv("LEGACY_DB_USER", "docauditai_admin"))
|
|
parser.add_argument("--legacy-password", default=os.getenv("LEGACY_DB_PASSWORD", "zhfw*123*"))
|
|
parser.add_argument("--apply", action="store_true", help="write data into target database")
|
|
return parser.parse_args()
|
|
|
|
|
|
def main() -> int:
|
|
args = parse_args()
|
|
return asyncio.run(main_async(args))
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|