"""控制器注册器 —— 扫描 controller_packages 并注册路由。""" from __future__ import annotations import importlib import pkgutil from pathlib import Path from typing import Any from fastapi import APIRouter, FastAPI controller_packages = [ "fastapi_modules.fastapi_leaudit.controllers", ] def register_controllers(app: FastAPI) -> None: """扫描所有控制器包,注册 BaseController 子类的路由。""" for package_name in controller_packages: try: pkg = importlib.import_module(package_name) except ImportError: continue package_routers: dict[str, APIRouter] = {} _collect_package_routers(pkg, str(Path(pkg.__file__ or "").parent), package_name, package_routers) _register_from_package(pkg, package_name, package_routers, app) def _collect_package_routers( pkg: Any, pkg_dir: str, pkg_name: str, routers: dict[str, APIRouter], ) -> None: """收集所有包级 router(从 __init__.py)。""" if hasattr(pkg, "router") and isinstance(pkg.router, APIRouter): routers[pkg_name] = pkg.router if not hasattr(pkg, "__path__"): return for _, sub_name, is_pkg in pkgutil.iter_modules([pkg_dir]): if sub_name.startswith("_"): continue sub_full = f"{pkg_name}.{sub_name}" if is_pkg: try: sub_pkg = importlib.import_module(sub_full) sub_dir = str(Path(pkg_dir) / sub_name) _collect_package_routers(sub_pkg, sub_dir, sub_full, routers) except ImportError: pass def _register_from_package( pkg: Any, pkg_name: str, package_routers: dict[str, APIRouter], app: FastAPI, ) -> None: """从包中注册 BaseController 子类。""" from fastapi_common.fastapi_common_web.controller import BaseController if not hasattr(pkg, "__path__"): return pkg_dir = str(Path(pkg.__file__ or "").parent) for _, module_name, is_pkg in pkgutil.iter_modules([pkg_dir]): if module_name.startswith("_"): continue if is_pkg: try: sub_pkg_name = f"{pkg_name}.{module_name}" sub_pkg = importlib.import_module(sub_pkg_name) _register_from_package(sub_pkg, sub_pkg_name, package_routers, app) except ImportError: continue continue try: mod = importlib.import_module(f"{pkg_name}.{module_name}") except ImportError: continue for _, obj in vars(mod).items(): if ( isinstance(obj, type) and issubclass(obj, BaseController) and obj is not BaseController ): instance = obj() target_router = _resolve_target_router(pkg_name, module_name, package_routers) app.include_router(instance.router, prefix="/api", dependencies=target_router.dependencies) def _resolve_target_router( pkg_name: str, module_name: str, package_routers: dict[str, APIRouter], ) -> APIRouter: """沿包路径向上查找最近的包级 router。""" parts = pkg_name.split(".") for i in range(len(parts), 0, -1): parent = ".".join(parts[:i]) if parent in package_routers: return package_routers[parent] return APIRouter()