Files
2026-04-29 22:25:06 +08:00

103 lines
3.3 KiB
Python

"""控制器注册器 —— 扫描 controller_packages 并注册路由。"""
from __future__ import annotations
import importlib
import pkgutil
from pathlib import Path
from typing import Any
from fastapi import APIRouter, FastAPI
controller_packages = [
"fastapi_modules.fastapi_leaudit.controllers",
]
def register_controllers(app: FastAPI) -> None:
"""扫描所有控制器包,注册 BaseController 子类的路由。"""
for package_name in controller_packages:
try:
pkg = importlib.import_module(package_name)
except ImportError:
continue
package_routers: dict[str, APIRouter] = {}
_collect_package_routers(pkg, str(Path(pkg.__file__ or "").parent), package_name, package_routers)
_register_from_package(pkg, package_name, package_routers, app)
def _collect_package_routers(
pkg: Any, pkg_dir: str, pkg_name: str, routers: dict[str, APIRouter],
) -> None:
"""收集所有包级 router(从 __init__.py)。"""
if hasattr(pkg, "router") and isinstance(pkg.router, APIRouter):
routers[pkg_name] = pkg.router
if not hasattr(pkg, "__path__"):
return
for _, sub_name, is_pkg in pkgutil.iter_modules([pkg_dir]):
if sub_name.startswith("_"):
continue
sub_full = f"{pkg_name}.{sub_name}"
if is_pkg:
try:
sub_pkg = importlib.import_module(sub_full)
sub_dir = str(Path(pkg_dir) / sub_name)
_collect_package_routers(sub_pkg, sub_dir, sub_full, routers)
except ImportError:
pass
def _register_from_package(
pkg: Any, pkg_name: str, package_routers: dict[str, APIRouter], app: FastAPI,
) -> None:
"""从包中注册 BaseController 子类。"""
from fastapi_common.fastapi_common_web.controller import BaseController
if not hasattr(pkg, "__path__"):
return
pkg_dir = str(Path(pkg.__file__ or "").parent)
for _, module_name, is_pkg in pkgutil.iter_modules([pkg_dir]):
if module_name.startswith("_"):
continue
if is_pkg:
try:
sub_pkg_name = f"{pkg_name}.{module_name}"
sub_pkg = importlib.import_module(sub_pkg_name)
_register_from_package(sub_pkg, sub_pkg_name, package_routers, app)
except ImportError:
continue
continue
try:
mod = importlib.import_module(f"{pkg_name}.{module_name}")
except ImportError:
continue
for _, obj in vars(mod).items():
if (
isinstance(obj, type)
and issubclass(obj, BaseController)
and obj is not BaseController
):
instance = obj()
target_router = _resolve_target_router(pkg_name, module_name, package_routers)
app.include_router(instance.router, prefix="/api", dependencies=target_router.dependencies)
def _resolve_target_router(
pkg_name: str, module_name: str, package_routers: dict[str, APIRouter],
) -> APIRouter:
"""沿包路径向上查找最近的包级 router。"""
parts = pkg_name.split(".")
for i in range(len(parts), 0, -1):
parent = ".".join(parts[:i])
if parent in package_routers:
return package_routers[parent]
return APIRouter()