Files
leaudit-platform-backend/scripts/import_oss_yaml_rules.py
T

592 lines
22 KiB
Python

"""Import local rule YAML files into OSS and rule version tables.
This script is intentionally conservative:
- uploads every local rules.yaml to the canonical OSS key;
- upserts matching rule version rows by tenant/rule_type/version_no;
- switches each rule set to the highest local version for that rule_type;
- does not delete old OSS objects or historical DB versions.
"""
from __future__ import annotations
import argparse
import asyncio
import hashlib
import re
from datetime import datetime
from dataclasses import dataclass
from pathlib import Path
from typing import Any
import yaml
from sqlalchemy import text
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from fastapi_common.fastapi_common_storage.oss_path_utils import OssPathUtils
from fastapi_modules.fastapi_leaudit.leaudit_bridge.ruleValidator import RuleValidator
from fastapi_modules.fastapi_leaudit.services.impl.ossServiceImpl import OssServiceImpl
@dataclass(frozen=True)
class LocalRule:
rule_type: str
version_no: str
rule_name: str
description: str | None
path: Path
yaml_text: str
sha256: str
size: int
def _version_sort_key(value: str) -> tuple[int, tuple[int | str, ...]]:
raw = str(value or "").strip()
normalized = raw[1:] if raw.lower().startswith("v") else raw
parts: list[int | str] = []
for part in re.split(r"([0-9]+)", normalized):
if not part:
continue
parts.append(int(part) if part.isdigit() else part)
return (0 if raw.lower().startswith("v") else 1, tuple(parts))
def _domain_type(rule_type: str) -> str:
if rule_type.startswith("contract."):
return "contract"
if rule_type.startswith("govdoc."):
return "govdoc"
if rule_type.startswith("行政卷宗."):
return "case_file"
return "custom"
def _legacy_region_for_tenant(tenant_code: str) -> str:
normalized = str(tenant_code or "").strip().upper()
return normalized or "PUBLIC"
def load_local_rules(root: Path) -> list[LocalRule]:
validator = RuleValidator()
items: list[LocalRule] = []
failures: list[str] = []
for path in sorted(root.glob("*/*/rules.yaml")):
yaml_text = path.read_text(encoding="utf-8")
result = validator.ValidateYaml(yaml_text)
if not result.valid:
failures.append(f"{path}: {'; '.join(result.errors or [])}")
continue
data = yaml.safe_load(yaml_text)
metadata = data.get("metadata") or {}
rule_type = str(metadata.get("type_id") or "").strip()
version_no = str(metadata.get("version") or "").strip()
if not rule_type or not version_no:
failures.append(f"{path}: metadata.type_id/version is required")
continue
if version_no != path.parent.name:
failures.append(f"{path}: metadata.version={version_no!r} does not match directory {path.parent.name!r}")
continue
items.append(
LocalRule(
rule_type=rule_type,
version_no=version_no,
rule_name=str(metadata.get("name") or rule_type),
description=metadata.get("description"),
path=path,
yaml_text=yaml_text,
sha256=hashlib.sha256(yaml_text.encode("utf-8")).hexdigest(),
size=len(yaml_text.encode("utf-8")),
)
)
if failures:
raise RuntimeError("Local YAML validation failed:\n" + "\n".join(failures))
return items
async def _load_rule_sets(session) -> list[dict[str, Any]]:
rows = (
await session.execute(
text(
"""
SELECT
rs.id,
rs.rule_type,
rs.rule_name,
rs.current_version_id,
COALESCE(NULLIF(BTRIM(rs.tenant_code), ''), 'PUBLIC') AS tenant_code,
COALESCE(NULLIF(BTRIM(rs.scope_type), ''), 'PUBLIC') AS scope_type
FROM leaudit_rule_sets rs
WHERE rs.deleted_at IS NULL
ORDER BY rs.tenant_code ASC, rs.rule_type ASC, rs.id ASC
"""
)
)
).mappings().all()
return [dict(row) for row in rows]
async def _tenant_name(session, tenant_code: str) -> str | None:
row = (
await session.execute(
text(
"""
SELECT tenant_name
FROM sys_tenants
WHERE tenant_code = :tenant_code
AND deleted_at IS NULL
LIMIT 1
"""
),
{"tenant_code": tenant_code},
)
).mappings().first()
return str(row["tenant_name"]) if row else None
async def _ensure_rule_set(session, local: LocalRule, tenant_code: str) -> dict[str, Any]:
row = (
await session.execute(
text(
"""
SELECT id, rule_type, rule_name, current_version_id, tenant_code, scope_type
FROM leaudit_rule_sets
WHERE rule_type = :rule_type
AND tenant_code = :tenant_code
AND deleted_at IS NULL
ORDER BY id ASC
LIMIT 1
"""
),
{"rule_type": local.rule_type, "tenant_code": tenant_code},
)
).mappings().first()
if row:
await session.execute(
text(
"""
UPDATE leaudit_rule_sets
SET rule_name = :rule_name,
domain_type = :domain_type,
description = :description,
updated_at = NOW()
WHERE id = :id
"""
),
{
"id": int(row["id"]),
"rule_name": local.rule_name,
"domain_type": _domain_type(local.rule_type),
"description": local.description,
},
)
return dict(row)
tenant_name = await _tenant_name(session, tenant_code)
scope_type = "PUBLIC" if tenant_code == "PUBLIC" else "TENANT"
created = (
await session.execute(
text(
"""
INSERT INTO leaudit_rule_sets (
rule_type, rule_name, domain_type, description, entry_module,
current_version_id, status, is_builtin, owner_user_id,
created_at, updated_at, deleted_at, region,
tenant_code, scope_type, source_rule_set_id, tenant_name_snapshot
) VALUES (
:rule_type, :rule_name, :domain_type, :description, NULL,
NULL, 'draft', FALSE, NULL,
NOW(), NOW(), NULL, :region,
:tenant_code, :scope_type, NULL, :tenant_name
)
RETURNING id, rule_type, rule_name, current_version_id, tenant_code, scope_type
"""
),
{
"rule_type": local.rule_type,
"rule_name": local.rule_name,
"domain_type": _domain_type(local.rule_type),
"description": local.description,
"region": _legacy_region_for_tenant(tenant_code),
"tenant_code": tenant_code,
"scope_type": scope_type,
"tenant_name": tenant_name,
},
)
).mappings().first()
return dict(created)
async def _upsert_version(session, local: LocalRule, rule_set_id: int, tenant_code: str, oss_url: str) -> tuple[int, str]:
existing = (
await session.execute(
text(
"""
SELECT id, file_sha256, status, version_seq
FROM leaudit_rule_versions
WHERE rule_set_id = :rule_set_id
AND version_no = :version_no
AND deleted_at IS NULL
ORDER BY id ASC
LIMIT 1
"""
),
{"rule_set_id": rule_set_id, "version_no": local.version_no},
)
).mappings().first()
if existing:
await session.execute(
text(
"""
UPDATE leaudit_rule_versions
SET source_type = 'oss_yaml',
dsl_format = 'yaml',
oss_url = :oss_url,
file_sha256 = :file_sha256,
file_size = :file_size,
metadata_type_id = :metadata_type_id,
metadata_name = :metadata_name,
metadata_version = :metadata_version,
tenant_code_snapshot = :tenant_code,
scope_type_snapshot = :scope_type,
updated_at = NOW()
WHERE id = :id
"""
),
{
"id": int(existing["id"]),
"oss_url": oss_url,
"file_sha256": local.sha256,
"file_size": local.size,
"metadata_type_id": local.rule_type,
"metadata_name": local.rule_name,
"metadata_version": local.version_no,
"tenant_code": tenant_code,
"scope_type": "PUBLIC" if tenant_code == "PUBLIC" else "TENANT",
},
)
action = "updated" if existing["file_sha256"] != local.sha256 else "refreshed"
return int(existing["id"]), action
seq_row = (
await session.execute(
text(
"""
SELECT COALESCE(MAX(version_seq), 0) + 1 AS next_seq
FROM leaudit_rule_versions
WHERE rule_set_id = :rule_set_id
"""
),
{"rule_set_id": rule_set_id},
)
).mappings().first()
created = (
await session.execute(
text(
"""
INSERT INTO leaudit_rule_versions (
rule_set_id, version_no, version_seq, status, source_type, dsl_format,
oss_url, file_sha256, file_size, local_cache_path,
metadata_type_id, metadata_name, metadata_version, change_note,
editor_user_id, publisher_user_id, published_at,
created_at, updated_at, deleted_at,
tenant_code_snapshot, scope_type_snapshot, source_version_id
) VALUES (
:rule_set_id, :version_no, :version_seq, 'draft', 'oss_yaml', 'yaml',
:oss_url, :file_sha256, :file_size, NULL,
:metadata_type_id, :metadata_name, :metadata_version, :change_note,
NULL, NULL, NULL,
NOW(), NOW(), NULL,
:tenant_code, :scope_type, NULL
)
RETURNING id
"""
),
{
"rule_set_id": rule_set_id,
"version_no": local.version_no,
"version_seq": int(seq_row["next_seq"] or 1),
"oss_url": oss_url,
"file_sha256": local.sha256,
"file_size": local.size,
"metadata_type_id": local.rule_type,
"metadata_name": local.rule_name,
"metadata_version": local.version_no,
"change_note": "从 leaudit-oss-yaml-files 全量导入",
"tenant_code": tenant_code,
"scope_type": "PUBLIC" if tenant_code == "PUBLIC" else "TENANT",
},
)
).mappings().first()
return int(created["id"]), "created"
async def import_rules(root: Path, dry_run: bool) -> None:
locals_by_type: dict[str, list[LocalRule]] = {}
for local in load_local_rules(root):
locals_by_type.setdefault(local.rule_type, []).append(local)
for rules in locals_by_type.values():
rules.sort(key=lambda item: _version_sort_key(item.version_no))
oss = OssServiceImpl()
async with GetAsyncSession() as session:
rule_sets = await _load_rule_sets(session)
tenant_codes = sorted({str(row["tenant_code"]).strip().upper() for row in rule_sets if row.get("tenant_code")})
tenant_codes = [code for code in tenant_codes if code != "PROVINCIAL"]
if "PUBLIC" not in tenant_codes:
tenant_codes.insert(0, "PUBLIC")
print(f"local_rule_files={sum(len(v) for v in locals_by_type.values())}")
print(f"local_rule_types={len(locals_by_type)}")
print(f"target_tenants={','.join(tenant_codes)}")
stats = {"uploaded": 0, "created_versions": 0, "updated_versions": 0, "refreshed_versions": 0, "published_sets": 0}
for tenant_code in tenant_codes:
for rule_type, versions in sorted(locals_by_type.items()):
current_version_id: int | None = None
rule_set_id: int | None = None
for local in versions:
object_key = OssPathUtils.BuildRuleYamlKey(local.rule_type, local.version_no)
if dry_run:
oss_url = object_key
else:
oss_url = await oss.UploadText(
ObjectKey=object_key,
Content=local.yaml_text,
ContentType="application/x-yaml; charset=utf-8",
)
stats["uploaded"] += 1
rule_set = await _ensure_rule_set(session, local, tenant_code)
rule_set_id = int(rule_set["id"])
version_id, action = await _upsert_version(session, local, rule_set_id, tenant_code, oss_url)
if action == "created":
stats["created_versions"] += 1
elif action == "updated":
stats["updated_versions"] += 1
else:
stats["refreshed_versions"] += 1
if local is versions[-1]:
current_version_id = version_id
if rule_set_id is not None and current_version_id is not None:
await session.execute(
text(
"""
UPDATE leaudit_rule_versions
SET status = CASE
WHEN id = :current_version_id THEN 'published'
WHEN status = 'published' THEN 'deprecated'
ELSE status
END,
published_at = CASE
WHEN id = :current_version_id AND published_at IS NULL THEN NOW()
ELSE published_at
END,
updated_at = NOW()
WHERE rule_set_id = :rule_set_id
AND deleted_at IS NULL
"""
),
{"rule_set_id": rule_set_id, "current_version_id": current_version_id},
)
await session.execute(
text(
"""
UPDATE leaudit_rule_sets
SET current_version_id = :current_version_id,
status = 'active',
updated_at = NOW()
WHERE id = :rule_set_id
"""
),
{"rule_set_id": rule_set_id, "current_version_id": current_version_id},
)
stats["published_sets"] += 1
if dry_run:
await session.rollback()
print("dry_run=true rolled_back=true")
else:
await session.commit()
print("dry_run=false committed=true")
for key, value in stats.items():
print(f"{key}={value}")
async def _backup_rule_domain(session) -> Path:
backup_dir = Path("docs/规则编辑/backups")
backup_dir.mkdir(parents=True, exist_ok=True)
stamp = datetime.now().strftime("%Y%m%d-%H%M%S")
backup_path = backup_dir / f"rule-domain-before-reset-{stamp}.sql"
rule_sets = (
await session.execute(
text(
"""
SELECT *
FROM leaudit_rule_sets
ORDER BY id ASC
"""
)
)
).mappings().all()
rule_versions = (
await session.execute(
text(
"""
SELECT *
FROM leaudit_rule_versions
ORDER BY id ASC
"""
)
)
).mappings().all()
lines = [
"-- Rule domain backup before reset",
f"-- generated_at: {datetime.now().isoformat(timespec='seconds')}",
f"-- rule_sets: {len(rule_sets)}",
f"-- rule_versions: {len(rule_versions)}",
"",
"-- This file is an audit snapshot, not an automatic restore script.",
"-- Use the rows below to inspect pre-reset IDs, current_version_id, oss_url and sha.",
"",
]
for row in rule_sets:
lines.append(
"-- rule_set "
+ " ".join(
[
f"id={row.get('id')}",
f"tenant_code={row.get('tenant_code')}",
f"rule_type={row.get('rule_type')}",
f"current_version_id={row.get('current_version_id')}",
f"status={row.get('status')}",
f"deleted_at={row.get('deleted_at')}",
]
)
)
lines.append("")
for row in rule_versions:
lines.append(
"-- rule_version "
+ " ".join(
[
f"id={row.get('id')}",
f"rule_set_id={row.get('rule_set_id')}",
f"version_no={row.get('version_no')}",
f"status={row.get('status')}",
f"oss_url={row.get('oss_url')}",
f"sha={row.get('file_sha256')}",
f"deleted_at={row.get('deleted_at')}",
]
)
)
backup_path.write_text("\n".join(lines) + "\n", encoding="utf-8")
return backup_path
async def reset_and_import_rules(root: Path, *, dry_run: bool, prune_oss: bool) -> None:
local_rules = load_local_rules(root)
canonical_keys = {
OssPathUtils.BuildRuleYamlKey(local.rule_type, local.version_no)
for local in local_rules
}
async with GetAsyncSession() as session:
backup_path = await _backup_rule_domain(session)
referenced_keys = {
str(row["oss_url"])
for row in (
await session.execute(
text(
"""
SELECT DISTINCT rule_source_oss_url AS oss_url
FROM leaudit_audit_runs
WHERE COALESCE(rule_source_oss_url, '') <> ''
"""
)
)
).mappings().all()
}
existing_version_keys = {
str(row["oss_url"])
for row in (
await session.execute(
text(
"""
SELECT DISTINCT oss_url
FROM leaudit_rule_versions
WHERE COALESCE(oss_url, '') <> ''
"""
)
)
).mappings().all()
}
deletable_oss_keys = sorted(existing_version_keys - canonical_keys - referenced_keys)
print(f"backup_path={backup_path}")
print(f"canonical_keys={len(canonical_keys)}")
print(f"existing_version_keys={len(existing_version_keys)}")
print(f"audit_referenced_keys={len(referenced_keys)}")
print(f"deletable_oss_keys={len(deletable_oss_keys)}")
for key in deletable_oss_keys[:50]:
print(f"deletable_oss_key={key}")
if dry_run:
await session.rollback()
print("reset_dry_run=true rolled_back=true")
return
await session.execute(
text(
"""
UPDATE leaudit_rule_sets
SET current_version_id = NULL,
status = 'draft',
updated_at = NOW()
WHERE deleted_at IS NULL
"""
)
)
await session.execute(
text(
"""
UPDATE leaudit_rule_versions
SET deleted_at = NOW(),
status = CASE WHEN status = 'published' THEN 'deprecated' ELSE status END,
updated_at = NOW()
WHERE deleted_at IS NULL
"""
)
)
await session.commit()
if prune_oss and deletable_oss_keys:
oss = OssServiceImpl()
client = oss.Client._GetMinioClient()
bucket = oss.Client.bucket
deleted = 0
for key in deletable_oss_keys:
try:
client.remove_object(bucket, key)
deleted += 1
except Exception as exc:
print(f"delete_oss_failed key={key} error={exc}")
print(f"oss_deleted={deleted}")
elif prune_oss:
print("oss_deleted=0")
await import_rules(root, dry_run=False)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--root", default="leaudit-oss-yaml-files")
parser.add_argument("--execute", action="store_true")
parser.add_argument("--reset-rule-domain", action="store_true")
parser.add_argument("--prune-oss", action="store_true")
args = parser.parse_args()
if args.reset_rule_domain:
asyncio.run(reset_and_import_rules(Path(args.root), dry_run=not args.execute, prune_oss=args.prune_oss))
else:
asyncio.run(import_rules(Path(args.root), dry_run=not args.execute))
if __name__ == "__main__":
main()