chore: initial commit — leaudit-platform project skeleton
17-table PostgreSQL schema with full Chinese column comments, FastAPI project structure (admin/common/modules), DSL rule files, and schema migration scripts.
This commit is contained in:
@@ -0,0 +1,15 @@
|
||||
"""控制器包(需要 JWT 鉴权)。"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi import APIRouter, Depends, Request
|
||||
|
||||
from fastapi_common.fastapi_common_security.security import verify_access_token
|
||||
|
||||
|
||||
async def jwt_auth_dependency(RequestObj: Request) -> dict[str, Any]:
|
||||
"""JWT 鉴权依赖。"""
|
||||
return verify_access_token(RequestObj)
|
||||
|
||||
|
||||
router = APIRouter(dependencies=[Depends(jwt_auth_dependency)])
|
||||
@@ -0,0 +1,42 @@
|
||||
"""评查控制器。"""
|
||||
|
||||
from fastapi_common.fastapi_common_web.controller import BaseController
|
||||
from fastapi_common.fastapi_common_web.domain.responses import Result
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.domian.Dto.auditDto import AuditRunDTO
|
||||
from fastapi_modules.fastapi_leaudit.domian.vo.auditVo import AuditRunVO, AuditResultVO
|
||||
from fastapi_modules.fastapi_leaudit.services import IAuditService
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.auditServiceImpl import AuditServiceImpl
|
||||
|
||||
|
||||
class AuditController(BaseController):
|
||||
"""评查控制器。"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(prefix="/audit", tags=["评查"])
|
||||
self.AuditService: IAuditService = AuditServiceImpl()
|
||||
|
||||
@self.router.post("/run", response_model=Result[AuditRunVO])
|
||||
async def RunAudit(body: AuditRunDTO):
|
||||
"""触发文档评查
|
||||
|
||||
对指定文档执行 LeAudit 完整评查链路。
|
||||
"""
|
||||
run = await self.AuditService.Run(
|
||||
DocumentId=body.documentId,
|
||||
RuleType=body.ruleType,
|
||||
Force=body.force,
|
||||
)
|
||||
return Result.success(data=run)
|
||||
|
||||
@self.router.get("/run/{RunId}", response_model=Result[AuditRunVO])
|
||||
async def GetRunStatus(RunId: int):
|
||||
"""查询评查运行状态。"""
|
||||
run = await self.AuditService.GetRunStatus(RunId)
|
||||
return Result.success(data=run)
|
||||
|
||||
@self.router.get("/result/{RunId}", response_model=Result[AuditResultVO])
|
||||
async def GetResult(RunId: int):
|
||||
"""获取评查结果。"""
|
||||
result = await self.AuditService.GetResult(RunId)
|
||||
return Result.success(data=result)
|
||||
@@ -0,0 +1,5 @@
|
||||
"""认证控制器包(无鉴权)。"""
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
router = APIRouter()
|
||||
@@ -0,0 +1,97 @@
|
||||
"""认证控制器。
|
||||
|
||||
路由路径与旧项目完全一致:
|
||||
POST /auth/login — 统一登录(OAuth + 密码自动检测)
|
||||
POST /auth/password_login — 账密登录
|
||||
|
||||
响应格式按新项目规范使用 Result。
|
||||
"""
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
|
||||
from fastapi_common.fastapi_common_web.controller import BaseController
|
||||
from fastapi_common.fastapi_common_web.domain.responses import Result
|
||||
from fastapi_common.fastapi_common_logger import logger
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.domian.Dto.auth.loginDto import PasswordLoginDTO, OAuthLoginDTO
|
||||
from fastapi_modules.fastapi_leaudit.domian.vo.auth.loginTokenVo import LoginTokenVO
|
||||
from fastapi_modules.fastapi_leaudit.services import IAuthService
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.authServiceImpl import AuthServiceImpl
|
||||
|
||||
|
||||
class AuthController(BaseController):
|
||||
"""认证控制器。"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(prefix="/auth", tags=["认证"])
|
||||
self.AuthService: IAuthService = AuthServiceImpl()
|
||||
|
||||
@self.router.post("/login")
|
||||
async def Login(RequestObj: Request):
|
||||
"""统一登录接口。
|
||||
|
||||
自动检测登录方式:
|
||||
- 含 userInfo.sub → OAuth 登录
|
||||
- 含 username + password → 密码登录
|
||||
"""
|
||||
try:
|
||||
requestData = await RequestObj.json()
|
||||
|
||||
if "userInfo" in requestData and isinstance(requestData["userInfo"], dict) and "sub" in requestData["userInfo"]:
|
||||
logger.info("检测到 OAuth 登录请求")
|
||||
ui = requestData["userInfo"]
|
||||
vo = await self.AuthService.OAuthLogin(
|
||||
Sub=ui["sub"],
|
||||
Username=ui.get("username"),
|
||||
Nickname=ui.get("nickname"),
|
||||
Email=ui.get("email"),
|
||||
PhoneNumber=ui.get("phone_number"),
|
||||
OuId=ui.get("ou_id"),
|
||||
OuName=ui.get("ou_name"),
|
||||
IsLeader=ui.get("is_leader"),
|
||||
Area=requestData.get("area"),
|
||||
ExpiresIn=requestData.get("expiresIn", 3600),
|
||||
)
|
||||
elif "username" in requestData and "password" in requestData:
|
||||
logger.info(f"检测到密码登录请求 - username={requestData['username']}")
|
||||
vo = await self.AuthService.PasswordLogin(
|
||||
Sub=requestData["username"],
|
||||
Password=requestData["password"],
|
||||
)
|
||||
else:
|
||||
return JSONResponse(status_code=400, content={"code": 400, "message": "无效的登录请求格式", "data": None})
|
||||
|
||||
return JSONResponse(status_code=200, content={
|
||||
"code": 200,
|
||||
"message": "ok",
|
||||
"data": {
|
||||
"access_token": vo.access_token,
|
||||
"token_type": vo.token_type,
|
||||
"expires_in": vo.expires_in,
|
||||
"issued_time": vo.issued_time,
|
||||
"user_info": vo.user_info,
|
||||
},
|
||||
})
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"登录失败: {e}")
|
||||
return JSONResponse(status_code=401, content={
|
||||
"code": 401, "message": str(e), "data": None,
|
||||
})
|
||||
|
||||
@self.router.post("/password_login")
|
||||
async def PasswordLogin(RequestObj: Request):
|
||||
"""账密登录。校验 sso_users 表 sub + password。"""
|
||||
try:
|
||||
requestData = await RequestObj.json()
|
||||
dto = PasswordLoginDTO(**requestData)
|
||||
vo = await self.AuthService.PasswordLogin(Sub=dto.sub, Password=dto.password)
|
||||
return JSONResponse(status_code=200, content={
|
||||
"code": 200, "message": "ok", "data": vo.model_dump(),
|
||||
})
|
||||
except Exception as e:
|
||||
logger.error(f"密码登录失败: {e}")
|
||||
return JSONResponse(status_code=401, content={
|
||||
"code": 401, "message": str(e), "data": None,
|
||||
})
|
||||
@@ -0,0 +1,11 @@
|
||||
"""评查 DTO(仅控制器层使用)。"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AuditRunDTO(BaseModel):
|
||||
"""触发评查请求。"""
|
||||
|
||||
documentId: int = Field(..., description="文档ID")
|
||||
ruleType: str | None = Field(None, description="指定规则类型编码")
|
||||
force: bool = Field(False, description="是否强制重跑")
|
||||
@@ -0,0 +1,31 @@
|
||||
"""认证 DTO(仅控制器层使用)。"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class PasswordLoginDTO(BaseModel):
|
||||
"""账密登录请求。"""
|
||||
|
||||
sub: str = Field(..., description="账号")
|
||||
password: str = Field(..., description="密码")
|
||||
|
||||
|
||||
class OAuthUserInfo(BaseModel):
|
||||
"""OAuth 用户信息。"""
|
||||
|
||||
sub: str = Field(..., description="IDaaS 用户唯一标识")
|
||||
username: str | None = Field(None, description="用户名/工号")
|
||||
nickname: str | None = Field(None, description="用户昵称")
|
||||
email: str | None = Field(None, description="邮箱")
|
||||
phone_number: str | None = Field(None, description="手机号")
|
||||
ou_id: str | None = Field(None, description="组织单位ID")
|
||||
ou_name: str | None = Field(None, description="组织单位名称")
|
||||
is_leader: bool | None = Field(False, description="是否为负责人")
|
||||
|
||||
|
||||
class OAuthLoginDTO(BaseModel):
|
||||
"""OAuth 登录请求。"""
|
||||
|
||||
userInfo: OAuthUserInfo = Field(..., description="OAuth 用户信息")
|
||||
expiresIn: int = Field(..., description="OAuth token 过期时间(秒)")
|
||||
area: str | None = Field(None, description="用户所属地区")
|
||||
@@ -0,0 +1,33 @@
|
||||
"""评查 VO。"""
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class AuditRunVO(BaseModel):
|
||||
"""评查运行响应。"""
|
||||
|
||||
runId: int = Field(..., description="运行ID")
|
||||
documentId: int = Field(..., description="文档ID")
|
||||
runNo: int = Field(..., description="执行序号")
|
||||
status: str = Field(..., description="状态")
|
||||
phase: str | None = Field(None, description="draft/executed")
|
||||
totalScore: float | None = Field(None, description="总分")
|
||||
passedCount: int | None = Field(None, description="通过数")
|
||||
failedCount: int | None = Field(None, description="失败数")
|
||||
startedAt: datetime | None = Field(None, description="开始时间")
|
||||
finishedAt: datetime | None = Field(None, description="结束时间")
|
||||
|
||||
|
||||
class AuditResultVO(BaseModel):
|
||||
"""评查结果响应。"""
|
||||
|
||||
runId: int = Field(..., description="运行ID")
|
||||
totalScore: float | None = Field(None, description="总分")
|
||||
passedCount: int = Field(0, description="通过数")
|
||||
failedCount: int = Field(0, description="失败数")
|
||||
skippedCount: int = Field(0, description="跳过数")
|
||||
phase: str | None = Field(None, description="draft/executed")
|
||||
rescueApplied: bool = Field(False, description="是否执行 rescue")
|
||||
rules: list[dict] = Field(default_factory=list, description="规则结果列表")
|
||||
@@ -0,0 +1,13 @@
|
||||
"""认证 VO(控制器层 + 服务层使用)。"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class LoginTokenVO(BaseModel):
|
||||
"""登录响应。"""
|
||||
|
||||
access_token: str = Field(..., description="JWT Token")
|
||||
token_type: str = Field("Bearer", description="Token 类型")
|
||||
expires_in: int = Field(..., description="Token 过期时间(秒)")
|
||||
issued_time: str = Field(..., description="签发时间 YYYY-MM-DD HH:MM:SS")
|
||||
user_info: dict = Field(..., description="用户信息")
|
||||
@@ -0,0 +1,26 @@
|
||||
"""规则 VO。"""
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RuleSetVO(BaseModel):
|
||||
"""规则集响应。"""
|
||||
|
||||
id: int = Field(..., description="规则集ID")
|
||||
ruleType: str = Field(..., description="业务规则类型编码")
|
||||
ruleName: str = Field(..., description="规则集名称")
|
||||
domainType: str | None = Field(None, description="域类型")
|
||||
currentVersionId: int | None = Field(None, description="当前激活版本ID")
|
||||
status: str = Field(..., description="draft/active/inactive/archived")
|
||||
|
||||
|
||||
class RuleVersionVO(BaseModel):
|
||||
"""规则版本响应。"""
|
||||
|
||||
id: int = Field(..., description="版本ID")
|
||||
ruleSetId: int = Field(..., description="所属规则集ID")
|
||||
versionNo: str = Field(..., description="版本号")
|
||||
status: str = Field(..., description="draft/validated/published/rolled_back")
|
||||
ossUrl: str = Field(..., description="YAML 文件 OSS 地址")
|
||||
changeNote: str | None = Field(None, description="变更说明")
|
||||
publishedAt: str | None = Field(None, description="发布时间")
|
||||
@@ -0,0 +1,82 @@
|
||||
"""leaudit bridge — use leaudit's full pipeline with docauditai's database storage.
|
||||
|
||||
Directly calls leaudit's OCR → extraction → evaluation pipeline
|
||||
and persists results into docauditai's PostgreSQL via PostgREST.
|
||||
|
||||
Configuration switch (in env.{port}):
|
||||
PIPELINE_MODE=leaudit → use leaudit pipeline
|
||||
"""
|
||||
|
||||
from leaudit_bridge.client_factory import (
|
||||
create_ocr_client,
|
||||
create_llm_client,
|
||||
create_vlm_client,
|
||||
)
|
||||
from leaudit_bridge.ocr_bridge import BridgeOCRClient
|
||||
from leaudit_bridge.pipeline import LauditPipeline, PipelineResult
|
||||
from leaudit_bridge.rules_loader import RulesLoader
|
||||
from leaudit_bridge.storage_adapter import StorageAdapter
|
||||
|
||||
|
||||
def is_leaudit_mode() -> bool:
|
||||
"""Check if the system is configured to use the leaudit pipeline."""
|
||||
from core.config import PIPELINE_MODE
|
||||
return PIPELINE_MODE == "leaudit"
|
||||
|
||||
|
||||
def create_pipeline(rules_path: str | None = None) -> LauditPipeline:
|
||||
"""Create a fully configured LauditPipeline from current config.
|
||||
|
||||
Wraps the raw OCR client with DocNormalizationAdapter so that a single
|
||||
``.ocr()`` call produces a fully enriched OcrResult with:
|
||||
- Document classification (type_id + rules_file_path)
|
||||
- Dossier segmentation (sub-document page mapping)
|
||||
- Seal/signature enrichment (text, seal_id, party_id)
|
||||
- Normalized markdown (seal blocks + page separators)
|
||||
|
||||
Args:
|
||||
rules_path: If provided, forces the adapter to use this rules file
|
||||
for classification and segmentation. When None, the adapter
|
||||
uses the RulesFileRegistry to classify from document content,
|
||||
enabling auto-detection of sub-types (e.g. 行政许可 variants).
|
||||
"""
|
||||
from pathlib import Path
|
||||
from leaudit.doc_normalization.adapter import DocNormalizationAdapter
|
||||
from leaudit.doc_normalization.doc_classifier import RulesFileRegistry
|
||||
|
||||
raw_ocr = create_ocr_client()
|
||||
llm_client = create_llm_client()
|
||||
vlm_client = create_vlm_client()
|
||||
|
||||
# Build registry from rules/ directory for content-based classification
|
||||
registry = None
|
||||
if rules_path is None:
|
||||
rules_dir = Path(__file__).resolve().parents[1] / "rules"
|
||||
if rules_dir.is_dir():
|
||||
registry = RulesFileRegistry.from_directory(rules_dir)
|
||||
|
||||
ocr_client = DocNormalizationAdapter(
|
||||
ocr_client=raw_ocr,
|
||||
registry=registry,
|
||||
llm_client=llm_client,
|
||||
vlm_client=vlm_client,
|
||||
force_rules_path=rules_path,
|
||||
)
|
||||
ocr_client = BridgeOCRClient(ocr_client, vlm_client=vlm_client)
|
||||
|
||||
return LauditPipeline(
|
||||
ocr_client=ocr_client,
|
||||
llm_client=llm_client,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"LauditPipeline",
|
||||
"PipelineResult",
|
||||
"StorageAdapter",
|
||||
"RulesLoader",
|
||||
"create_ocr_client",
|
||||
"create_llm_client",
|
||||
"create_pipeline",
|
||||
"is_leaudit_mode",
|
||||
]
|
||||
@@ -0,0 +1,128 @@
|
||||
"""Extract case number (案件编号) from leaudit OcrResult.
|
||||
|
||||
Port of docauditai's ``extract_case_number_for_regex`` adapted for leaudit's
|
||||
OcrResult model. Uses regex patterns first; falls back to LLM when available.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from leaudit.ocr.models import OcrResult
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
# Regex patterns for case number extraction
|
||||
_PATTERNS: list[tuple[str, str]] = [
|
||||
# Direct match: "案件编号:梅烟专罚〔2024〕第XX号"
|
||||
(r"案件编号[::]\s*(.*?)(?:\n|$)", "direct"),
|
||||
# From 卷宗/卷 宗 header: extract content between 卷宗 and 案由
|
||||
(r"卷\s*宗([\s\S]*?)案\s*由", "file_content"),
|
||||
# Standalone pattern: e.g. 梅烟专罚〔2024〕第001号
|
||||
(r"[\u4e00-\u9fa5]{2,8}[专罚处决][〔(\(\[]\d{4}[〕)\)\]][\u4e00-\u9fa5]*\d+号", "standalone"),
|
||||
# Dossier number: e.g. "2024 年度 郁烟 第 71 号"
|
||||
(r"\d{4}\s*年度\s*[\u4e00-\u9fa5]{1,6}\s*第\s*\d+\s*号", "dossier"),
|
||||
]
|
||||
|
||||
# Chinese number format within 卷宗 extracted text
|
||||
_YEAR_NUMBER_RE = re.compile(r"\d{4}[\u4e00-\u9fa5]+\d+号")
|
||||
|
||||
|
||||
def extract_case_number(ocr_result: OcrResult) -> str | None:
|
||||
"""Extract case number from OCR result using regex patterns.
|
||||
|
||||
Searches across all pages but prioritizes early pages where case numbers
|
||||
typically appear (封面, 卷宗封面).
|
||||
|
||||
Args:
|
||||
ocr_result: OCR result with pages containing text.
|
||||
|
||||
Returns:
|
||||
Extracted case number string, or None if not found.
|
||||
"""
|
||||
# Build text from first few pages (case numbers appear early)
|
||||
pages_to_check = ocr_result.pages[:5] if len(ocr_result.pages) > 5 else ocr_result.pages
|
||||
text = "\n".join(p.text for p in pages_to_check)
|
||||
|
||||
if not text.strip():
|
||||
return None
|
||||
|
||||
for pattern, ptype in _PATTERNS:
|
||||
match = re.search(pattern, text)
|
||||
if not match:
|
||||
continue
|
||||
|
||||
if ptype == "direct":
|
||||
return match.group(1).strip()
|
||||
|
||||
if ptype == "file_content":
|
||||
content = match.group(1).strip()
|
||||
num_match = _YEAR_NUMBER_RE.search(content)
|
||||
if num_match:
|
||||
return num_match.group()
|
||||
|
||||
if ptype == "standalone":
|
||||
return match.group()
|
||||
|
||||
if ptype == "dossier":
|
||||
return match.group()
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def extract_case_number_with_llm(
|
||||
ocr_result: OcrResult,
|
||||
llm_client: Any = None,
|
||||
) -> str | None:
|
||||
"""Extract case number using regex first, then LLM fallback.
|
||||
|
||||
Args:
|
||||
ocr_result: OCR result with pages containing text.
|
||||
llm_client: Optional LLM client for fallback extraction.
|
||||
|
||||
Returns:
|
||||
Extracted case number string, or None if not found.
|
||||
"""
|
||||
# Try regex first (fast, no API call)
|
||||
result = extract_case_number(ocr_result)
|
||||
if result:
|
||||
return result
|
||||
|
||||
# LLM fallback
|
||||
if llm_client is None:
|
||||
return None
|
||||
|
||||
text = "\n".join(p.text for p in ocr_result.pages[:5])
|
||||
if not text.strip():
|
||||
return None
|
||||
|
||||
try:
|
||||
from leaudit.llm.base import BaseLLMClient, LLMRequest, LLMMessage
|
||||
if not isinstance(llm_client, BaseLLMClient):
|
||||
return None
|
||||
|
||||
prompt = (
|
||||
"请从以下法律文书文本中提取案件编号。"
|
||||
"案件编号通常格式如:梅烟专罚〔2024〕第001号。\n"
|
||||
"只返回JSON: {\"case_number\": \"案件编号\"}\n\n"
|
||||
f"文本:\n{text[:2000]}"
|
||||
)
|
||||
request = LLMRequest(
|
||||
messages=[LLMMessage(role="user", content=prompt)],
|
||||
response_format={"type": "json_object"},
|
||||
max_tokens=512,
|
||||
)
|
||||
response = await llm_client.complete(request)
|
||||
|
||||
import json
|
||||
parsed = json.loads(response.content)
|
||||
case_number = parsed.get("case_number")
|
||||
if case_number and case_number != "未找到":
|
||||
if re.search(r"[\u4e00-\u9fa5]|[\d()]", case_number):
|
||||
return case_number
|
||||
|
||||
except Exception as e:
|
||||
log.warning("LLM case number extraction failed: %s", e)
|
||||
|
||||
return None
|
||||
@@ -0,0 +1,80 @@
|
||||
"""Create leaudit OCR/LLM/VLM clients from docauditai's env.{port} config."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.config import (
|
||||
OCR_CONFIG,
|
||||
DEFAULT_BASE_URL,
|
||||
DEFAULT_LLM_MODEL,
|
||||
DEFAULT_API_KEY,
|
||||
DEFAULT_VLM_BASE_URL,
|
||||
DEFAULT_VLM_MODEL,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from leaudit.llm.base import BaseLLMClient
|
||||
from leaudit.llm.vlm_base import BaseVLMClient
|
||||
from leaudit.ocr.base import BaseOCRClient
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def create_ocr_client() -> BaseOCRClient:
|
||||
"""Create a leaudit ChandraOCRClient from LEAUDIT_OCR_URL config."""
|
||||
import os
|
||||
from leaudit.ocr.chandra_client import ChandraOCRClient
|
||||
|
||||
base_url = os.getenv("LEAUDIT_OCR_URL", "").rstrip("/")
|
||||
if not base_url:
|
||||
base_url = OCR_CONFIG["API_URL"].rsplit("/api/v1/ocr", 1)[0]
|
||||
timeout = float(OCR_CONFIG["TIMEOUT"])
|
||||
|
||||
client = ChandraOCRClient(
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
include_images=True,
|
||||
)
|
||||
log.info("leaudit OCR client created: %s (timeout=%ss)", base_url, timeout)
|
||||
return client
|
||||
|
||||
|
||||
def create_llm_client() -> BaseLLMClient:
|
||||
"""Create a leaudit OpenAICompatibleClient from docauditai's LLM config."""
|
||||
from leaudit.llm.openai_client import OpenAICompatibleClient
|
||||
|
||||
base_url = DEFAULT_BASE_URL
|
||||
model = DEFAULT_LLM_MODEL
|
||||
api_key = DEFAULT_API_KEY or "no-key"
|
||||
|
||||
client = OpenAICompatibleClient(
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
default_model=model,
|
||||
timeout=120.0,
|
||||
)
|
||||
log.info("leaudit LLM client created: %s (model=%s)", base_url, model)
|
||||
return client
|
||||
|
||||
|
||||
def create_vlm_client() -> BaseVLMClient | None:
|
||||
"""Create a leaudit QwenVLMClient from docauditai's VLM config."""
|
||||
from leaudit.llm.qwen_vlm_client import QwenVLMClient
|
||||
|
||||
base_url = DEFAULT_VLM_BASE_URL
|
||||
model = DEFAULT_VLM_MODEL
|
||||
api_key = DEFAULT_API_KEY or "no-key"
|
||||
|
||||
if not base_url or not model:
|
||||
log.info("leaudit VLM client skipped: no VLM config")
|
||||
return None
|
||||
|
||||
client = QwenVLMClient(
|
||||
base_url=base_url,
|
||||
api_key=api_key,
|
||||
model=model,
|
||||
)
|
||||
log.info("leaudit VLM client created: %s (model=%s)", base_url, model)
|
||||
return client
|
||||
@@ -0,0 +1,132 @@
|
||||
"""Build leaudit execution context from docauditai document data.
|
||||
|
||||
Currently leaudit's pipeline in docauditai bypasses leaudit's own
|
||||
``AuditCtx`` / ``AuditService`` and calls engine modules directly.
|
||||
This module encapsulates the pre-execution setup that currently lives
|
||||
inlined in ``pipeline.py`` and ``tasks.py``:
|
||||
|
||||
- Resolve local file path (download from OSS to temp if needed)
|
||||
- Determine RulesFile (from document metadata, type binding, or
|
||||
content classification)
|
||||
- Prepare OCR/LLM/VLM client references
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from leaudit.dsl.schema import RulesFile
|
||||
from leaudit.llm.base import BaseLLMClient
|
||||
from leaudit.ocr.base import BaseOCRClient
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from leaudit.llm.vlm_base import BaseVLMClient
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExecutionContext:
|
||||
"""Everything leaudit needs to run for one document."""
|
||||
|
||||
document_id: int
|
||||
file_path: Path
|
||||
rules_file: RulesFile
|
||||
ocr_client: BaseOCRClient
|
||||
llm_client: BaseLLMClient | None = None
|
||||
vlm_client: object | None = None
|
||||
source_port: int = 8000
|
||||
tmp_path: Path | None = None
|
||||
metadata: dict = field(default_factory=dict)
|
||||
|
||||
def cleanup(self) -> None:
|
||||
"""Remove temporary file if one was created."""
|
||||
if self.tmp_path is not None:
|
||||
try:
|
||||
os.remove(self.tmp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
|
||||
class CtxBuilder:
|
||||
"""Build :class:`ExecutionContext` from docauditai document data.
|
||||
|
||||
Handles the glue between docauditai's document model and leaudit's
|
||||
execution expectations — primarily file-path resolution and rules
|
||||
selection.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ocr_client: BaseOCRClient | None = None,
|
||||
llm_client: BaseLLMClient | None = None,
|
||||
vlm_client: object | None = None,
|
||||
) -> None:
|
||||
self.ocr_client = ocr_client
|
||||
self.llm_client = llm_client
|
||||
self.vlm_client = vlm_client
|
||||
|
||||
async def build(
|
||||
self,
|
||||
document_id: int,
|
||||
file_path: str | Path | None = None,
|
||||
file_content: bytes | None = None,
|
||||
filename: str | None = None,
|
||||
rules_file: RulesFile | None = None,
|
||||
*,
|
||||
source_port: int = 8000,
|
||||
) -> ExecutionContext:
|
||||
"""Build a ready-to-use execution context.
|
||||
|
||||
At least one of *file_path* or (*file_content* + *filename*)
|
||||
must be provided.
|
||||
|
||||
Args:
|
||||
document_id: docauditai document ID.
|
||||
file_path: Existing local path to the document file.
|
||||
file_content: Raw bytes (from DB or OSS) — a temp file is
|
||||
created.
|
||||
filename: Required when *file_content* is given.
|
||||
rules_file: Pre-loaded RulesFile. When None, the caller
|
||||
must resolve after OCR classification.
|
||||
source_port: Instance port.
|
||||
|
||||
Returns:
|
||||
ExecutionContext ready for pipeline.run().
|
||||
"""
|
||||
tmp_path: Path | None = None
|
||||
|
||||
if file_path is not None:
|
||||
resolved = Path(file_path)
|
||||
elif file_content is not None and filename is not None:
|
||||
suffix = self._suffix(filename)
|
||||
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
|
||||
tmp.write(file_content)
|
||||
tmp.close()
|
||||
resolved = Path(tmp.name)
|
||||
tmp_path = resolved
|
||||
else:
|
||||
raise ValueError(
|
||||
"Either file_path or (file_content + filename) is required"
|
||||
)
|
||||
|
||||
return ExecutionContext(
|
||||
document_id=document_id,
|
||||
file_path=resolved,
|
||||
rules_file=rules_file, # type: ignore[arg-type]
|
||||
ocr_client=self.ocr_client, # type: ignore[arg-type]
|
||||
llm_client=self.llm_client,
|
||||
vlm_client=self.vlm_client,
|
||||
source_port=source_port,
|
||||
tmp_path=tmp_path,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _suffix(filename: str) -> str:
|
||||
_, ext = os.path.splitext(filename)
|
||||
return ext if ext else ".pdf"
|
||||
@@ -0,0 +1,272 @@
|
||||
"""Bridge-side OCR post-processing for leaudit integration.
|
||||
|
||||
Keeps docauditai-specific fixes outside ``services/leaudit/**``:
|
||||
- DOCX embedded-image visuals can be refined once more with the VLM after
|
||||
the merged ``OcrResult`` is built.
|
||||
- Cross-page seals with missing completeness flags are normalized so the
|
||||
legacy compatibility checks have a stable shape to consume.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
|
||||
from leaudit.ocr.base import BaseOCRClient
|
||||
from leaudit.ocr.models import OcrResult, VisualManifestItem
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BridgeOCRClient(BaseOCRClient):
|
||||
"""Wrap an OCR client and apply integration-side post-processing."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
inner: BaseOCRClient,
|
||||
*,
|
||||
vlm_client: object | None = None,
|
||||
vlm_concurrency: int = 6,
|
||||
) -> None:
|
||||
self.inner = inner
|
||||
self.vlm_client = vlm_client
|
||||
self.vlm_concurrency = vlm_concurrency
|
||||
|
||||
async def ocr(self, file_path: Path | str) -> OcrResult:
|
||||
path = Path(file_path)
|
||||
result = await self.inner.ocr(path)
|
||||
await postprocess_ocr_result(
|
||||
result,
|
||||
file_path=path,
|
||||
vlm_client=self.vlm_client,
|
||||
vlm_concurrency=self.vlm_concurrency,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
async def postprocess_ocr_result(
|
||||
ocr_result: OcrResult,
|
||||
*,
|
||||
file_path: Path,
|
||||
vlm_client: object | None = None,
|
||||
vlm_concurrency: int = 6,
|
||||
) -> OcrResult:
|
||||
"""Apply bridge-side visual repairs without touching leaudit core."""
|
||||
suffix = file_path.suffix.lower()
|
||||
if suffix not in {".docx", ".doc", ".wps"}:
|
||||
return ocr_result
|
||||
|
||||
await _maybe_refine_docx_visuals(
|
||||
ocr_result,
|
||||
vlm_client=vlm_client,
|
||||
concurrency=vlm_concurrency,
|
||||
)
|
||||
await _inject_docx_signature_candidates(
|
||||
ocr_result,
|
||||
vlm_client=vlm_client,
|
||||
)
|
||||
_normalize_cross_page_seals(ocr_result)
|
||||
return ocr_result
|
||||
|
||||
|
||||
async def _maybe_refine_docx_visuals(
|
||||
ocr_result: OcrResult,
|
||||
*,
|
||||
vlm_client: object | None,
|
||||
concurrency: int,
|
||||
) -> None:
|
||||
vm = ocr_result.visual_manifest
|
||||
if vlm_client is None or vm is None:
|
||||
return
|
||||
if not (vm.seals or vm.signatures or vm.cross_page_seals):
|
||||
return
|
||||
|
||||
try:
|
||||
from leaudit.ocr.visual_classifier import refine_visual_manifest
|
||||
|
||||
await refine_visual_manifest(
|
||||
ocr_result,
|
||||
vlm_client,
|
||||
concurrency=concurrency,
|
||||
)
|
||||
except Exception as exc:
|
||||
log.warning("bridge visual refinement skipped: %s", exc)
|
||||
|
||||
|
||||
async def _inject_docx_signature_candidates(
|
||||
ocr_result: OcrResult,
|
||||
*,
|
||||
vlm_client: object | None,
|
||||
) -> None:
|
||||
"""Probe likely handwritten-signature zones on DOCX parent images."""
|
||||
if vlm_client is None:
|
||||
return
|
||||
|
||||
try:
|
||||
from PIL import Image
|
||||
except ImportError:
|
||||
log.warning("Pillow unavailable, skip DOCX signature candidate probing")
|
||||
return
|
||||
|
||||
parent_to_items: dict[str, list[VisualManifestItem]] = {}
|
||||
for bucket in (
|
||||
ocr_result.visual_manifest.seals or [],
|
||||
ocr_result.visual_manifest.signatures or [],
|
||||
ocr_result.visual_manifest.cross_page_seals or [],
|
||||
):
|
||||
for item in bucket:
|
||||
parent_key = getattr(item, "parent_image_key", None)
|
||||
if parent_key:
|
||||
parent_to_items.setdefault(parent_key, []).append(item)
|
||||
|
||||
for parent_key, items in parent_to_items.items():
|
||||
if any((it.label or "") == "signature" for it in items):
|
||||
continue
|
||||
parent_bytes = ocr_result.get_image_bytes(parent_key)
|
||||
if not parent_bytes:
|
||||
continue
|
||||
|
||||
try:
|
||||
image = Image.open(BytesIO(parent_bytes))
|
||||
except Exception as exc:
|
||||
log.warning("failed to open parent image %s: %s", parent_key, exc)
|
||||
continue
|
||||
|
||||
width, height = image.size
|
||||
for candidate_bbox in _signature_candidate_boxes(items, width, height):
|
||||
try:
|
||||
crop = image.crop(tuple(candidate_bbox))
|
||||
buf = BytesIO()
|
||||
crop.save(buf, format="PNG")
|
||||
result = await _classify_signature_candidate(
|
||||
vlm_client,
|
||||
buf.getvalue(),
|
||||
"这是合同签章页里疑似法人签名的候选区域,请优先判断是否为手写签名。",
|
||||
)
|
||||
except Exception as exc:
|
||||
log.warning("signature probe failed for %s: %s", parent_key, exc)
|
||||
continue
|
||||
|
||||
if getattr(result, "kind", None) != "signature":
|
||||
continue
|
||||
|
||||
page_num = _infer_parent_page_num(items)
|
||||
ocr_result.visual_manifest.signatures.append(
|
||||
VisualManifestItem(
|
||||
page_num=page_num,
|
||||
bbox=candidate_bbox,
|
||||
label="signature",
|
||||
confidence=getattr(result, "confidence", 0.9) or 0.9,
|
||||
text_match=(getattr(result, "text", None) or "").strip() or None,
|
||||
alt_text="docx_signature_candidate",
|
||||
image_key=parent_key,
|
||||
parent_image_key=parent_key,
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
|
||||
async def _classify_signature_candidate(
|
||||
vlm_client: object,
|
||||
image_bytes: bytes,
|
||||
user_hint: str,
|
||||
) -> object:
|
||||
"""Classify with one retry using a fresh VLM client when needed."""
|
||||
try:
|
||||
return await vlm_client.classify_visual(image_bytes, user_hint=user_hint)
|
||||
except Exception as exc:
|
||||
log.warning("signature probe primary VLM failed, retrying fresh client: %s", exc)
|
||||
|
||||
try:
|
||||
from leaudit.llm.qwen_vlm_client import QwenVLMClient
|
||||
|
||||
fresh = QwenVLMClient(
|
||||
base_url=getattr(vlm_client, "base_url"),
|
||||
api_key=getattr(vlm_client, "api_key", ""),
|
||||
model=getattr(vlm_client, "model"),
|
||||
timeout=getattr(vlm_client, "timeout", 90.0),
|
||||
)
|
||||
try:
|
||||
return await fresh.classify_visual(image_bytes, user_hint=user_hint)
|
||||
finally:
|
||||
await fresh.close()
|
||||
except Exception as exc:
|
||||
raise RuntimeError(exc) from exc
|
||||
|
||||
|
||||
def _signature_candidate_boxes(
|
||||
items: list[VisualManifestItem],
|
||||
width: int,
|
||||
height: int,
|
||||
) -> list[list[int]]:
|
||||
candidates: list[list[int]] = []
|
||||
seen: set[tuple[int, int, int, int]] = set()
|
||||
|
||||
for item in items:
|
||||
seal_type = getattr(item, "seal_type", None)
|
||||
label = getattr(item, "label", None)
|
||||
bbox = getattr(item, "bbox", None) or []
|
||||
if len(bbox) != 4:
|
||||
continue
|
||||
|
||||
x1, y1, x2, y2 = bbox
|
||||
box_w = max(1, x2 - x1)
|
||||
box_h = max(1, y2 - y1)
|
||||
ratio = box_w / box_h
|
||||
|
||||
if seal_type == "法人章" or label == "法人章":
|
||||
continue
|
||||
if not (0.75 <= ratio <= 1.35):
|
||||
continue
|
||||
if box_w < width * 0.10 or box_h < height * 0.10:
|
||||
continue
|
||||
|
||||
cand = [
|
||||
max(0, int(x1 - box_w * 0.25)),
|
||||
max(0, int(y1 + box_h * 0.50)),
|
||||
min(width, int(x2 + box_w * 0.25)),
|
||||
min(height, int(y2 + box_h * 0.95)),
|
||||
]
|
||||
if cand[2] - cand[0] < 24 or cand[3] - cand[1] < 24:
|
||||
continue
|
||||
key = tuple(cand)
|
||||
if key not in seen:
|
||||
seen.add(key)
|
||||
candidates.append(cand)
|
||||
|
||||
return candidates
|
||||
|
||||
|
||||
def _infer_parent_page_num(items: list[VisualManifestItem]) -> int:
|
||||
for item in items:
|
||||
page_num = getattr(item, "page_num", None)
|
||||
if isinstance(page_num, int):
|
||||
return page_num
|
||||
return 0
|
||||
|
||||
|
||||
def _normalize_cross_page_seals(ocr_result: OcrResult) -> None:
|
||||
"""Fill obvious completeness defaults for bridge-side checks."""
|
||||
for item in ocr_result.visual_manifest.cross_page_seals or []:
|
||||
if item.pages and len(item.pages) >= 2:
|
||||
item.is_complete = True
|
||||
continue
|
||||
|
||||
bbox = item.bbox or []
|
||||
if len(bbox) == 4:
|
||||
width = max(1, bbox[2] - bbox[0])
|
||||
height = max(1, bbox[3] - bbox[1])
|
||||
ratio = width / height
|
||||
# DOCX embedded images often contain a complete round seal near the
|
||||
# page edge; Chandra may still classify it as a seam-seal half by
|
||||
# geometry. A near-square crop is a strong signal that the visible
|
||||
# stamp is already complete.
|
||||
if 0.65 <= ratio <= 1.35:
|
||||
item.is_complete = True
|
||||
continue
|
||||
|
||||
if item.is_complete is not None:
|
||||
continue
|
||||
if item.pages and len(item.pages) == 1:
|
||||
item.is_complete = False
|
||||
@@ -0,0 +1,268 @@
|
||||
"""Main leaudit pipeline orchestrator: OCR → Extract → Evaluate.
|
||||
|
||||
Uses leaudit's own pipeline directly (no conversion),
|
||||
stores results into docauditai's database via StorageAdapter.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
|
||||
from leaudit.dsl.schema import RulesFile
|
||||
from leaudit.engine.case_file_evaluator import evaluate_extraction
|
||||
from leaudit.engine.models import EvaluationResult
|
||||
from leaudit.extraction.bundle import ExtractionBundle
|
||||
from leaudit.extraction.dispatcher import dispatch_extract
|
||||
from leaudit.extraction.phase_detection import determine_phase
|
||||
from leaudit.llm.base import BaseLLMClient
|
||||
from leaudit.ocr.base import BaseOCRClient
|
||||
from leaudit.ocr.models import OcrResult
|
||||
|
||||
from leaudit_bridge.storage_adapter import StorageAdapter
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PipelineResult:
|
||||
"""Complete result from the leaudit pipeline."""
|
||||
|
||||
ocr_result: OcrResult
|
||||
extraction_bundle: ExtractionBundle
|
||||
evaluation_result: EvaluationResult
|
||||
detected_phase: str
|
||||
timing: dict[str, float] = field(default_factory=dict)
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
class LauditPipeline:
|
||||
"""Run leaudit's full OCR → extraction → evaluation pipeline.
|
||||
|
||||
Does NOT use leaudit's own SQLAlchemy storage.
|
||||
All results are written to docauditai's database via StorageAdapter.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
ocr_client: BaseOCRClient,
|
||||
llm_client: BaseLLMClient | None = None,
|
||||
storage_adapter: StorageAdapter | None = None,
|
||||
) -> None:
|
||||
self.ocr_client = ocr_client
|
||||
self.llm_client = llm_client
|
||||
self.storage = storage_adapter or StorageAdapter()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
document_id: int,
|
||||
file_path: str | Path,
|
||||
rules_file: RulesFile | None = None,
|
||||
*,
|
||||
source_port: int = 8000,
|
||||
) -> PipelineResult:
|
||||
"""Execute the full pipeline for one document.
|
||||
|
||||
Args:
|
||||
document_id: docauditai document ID for DB writes.
|
||||
file_path: Path to the document file (PDF/DOCX/etc).
|
||||
rules_file: leaudit RulesFile (parsed from YAML). When None,
|
||||
the pipeline attempts to load rules from the OCR result's
|
||||
``rules_file_path`` (set by the classifier).
|
||||
source_port: Instance port for context switching.
|
||||
|
||||
Returns:
|
||||
PipelineResult with all intermediate and final outputs.
|
||||
"""
|
||||
file_path = Path(file_path)
|
||||
errors: list[str] = []
|
||||
timing: dict[str, float] = {}
|
||||
|
||||
# --- Phase 1: Update status to Cutting ---
|
||||
await self.storage.update_document_status(document_id, "Cutting")
|
||||
|
||||
# --- Phase 2: OCR ---
|
||||
t0 = time.time()
|
||||
log.info("[%d] OCR starting: %s", document_id, file_path.name)
|
||||
ocr_result = await self._run_ocr(file_path)
|
||||
timing["ocr"] = round(time.time() - t0, 2)
|
||||
log.info(
|
||||
"[%d] OCR done: %d pages, %.1fs",
|
||||
document_id,
|
||||
len(ocr_result.pages),
|
||||
timing["ocr"],
|
||||
)
|
||||
|
||||
# --- Resolve rules_file after OCR if not provided ---
|
||||
if rules_file is None:
|
||||
rules_file = await self._resolve_rules_from_ocr(ocr_result, document_id)
|
||||
if rules_file is None:
|
||||
raise ValueError(
|
||||
f"Cannot resolve rules_file for document {document_id}. "
|
||||
"Neither passed explicitly nor classified from OCR content."
|
||||
)
|
||||
|
||||
# --- Save OCR result ---
|
||||
await self.storage.save_ocr_result(document_id, ocr_result)
|
||||
|
||||
# --- Extract & save case number (案件编号) ---
|
||||
await self._extract_and_save_case_number(document_id, ocr_result)
|
||||
|
||||
# --- Phase 3: Extraction ---
|
||||
t0 = time.time()
|
||||
await self.storage.update_document_status(document_id, "Extractioning")
|
||||
log.info("[%d] Extraction starting", document_id)
|
||||
|
||||
extraction_bundle = await dispatch_extract(
|
||||
ocr_result,
|
||||
rules_file,
|
||||
llm_client=self.llm_client,
|
||||
phase="executed",
|
||||
)
|
||||
timing["extraction"] = round(time.time() - t0, 2)
|
||||
|
||||
if extraction_bundle.all_errors:
|
||||
errors.extend(extraction_bundle.all_errors)
|
||||
log.warning(
|
||||
"[%d] Extraction completed with %d errors",
|
||||
document_id,
|
||||
len(extraction_bundle.all_errors),
|
||||
)
|
||||
|
||||
log.info(
|
||||
"[%d] Extraction done: %d fields, %.1fs",
|
||||
document_id,
|
||||
len(extraction_bundle.fields),
|
||||
timing["extraction"],
|
||||
)
|
||||
|
||||
# --- Save extraction result ---
|
||||
await self.storage.save_extraction_result(document_id, extraction_bundle)
|
||||
|
||||
# --- Resolve field positions from OCR chunks ---
|
||||
from leaudit.extraction.coordinate_resolver import resolve_bundle_positions
|
||||
resolve_bundle_positions(extraction_bundle, ocr_result)
|
||||
positioned_count = sum(
|
||||
1 for fv in extraction_bundle.fields.values() if fv.position is not None
|
||||
)
|
||||
log.info(
|
||||
"[%d] Coordinate resolution: %d/%d fields positioned",
|
||||
document_id,
|
||||
positioned_count,
|
||||
len(extraction_bundle.fields),
|
||||
)
|
||||
|
||||
# --- Phase 4: Phase detection ---
|
||||
visual_manifest = extraction_bundle.visual_manifest or ocr_result.visual_manifest
|
||||
detected_phase = await determine_phase(
|
||||
extraction_bundle.fields,
|
||||
llm_client=self.llm_client,
|
||||
visual_manifest=visual_manifest,
|
||||
)
|
||||
log.info("[%d] Detected phase: %s", document_id, detected_phase)
|
||||
|
||||
# --- Phase 5: Evaluation ---
|
||||
t0 = time.time()
|
||||
await self.storage.update_document_status(document_id, "Evaluationing")
|
||||
log.info("[%d] Evaluation starting (phase=%s)", document_id, detected_phase)
|
||||
|
||||
external_mocks: dict[str, Any] = {}
|
||||
if self.llm_client is not None:
|
||||
external_mocks["llm_client"] = self.llm_client
|
||||
external_mocks["rules_file"] = rules_file
|
||||
|
||||
evaluation_result = await evaluate_extraction(
|
||||
rules_file,
|
||||
extraction_bundle,
|
||||
visual_manifest=visual_manifest,
|
||||
phase=detected_phase,
|
||||
external_mocks=external_mocks,
|
||||
)
|
||||
timing["evaluation"] = round(time.time() - t0, 2)
|
||||
log.info(
|
||||
"[%d] Evaluation done: %d passed, %d failed, %d skipped, %.1fs",
|
||||
document_id,
|
||||
evaluation_result.passed_count,
|
||||
evaluation_result.failed_count,
|
||||
evaluation_result.skipped_count,
|
||||
timing["evaluation"],
|
||||
)
|
||||
|
||||
# --- Save evaluation results ---
|
||||
await self.storage.save_evaluation_results(
|
||||
document_id, rules_file, evaluation_result, extraction_bundle,
|
||||
)
|
||||
|
||||
# --- Phase 6: Finalize ---
|
||||
timing["total"] = round(sum(timing.values()), 2)
|
||||
await self.storage.update_document_status(document_id, "Processed")
|
||||
|
||||
log.info(
|
||||
"[%d] Pipeline complete: phase=%s, timing=%s",
|
||||
document_id,
|
||||
detected_phase,
|
||||
timing,
|
||||
)
|
||||
|
||||
return PipelineResult(
|
||||
ocr_result=ocr_result,
|
||||
extraction_bundle=extraction_bundle,
|
||||
evaluation_result=evaluation_result,
|
||||
detected_phase=detected_phase,
|
||||
timing=timing,
|
||||
errors=errors,
|
||||
)
|
||||
|
||||
async def _run_ocr(self, file_path: Path) -> OcrResult:
|
||||
"""Run OCR with error handling."""
|
||||
try:
|
||||
return await self.ocr_client.ocr(file_path)
|
||||
except Exception as e:
|
||||
log.error("OCR failed for %s: %s", file_path.name, e)
|
||||
raise
|
||||
|
||||
async def _extract_and_save_case_number(
|
||||
self, document_id: int, ocr_result: OcrResult,
|
||||
) -> None:
|
||||
"""Extract case number from OCR and write to database."""
|
||||
from leaudit_bridge.case_number_extractor import (
|
||||
extract_case_number_with_llm,
|
||||
)
|
||||
|
||||
case_number = await extract_case_number_with_llm(
|
||||
ocr_result, llm_client=self.llm_client,
|
||||
)
|
||||
if case_number:
|
||||
await self.storage.update_document_number(document_id, case_number)
|
||||
log.info("[%d] Case number: %s", document_id, case_number)
|
||||
else:
|
||||
log.info("[%d] No case number found", document_id)
|
||||
|
||||
async def _resolve_rules_from_ocr(
|
||||
self, ocr_result: OcrResult, document_id: int,
|
||||
) -> RulesFile | None:
|
||||
"""Load rules_file from OCR classification result."""
|
||||
from leaudit.dsl.loader import load_rules_file
|
||||
|
||||
rfp = ocr_result.rules_file_path
|
||||
if not rfp:
|
||||
log.warning(
|
||||
"[%d] No rules_file_path in OCR result, cannot resolve rules",
|
||||
document_id,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
rules_file = load_rules_file(rfp)
|
||||
log.info(
|
||||
"[%d] Resolved rules from classification: %s (%d rules)",
|
||||
document_id, rfp, len(rules_file.flat_rules),
|
||||
)
|
||||
return rules_file
|
||||
except Exception as e:
|
||||
log.error("[%d] Failed to load rules from %s: %s", document_id, rfp, e)
|
||||
return None
|
||||
@@ -0,0 +1,303 @@
|
||||
"""Adapt leaudit raw results into docauditai's standardized format.
|
||||
|
||||
Currently this logic is inlined in ``storage_adapter.py``
|
||||
(``_rule_result_to_row``, ``_bundle_to_extracted``, ``_ocr_to_dict``).
|
||||
This module extracts the conversion layer so storage_adapter focuses
|
||||
on persistence (the "adapter" part of result_adapter).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from leaudit.dsl.schema import Rule as DslRule
|
||||
from leaudit.dsl.schema import RulesFile
|
||||
from leaudit.engine.models import EvaluationResult, RuleResult
|
||||
from leaudit.extraction.bundle import ExtractionBundle
|
||||
from leaudit.extraction.models import FieldValue
|
||||
from leaudit.ocr.models import OcrResult
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# OCR result → dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def ocr_to_dict(ocr: OcrResult) -> dict[str, Any]:
|
||||
"""Convert OcrResult to a JSON-safe dict for storage."""
|
||||
result: dict[str, Any] = {
|
||||
"numPages": len(ocr.pages),
|
||||
"full_text": ocr.full_text,
|
||||
"pages": [],
|
||||
}
|
||||
|
||||
for page in ocr.pages:
|
||||
page_dict: dict[str, Any] = {
|
||||
"page_num": page.page_num,
|
||||
"text": page.text,
|
||||
"page_box": page.page_box,
|
||||
}
|
||||
if page.chunks:
|
||||
page_dict["chunks"] = [
|
||||
(
|
||||
{"bbox": c["bbox"], "content": c["content"], "label": c.get("label")}
|
||||
if isinstance(c, dict) and "bbox" in c and "content" in c
|
||||
else {
|
||||
"bbox": c.bbox if hasattr(c, "bbox") else None,
|
||||
"content": c.content if hasattr(c, "content") else str(c),
|
||||
"label": c.label if hasattr(c, "label") else None,
|
||||
}
|
||||
)
|
||||
for c in page.chunks
|
||||
]
|
||||
if page.bboxes:
|
||||
page_dict["bboxes"] = page.bboxes
|
||||
result["pages"].append(page_dict)
|
||||
|
||||
if ocr.visual_manifest:
|
||||
result["visual_manifest"] = ocr.visual_manifest.model_dump(mode="json")
|
||||
|
||||
if ocr.images:
|
||||
result["images"] = ocr.images
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# ExtractionBundle → dict
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def bundle_to_extracted(bundle: ExtractionBundle) -> dict[str, Any]:
|
||||
"""Convert ExtractionBundle to docauditai's extracted_results format."""
|
||||
fields: dict[str, Any] = {}
|
||||
for name, fv in bundle.fields.items():
|
||||
if isinstance(fv, FieldValue):
|
||||
field_data = {
|
||||
"value": fv.value,
|
||||
"confidence": float(fv.confidence) if fv.confidence else 0.0,
|
||||
}
|
||||
if fv.position is not None:
|
||||
field_data["position"] = fv.position.model_dump(mode="json")
|
||||
fields[name] = field_data
|
||||
else:
|
||||
fields[name] = {"value": fv}
|
||||
|
||||
multi_entity: dict[str, Any] = {}
|
||||
for name, rows in bundle.multi_entity.items():
|
||||
multi_entity[name] = [
|
||||
{
|
||||
k: (v.value if isinstance(v, FieldValue) else v)
|
||||
for k, v in row.items()
|
||||
}
|
||||
if isinstance(row, dict)
|
||||
else {"value": row}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
return {
|
||||
"fields": fields,
|
||||
"multi_entity": multi_entity,
|
||||
"derived": dict(bundle.derived) if bundle.derived else {},
|
||||
"is_case_file": bundle.is_case_file,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# EvaluationResult → per-rule rows
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def rule_result_to_row(
|
||||
document_id: int,
|
||||
run_id: int,
|
||||
rule_result: RuleResult,
|
||||
rule: Any | None,
|
||||
bundle: ExtractionBundle,
|
||||
) -> dict[str, Any]:
|
||||
"""Convert one RuleResult to a database row dict.
|
||||
|
||||
Args:
|
||||
document_id: docauditai document ID.
|
||||
run_id: ``leaudit_audit_runs.id`` for this execution.
|
||||
rule_result: Single rule evaluation result from leaudit.
|
||||
rule: DSL Rule definition (for metadata lookups).
|
||||
bundle: The extraction bundle (for field position lookups).
|
||||
"""
|
||||
passed = rule_result.passed
|
||||
|
||||
pass_msg = ""
|
||||
fail_msg = ""
|
||||
if rule_result.messages:
|
||||
pass_msg = rule_result.messages.get("pass", "")
|
||||
fail_msg = rule_result.messages.get("fail", "")
|
||||
elif isinstance(rule, DslRule) and rule.messages:
|
||||
pass_msg = rule.messages.get("pass", "")
|
||||
fail_msg = rule.messages.get("fail", "")
|
||||
|
||||
relevant_fields = _extract_relevant_fields(rule, bundle)
|
||||
field_positions = _extract_relevant_field_positions(rule, bundle)
|
||||
|
||||
remediation = None
|
||||
if rule_result.remediation:
|
||||
remediation = rule_result.remediation.model_dump(mode="json")
|
||||
|
||||
rule_meta: dict[str, Any] = {}
|
||||
if isinstance(rule, DslRule):
|
||||
if rule.references_laws:
|
||||
rule_meta["references_laws"] = rule.references_laws
|
||||
if rule.desc:
|
||||
rule_meta["desc"] = rule.desc
|
||||
if rule.group:
|
||||
rule_meta["group"] = rule.group
|
||||
|
||||
return {
|
||||
"document_id": document_id,
|
||||
"run_id": run_id,
|
||||
"rule_id": rule_result.rule_id,
|
||||
"rule_name": rule_result.name,
|
||||
"risk": rule_result.risk or "medium",
|
||||
"score": rule_result.score,
|
||||
"passed": passed,
|
||||
"status": rule_result.status,
|
||||
"skip_reason": rule_result.skip_reason,
|
||||
"confidence": rule_result.confidence,
|
||||
"pass_message": pass_msg,
|
||||
"fail_message": fail_msg,
|
||||
"stages": [s.model_dump(mode="json") for s in (rule_result.stages or [])],
|
||||
"extracted_fields": relevant_fields,
|
||||
"field_positions": field_positions,
|
||||
"remediation": remediation,
|
||||
"rule_meta": rule_meta,
|
||||
"rescue_applied": False,
|
||||
"rescue_passed": None,
|
||||
}
|
||||
|
||||
|
||||
def evaluation_summary(eval_result: EvaluationResult) -> dict[str, Any]:
|
||||
"""Extract summary fields from an EvaluationResult."""
|
||||
return {
|
||||
"total_score": eval_result.total_score,
|
||||
"passed_count": eval_result.passed_count,
|
||||
"failed_count": eval_result.failed_count,
|
||||
"skipped_count": eval_result.skipped_count,
|
||||
"result_status": _result_status(eval_result),
|
||||
}
|
||||
|
||||
|
||||
def _result_status(eval_result: EvaluationResult) -> str:
|
||||
if eval_result.errors:
|
||||
return "error"
|
||||
if eval_result.failed_count == 0 and eval_result.skipped_count == 0:
|
||||
return "pass"
|
||||
if eval_result.failed_count > 0:
|
||||
return "fail"
|
||||
return "partial"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Internal helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _extract_relevant_fields(
|
||||
rule: Any, bundle: ExtractionBundle,
|
||||
) -> dict[str, Any]:
|
||||
"""Extract field values referenced by a rule's stages."""
|
||||
relevant: dict[str, Any] = {}
|
||||
if not rule or not hasattr(rule, "stages") or not rule.stages:
|
||||
return relevant
|
||||
|
||||
for stage in rule.stages:
|
||||
stage_data = (
|
||||
stage.model_dump(exclude_none=True)
|
||||
if hasattr(stage, "model_dump")
|
||||
else {}
|
||||
)
|
||||
extra = stage.extra if hasattr(stage, "extra") else {}
|
||||
field_names: list[str] = []
|
||||
|
||||
for key in ("field", "field1", "field2", "fields"):
|
||||
val = getattr(stage, key, None) or extra.get(key) or stage_data.get(key)
|
||||
if isinstance(val, list):
|
||||
field_names.extend(f for f in val if isinstance(f, str))
|
||||
elif isinstance(val, str):
|
||||
field_names.append(val)
|
||||
|
||||
pairs = stage_data.get("pairs")
|
||||
if isinstance(pairs, list):
|
||||
for pair in pairs:
|
||||
if not isinstance(pair, dict):
|
||||
continue
|
||||
for key in ("source", "target", "a", "b"):
|
||||
ref = pair.get(key)
|
||||
if isinstance(ref, str):
|
||||
field_names.append(ref)
|
||||
|
||||
prompt = stage_data.get("prompt")
|
||||
if isinstance(prompt, str):
|
||||
for m in re.finditer(r"\{\{\s*([^{}]+?)\s*\}\}", prompt):
|
||||
field_names.append(m.group(1).strip())
|
||||
|
||||
for f in field_names:
|
||||
if f in relevant:
|
||||
continue
|
||||
if f not in bundle.fields:
|
||||
continue
|
||||
fv = bundle.fields[f]
|
||||
relevant[f] = fv.value if isinstance(fv, FieldValue) else fv
|
||||
|
||||
return relevant
|
||||
|
||||
|
||||
def _extract_relevant_field_positions(
|
||||
rule: Any, bundle: ExtractionBundle,
|
||||
) -> dict[str, Any]:
|
||||
"""Extract position data for fields referenced by a rule's stages."""
|
||||
positions: dict[str, Any] = {}
|
||||
if not rule or not hasattr(rule, "stages") or not rule.stages:
|
||||
return positions
|
||||
|
||||
for stage in rule.stages:
|
||||
stage_data = (
|
||||
stage.model_dump(exclude_none=True)
|
||||
if hasattr(stage, "model_dump")
|
||||
else {}
|
||||
)
|
||||
extra = stage.extra if hasattr(stage, "extra") else {}
|
||||
field_names: list[str] = []
|
||||
|
||||
for key in ("field", "field1", "field2", "fields"):
|
||||
val = getattr(stage, key, None) or extra.get(key) or stage_data.get(key)
|
||||
if isinstance(val, list):
|
||||
field_names.extend(f for f in val if isinstance(f, str))
|
||||
elif isinstance(val, str):
|
||||
field_names.append(val)
|
||||
|
||||
pairs = stage_data.get("pairs")
|
||||
if isinstance(pairs, list):
|
||||
for pair in pairs:
|
||||
if not isinstance(pair, dict):
|
||||
continue
|
||||
for key in ("source", "target", "a", "b"):
|
||||
ref = pair.get(key)
|
||||
if isinstance(ref, str):
|
||||
field_names.append(ref)
|
||||
|
||||
prompt = stage_data.get("prompt")
|
||||
if isinstance(prompt, str):
|
||||
for m in re.finditer(r"\{\{\s*([^{}]+?)\s*\}\}", prompt):
|
||||
field_names.append(m.group(1).strip())
|
||||
|
||||
for f in field_names:
|
||||
if f in positions:
|
||||
continue
|
||||
fv = bundle.fields.get(f)
|
||||
if fv is not None and isinstance(fv, FieldValue) and fv.position is not None:
|
||||
positions[f] = fv.position.model_dump(mode="json")
|
||||
|
||||
return positions
|
||||
@@ -0,0 +1,55 @@
|
||||
"""Load leaudit YAML RulesFile from filesystem or MinIO."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from leaudit.dsl.schema import RulesFile
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
_DEFAULT_RULES_DIR = Path(__file__).resolve().parents[2] / "rules"
|
||||
|
||||
|
||||
class RulesLoader:
|
||||
"""Load and cache leaudit RulesFile from YAML files."""
|
||||
|
||||
def __init__(self, rules_dir: str | Path | None = None) -> None:
|
||||
self._rules_dir = Path(rules_dir) if rules_dir else _DEFAULT_RULES_DIR
|
||||
self._cache: dict[str, RulesFile] = {}
|
||||
|
||||
def load(self, rules_path: str) -> RulesFile:
|
||||
"""Load a RulesFile by relative path under rules_dir, or absolute path."""
|
||||
from leaudit.dsl.loader import load_rules_file
|
||||
|
||||
if rules_path in self._cache:
|
||||
return self._cache[rules_path]
|
||||
|
||||
p = Path(rules_path)
|
||||
if not p.is_absolute():
|
||||
p = self._rules_dir / p
|
||||
|
||||
log.info("Loading RulesFile: %s", p)
|
||||
rules_file = load_rules_file(p)
|
||||
self._cache[rules_path] = rules_file
|
||||
return rules_file
|
||||
|
||||
def load_from_yaml_text(self, yaml_text: str, cache_key: str | None = None) -> RulesFile:
|
||||
"""Parse a RulesFile from raw YAML string."""
|
||||
from leaudit.dsl.loader import parse_rules_yaml_text
|
||||
|
||||
if cache_key and cache_key in self._cache:
|
||||
return self._cache[cache_key]
|
||||
|
||||
rules_file = parse_rules_yaml_text(yaml_text)
|
||||
|
||||
if cache_key:
|
||||
self._cache[cache_key] = rules_file
|
||||
|
||||
return rules_file
|
||||
|
||||
def clear_cache(self) -> None:
|
||||
self._cache.clear()
|
||||
@@ -0,0 +1,364 @@
|
||||
"""Storage adapter — write leaudit results into docauditai's PostgreSQL via PostgREST.
|
||||
|
||||
Converts leaudit's OcrResult, ExtractionBundle, EvaluationResult
|
||||
into docauditai's table format and writes via PostgRESTClient.
|
||||
|
||||
Uses the new `leaudit_evaluation_results` table for per-rule results.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from typing import Any
|
||||
|
||||
from leaudit.dsl.schema import RulesFile
|
||||
from leaudit.engine.models import EvaluationResult, RuleResult
|
||||
from leaudit.extraction.bundle import ExtractionBundle
|
||||
from leaudit.extraction.models import FieldValue
|
||||
from leaudit.ocr.models import OcrResult
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _get_postgrest_client():
|
||||
"""Lazy import to avoid circular dependency at module load."""
|
||||
from core.postgrest.client import get_postgrest_client
|
||||
return get_postgrest_client()
|
||||
|
||||
|
||||
class StorageAdapter:
|
||||
"""Write leaudit pipeline results to docauditai's database."""
|
||||
|
||||
# ---- Document status ----
|
||||
|
||||
async def update_document_status(self, document_id: int, status: str) -> None:
|
||||
"""Update the document's processing status."""
|
||||
client = _get_postgrest_client()
|
||||
await client.update(
|
||||
table="documents",
|
||||
filters={"id": f"eq.{document_id}"},
|
||||
data={"status": status},
|
||||
)
|
||||
log.debug("[%d] Status updated: %s", document_id, status)
|
||||
|
||||
# ---- Document number (案件编号) ----
|
||||
|
||||
async def update_document_number(self, document_id: int, document_number: str) -> None:
|
||||
"""Update the document's case number (document_number field)."""
|
||||
client = _get_postgrest_client()
|
||||
await client.update(
|
||||
table="documents",
|
||||
filters={"id": f"eq.{document_id}"},
|
||||
data={"document_number": document_number},
|
||||
)
|
||||
log.info("[%d] document_number updated: %s", document_id, document_number)
|
||||
|
||||
# ---- OCR result ----
|
||||
|
||||
async def save_ocr_result(self, document_id: int, ocr_result: OcrResult) -> None:
|
||||
"""Save OCR result to documents.ocr_result and raw_full_text_original."""
|
||||
client = _get_postgrest_client()
|
||||
|
||||
ocr_dict = _ocr_to_dict(ocr_result)
|
||||
full_text = ocr_result.full_text or "\n".join(p.text for p in ocr_result.pages)
|
||||
|
||||
await client.update(
|
||||
table="documents",
|
||||
filters={"id": f"eq.{document_id}"},
|
||||
data={
|
||||
"ocr_result": ocr_dict,
|
||||
"raw_full_text_original": full_text,
|
||||
},
|
||||
)
|
||||
log.info("[%d] OCR result saved (%d pages)", document_id, len(ocr_result.pages))
|
||||
|
||||
# ---- Extraction result ----
|
||||
|
||||
async def save_extraction_result(
|
||||
self, document_id: int, bundle: ExtractionBundle,
|
||||
) -> None:
|
||||
"""Save extraction result to documents.extracted_results."""
|
||||
client = _get_postgrest_client()
|
||||
|
||||
extracted = _bundle_to_extracted(bundle)
|
||||
|
||||
await client.update(
|
||||
table="documents",
|
||||
filters={"id": f"eq.{document_id}"},
|
||||
data={"extracted_results": extracted},
|
||||
)
|
||||
log.info(
|
||||
"[%d] Extraction result saved (%d fields)",
|
||||
document_id,
|
||||
len(bundle.fields),
|
||||
)
|
||||
|
||||
# ---- Evaluation results ----
|
||||
|
||||
async def save_evaluation_results(
|
||||
self,
|
||||
document_id: int,
|
||||
rules_file: RulesFile,
|
||||
evaluation: EvaluationResult,
|
||||
bundle: ExtractionBundle,
|
||||
) -> None:
|
||||
"""Save evaluation results to leaudit_evaluation_results table.
|
||||
|
||||
One row per rule. Deletes existing results for the document first,
|
||||
then inserts fresh rows.
|
||||
"""
|
||||
client = _get_postgrest_client()
|
||||
|
||||
# Delete existing results for this document
|
||||
await client.delete(
|
||||
table="leaudit_evaluation_results",
|
||||
filters={"document_id": f"eq.{document_id}"},
|
||||
)
|
||||
|
||||
# Build rule_id → rule metadata lookup
|
||||
rule_meta = {}
|
||||
for rule in rules_file.flat_rules:
|
||||
rule_meta[rule.rule_id] = rule
|
||||
|
||||
# Insert one row per rule result
|
||||
for rule_result in evaluation.rules:
|
||||
rule = rule_meta.get(rule_result.rule_id)
|
||||
row = _rule_result_to_row(document_id, rule_result, rule, bundle)
|
||||
await client.insert(table="leaudit_evaluation_results", data=row)
|
||||
|
||||
log.info(
|
||||
"[%d] Evaluation results saved: %d passed, %d failed, %d skipped",
|
||||
document_id,
|
||||
evaluation.passed_count,
|
||||
evaluation.failed_count,
|
||||
evaluation.skipped_count,
|
||||
)
|
||||
|
||||
|
||||
# ---- Serialization helpers ----
|
||||
|
||||
|
||||
def _ocr_to_dict(ocr: OcrResult) -> dict[str, Any]:
|
||||
"""Convert OcrResult to a JSON-safe dict for PostgREST storage."""
|
||||
result: dict[str, Any] = {
|
||||
"numPages": len(ocr.pages),
|
||||
"full_text": ocr.full_text,
|
||||
"pages": [],
|
||||
}
|
||||
|
||||
for page in ocr.pages:
|
||||
page_dict: dict[str, Any] = {
|
||||
"page_num": page.page_num,
|
||||
"text": page.text,
|
||||
"page_box": page.page_box,
|
||||
}
|
||||
if page.chunks:
|
||||
page_dict["chunks"] = [
|
||||
(
|
||||
{"bbox": c["bbox"], "content": c["content"], "label": c.get("label")}
|
||||
if isinstance(c, dict) and "bbox" in c and "content" in c
|
||||
else {
|
||||
"bbox": c.bbox if hasattr(c, "bbox") else None,
|
||||
"content": c.content if hasattr(c, "content") else str(c),
|
||||
"label": c.label if hasattr(c, "label") else None,
|
||||
}
|
||||
)
|
||||
for c in page.chunks
|
||||
]
|
||||
if page.bboxes:
|
||||
page_dict["bboxes"] = page.bboxes
|
||||
result["pages"].append(page_dict)
|
||||
|
||||
if ocr.visual_manifest:
|
||||
result["visual_manifest"] = ocr.visual_manifest.model_dump(mode="json")
|
||||
|
||||
if ocr.images:
|
||||
result["images"] = ocr.images
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _bundle_to_extracted(bundle: ExtractionBundle) -> dict[str, Any]:
|
||||
"""Convert ExtractionBundle to docauditai's extracted_results format."""
|
||||
fields: dict[str, Any] = {}
|
||||
for name, fv in bundle.fields.items():
|
||||
if isinstance(fv, FieldValue):
|
||||
field_data = {
|
||||
"value": fv.value,
|
||||
"confidence": float(fv.confidence) if fv.confidence else 0.0,
|
||||
}
|
||||
if fv.position is not None:
|
||||
field_data["position"] = fv.position.model_dump(mode="json")
|
||||
fields[name] = field_data
|
||||
else:
|
||||
fields[name] = {"value": fv}
|
||||
|
||||
multi_entity: dict[str, Any] = {}
|
||||
for name, rows in bundle.multi_entity.items():
|
||||
multi_entity[name] = [
|
||||
{
|
||||
k: (v.value if isinstance(v, FieldValue) else v)
|
||||
for k, v in row.items()
|
||||
}
|
||||
if isinstance(row, dict) else {"value": row}
|
||||
for row in rows
|
||||
]
|
||||
|
||||
return {
|
||||
"fields": fields,
|
||||
"multi_entity": multi_entity,
|
||||
"derived": dict(bundle.derived) if bundle.derived else {},
|
||||
"is_case_file": bundle.is_case_file,
|
||||
}
|
||||
|
||||
|
||||
def _extract_relevant_fields(rule: Any, bundle: ExtractionBundle) -> dict[str, Any]:
|
||||
"""Extract field values referenced by a rule's stages."""
|
||||
relevant: dict[str, Any] = {}
|
||||
|
||||
if not rule or not hasattr(rule, "stages") or not rule.stages:
|
||||
return relevant
|
||||
|
||||
for stage in rule.stages:
|
||||
stage_data = stage.model_dump(exclude_none=True) if hasattr(stage, "model_dump") else {}
|
||||
extra = stage.extra if hasattr(stage, "extra") else {}
|
||||
field_names: list[str] = []
|
||||
|
||||
for key in ("field", "field1", "field2", "fields"):
|
||||
val = getattr(stage, key, None) or extra.get(key) or stage_data.get(key)
|
||||
if isinstance(val, list):
|
||||
field_names.extend(f for f in val if isinstance(f, str))
|
||||
elif isinstance(val, str):
|
||||
field_names.append(val)
|
||||
|
||||
pairs = stage_data.get("pairs")
|
||||
if isinstance(pairs, list):
|
||||
for pair in pairs:
|
||||
if not isinstance(pair, dict):
|
||||
continue
|
||||
for key in ("source", "target", "a", "b"):
|
||||
ref = pair.get(key)
|
||||
if isinstance(ref, str):
|
||||
field_names.append(ref)
|
||||
|
||||
prompt = stage_data.get("prompt")
|
||||
if isinstance(prompt, str):
|
||||
for m in re.finditer(r"\{\{\s*([^{}]+?)\s*\}\}", prompt):
|
||||
field_names.append(m.group(1).strip())
|
||||
|
||||
for f in field_names:
|
||||
if f in relevant:
|
||||
continue
|
||||
if f not in bundle.fields:
|
||||
continue
|
||||
fv = bundle.fields[f]
|
||||
relevant[f] = fv.value if isinstance(fv, FieldValue) else fv
|
||||
|
||||
return relevant
|
||||
|
||||
|
||||
def _extract_relevant_field_positions(
|
||||
rule: Any,
|
||||
bundle: ExtractionBundle,
|
||||
) -> dict[str, Any]:
|
||||
"""Extract position data for fields referenced by a rule's stages."""
|
||||
positions: dict[str, Any] = {}
|
||||
if not rule or not hasattr(rule, "stages") or not rule.stages:
|
||||
return positions
|
||||
|
||||
for stage in rule.stages:
|
||||
stage_data = stage.model_dump(exclude_none=True) if hasattr(stage, "model_dump") else {}
|
||||
extra = stage.extra if hasattr(stage, "extra") else {}
|
||||
field_names: list[str] = []
|
||||
|
||||
for key in ("field", "field1", "field2", "fields"):
|
||||
val = getattr(stage, key, None) or extra.get(key) or stage_data.get(key)
|
||||
if isinstance(val, list):
|
||||
field_names.extend(f for f in val if isinstance(f, str))
|
||||
elif isinstance(val, str):
|
||||
field_names.append(val)
|
||||
|
||||
pairs = stage_data.get("pairs")
|
||||
if isinstance(pairs, list):
|
||||
for pair in pairs:
|
||||
if not isinstance(pair, dict):
|
||||
continue
|
||||
for key in ("source", "target", "a", "b"):
|
||||
ref = pair.get(key)
|
||||
if isinstance(ref, str):
|
||||
field_names.append(ref)
|
||||
|
||||
prompt = stage_data.get("prompt")
|
||||
if isinstance(prompt, str):
|
||||
for m in re.finditer(r"\{\{\s*([^{}]+?)\s*\}\}", prompt):
|
||||
field_names.append(m.group(1).strip())
|
||||
|
||||
for f in field_names:
|
||||
if f in positions:
|
||||
continue
|
||||
fv = bundle.fields.get(f)
|
||||
if fv is not None and isinstance(fv, FieldValue) and fv.position is not None:
|
||||
positions[f] = fv.position.model_dump(mode="json")
|
||||
return positions
|
||||
|
||||
|
||||
def _rule_result_to_row(
|
||||
document_id: int,
|
||||
rule_result: RuleResult,
|
||||
rule: Any | None,
|
||||
bundle: ExtractionBundle,
|
||||
) -> dict[str, Any]:
|
||||
"""Convert a RuleResult to a leaudit_evaluation_results row."""
|
||||
passed = rule_result.passed
|
||||
|
||||
# Resolve messages: rule_result → rule definition
|
||||
pass_msg = ""
|
||||
fail_msg = ""
|
||||
if rule_result.messages:
|
||||
pass_msg = rule_result.messages.get("pass", "")
|
||||
fail_msg = rule_result.messages.get("fail", "")
|
||||
elif rule:
|
||||
from leaudit.dsl.schema import Rule as DslRule
|
||||
if isinstance(rule, DslRule) and rule.messages:
|
||||
pass_msg = rule.messages.get("pass", "")
|
||||
fail_msg = rule.messages.get("fail", "")
|
||||
|
||||
# Extract relevant fields
|
||||
relevant_fields = _extract_relevant_fields(rule, bundle)
|
||||
|
||||
# Remediation (if present)
|
||||
remediation = None
|
||||
if rule_result.remediation:
|
||||
remediation = rule_result.remediation.model_dump(mode="json")
|
||||
|
||||
# Rule metadata (references_laws, etc.)
|
||||
rule_meta_data: dict[str, Any] = {}
|
||||
if rule:
|
||||
from leaudit.dsl.schema import Rule as DslRule
|
||||
if isinstance(rule, DslRule):
|
||||
if rule.references_laws:
|
||||
rule_meta_data["references_laws"] = rule.references_laws
|
||||
if rule.desc:
|
||||
rule_meta_data["desc"] = rule.desc
|
||||
if rule.group:
|
||||
rule_meta_data["group"] = rule.group
|
||||
|
||||
return {
|
||||
"document_id": document_id,
|
||||
"rule_id": rule_result.rule_id,
|
||||
"rule_name": rule_result.name,
|
||||
"risk": rule_result.risk or "medium",
|
||||
"score": rule_result.score,
|
||||
"passed": passed,
|
||||
"status": rule_result.status,
|
||||
"skip_reason": rule_result.skip_reason,
|
||||
"confidence": rule_result.confidence,
|
||||
"pass_message": pass_msg,
|
||||
"fail_message": fail_msg,
|
||||
"stages": [s.model_dump(mode="json") for s in (rule_result.stages or [])],
|
||||
"extracted_fields": relevant_fields,
|
||||
"field_positions": _extract_relevant_field_positions(rule, bundle),
|
||||
"remediation": remediation,
|
||||
"rule_meta": rule_meta_data,
|
||||
}
|
||||
@@ -0,0 +1,201 @@
|
||||
"""Celery task for leaudit pipeline processing.
|
||||
|
||||
Activated when PIPELINE_MODE=leaudit in env.{port} config.
|
||||
Replaces the legacy OCR → extraction → evaluation pipeline with
|
||||
leaudit's YAML-rules-driven approach.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from core.celery_app_limited import celery_app
|
||||
from core.postgrest.client import get_postgrest_client
|
||||
from core.logger import log
|
||||
|
||||
from leaudit_bridge import create_pipeline, RulesLoader
|
||||
|
||||
|
||||
@celery_app.task(bind=True, name="leaudit.process_document")
|
||||
def leaudit_process_document(
|
||||
self,
|
||||
document_id: int,
|
||||
file_content: bytes,
|
||||
filename: str,
|
||||
upload_info: Optional[Dict[str, Any]] = None,
|
||||
source_port: Optional[int] = None,
|
||||
rules_path: Optional[str] = None,
|
||||
):
|
||||
"""Process a document using leaudit's full pipeline.
|
||||
|
||||
Steps: OCR → Extraction → Evaluation → Store in docauditai DB.
|
||||
"""
|
||||
task_id = self.request.id
|
||||
log.task.info(f"[任务ID: {task_id}] leaudit管线开始处理: {filename}")
|
||||
|
||||
if source_port:
|
||||
from core.utils.instance_context import set_instance_environment
|
||||
instance_name = set_instance_environment(source_port)
|
||||
log.task.info(
|
||||
f"[任务ID: {task_id}] 实例环境: {instance_name} (端口: {source_port})"
|
||||
)
|
||||
|
||||
if upload_info is None:
|
||||
upload_info = {}
|
||||
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
try:
|
||||
rules_path_resolved = rules_path or _resolve_rules_path(document_id, loop)
|
||||
|
||||
# For types with a known mapping (e.g. 行政处罚), pre-load rules_file.
|
||||
# For types that need content classification (e.g. 行政许可 sub-types),
|
||||
# rules_path will be None → adapter classifies after OCR → pipeline
|
||||
# loads rules from ocr_result.rules_file_path.
|
||||
rules_file = None
|
||||
if rules_path_resolved:
|
||||
loader = RulesLoader()
|
||||
rules_file = loader.load(rules_path_resolved)
|
||||
log.task.info(
|
||||
f"[任务ID: {task_id}] RulesFile pre-loaded: {rules_path_resolved} "
|
||||
f"({len(rules_file.flat_rules)} rules, {len(rules_file.flat_extract)} fields)"
|
||||
)
|
||||
else:
|
||||
log.task.info(
|
||||
f"[任务ID: {task_id}] No fixed rules_path — "
|
||||
"will classify from document content after OCR"
|
||||
)
|
||||
|
||||
suffix = _get_suffix(filename)
|
||||
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp:
|
||||
temp.write(file_content)
|
||||
temp_path = temp.name
|
||||
|
||||
pipeline = create_pipeline(rules_path=rules_path_resolved)
|
||||
|
||||
t0 = time.time()
|
||||
result = loop.run_until_complete(
|
||||
pipeline.run(
|
||||
document_id=document_id,
|
||||
file_path=temp_path,
|
||||
rules_file=rules_file,
|
||||
source_port=source_port or int(os.getenv("APP_PORT", "8000")),
|
||||
)
|
||||
)
|
||||
elapsed = round(time.time() - t0, 2)
|
||||
|
||||
try:
|
||||
os.remove(temp_path)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
log.task.info(
|
||||
f"[任务ID: {task_id}] leaudit管线完成: phase={result.detected_phase}, "
|
||||
f"timing={result.timing}, 总耗时={elapsed:.1f}s"
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"document_id": document_id,
|
||||
"phase": result.detected_phase,
|
||||
"timing": result.timing,
|
||||
"errors": result.errors,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
log.task.error(f"[任务ID: {task_id}] leaudit管线失败: {e}", exc_info=True)
|
||||
try:
|
||||
loop.run_until_complete(_update_status_safe(document_id, "Failed"))
|
||||
except Exception:
|
||||
pass
|
||||
raise
|
||||
|
||||
finally:
|
||||
loop.close()
|
||||
|
||||
|
||||
# type_id → rules directory mapping (only fixed-mapping types)
|
||||
# 行政许可 (type_id=2) has 9 sub-types, NOT mapped here —
|
||||
# must come from document metadata (rules_file_path) or content classification.
|
||||
_TYPE_ID_RULES_MAP: dict[int, str] = {
|
||||
3: "行政处罚",
|
||||
}
|
||||
|
||||
|
||||
def _resolve_rules_path(document_id: int, loop: asyncio.AbstractEventLoop) -> str | None:
|
||||
"""Resolve rules_path: config override → document metadata → type_id mapping."""
|
||||
from core.config import LEAUDIT_CONFIG
|
||||
|
||||
# 1. Config override (when explicitly set)
|
||||
config_path = LEAUDIT_CONFIG.get("RULES_PATH", "")
|
||||
if config_path:
|
||||
return config_path
|
||||
|
||||
try:
|
||||
client = get_postgrest_client()
|
||||
doc = loop.run_until_complete(
|
||||
client.select(
|
||||
table="documents",
|
||||
filters={"id": f"eq.{document_id}"},
|
||||
single=True,
|
||||
)
|
||||
)
|
||||
if not doc:
|
||||
return None
|
||||
|
||||
# 2. Document-level override
|
||||
rfp = doc.get("rules_file_path")
|
||||
if rfp:
|
||||
return rfp
|
||||
|
||||
# 3. type_id mapping
|
||||
type_id = doc.get("type_id")
|
||||
if type_id and type_id in _TYPE_ID_RULES_MAP:
|
||||
return f"{_TYPE_ID_RULES_MAP[type_id]}/rules.yaml"
|
||||
except Exception as e:
|
||||
log.task.warning(f"Failed to resolve rules_path from document: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def _update_status_safe(document_id: int, status: str) -> None:
|
||||
"""Safely update document status, ignoring errors."""
|
||||
try:
|
||||
client = get_postgrest_client()
|
||||
await client.update(
|
||||
table="documents",
|
||||
filters={"id": f"eq.{document_id}"},
|
||||
data={"status": status},
|
||||
)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def _get_suffix(filename: str) -> str:
|
||||
"""Extract file suffix from filename."""
|
||||
_, ext = os.path.splitext(filename)
|
||||
return ext if ext else ".pdf"
|
||||
|
||||
|
||||
def dispatch_leaudit_task(
|
||||
document_id: int,
|
||||
file_content: bytes,
|
||||
filename: str,
|
||||
upload_info: Optional[Dict[str, Any]] = None,
|
||||
source_port: Optional[int] = None,
|
||||
rules_path: Optional[str] = None,
|
||||
):
|
||||
"""Dispatch a leaudit processing task."""
|
||||
return leaudit_process_document.apply_async(
|
||||
args=[document_id, file_content, filename],
|
||||
kwargs={
|
||||
"upload_info": upload_info,
|
||||
"source_port": source_port or int(os.getenv("APP_PORT", "8000")),
|
||||
"rules_path": rules_path,
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,11 @@
|
||||
"""LeAudit 模型导出。"""
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.models.leauditDocument import LeauditDocument
|
||||
from fastapi_modules.fastapi_leaudit.models.leauditDocumentFile import LeauditDocumentFile
|
||||
from fastapi_modules.fastapi_leaudit.models.leauditAuditRun import LeauditAuditRun
|
||||
|
||||
__all__ = [
|
||||
"LeauditDocument",
|
||||
"LeauditDocumentFile",
|
||||
"LeauditAuditRun",
|
||||
]
|
||||
@@ -0,0 +1,77 @@
|
||||
"""JWT Token 模型 —— jwt_tokens 表。
|
||||
|
||||
记录每次签发的 Token 生命周期:签发、刷新、撤销。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import BigInteger, Boolean, DateTime, String, select, update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from fastapi_common.fastapi_common_web.models import BaseModel
|
||||
|
||||
|
||||
class JwtToken(BaseModel):
|
||||
"""JWT Token 记录表。"""
|
||||
|
||||
__tablename__ = "jwt_tokens"
|
||||
|
||||
Id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
userId: Mapped[int] = mapped_column(BigInteger, comment="用户ID")
|
||||
tokenJti: Mapped[str] = mapped_column(String(128), comment="Token JTI")
|
||||
tokenHash: Mapped[str] = mapped_column(String(128), comment="Access Token SHA256")
|
||||
refreshTokenHash: Mapped[str | None] = mapped_column(String(128), comment="Refresh Token SHA256")
|
||||
tokenType: Mapped[str] = mapped_column(String(32), default="ACCESS", comment="ACCESS/REFRESH")
|
||||
deviceId: Mapped[str | None] = mapped_column(String(128))
|
||||
deviceName: Mapped[str | None] = mapped_column(String(256))
|
||||
userAgent: Mapped[str | None] = mapped_column(String(512))
|
||||
ipAddress: Mapped[str | None] = mapped_column(String(64))
|
||||
issuedAt: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
||||
expiresAt: Mapped[datetime] = mapped_column(DateTime(timezone=True))
|
||||
refreshExpiresAt: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
lastUsedAt: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
isRevoked: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
revokedAt: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
revokeReason: Mapped[str | None] = mapped_column(String(256))
|
||||
|
||||
@classmethod
|
||||
async def get_by_jti(cls, session: AsyncSession, jti: str) -> "JwtToken | None":
|
||||
"""按 JTI 查询 Token 记录。"""
|
||||
return await session.scalar(select(cls).where(cls.tokenJti == jti))
|
||||
|
||||
@classmethod
|
||||
async def revoke_by_jti(cls, session: AsyncSession, jti: str, reason: str = "") -> None:
|
||||
"""撤销指定 JTI 的 Token。"""
|
||||
await session.execute(
|
||||
update(cls)
|
||||
.where(cls.tokenJti == jti)
|
||||
.values(isRevoked=True, revokedAt=datetime.now(), revokeReason=reason)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def revoke_all_user_tokens(cls, session: AsyncSession, userId: int, reason: str = "") -> list[str]:
|
||||
"""撤销用户的所有活跃 Token,返回被撤销的 JTI 列表。"""
|
||||
result = await session.execute(
|
||||
select(cls.tokenJti).where(cls.userId == userId, cls.isRevoked == False)
|
||||
)
|
||||
jtis = [row[0] for row in result.fetchall()]
|
||||
await session.execute(
|
||||
update(cls)
|
||||
.where(cls.userId == userId, cls.isRevoked == False)
|
||||
.values(isRevoked=True, revokedAt=datetime.now(), revokeReason=reason)
|
||||
)
|
||||
return jtis
|
||||
|
||||
@classmethod
|
||||
async def cleanup_expired(cls, session: AsyncSession, before: datetime) -> int:
|
||||
"""清理过期的 Token 记录,返回删除数。"""
|
||||
result = await session.execute(
|
||||
select(cls).where(cls.expiresAt < before)
|
||||
)
|
||||
rows = result.scalars().all()
|
||||
for row in rows:
|
||||
await session.delete(row)
|
||||
return len(rows)
|
||||
@@ -0,0 +1,60 @@
|
||||
"""LeAudit AuditRun 模型 —— leaudit_audit_runs 表。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
from sqlalchemy import BigInteger, Boolean, DateTime, Integer, Numeric, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from fastapi_common.fastapi_common_web.models import BaseModel
|
||||
|
||||
|
||||
class LeauditAuditRun(BaseModel):
|
||||
"""评查运行主表。"""
|
||||
|
||||
__tablename__ = "leaudit_audit_runs"
|
||||
|
||||
Id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
documentId: Mapped[int] = mapped_column(BigInteger, comment="关联 leaudit_documents.id")
|
||||
documentFileId: Mapped[int | None] = mapped_column(BigInteger, comment="输入文件ID")
|
||||
runNo: Mapped[int] = mapped_column(Integer, comment="同一文档第几次执行")
|
||||
triggerSource: Mapped[str] = mapped_column(String(64), comment="upload/manual/retry/migration/batch")
|
||||
triggerUserId: Mapped[int | None] = mapped_column(BigInteger, comment="触发人")
|
||||
taskId: Mapped[str | None] = mapped_column(String(128), comment="Celery 任务 ID")
|
||||
|
||||
# 状态
|
||||
status: Mapped[str] = mapped_column(String(64), default="pending", comment="pending/processing/completed/failed/cancelled")
|
||||
phase: Mapped[str | None] = mapped_column(String(32), comment="draft/executed")
|
||||
|
||||
# 规则溯源
|
||||
ruleSetId: Mapped[int] = mapped_column(BigInteger, comment="关联 leaudit_rule_sets.id")
|
||||
ruleVersionId: Mapped[int] = mapped_column(BigInteger, comment="关联 leaudit_rule_versions.id")
|
||||
ruleTypeId: Mapped[str | None] = mapped_column(String(256), comment="LeAudit metadata.type_id")
|
||||
ruleSourceOssUrl: Mapped[str | None] = mapped_column(String(2048), comment="规则 YAML OSS 地址")
|
||||
ruleSourceSha256: Mapped[str | None] = mapped_column(String(64), comment="规则文件 SHA256")
|
||||
ruleLocalCachePath: Mapped[str | None] = mapped_column(String(1024), comment="本地缓存路径")
|
||||
|
||||
# 模型快照
|
||||
engineVersion: Mapped[str | None] = mapped_column(String(64))
|
||||
llmProvider: Mapped[str | None] = mapped_column(String(64))
|
||||
llmModel: Mapped[str | None] = mapped_column(String(128))
|
||||
vlmProvider: Mapped[str | None] = mapped_column(String(64))
|
||||
vlmModel: Mapped[str | None] = mapped_column(String(128))
|
||||
ocrProvider: Mapped[str | None] = mapped_column(String(64))
|
||||
ocrModel: Mapped[str | None] = mapped_column(String(128))
|
||||
|
||||
# Rescue
|
||||
rescueMode: Mapped[str | None] = mapped_column(String(32), comment="off/tier1/auto")
|
||||
rescueApplied: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否执行 rescue")
|
||||
|
||||
# 结果汇总
|
||||
totalScore: Mapped[float | None] = mapped_column(Numeric(10, 2))
|
||||
passedCount: Mapped[int | None] = mapped_column(Integer)
|
||||
failedCount: Mapped[int | None] = mapped_column(Integer)
|
||||
skippedCount: Mapped[int | None] = mapped_column(Integer)
|
||||
resultStatus: Mapped[str | None] = mapped_column(String(32), comment="pass/fail/partial/error")
|
||||
|
||||
# 时间
|
||||
startedAt: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
finishedAt: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
|
||||
@@ -0,0 +1,44 @@
|
||||
"""LeAudit 域文档镜像模型 —— leaudit_documents 表。
|
||||
|
||||
通过 biz_document_id 关联业务 documents 表,不复制业务字段。
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import BigInteger, String, ForeignKey
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from fastapi_common.fastapi_common_web.models import BaseModel
|
||||
|
||||
|
||||
class LeauditDocument(BaseModel):
|
||||
"""LeAudit 文档镜像表。"""
|
||||
|
||||
__tablename__ = "leaudit_documents"
|
||||
|
||||
Id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
bizDocumentId: Mapped[int] = mapped_column(BigInteger, unique=True, comment="关联业务 documents.id")
|
||||
typeId: Mapped[int | None] = mapped_column(BigInteger, comment="文档类型ID")
|
||||
processingStatus: Mapped[str | None] = mapped_column(String(64), default="waiting", comment="waiting/processing/completed/failed")
|
||||
currentRunId: Mapped[int | None] = mapped_column(BigInteger, comment="最新有效 run id")
|
||||
|
||||
@classmethod
|
||||
async def get_by_biz_id(cls, session: AsyncSession, bizDocumentId: int) -> "LeauditDocument | None":
|
||||
"""按业务文档 ID 查询。"""
|
||||
from sqlalchemy import select
|
||||
return await session.scalar(select(cls).where(cls.bizDocumentId == bizDocumentId))
|
||||
|
||||
@classmethod
|
||||
async def upsert_by_biz_id(cls, session: AsyncSession, bizDocumentId: int, **fields) -> "LeauditDocument":
|
||||
"""按业务文档 ID 创建或更新。"""
|
||||
from sqlalchemy import select
|
||||
doc = await session.scalar(select(cls).where(cls.bizDocumentId == bizDocumentId))
|
||||
if doc is None:
|
||||
doc = cls(bizDocumentId=bizDocumentId, **fields)
|
||||
session.add(doc)
|
||||
else:
|
||||
for k, v in fields.items():
|
||||
setattr(doc, k, v)
|
||||
await session.flush()
|
||||
return doc
|
||||
@@ -0,0 +1,28 @@
|
||||
"""LeAudit 文档文件模型 —— leaudit_document_files 表。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from sqlalchemy import BigInteger, Boolean, String
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from fastapi_common.fastapi_common_web.models import BaseModel
|
||||
|
||||
|
||||
class LeauditDocumentFile(BaseModel):
|
||||
"""文档文件表。"""
|
||||
|
||||
__tablename__ = "leaudit_document_files"
|
||||
|
||||
Id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
|
||||
documentId: Mapped[int] = mapped_column(BigInteger, comment="关联 leaudit_documents.id")
|
||||
fileRole: Mapped[str] = mapped_column(String(64), comment="original/converted_pdf/merged_pdf/temp_input")
|
||||
fileName: Mapped[str] = mapped_column(String(512), comment="文件名")
|
||||
fileExt: Mapped[str | None] = mapped_column(String(32), comment="扩展名")
|
||||
mimeType: Mapped[str | None] = mapped_column(String(128), comment="MIME")
|
||||
fileSize: Mapped[int | None] = mapped_column(BigInteger, comment="文件大小")
|
||||
sha256: Mapped[str | None] = mapped_column(String(64), comment="SHA256")
|
||||
localPath: Mapped[str | None] = mapped_column(String(1024), comment="本地路径")
|
||||
ossUrl: Mapped[str | None] = mapped_column(String(2048), comment="OSS 地址")
|
||||
storageProvider: Mapped[str | None] = mapped_column(String(32), comment="oss/minio/local")
|
||||
isActive: Mapped[bool] = mapped_column(Boolean, default=True, comment="当前生效文件")
|
||||
createdBy: Mapped[int | None] = mapped_column(BigInteger, comment="上传人")
|
||||
@@ -0,0 +1,8 @@
|
||||
"""LeAudit 服务层导出。"""
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.services.auditService import IAuditService
|
||||
from fastapi_modules.fastapi_leaudit.services.authService import IAuthService
|
||||
from fastapi_modules.fastapi_leaudit.services.permissionService import IPermissionService
|
||||
from fastapi_modules.fastapi_leaudit.services.ruleService import IRuleService
|
||||
|
||||
__all__ = ["IAuditService", "IAuthService", "IPermissionService", "IRuleService"]
|
||||
@@ -0,0 +1,24 @@
|
||||
"""评查服务接口。"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.domian.vo.auditVo import AuditRunVO, AuditResultVO
|
||||
|
||||
|
||||
class IAuditService(ABC):
|
||||
"""评查服务接口。"""
|
||||
|
||||
@abstractmethod
|
||||
async def Run(self, DocumentId: int) -> AuditRunVO:
|
||||
"""触发文档评查。"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def GetRunStatus(self, RunId: int) -> AuditRunVO:
|
||||
"""查询评查运行状态。"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def GetResult(self, RunId: int) -> AuditResultVO:
|
||||
"""获取评查结果。"""
|
||||
...
|
||||
@@ -0,0 +1,22 @@
|
||||
"""认证服务接口。"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.domian.vo.auth.loginTokenVo import LoginTokenVO
|
||||
|
||||
|
||||
class IAuthService(ABC):
|
||||
"""认证服务接口。"""
|
||||
|
||||
@abstractmethod
|
||||
async def PasswordLogin(self, Sub: str, Password: str) -> LoginTokenVO:
|
||||
"""账密登录。"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def OAuthLogin(self, Sub: str, Username: str | None, Nickname: str | None,
|
||||
Email: str | None, PhoneNumber: str | None,
|
||||
OuId: str | None, OuName: str | None,
|
||||
IsLeader: bool | None, Area: str | None, ExpiresIn: int) -> LoginTokenVO:
|
||||
"""OAuth 登录。"""
|
||||
...
|
||||
@@ -0,0 +1,65 @@
|
||||
"""评查服务实现。
|
||||
|
||||
编排 LeAudit 引擎执行链路:
|
||||
文档 → OCR → Extract → Evaluate → Rescue → Persist
|
||||
"""
|
||||
|
||||
from fastapi_common.fastapi_common_logger import logger
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
|
||||
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.domian.vo.auditVo import AuditRunVO, AuditResultVO
|
||||
from fastapi_modules.fastapi_leaudit.models import LeauditAuditRun
|
||||
from fastapi_modules.fastapi_leaudit.services import IAuditService
|
||||
|
||||
|
||||
class AuditServiceImpl(IAuditService):
|
||||
"""评查服务实现。"""
|
||||
|
||||
async def Run(self, DocumentId: int, RuleType: str | None = None, Force: bool = False) -> AuditRunVO:
|
||||
"""触发文档评查。
|
||||
|
||||
实际执行流程由 Celery 任务异步处理。
|
||||
"""
|
||||
async with GetAsyncSession() as session:
|
||||
# TODO: 从 bridge 层获取 pipeline,提交 Celery 任务
|
||||
logger.info(f"触发评查: documentId={DocumentId}, ruleType={RuleType}")
|
||||
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "Celery 任务集成待实现")
|
||||
|
||||
async def GetRunStatus(self, RunId: int) -> AuditRunVO:
|
||||
"""查询评查运行状态。"""
|
||||
async with GetAsyncSession() as session:
|
||||
run = await session.get(LeauditAuditRun, RunId)
|
||||
if not run:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "评查运行记录不存在")
|
||||
return AuditRunVO(
|
||||
runId=run.Id,
|
||||
documentId=run.documentId,
|
||||
runNo=run.runNo,
|
||||
status=run.status,
|
||||
phase=run.phase,
|
||||
totalScore=float(run.totalScore) if run.totalScore else None,
|
||||
passedCount=run.passedCount,
|
||||
failedCount=run.failedCount,
|
||||
startedAt=run.startedAt,
|
||||
finishedAt=run.finishedAt,
|
||||
)
|
||||
|
||||
async def GetResult(self, RunId: int) -> AuditResultVO:
|
||||
"""获取评查结果。"""
|
||||
async with GetAsyncSession() as session:
|
||||
run = await session.get(LeauditAuditRun, RunId)
|
||||
if not run:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "评查运行记录不存在")
|
||||
# TODO: 从 leaudit_rule_results 表查询规则级结果
|
||||
return AuditResultVO(
|
||||
runId=run.Id,
|
||||
totalScore=float(run.totalScore) if run.totalScore else None,
|
||||
passedCount=run.passedCount or 0,
|
||||
failedCount=run.failedCount or 0,
|
||||
skippedCount=run.skippedCount or 0,
|
||||
phase=run.phase,
|
||||
rescueApplied=run.rescueApplied or False,
|
||||
rules=[],
|
||||
)
|
||||
@@ -0,0 +1,162 @@
|
||||
"""认证服务实现。
|
||||
|
||||
从旧项目 app/routes/auth.py 和 app/auth/auth.py 迁移,业务逻辑完全不变。
|
||||
仅重组为 Controller → Service(interface+impl) → Model 结构。
|
||||
"""
|
||||
|
||||
from fastapi_common.fastapi_common_logger import logger
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from fastapi_common.fastapi_common_security.jwtService import JwtService
|
||||
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
|
||||
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.domian.vo.auth.loginTokenVo import LoginTokenVO
|
||||
from fastapi_modules.fastapi_leaudit.services.authService import IAuthService
|
||||
|
||||
|
||||
class AuthServiceImpl(IAuthService):
|
||||
"""认证服务实现。"""
|
||||
|
||||
async def PasswordLogin(self, Sub: str, Password: str) -> LoginTokenVO:
|
||||
"""账密登录。
|
||||
|
||||
校验 sso_users 表:sub + password + status=0 + deleted_at IS NULL。
|
||||
安全:统一错误提示"账号或密码错误",防止用户枚举。
|
||||
"""
|
||||
async with GetAsyncSession() as session:
|
||||
from sqlalchemy import select, text
|
||||
|
||||
result = await session.execute(
|
||||
text("SELECT id, sub, username, nick_name, phone_number, email, "
|
||||
"ou_id, ou_name, is_leader, password, status, deleted_at, "
|
||||
"try_count, try_login_time, area, tenant_name, dep_name, dep_short_name "
|
||||
"FROM sso_users WHERE sub = :sub"),
|
||||
{"sub": Sub},
|
||||
)
|
||||
row = result.fetchone()
|
||||
|
||||
if not row:
|
||||
logger.warning(f"登录失败: 用户不存在 - sub={Sub}")
|
||||
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号或密码错误")
|
||||
|
||||
user = dict(row._mapping)
|
||||
|
||||
if user.get("deleted_at") is not None:
|
||||
logger.warning(f"登录失败: 账号已删除 - sub={Sub}")
|
||||
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号或密码错误")
|
||||
|
||||
if user.get("status") != 0:
|
||||
logger.warning(f"登录失败: 账号已禁用 - sub={Sub}, status={user.get('status')}")
|
||||
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号或密码错误")
|
||||
|
||||
if user.get("password") != Password:
|
||||
logger.warning(f"登录失败: 密码错误 - sub={Sub}")
|
||||
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号或密码错误")
|
||||
|
||||
return await self._buildLoginResponse(user, session)
|
||||
|
||||
async def OAuthLogin(self, Sub: str, Username: str | None, Nickname: str | None,
|
||||
Email: str | None, PhoneNumber: str | None,
|
||||
OuId: str | None, OuName: str | None,
|
||||
IsLeader: bool | None, Area: str | None, ExpiresIn: int) -> LoginTokenVO:
|
||||
"""OAuth 登录。验证 sub 是否存在,不存在则自动创建用户。"""
|
||||
async with GetAsyncSession() as session:
|
||||
from sqlalchemy import select, text
|
||||
from datetime import datetime, timezone
|
||||
|
||||
result = await session.execute(
|
||||
text("SELECT id, sub, username, nick_name, phone_number, email, "
|
||||
"ou_id, ou_name, is_leader, status, deleted_at, "
|
||||
"area, tenant_name, dep_name, dep_short_name "
|
||||
"FROM sso_users WHERE sub = :sub"),
|
||||
{"sub": Sub},
|
||||
)
|
||||
row = result.fetchone()
|
||||
|
||||
if row:
|
||||
user = dict(row._mapping)
|
||||
if user.get("deleted_at") is not None or user.get("status") != 0:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号已被禁用或删除")
|
||||
# 更新最后登录信息
|
||||
await session.execute(
|
||||
text("UPDATE sso_users SET username = :username, nick_name = :nick, "
|
||||
"email = :email, phone_number = :phone, ou_id = :ou_id, "
|
||||
"ou_name = :ou_name, is_leader = :is_leader, area = :area, "
|
||||
"updated_at = :now WHERE id = :id"),
|
||||
{"username": Username, "nick": Nickname, "email": Email,
|
||||
"phone": PhoneNumber, "ou_id": OuId, "ou_name": OuName,
|
||||
"is_leader": IsLeader, "area": Area,
|
||||
"now": datetime.now(timezone.utc), "id": user["id"]},
|
||||
)
|
||||
else:
|
||||
# 自动创建用户
|
||||
await session.execute(
|
||||
text("INSERT INTO sso_users (sub, username, nick_name, email, "
|
||||
"phone_number, ou_id, ou_name, is_leader, area, status) "
|
||||
"VALUES (:sub, :username, :nick, :email, :phone, :ou_id, "
|
||||
":ou_name, :is_leader, :area, 0)"),
|
||||
{"sub": Sub, "username": Username, "nick": Nickname, "email": Email,
|
||||
"phone": PhoneNumber, "ou_id": OuId, "ou_name": OuName,
|
||||
"is_leader": IsLeader, "area": Area},
|
||||
)
|
||||
await session.commit()
|
||||
|
||||
result = await session.execute(
|
||||
text("SELECT id, sub, username, nick_name, phone_number, email, "
|
||||
"ou_id, ou_name, is_leader, area, tenant_name, dep_name, dep_short_name "
|
||||
"FROM sso_users WHERE sub = :sub"),
|
||||
{"sub": Sub},
|
||||
)
|
||||
row = result.fetchone()
|
||||
user = dict(row._mapping) if row else {}
|
||||
|
||||
if not user:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "用户创建失败")
|
||||
|
||||
return await self._buildLoginResponse(user, session)
|
||||
|
||||
async def _buildLoginResponse(self, user: dict, session) -> LoginTokenVO:
|
||||
"""组装登录响应:查询角色 → 签发 JWT → 返回 LoginTokenVO。"""
|
||||
from sqlalchemy import text
|
||||
|
||||
# 查询用户角色
|
||||
roleResult = await session.execute(
|
||||
text("SELECT r.role_key FROM user_role ur "
|
||||
"JOIN roles r ON ur.role_id = r.id "
|
||||
"WHERE ur.user_id = :uid LIMIT 1"),
|
||||
{"uid": user["id"]},
|
||||
)
|
||||
roleRow = roleResult.fetchone()
|
||||
userRole = roleRow[0] if roleRow else "common"
|
||||
|
||||
# 签发 JWT
|
||||
expiresIn = 3600 # 默认 1 小时
|
||||
tokens = JwtService.generate(
|
||||
userId=user["id"],
|
||||
username=user.get("username") or user.get("sub", ""),
|
||||
nickName=user.get("nick_name") or "",
|
||||
ouId=user.get("ou_id") or "",
|
||||
ouName=user.get("ou_name") or "",
|
||||
area=user.get("area"),
|
||||
userRole=userRole,
|
||||
)
|
||||
|
||||
return LoginTokenVO(
|
||||
access_token=tokens["access_token"],
|
||||
token_type="Bearer",
|
||||
expires_in=expiresIn,
|
||||
issued_time=tokens.get("issued_time", ""),
|
||||
user_info={
|
||||
"user_id": user["id"],
|
||||
"sub": user.get("sub"),
|
||||
"username": user.get("username"),
|
||||
"nick_name": user.get("nick_name"),
|
||||
"email": user.get("email"),
|
||||
"phone_number": user.get("phone_number"),
|
||||
"ou_id": user.get("ou_id"),
|
||||
"ou_name": user.get("ou_name"),
|
||||
"is_leader": user.get("is_leader"),
|
||||
"area": user.get("area"),
|
||||
"role": userRole,
|
||||
},
|
||||
)
|
||||
@@ -0,0 +1,141 @@
|
||||
"""权限服务实现。
|
||||
|
||||
从旧项目 app/rbac/permission_checker_v2.py 迁移,业务逻辑完全不变。
|
||||
- 数据库驱动 GRANT/DENY 机制(DENY 优先级更高)
|
||||
- 支持通配符:document:*:*、*:*:* 等
|
||||
- 实时查询(无 Redis 缓存)
|
||||
仅改造为 SQLAlchemy 会话 + 项目统一配置。
|
||||
"""
|
||||
|
||||
from fastapi_common.fastapi_common_logger import logger
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.services.permissionService import IPermissionService
|
||||
|
||||
|
||||
class PermissionServiceImpl(IPermissionService):
|
||||
"""权限检查服务实现。"""
|
||||
|
||||
async def CheckPermission(self, UserId: int, PermissionKey: str) -> bool:
|
||||
"""检查用户是否拥有指定权限。
|
||||
|
||||
GRANT/DENY 优先级:
|
||||
1. 精确 DENY → 拒绝
|
||||
2. 通配符 DENY → 拒绝
|
||||
3. 精确 GRANT → 通过
|
||||
4. 通配符 GRANT → 通过
|
||||
5. 无匹配 → 拒绝
|
||||
"""
|
||||
try:
|
||||
grants, denies = await self._getUserPermissions(UserId)
|
||||
|
||||
# DENY 优先
|
||||
if PermissionKey in denies:
|
||||
logger.debug(f"[DENY] 精确拒绝: user={UserId}, perm={PermissionKey}")
|
||||
return False
|
||||
|
||||
if self._matchWildcard(PermissionKey, denies):
|
||||
logger.debug(f"[DENY] 通配符拒绝: user={UserId}, perm={PermissionKey}")
|
||||
return False
|
||||
|
||||
# GRANT
|
||||
if PermissionKey in grants:
|
||||
logger.debug(f"[GRANT] 精确授权: user={UserId}, perm={PermissionKey}")
|
||||
return True
|
||||
|
||||
if self._matchWildcard(PermissionKey, grants):
|
||||
logger.debug(f"[GRANT] 通配符授权: user={UserId}, perm={PermissionKey}")
|
||||
return True
|
||||
|
||||
logger.debug(f"[DENY] 无匹配权限: user={UserId}, perm={PermissionKey}")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"权限检查异常: user={UserId}, perm={PermissionKey}, error={e}")
|
||||
return False # 安全优先:异常时拒绝
|
||||
|
||||
async def HasAnyPermission(self, UserId: int, PermissionKeys: list[str]) -> bool:
|
||||
"""OR 逻辑:任一权限通过即返回 True。"""
|
||||
for key in PermissionKeys:
|
||||
if await self.CheckPermission(UserId, key):
|
||||
return True
|
||||
return False
|
||||
|
||||
async def HasAllPermissions(self, UserId: int, PermissionKeys: list[str]) -> bool:
|
||||
"""AND 逻辑:全部权限通过才返回 True。"""
|
||||
for key in PermissionKeys:
|
||||
if not await self.CheckPermission(UserId, key):
|
||||
return False
|
||||
return True
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# 内部方法
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _getUserPermissions(self, UserId: int) -> tuple[set[str], set[str]]:
|
||||
"""从数据库查询用户的 GRANT 和 DENY 权限集合。"""
|
||||
grants: set[str] = set()
|
||||
denies: set[str] = set()
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
from sqlalchemy import text
|
||||
|
||||
result = await session.execute(
|
||||
text(
|
||||
"SELECT p.permission_key, rp.grant_type "
|
||||
"FROM sso_users u "
|
||||
"JOIN user_role ur ON u.id = ur.user_id "
|
||||
"JOIN roles r ON ur.role_id = r.id "
|
||||
"JOIN role_permissions rp ON r.id = rp.role_id "
|
||||
"JOIN permissions p ON rp.permission_id = p.id "
|
||||
"WHERE u.id = :uid"
|
||||
),
|
||||
{"uid": UserId},
|
||||
)
|
||||
|
||||
rows = result.fetchall()
|
||||
if not rows:
|
||||
logger.warning(f"用户无角色或权限: user_id={UserId}")
|
||||
return grants, denies
|
||||
|
||||
for row in rows:
|
||||
permKey = row[0]
|
||||
grantType = row[1]
|
||||
if grantType == "GRANT":
|
||||
grants.add(permKey)
|
||||
elif grantType == "DENY":
|
||||
denies.add(permKey)
|
||||
|
||||
logger.debug(f"用户权限: user={UserId}, grants={len(grants)}, denies={len(denies)}")
|
||||
return grants, denies
|
||||
|
||||
@staticmethod
|
||||
def _matchWildcard(PermissionKey: str, PermissionSet: set[str]) -> bool:
|
||||
"""检查 PermissionKey 是否匹配集合中的任一通配符模式。"""
|
||||
for pattern in PermissionSet:
|
||||
if PermissionServiceImpl._wildcardMatch(PermissionKey, pattern):
|
||||
return True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _wildcardMatch(PermissionKey: str, Pattern: str) -> bool:
|
||||
"""通配符匹配。
|
||||
|
||||
_wildcardMatch("document:read:all", "document:*:*") → True
|
||||
_wildcardMatch("document:read:all", "document:read:*") → True
|
||||
_wildcardMatch("document:read:all", "evaluation:*:*") → False
|
||||
_wildcardMatch("document:read:all", "*:*:*") → True
|
||||
"""
|
||||
keyParts = PermissionKey.split(":")
|
||||
patternParts = Pattern.split(":")
|
||||
|
||||
if len(keyParts) != len(patternParts):
|
||||
return False
|
||||
|
||||
for keyPart, patternPart in zip(keyParts, patternParts):
|
||||
if patternPart == "*":
|
||||
continue
|
||||
if keyPart != patternPart:
|
||||
return False
|
||||
|
||||
return True
|
||||
@@ -0,0 +1,27 @@
|
||||
"""权限服务接口。"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class IPermissionService(ABC):
|
||||
"""权限检查服务接口。
|
||||
|
||||
权限格式:module:resource:action(如 document:list:read)
|
||||
支持通配符:document:*:*、*:*:* 等
|
||||
支持 GRANT/DENY 机制(DENY 优先级更高)
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def CheckPermission(self, UserId: int, PermissionKey: str) -> bool:
|
||||
"""检查用户是否拥有指定权限。"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def HasAnyPermission(self, UserId: int, PermissionKeys: list[str]) -> bool:
|
||||
"""检查用户是否拥有任意一个权限(OR 逻辑)。"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def HasAllPermissions(self, UserId: int, PermissionKeys: list[str]) -> bool:
|
||||
"""检查用户是否拥有所有权限(AND 逻辑)。"""
|
||||
...
|
||||
@@ -0,0 +1,24 @@
|
||||
"""规则服务接口。"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.domian.vo.ruleVo import RuleSetVO, RuleVersionVO
|
||||
|
||||
|
||||
class IRuleService(ABC):
|
||||
"""规则服务接口。"""
|
||||
|
||||
@abstractmethod
|
||||
async def ListSets(self) -> list[RuleSetVO]:
|
||||
"""列出所有规则集。"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def GetVersions(self, RuleType: str) -> list[RuleVersionVO]:
|
||||
"""获取规则集的所有版本。"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def Publish(self, RuleType: str, VersionId: int) -> RuleVersionVO:
|
||||
"""发布指定版本。"""
|
||||
...
|
||||
Reference in New Issue
Block a user