103 lines
3.3 KiB
Python
103 lines
3.3 KiB
Python
"""控制器注册器 —— 扫描 controller_packages 并注册路由。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import importlib
|
|
import pkgutil
|
|
from pathlib import Path
|
|
from typing import Any
|
|
|
|
from fastapi import APIRouter, FastAPI
|
|
|
|
controller_packages = [
|
|
"fastapi_modules.fastapi_leaudit.controllers",
|
|
]
|
|
|
|
|
|
def register_controllers(app: FastAPI) -> None:
|
|
"""扫描所有控制器包,注册 BaseController 子类的路由。"""
|
|
for package_name in controller_packages:
|
|
try:
|
|
pkg = importlib.import_module(package_name)
|
|
except ImportError:
|
|
continue
|
|
|
|
package_routers: dict[str, APIRouter] = {}
|
|
_collect_package_routers(pkg, str(Path(pkg.__file__ or "").parent), package_name, package_routers)
|
|
_register_from_package(pkg, package_name, package_routers, app)
|
|
|
|
|
|
def _collect_package_routers(
|
|
pkg: Any, pkg_dir: str, pkg_name: str, routers: dict[str, APIRouter],
|
|
) -> None:
|
|
"""收集所有包级 router(从 __init__.py)。"""
|
|
if hasattr(pkg, "router") and isinstance(pkg.router, APIRouter):
|
|
routers[pkg_name] = pkg.router
|
|
|
|
if not hasattr(pkg, "__path__"):
|
|
return
|
|
|
|
for _, sub_name, is_pkg in pkgutil.iter_modules([pkg_dir]):
|
|
if sub_name.startswith("_"):
|
|
continue
|
|
sub_full = f"{pkg_name}.{sub_name}"
|
|
if is_pkg:
|
|
try:
|
|
sub_pkg = importlib.import_module(sub_full)
|
|
sub_dir = str(Path(pkg_dir) / sub_name)
|
|
_collect_package_routers(sub_pkg, sub_dir, sub_full, routers)
|
|
except ImportError:
|
|
pass
|
|
|
|
|
|
def _register_from_package(
|
|
pkg: Any, pkg_name: str, package_routers: dict[str, APIRouter], app: FastAPI,
|
|
) -> None:
|
|
"""从包中注册 BaseController 子类。"""
|
|
from fastapi_common.fastapi_common_web.controller import BaseController
|
|
|
|
if not hasattr(pkg, "__path__"):
|
|
return
|
|
|
|
pkg_dir = str(Path(pkg.__file__ or "").parent)
|
|
|
|
for _, module_name, is_pkg in pkgutil.iter_modules([pkg_dir]):
|
|
if module_name.startswith("_"):
|
|
continue
|
|
|
|
if is_pkg:
|
|
try:
|
|
sub_pkg_name = f"{pkg_name}.{module_name}"
|
|
sub_pkg = importlib.import_module(sub_pkg_name)
|
|
_register_from_package(sub_pkg, sub_pkg_name, package_routers, app)
|
|
except ImportError:
|
|
continue
|
|
continue
|
|
|
|
try:
|
|
mod = importlib.import_module(f"{pkg_name}.{module_name}")
|
|
except ImportError:
|
|
continue
|
|
|
|
for _, obj in vars(mod).items():
|
|
if (
|
|
isinstance(obj, type)
|
|
and issubclass(obj, BaseController)
|
|
and obj is not BaseController
|
|
):
|
|
instance = obj()
|
|
target_router = _resolve_target_router(pkg_name, module_name, package_routers)
|
|
app.include_router(instance.router, prefix="/api", dependencies=target_router.dependencies)
|
|
|
|
|
|
def _resolve_target_router(
|
|
pkg_name: str, module_name: str, package_routers: dict[str, APIRouter],
|
|
) -> APIRouter:
|
|
"""沿包路径向上查找最近的包级 router。"""
|
|
parts = pkg_name.split(".")
|
|
for i in range(len(parts), 0, -1):
|
|
parent = ".".join(parts[:i])
|
|
if parent in package_routers:
|
|
return package_routers[parent]
|
|
return APIRouter()
|