chore: initial commit — leaudit-platform project skeleton

17-table PostgreSQL schema with full Chinese column comments,
FastAPI project structure (admin/common/modules),
DSL rule files, and schema migration scripts.
This commit is contained in:
wren
2026-04-27 16:48:22 +08:00
commit 535d97a70c
142 changed files with 25219 additions and 0 deletions
@@ -0,0 +1,15 @@
"""控制器包(需要 JWT 鉴权)。"""
from typing import Any
from fastapi import APIRouter, Depends, Request
from fastapi_common.fastapi_common_security.security import verify_access_token
async def jwt_auth_dependency(RequestObj: Request) -> dict[str, Any]:
"""JWT 鉴权依赖。"""
return verify_access_token(RequestObj)
router = APIRouter(dependencies=[Depends(jwt_auth_dependency)])
@@ -0,0 +1,42 @@
"""评查控制器。"""
from fastapi_common.fastapi_common_web.controller import BaseController
from fastapi_common.fastapi_common_web.domain.responses import Result
from fastapi_modules.fastapi_leaudit.domian.Dto.auditDto import AuditRunDTO
from fastapi_modules.fastapi_leaudit.domian.vo.auditVo import AuditRunVO, AuditResultVO
from fastapi_modules.fastapi_leaudit.services import IAuditService
from fastapi_modules.fastapi_leaudit.services.impl.auditServiceImpl import AuditServiceImpl
class AuditController(BaseController):
"""评查控制器。"""
def __init__(self):
super().__init__(prefix="/audit", tags=["评查"])
self.AuditService: IAuditService = AuditServiceImpl()
@self.router.post("/run", response_model=Result[AuditRunVO])
async def RunAudit(body: AuditRunDTO):
"""触发文档评查
对指定文档执行 LeAudit 完整评查链路。
"""
run = await self.AuditService.Run(
DocumentId=body.documentId,
RuleType=body.ruleType,
Force=body.force,
)
return Result.success(data=run)
@self.router.get("/run/{RunId}", response_model=Result[AuditRunVO])
async def GetRunStatus(RunId: int):
"""查询评查运行状态。"""
run = await self.AuditService.GetRunStatus(RunId)
return Result.success(data=run)
@self.router.get("/result/{RunId}", response_model=Result[AuditResultVO])
async def GetResult(RunId: int):
"""获取评查结果。"""
result = await self.AuditService.GetResult(RunId)
return Result.success(data=result)
@@ -0,0 +1,5 @@
"""认证控制器包(无鉴权)。"""
from fastapi import APIRouter
router = APIRouter()
@@ -0,0 +1,97 @@
"""认证控制器。
路由路径与旧项目完全一致:
POST /auth/login — 统一登录(OAuth + 密码自动检测)
POST /auth/password_login — 账密登录
响应格式按新项目规范使用 Result。
"""
from fastapi import Request
from fastapi.responses import JSONResponse
from fastapi_common.fastapi_common_web.controller import BaseController
from fastapi_common.fastapi_common_web.domain.responses import Result
from fastapi_common.fastapi_common_logger import logger
from fastapi_modules.fastapi_leaudit.domian.Dto.auth.loginDto import PasswordLoginDTO, OAuthLoginDTO
from fastapi_modules.fastapi_leaudit.domian.vo.auth.loginTokenVo import LoginTokenVO
from fastapi_modules.fastapi_leaudit.services import IAuthService
from fastapi_modules.fastapi_leaudit.services.impl.authServiceImpl import AuthServiceImpl
class AuthController(BaseController):
"""认证控制器。"""
def __init__(self):
super().__init__(prefix="/auth", tags=["认证"])
self.AuthService: IAuthService = AuthServiceImpl()
@self.router.post("/login")
async def Login(RequestObj: Request):
"""统一登录接口。
自动检测登录方式:
- 含 userInfo.sub → OAuth 登录
- 含 username + password → 密码登录
"""
try:
requestData = await RequestObj.json()
if "userInfo" in requestData and isinstance(requestData["userInfo"], dict) and "sub" in requestData["userInfo"]:
logger.info("检测到 OAuth 登录请求")
ui = requestData["userInfo"]
vo = await self.AuthService.OAuthLogin(
Sub=ui["sub"],
Username=ui.get("username"),
Nickname=ui.get("nickname"),
Email=ui.get("email"),
PhoneNumber=ui.get("phone_number"),
OuId=ui.get("ou_id"),
OuName=ui.get("ou_name"),
IsLeader=ui.get("is_leader"),
Area=requestData.get("area"),
ExpiresIn=requestData.get("expiresIn", 3600),
)
elif "username" in requestData and "password" in requestData:
logger.info(f"检测到密码登录请求 - username={requestData['username']}")
vo = await self.AuthService.PasswordLogin(
Sub=requestData["username"],
Password=requestData["password"],
)
else:
return JSONResponse(status_code=400, content={"code": 400, "message": "无效的登录请求格式", "data": None})
return JSONResponse(status_code=200, content={
"code": 200,
"message": "ok",
"data": {
"access_token": vo.access_token,
"token_type": vo.token_type,
"expires_in": vo.expires_in,
"issued_time": vo.issued_time,
"user_info": vo.user_info,
},
})
except Exception as e:
logger.error(f"登录失败: {e}")
return JSONResponse(status_code=401, content={
"code": 401, "message": str(e), "data": None,
})
@self.router.post("/password_login")
async def PasswordLogin(RequestObj: Request):
"""账密登录。校验 sso_users 表 sub + password。"""
try:
requestData = await RequestObj.json()
dto = PasswordLoginDTO(**requestData)
vo = await self.AuthService.PasswordLogin(Sub=dto.sub, Password=dto.password)
return JSONResponse(status_code=200, content={
"code": 200, "message": "ok", "data": vo.model_dump(),
})
except Exception as e:
logger.error(f"密码登录失败: {e}")
return JSONResponse(status_code=401, content={
"code": 401, "message": str(e), "data": None,
})
@@ -0,0 +1,11 @@
"""评查 DTO(仅控制器层使用)。"""
from pydantic import BaseModel, Field
class AuditRunDTO(BaseModel):
"""触发评查请求。"""
documentId: int = Field(..., description="文档ID")
ruleType: str | None = Field(None, description="指定规则类型编码")
force: bool = Field(False, description="是否强制重跑")
@@ -0,0 +1,31 @@
"""认证 DTO(仅控制器层使用)。"""
from pydantic import BaseModel, Field
class PasswordLoginDTO(BaseModel):
"""账密登录请求。"""
sub: str = Field(..., description="账号")
password: str = Field(..., description="密码")
class OAuthUserInfo(BaseModel):
"""OAuth 用户信息。"""
sub: str = Field(..., description="IDaaS 用户唯一标识")
username: str | None = Field(None, description="用户名/工号")
nickname: str | None = Field(None, description="用户昵称")
email: str | None = Field(None, description="邮箱")
phone_number: str | None = Field(None, description="手机号")
ou_id: str | None = Field(None, description="组织单位ID")
ou_name: str | None = Field(None, description="组织单位名称")
is_leader: bool | None = Field(False, description="是否为负责人")
class OAuthLoginDTO(BaseModel):
"""OAuth 登录请求。"""
userInfo: OAuthUserInfo = Field(..., description="OAuth 用户信息")
expiresIn: int = Field(..., description="OAuth token 过期时间(秒)")
area: str | None = Field(None, description="用户所属地区")
@@ -0,0 +1,33 @@
"""评查 VO。"""
from datetime import datetime
from pydantic import BaseModel, Field
class AuditRunVO(BaseModel):
"""评查运行响应。"""
runId: int = Field(..., description="运行ID")
documentId: int = Field(..., description="文档ID")
runNo: int = Field(..., description="执行序号")
status: str = Field(..., description="状态")
phase: str | None = Field(None, description="draft/executed")
totalScore: float | None = Field(None, description="总分")
passedCount: int | None = Field(None, description="通过数")
failedCount: int | None = Field(None, description="失败数")
startedAt: datetime | None = Field(None, description="开始时间")
finishedAt: datetime | None = Field(None, description="结束时间")
class AuditResultVO(BaseModel):
"""评查结果响应。"""
runId: int = Field(..., description="运行ID")
totalScore: float | None = Field(None, description="总分")
passedCount: int = Field(0, description="通过数")
failedCount: int = Field(0, description="失败数")
skippedCount: int = Field(0, description="跳过数")
phase: str | None = Field(None, description="draft/executed")
rescueApplied: bool = Field(False, description="是否执行 rescue")
rules: list[dict] = Field(default_factory=list, description="规则结果列表")
@@ -0,0 +1,13 @@
"""认证 VO(控制器层 + 服务层使用)。"""
from pydantic import BaseModel, Field
class LoginTokenVO(BaseModel):
"""登录响应。"""
access_token: str = Field(..., description="JWT Token")
token_type: str = Field("Bearer", description="Token 类型")
expires_in: int = Field(..., description="Token 过期时间(秒)")
issued_time: str = Field(..., description="签发时间 YYYY-MM-DD HH:MM:SS")
user_info: dict = Field(..., description="用户信息")
@@ -0,0 +1,26 @@
"""规则 VO。"""
from pydantic import BaseModel, Field
class RuleSetVO(BaseModel):
"""规则集响应。"""
id: int = Field(..., description="规则集ID")
ruleType: str = Field(..., description="业务规则类型编码")
ruleName: str = Field(..., description="规则集名称")
domainType: str | None = Field(None, description="域类型")
currentVersionId: int | None = Field(None, description="当前激活版本ID")
status: str = Field(..., description="draft/active/inactive/archived")
class RuleVersionVO(BaseModel):
"""规则版本响应。"""
id: int = Field(..., description="版本ID")
ruleSetId: int = Field(..., description="所属规则集ID")
versionNo: str = Field(..., description="版本号")
status: str = Field(..., description="draft/validated/published/rolled_back")
ossUrl: str = Field(..., description="YAML 文件 OSS 地址")
changeNote: str | None = Field(None, description="变更说明")
publishedAt: str | None = Field(None, description="发布时间")
@@ -0,0 +1,82 @@
"""leaudit bridge — use leaudit's full pipeline with docauditai's database storage.
Directly calls leaudit's OCR → extraction → evaluation pipeline
and persists results into docauditai's PostgreSQL via PostgREST.
Configuration switch (in env.{port}):
PIPELINE_MODE=leaudit → use leaudit pipeline
"""
from leaudit_bridge.client_factory import (
create_ocr_client,
create_llm_client,
create_vlm_client,
)
from leaudit_bridge.ocr_bridge import BridgeOCRClient
from leaudit_bridge.pipeline import LauditPipeline, PipelineResult
from leaudit_bridge.rules_loader import RulesLoader
from leaudit_bridge.storage_adapter import StorageAdapter
def is_leaudit_mode() -> bool:
"""Check if the system is configured to use the leaudit pipeline."""
from core.config import PIPELINE_MODE
return PIPELINE_MODE == "leaudit"
def create_pipeline(rules_path: str | None = None) -> LauditPipeline:
"""Create a fully configured LauditPipeline from current config.
Wraps the raw OCR client with DocNormalizationAdapter so that a single
``.ocr()`` call produces a fully enriched OcrResult with:
- Document classification (type_id + rules_file_path)
- Dossier segmentation (sub-document page mapping)
- Seal/signature enrichment (text, seal_id, party_id)
- Normalized markdown (seal blocks + page separators)
Args:
rules_path: If provided, forces the adapter to use this rules file
for classification and segmentation. When None, the adapter
uses the RulesFileRegistry to classify from document content,
enabling auto-detection of sub-types (e.g. 行政许可 variants).
"""
from pathlib import Path
from leaudit.doc_normalization.adapter import DocNormalizationAdapter
from leaudit.doc_normalization.doc_classifier import RulesFileRegistry
raw_ocr = create_ocr_client()
llm_client = create_llm_client()
vlm_client = create_vlm_client()
# Build registry from rules/ directory for content-based classification
registry = None
if rules_path is None:
rules_dir = Path(__file__).resolve().parents[1] / "rules"
if rules_dir.is_dir():
registry = RulesFileRegistry.from_directory(rules_dir)
ocr_client = DocNormalizationAdapter(
ocr_client=raw_ocr,
registry=registry,
llm_client=llm_client,
vlm_client=vlm_client,
force_rules_path=rules_path,
)
ocr_client = BridgeOCRClient(ocr_client, vlm_client=vlm_client)
return LauditPipeline(
ocr_client=ocr_client,
llm_client=llm_client,
)
__all__ = [
"LauditPipeline",
"PipelineResult",
"StorageAdapter",
"RulesLoader",
"create_ocr_client",
"create_llm_client",
"create_pipeline",
"is_leaudit_mode",
]
@@ -0,0 +1,128 @@
"""Extract case number (案件编号) from leaudit OcrResult.
Port of docauditai's ``extract_case_number_for_regex`` adapted for leaudit's
OcrResult model. Uses regex patterns first; falls back to LLM when available.
"""
from __future__ import annotations
import logging
import re
from typing import Any
from leaudit.ocr.models import OcrResult
log = logging.getLogger(__name__)
# Regex patterns for case number extraction
_PATTERNS: list[tuple[str, str]] = [
# Direct match: "案件编号:梅烟专罚〔2024〕第XX号"
(r"案件编号[:]\s*(.*?)(?:\n|$)", "direct"),
# From 卷宗/卷 宗 header: extract content between 卷宗 and 案由
(r"\s*宗([\s\S]*?)案\s*由", "file_content"),
# Standalone pattern: e.g. 梅烟专罚〔2024〕第001号
(r"[\u4e00-\u9fa5]{2,8}[专罚处决][〔(\(\[]\d{4}[〕)\)\]][\u4e00-\u9fa5]*\d+号", "standalone"),
# Dossier number: e.g. "2024 年度 郁烟 第 71 号"
(r"\d{4}\s*年度\s*[\u4e00-\u9fa5]{1,6}\s*第\s*\d+\s*号", "dossier"),
]
# Chinese number format within 卷宗 extracted text
_YEAR_NUMBER_RE = re.compile(r"\d{4}[\u4e00-\u9fa5]+\d+号")
def extract_case_number(ocr_result: OcrResult) -> str | None:
"""Extract case number from OCR result using regex patterns.
Searches across all pages but prioritizes early pages where case numbers
typically appear (封面, 卷宗封面).
Args:
ocr_result: OCR result with pages containing text.
Returns:
Extracted case number string, or None if not found.
"""
# Build text from first few pages (case numbers appear early)
pages_to_check = ocr_result.pages[:5] if len(ocr_result.pages) > 5 else ocr_result.pages
text = "\n".join(p.text for p in pages_to_check)
if not text.strip():
return None
for pattern, ptype in _PATTERNS:
match = re.search(pattern, text)
if not match:
continue
if ptype == "direct":
return match.group(1).strip()
if ptype == "file_content":
content = match.group(1).strip()
num_match = _YEAR_NUMBER_RE.search(content)
if num_match:
return num_match.group()
if ptype == "standalone":
return match.group()
if ptype == "dossier":
return match.group()
return None
async def extract_case_number_with_llm(
ocr_result: OcrResult,
llm_client: Any = None,
) -> str | None:
"""Extract case number using regex first, then LLM fallback.
Args:
ocr_result: OCR result with pages containing text.
llm_client: Optional LLM client for fallback extraction.
Returns:
Extracted case number string, or None if not found.
"""
# Try regex first (fast, no API call)
result = extract_case_number(ocr_result)
if result:
return result
# LLM fallback
if llm_client is None:
return None
text = "\n".join(p.text for p in ocr_result.pages[:5])
if not text.strip():
return None
try:
from leaudit.llm.base import BaseLLMClient, LLMRequest, LLMMessage
if not isinstance(llm_client, BaseLLMClient):
return None
prompt = (
"请从以下法律文书文本中提取案件编号。"
"案件编号通常格式如:梅烟专罚〔2024〕第001号。\n"
"只返回JSON: {\"case_number\": \"案件编号\"}\n\n"
f"文本:\n{text[:2000]}"
)
request = LLMRequest(
messages=[LLMMessage(role="user", content=prompt)],
response_format={"type": "json_object"},
max_tokens=512,
)
response = await llm_client.complete(request)
import json
parsed = json.loads(response.content)
case_number = parsed.get("case_number")
if case_number and case_number != "未找到":
if re.search(r"[\u4e00-\u9fa5]|[\d()]", case_number):
return case_number
except Exception as e:
log.warning("LLM case number extraction failed: %s", e)
return None
@@ -0,0 +1,80 @@
"""Create leaudit OCR/LLM/VLM clients from docauditai's env.{port} config."""
from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from core.config import (
OCR_CONFIG,
DEFAULT_BASE_URL,
DEFAULT_LLM_MODEL,
DEFAULT_API_KEY,
DEFAULT_VLM_BASE_URL,
DEFAULT_VLM_MODEL,
)
if TYPE_CHECKING:
from leaudit.llm.base import BaseLLMClient
from leaudit.llm.vlm_base import BaseVLMClient
from leaudit.ocr.base import BaseOCRClient
log = logging.getLogger(__name__)
def create_ocr_client() -> BaseOCRClient:
"""Create a leaudit ChandraOCRClient from LEAUDIT_OCR_URL config."""
import os
from leaudit.ocr.chandra_client import ChandraOCRClient
base_url = os.getenv("LEAUDIT_OCR_URL", "").rstrip("/")
if not base_url:
base_url = OCR_CONFIG["API_URL"].rsplit("/api/v1/ocr", 1)[0]
timeout = float(OCR_CONFIG["TIMEOUT"])
client = ChandraOCRClient(
base_url=base_url,
timeout=timeout,
include_images=True,
)
log.info("leaudit OCR client created: %s (timeout=%ss)", base_url, timeout)
return client
def create_llm_client() -> BaseLLMClient:
"""Create a leaudit OpenAICompatibleClient from docauditai's LLM config."""
from leaudit.llm.openai_client import OpenAICompatibleClient
base_url = DEFAULT_BASE_URL
model = DEFAULT_LLM_MODEL
api_key = DEFAULT_API_KEY or "no-key"
client = OpenAICompatibleClient(
api_key=api_key,
base_url=base_url,
default_model=model,
timeout=120.0,
)
log.info("leaudit LLM client created: %s (model=%s)", base_url, model)
return client
def create_vlm_client() -> BaseVLMClient | None:
"""Create a leaudit QwenVLMClient from docauditai's VLM config."""
from leaudit.llm.qwen_vlm_client import QwenVLMClient
base_url = DEFAULT_VLM_BASE_URL
model = DEFAULT_VLM_MODEL
api_key = DEFAULT_API_KEY or "no-key"
if not base_url or not model:
log.info("leaudit VLM client skipped: no VLM config")
return None
client = QwenVLMClient(
base_url=base_url,
api_key=api_key,
model=model,
)
log.info("leaudit VLM client created: %s (model=%s)", base_url, model)
return client
@@ -0,0 +1,132 @@
"""Build leaudit execution context from docauditai document data.
Currently leaudit's pipeline in docauditai bypasses leaudit's own
``AuditCtx`` / ``AuditService`` and calls engine modules directly.
This module encapsulates the pre-execution setup that currently lives
inlined in ``pipeline.py`` and ``tasks.py``:
- Resolve local file path (download from OSS to temp if needed)
- Determine RulesFile (from document metadata, type binding, or
content classification)
- Prepare OCR/LLM/VLM client references
"""
from __future__ import annotations
import logging
import os
import tempfile
from dataclasses import dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING
from leaudit.dsl.schema import RulesFile
from leaudit.llm.base import BaseLLMClient
from leaudit.ocr.base import BaseOCRClient
if TYPE_CHECKING:
from leaudit.llm.vlm_base import BaseVLMClient
log = logging.getLogger(__name__)
@dataclass
class ExecutionContext:
"""Everything leaudit needs to run for one document."""
document_id: int
file_path: Path
rules_file: RulesFile
ocr_client: BaseOCRClient
llm_client: BaseLLMClient | None = None
vlm_client: object | None = None
source_port: int = 8000
tmp_path: Path | None = None
metadata: dict = field(default_factory=dict)
def cleanup(self) -> None:
"""Remove temporary file if one was created."""
if self.tmp_path is not None:
try:
os.remove(self.tmp_path)
except OSError:
pass
class CtxBuilder:
"""Build :class:`ExecutionContext` from docauditai document data.
Handles the glue between docauditai's document model and leaudit's
execution expectations — primarily file-path resolution and rules
selection.
"""
def __init__(
self,
ocr_client: BaseOCRClient | None = None,
llm_client: BaseLLMClient | None = None,
vlm_client: object | None = None,
) -> None:
self.ocr_client = ocr_client
self.llm_client = llm_client
self.vlm_client = vlm_client
async def build(
self,
document_id: int,
file_path: str | Path | None = None,
file_content: bytes | None = None,
filename: str | None = None,
rules_file: RulesFile | None = None,
*,
source_port: int = 8000,
) -> ExecutionContext:
"""Build a ready-to-use execution context.
At least one of *file_path* or (*file_content* + *filename*)
must be provided.
Args:
document_id: docauditai document ID.
file_path: Existing local path to the document file.
file_content: Raw bytes (from DB or OSS) — a temp file is
created.
filename: Required when *file_content* is given.
rules_file: Pre-loaded RulesFile. When None, the caller
must resolve after OCR classification.
source_port: Instance port.
Returns:
ExecutionContext ready for pipeline.run().
"""
tmp_path: Path | None = None
if file_path is not None:
resolved = Path(file_path)
elif file_content is not None and filename is not None:
suffix = self._suffix(filename)
tmp = tempfile.NamedTemporaryFile(suffix=suffix, delete=False)
tmp.write(file_content)
tmp.close()
resolved = Path(tmp.name)
tmp_path = resolved
else:
raise ValueError(
"Either file_path or (file_content + filename) is required"
)
return ExecutionContext(
document_id=document_id,
file_path=resolved,
rules_file=rules_file, # type: ignore[arg-type]
ocr_client=self.ocr_client, # type: ignore[arg-type]
llm_client=self.llm_client,
vlm_client=self.vlm_client,
source_port=source_port,
tmp_path=tmp_path,
)
@staticmethod
def _suffix(filename: str) -> str:
_, ext = os.path.splitext(filename)
return ext if ext else ".pdf"
@@ -0,0 +1,272 @@
"""Bridge-side OCR post-processing for leaudit integration.
Keeps docauditai-specific fixes outside ``services/leaudit/**``:
- DOCX embedded-image visuals can be refined once more with the VLM after
the merged ``OcrResult`` is built.
- Cross-page seals with missing completeness flags are normalized so the
legacy compatibility checks have a stable shape to consume.
"""
from __future__ import annotations
import logging
from io import BytesIO
from pathlib import Path
from leaudit.ocr.base import BaseOCRClient
from leaudit.ocr.models import OcrResult, VisualManifestItem
log = logging.getLogger(__name__)
class BridgeOCRClient(BaseOCRClient):
"""Wrap an OCR client and apply integration-side post-processing."""
def __init__(
self,
inner: BaseOCRClient,
*,
vlm_client: object | None = None,
vlm_concurrency: int = 6,
) -> None:
self.inner = inner
self.vlm_client = vlm_client
self.vlm_concurrency = vlm_concurrency
async def ocr(self, file_path: Path | str) -> OcrResult:
path = Path(file_path)
result = await self.inner.ocr(path)
await postprocess_ocr_result(
result,
file_path=path,
vlm_client=self.vlm_client,
vlm_concurrency=self.vlm_concurrency,
)
return result
async def postprocess_ocr_result(
ocr_result: OcrResult,
*,
file_path: Path,
vlm_client: object | None = None,
vlm_concurrency: int = 6,
) -> OcrResult:
"""Apply bridge-side visual repairs without touching leaudit core."""
suffix = file_path.suffix.lower()
if suffix not in {".docx", ".doc", ".wps"}:
return ocr_result
await _maybe_refine_docx_visuals(
ocr_result,
vlm_client=vlm_client,
concurrency=vlm_concurrency,
)
await _inject_docx_signature_candidates(
ocr_result,
vlm_client=vlm_client,
)
_normalize_cross_page_seals(ocr_result)
return ocr_result
async def _maybe_refine_docx_visuals(
ocr_result: OcrResult,
*,
vlm_client: object | None,
concurrency: int,
) -> None:
vm = ocr_result.visual_manifest
if vlm_client is None or vm is None:
return
if not (vm.seals or vm.signatures or vm.cross_page_seals):
return
try:
from leaudit.ocr.visual_classifier import refine_visual_manifest
await refine_visual_manifest(
ocr_result,
vlm_client,
concurrency=concurrency,
)
except Exception as exc:
log.warning("bridge visual refinement skipped: %s", exc)
async def _inject_docx_signature_candidates(
ocr_result: OcrResult,
*,
vlm_client: object | None,
) -> None:
"""Probe likely handwritten-signature zones on DOCX parent images."""
if vlm_client is None:
return
try:
from PIL import Image
except ImportError:
log.warning("Pillow unavailable, skip DOCX signature candidate probing")
return
parent_to_items: dict[str, list[VisualManifestItem]] = {}
for bucket in (
ocr_result.visual_manifest.seals or [],
ocr_result.visual_manifest.signatures or [],
ocr_result.visual_manifest.cross_page_seals or [],
):
for item in bucket:
parent_key = getattr(item, "parent_image_key", None)
if parent_key:
parent_to_items.setdefault(parent_key, []).append(item)
for parent_key, items in parent_to_items.items():
if any((it.label or "") == "signature" for it in items):
continue
parent_bytes = ocr_result.get_image_bytes(parent_key)
if not parent_bytes:
continue
try:
image = Image.open(BytesIO(parent_bytes))
except Exception as exc:
log.warning("failed to open parent image %s: %s", parent_key, exc)
continue
width, height = image.size
for candidate_bbox in _signature_candidate_boxes(items, width, height):
try:
crop = image.crop(tuple(candidate_bbox))
buf = BytesIO()
crop.save(buf, format="PNG")
result = await _classify_signature_candidate(
vlm_client,
buf.getvalue(),
"这是合同签章页里疑似法人签名的候选区域,请优先判断是否为手写签名。",
)
except Exception as exc:
log.warning("signature probe failed for %s: %s", parent_key, exc)
continue
if getattr(result, "kind", None) != "signature":
continue
page_num = _infer_parent_page_num(items)
ocr_result.visual_manifest.signatures.append(
VisualManifestItem(
page_num=page_num,
bbox=candidate_bbox,
label="signature",
confidence=getattr(result, "confidence", 0.9) or 0.9,
text_match=(getattr(result, "text", None) or "").strip() or None,
alt_text="docx_signature_candidate",
image_key=parent_key,
parent_image_key=parent_key,
)
)
break
async def _classify_signature_candidate(
vlm_client: object,
image_bytes: bytes,
user_hint: str,
) -> object:
"""Classify with one retry using a fresh VLM client when needed."""
try:
return await vlm_client.classify_visual(image_bytes, user_hint=user_hint)
except Exception as exc:
log.warning("signature probe primary VLM failed, retrying fresh client: %s", exc)
try:
from leaudit.llm.qwen_vlm_client import QwenVLMClient
fresh = QwenVLMClient(
base_url=getattr(vlm_client, "base_url"),
api_key=getattr(vlm_client, "api_key", ""),
model=getattr(vlm_client, "model"),
timeout=getattr(vlm_client, "timeout", 90.0),
)
try:
return await fresh.classify_visual(image_bytes, user_hint=user_hint)
finally:
await fresh.close()
except Exception as exc:
raise RuntimeError(exc) from exc
def _signature_candidate_boxes(
items: list[VisualManifestItem],
width: int,
height: int,
) -> list[list[int]]:
candidates: list[list[int]] = []
seen: set[tuple[int, int, int, int]] = set()
for item in items:
seal_type = getattr(item, "seal_type", None)
label = getattr(item, "label", None)
bbox = getattr(item, "bbox", None) or []
if len(bbox) != 4:
continue
x1, y1, x2, y2 = bbox
box_w = max(1, x2 - x1)
box_h = max(1, y2 - y1)
ratio = box_w / box_h
if seal_type == "法人章" or label == "法人章":
continue
if not (0.75 <= ratio <= 1.35):
continue
if box_w < width * 0.10 or box_h < height * 0.10:
continue
cand = [
max(0, int(x1 - box_w * 0.25)),
max(0, int(y1 + box_h * 0.50)),
min(width, int(x2 + box_w * 0.25)),
min(height, int(y2 + box_h * 0.95)),
]
if cand[2] - cand[0] < 24 or cand[3] - cand[1] < 24:
continue
key = tuple(cand)
if key not in seen:
seen.add(key)
candidates.append(cand)
return candidates
def _infer_parent_page_num(items: list[VisualManifestItem]) -> int:
for item in items:
page_num = getattr(item, "page_num", None)
if isinstance(page_num, int):
return page_num
return 0
def _normalize_cross_page_seals(ocr_result: OcrResult) -> None:
"""Fill obvious completeness defaults for bridge-side checks."""
for item in ocr_result.visual_manifest.cross_page_seals or []:
if item.pages and len(item.pages) >= 2:
item.is_complete = True
continue
bbox = item.bbox or []
if len(bbox) == 4:
width = max(1, bbox[2] - bbox[0])
height = max(1, bbox[3] - bbox[1])
ratio = width / height
# DOCX embedded images often contain a complete round seal near the
# page edge; Chandra may still classify it as a seam-seal half by
# geometry. A near-square crop is a strong signal that the visible
# stamp is already complete.
if 0.65 <= ratio <= 1.35:
item.is_complete = True
continue
if item.is_complete is not None:
continue
if item.pages and len(item.pages) == 1:
item.is_complete = False
@@ -0,0 +1,268 @@
"""Main leaudit pipeline orchestrator: OCR → Extract → Evaluate.
Uses leaudit's own pipeline directly (no conversion),
stores results into docauditai's database via StorageAdapter.
"""
from __future__ import annotations
import logging
import time
from dataclasses import dataclass, field
from pathlib import Path
from typing import Any
from leaudit.dsl.schema import RulesFile
from leaudit.engine.case_file_evaluator import evaluate_extraction
from leaudit.engine.models import EvaluationResult
from leaudit.extraction.bundle import ExtractionBundle
from leaudit.extraction.dispatcher import dispatch_extract
from leaudit.extraction.phase_detection import determine_phase
from leaudit.llm.base import BaseLLMClient
from leaudit.ocr.base import BaseOCRClient
from leaudit.ocr.models import OcrResult
from leaudit_bridge.storage_adapter import StorageAdapter
log = logging.getLogger(__name__)
@dataclass
class PipelineResult:
"""Complete result from the leaudit pipeline."""
ocr_result: OcrResult
extraction_bundle: ExtractionBundle
evaluation_result: EvaluationResult
detected_phase: str
timing: dict[str, float] = field(default_factory=dict)
errors: list[str] = field(default_factory=list)
class LauditPipeline:
"""Run leaudit's full OCR → extraction → evaluation pipeline.
Does NOT use leaudit's own SQLAlchemy storage.
All results are written to docauditai's database via StorageAdapter.
"""
def __init__(
self,
ocr_client: BaseOCRClient,
llm_client: BaseLLMClient | None = None,
storage_adapter: StorageAdapter | None = None,
) -> None:
self.ocr_client = ocr_client
self.llm_client = llm_client
self.storage = storage_adapter or StorageAdapter()
async def run(
self,
document_id: int,
file_path: str | Path,
rules_file: RulesFile | None = None,
*,
source_port: int = 8000,
) -> PipelineResult:
"""Execute the full pipeline for one document.
Args:
document_id: docauditai document ID for DB writes.
file_path: Path to the document file (PDF/DOCX/etc).
rules_file: leaudit RulesFile (parsed from YAML). When None,
the pipeline attempts to load rules from the OCR result's
``rules_file_path`` (set by the classifier).
source_port: Instance port for context switching.
Returns:
PipelineResult with all intermediate and final outputs.
"""
file_path = Path(file_path)
errors: list[str] = []
timing: dict[str, float] = {}
# --- Phase 1: Update status to Cutting ---
await self.storage.update_document_status(document_id, "Cutting")
# --- Phase 2: OCR ---
t0 = time.time()
log.info("[%d] OCR starting: %s", document_id, file_path.name)
ocr_result = await self._run_ocr(file_path)
timing["ocr"] = round(time.time() - t0, 2)
log.info(
"[%d] OCR done: %d pages, %.1fs",
document_id,
len(ocr_result.pages),
timing["ocr"],
)
# --- Resolve rules_file after OCR if not provided ---
if rules_file is None:
rules_file = await self._resolve_rules_from_ocr(ocr_result, document_id)
if rules_file is None:
raise ValueError(
f"Cannot resolve rules_file for document {document_id}. "
"Neither passed explicitly nor classified from OCR content."
)
# --- Save OCR result ---
await self.storage.save_ocr_result(document_id, ocr_result)
# --- Extract & save case number (案件编号) ---
await self._extract_and_save_case_number(document_id, ocr_result)
# --- Phase 3: Extraction ---
t0 = time.time()
await self.storage.update_document_status(document_id, "Extractioning")
log.info("[%d] Extraction starting", document_id)
extraction_bundle = await dispatch_extract(
ocr_result,
rules_file,
llm_client=self.llm_client,
phase="executed",
)
timing["extraction"] = round(time.time() - t0, 2)
if extraction_bundle.all_errors:
errors.extend(extraction_bundle.all_errors)
log.warning(
"[%d] Extraction completed with %d errors",
document_id,
len(extraction_bundle.all_errors),
)
log.info(
"[%d] Extraction done: %d fields, %.1fs",
document_id,
len(extraction_bundle.fields),
timing["extraction"],
)
# --- Save extraction result ---
await self.storage.save_extraction_result(document_id, extraction_bundle)
# --- Resolve field positions from OCR chunks ---
from leaudit.extraction.coordinate_resolver import resolve_bundle_positions
resolve_bundle_positions(extraction_bundle, ocr_result)
positioned_count = sum(
1 for fv in extraction_bundle.fields.values() if fv.position is not None
)
log.info(
"[%d] Coordinate resolution: %d/%d fields positioned",
document_id,
positioned_count,
len(extraction_bundle.fields),
)
# --- Phase 4: Phase detection ---
visual_manifest = extraction_bundle.visual_manifest or ocr_result.visual_manifest
detected_phase = await determine_phase(
extraction_bundle.fields,
llm_client=self.llm_client,
visual_manifest=visual_manifest,
)
log.info("[%d] Detected phase: %s", document_id, detected_phase)
# --- Phase 5: Evaluation ---
t0 = time.time()
await self.storage.update_document_status(document_id, "Evaluationing")
log.info("[%d] Evaluation starting (phase=%s)", document_id, detected_phase)
external_mocks: dict[str, Any] = {}
if self.llm_client is not None:
external_mocks["llm_client"] = self.llm_client
external_mocks["rules_file"] = rules_file
evaluation_result = await evaluate_extraction(
rules_file,
extraction_bundle,
visual_manifest=visual_manifest,
phase=detected_phase,
external_mocks=external_mocks,
)
timing["evaluation"] = round(time.time() - t0, 2)
log.info(
"[%d] Evaluation done: %d passed, %d failed, %d skipped, %.1fs",
document_id,
evaluation_result.passed_count,
evaluation_result.failed_count,
evaluation_result.skipped_count,
timing["evaluation"],
)
# --- Save evaluation results ---
await self.storage.save_evaluation_results(
document_id, rules_file, evaluation_result, extraction_bundle,
)
# --- Phase 6: Finalize ---
timing["total"] = round(sum(timing.values()), 2)
await self.storage.update_document_status(document_id, "Processed")
log.info(
"[%d] Pipeline complete: phase=%s, timing=%s",
document_id,
detected_phase,
timing,
)
return PipelineResult(
ocr_result=ocr_result,
extraction_bundle=extraction_bundle,
evaluation_result=evaluation_result,
detected_phase=detected_phase,
timing=timing,
errors=errors,
)
async def _run_ocr(self, file_path: Path) -> OcrResult:
"""Run OCR with error handling."""
try:
return await self.ocr_client.ocr(file_path)
except Exception as e:
log.error("OCR failed for %s: %s", file_path.name, e)
raise
async def _extract_and_save_case_number(
self, document_id: int, ocr_result: OcrResult,
) -> None:
"""Extract case number from OCR and write to database."""
from leaudit_bridge.case_number_extractor import (
extract_case_number_with_llm,
)
case_number = await extract_case_number_with_llm(
ocr_result, llm_client=self.llm_client,
)
if case_number:
await self.storage.update_document_number(document_id, case_number)
log.info("[%d] Case number: %s", document_id, case_number)
else:
log.info("[%d] No case number found", document_id)
async def _resolve_rules_from_ocr(
self, ocr_result: OcrResult, document_id: int,
) -> RulesFile | None:
"""Load rules_file from OCR classification result."""
from leaudit.dsl.loader import load_rules_file
rfp = ocr_result.rules_file_path
if not rfp:
log.warning(
"[%d] No rules_file_path in OCR result, cannot resolve rules",
document_id,
)
return None
try:
rules_file = load_rules_file(rfp)
log.info(
"[%d] Resolved rules from classification: %s (%d rules)",
document_id, rfp, len(rules_file.flat_rules),
)
return rules_file
except Exception as e:
log.error("[%d] Failed to load rules from %s: %s", document_id, rfp, e)
return None
@@ -0,0 +1,303 @@
"""Adapt leaudit raw results into docauditai's standardized format.
Currently this logic is inlined in ``storage_adapter.py``
(``_rule_result_to_row``, ``_bundle_to_extracted``, ``_ocr_to_dict``).
This module extracts the conversion layer so storage_adapter focuses
on persistence (the "adapter" part of result_adapter).
"""
from __future__ import annotations
import logging
import re
from typing import Any
from leaudit.dsl.schema import Rule as DslRule
from leaudit.dsl.schema import RulesFile
from leaudit.engine.models import EvaluationResult, RuleResult
from leaudit.extraction.bundle import ExtractionBundle
from leaudit.extraction.models import FieldValue
from leaudit.ocr.models import OcrResult
log = logging.getLogger(__name__)
# ---------------------------------------------------------------------------
# OCR result → dict
# ---------------------------------------------------------------------------
def ocr_to_dict(ocr: OcrResult) -> dict[str, Any]:
"""Convert OcrResult to a JSON-safe dict for storage."""
result: dict[str, Any] = {
"numPages": len(ocr.pages),
"full_text": ocr.full_text,
"pages": [],
}
for page in ocr.pages:
page_dict: dict[str, Any] = {
"page_num": page.page_num,
"text": page.text,
"page_box": page.page_box,
}
if page.chunks:
page_dict["chunks"] = [
(
{"bbox": c["bbox"], "content": c["content"], "label": c.get("label")}
if isinstance(c, dict) and "bbox" in c and "content" in c
else {
"bbox": c.bbox if hasattr(c, "bbox") else None,
"content": c.content if hasattr(c, "content") else str(c),
"label": c.label if hasattr(c, "label") else None,
}
)
for c in page.chunks
]
if page.bboxes:
page_dict["bboxes"] = page.bboxes
result["pages"].append(page_dict)
if ocr.visual_manifest:
result["visual_manifest"] = ocr.visual_manifest.model_dump(mode="json")
if ocr.images:
result["images"] = ocr.images
return result
# ---------------------------------------------------------------------------
# ExtractionBundle → dict
# ---------------------------------------------------------------------------
def bundle_to_extracted(bundle: ExtractionBundle) -> dict[str, Any]:
"""Convert ExtractionBundle to docauditai's extracted_results format."""
fields: dict[str, Any] = {}
for name, fv in bundle.fields.items():
if isinstance(fv, FieldValue):
field_data = {
"value": fv.value,
"confidence": float(fv.confidence) if fv.confidence else 0.0,
}
if fv.position is not None:
field_data["position"] = fv.position.model_dump(mode="json")
fields[name] = field_data
else:
fields[name] = {"value": fv}
multi_entity: dict[str, Any] = {}
for name, rows in bundle.multi_entity.items():
multi_entity[name] = [
{
k: (v.value if isinstance(v, FieldValue) else v)
for k, v in row.items()
}
if isinstance(row, dict)
else {"value": row}
for row in rows
]
return {
"fields": fields,
"multi_entity": multi_entity,
"derived": dict(bundle.derived) if bundle.derived else {},
"is_case_file": bundle.is_case_file,
}
# ---------------------------------------------------------------------------
# EvaluationResult → per-rule rows
# ---------------------------------------------------------------------------
def rule_result_to_row(
document_id: int,
run_id: int,
rule_result: RuleResult,
rule: Any | None,
bundle: ExtractionBundle,
) -> dict[str, Any]:
"""Convert one RuleResult to a database row dict.
Args:
document_id: docauditai document ID.
run_id: ``leaudit_audit_runs.id`` for this execution.
rule_result: Single rule evaluation result from leaudit.
rule: DSL Rule definition (for metadata lookups).
bundle: The extraction bundle (for field position lookups).
"""
passed = rule_result.passed
pass_msg = ""
fail_msg = ""
if rule_result.messages:
pass_msg = rule_result.messages.get("pass", "")
fail_msg = rule_result.messages.get("fail", "")
elif isinstance(rule, DslRule) and rule.messages:
pass_msg = rule.messages.get("pass", "")
fail_msg = rule.messages.get("fail", "")
relevant_fields = _extract_relevant_fields(rule, bundle)
field_positions = _extract_relevant_field_positions(rule, bundle)
remediation = None
if rule_result.remediation:
remediation = rule_result.remediation.model_dump(mode="json")
rule_meta: dict[str, Any] = {}
if isinstance(rule, DslRule):
if rule.references_laws:
rule_meta["references_laws"] = rule.references_laws
if rule.desc:
rule_meta["desc"] = rule.desc
if rule.group:
rule_meta["group"] = rule.group
return {
"document_id": document_id,
"run_id": run_id,
"rule_id": rule_result.rule_id,
"rule_name": rule_result.name,
"risk": rule_result.risk or "medium",
"score": rule_result.score,
"passed": passed,
"status": rule_result.status,
"skip_reason": rule_result.skip_reason,
"confidence": rule_result.confidence,
"pass_message": pass_msg,
"fail_message": fail_msg,
"stages": [s.model_dump(mode="json") for s in (rule_result.stages or [])],
"extracted_fields": relevant_fields,
"field_positions": field_positions,
"remediation": remediation,
"rule_meta": rule_meta,
"rescue_applied": False,
"rescue_passed": None,
}
def evaluation_summary(eval_result: EvaluationResult) -> dict[str, Any]:
"""Extract summary fields from an EvaluationResult."""
return {
"total_score": eval_result.total_score,
"passed_count": eval_result.passed_count,
"failed_count": eval_result.failed_count,
"skipped_count": eval_result.skipped_count,
"result_status": _result_status(eval_result),
}
def _result_status(eval_result: EvaluationResult) -> str:
if eval_result.errors:
return "error"
if eval_result.failed_count == 0 and eval_result.skipped_count == 0:
return "pass"
if eval_result.failed_count > 0:
return "fail"
return "partial"
# ---------------------------------------------------------------------------
# Internal helpers
# ---------------------------------------------------------------------------
def _extract_relevant_fields(
rule: Any, bundle: ExtractionBundle,
) -> dict[str, Any]:
"""Extract field values referenced by a rule's stages."""
relevant: dict[str, Any] = {}
if not rule or not hasattr(rule, "stages") or not rule.stages:
return relevant
for stage in rule.stages:
stage_data = (
stage.model_dump(exclude_none=True)
if hasattr(stage, "model_dump")
else {}
)
extra = stage.extra if hasattr(stage, "extra") else {}
field_names: list[str] = []
for key in ("field", "field1", "field2", "fields"):
val = getattr(stage, key, None) or extra.get(key) or stage_data.get(key)
if isinstance(val, list):
field_names.extend(f for f in val if isinstance(f, str))
elif isinstance(val, str):
field_names.append(val)
pairs = stage_data.get("pairs")
if isinstance(pairs, list):
for pair in pairs:
if not isinstance(pair, dict):
continue
for key in ("source", "target", "a", "b"):
ref = pair.get(key)
if isinstance(ref, str):
field_names.append(ref)
prompt = stage_data.get("prompt")
if isinstance(prompt, str):
for m in re.finditer(r"\{\{\s*([^{}]+?)\s*\}\}", prompt):
field_names.append(m.group(1).strip())
for f in field_names:
if f in relevant:
continue
if f not in bundle.fields:
continue
fv = bundle.fields[f]
relevant[f] = fv.value if isinstance(fv, FieldValue) else fv
return relevant
def _extract_relevant_field_positions(
rule: Any, bundle: ExtractionBundle,
) -> dict[str, Any]:
"""Extract position data for fields referenced by a rule's stages."""
positions: dict[str, Any] = {}
if not rule or not hasattr(rule, "stages") or not rule.stages:
return positions
for stage in rule.stages:
stage_data = (
stage.model_dump(exclude_none=True)
if hasattr(stage, "model_dump")
else {}
)
extra = stage.extra if hasattr(stage, "extra") else {}
field_names: list[str] = []
for key in ("field", "field1", "field2", "fields"):
val = getattr(stage, key, None) or extra.get(key) or stage_data.get(key)
if isinstance(val, list):
field_names.extend(f for f in val if isinstance(f, str))
elif isinstance(val, str):
field_names.append(val)
pairs = stage_data.get("pairs")
if isinstance(pairs, list):
for pair in pairs:
if not isinstance(pair, dict):
continue
for key in ("source", "target", "a", "b"):
ref = pair.get(key)
if isinstance(ref, str):
field_names.append(ref)
prompt = stage_data.get("prompt")
if isinstance(prompt, str):
for m in re.finditer(r"\{\{\s*([^{}]+?)\s*\}\}", prompt):
field_names.append(m.group(1).strip())
for f in field_names:
if f in positions:
continue
fv = bundle.fields.get(f)
if fv is not None and isinstance(fv, FieldValue) and fv.position is not None:
positions[f] = fv.position.model_dump(mode="json")
return positions
@@ -0,0 +1,55 @@
"""Load leaudit YAML RulesFile from filesystem or MinIO."""
from __future__ import annotations
import logging
from pathlib import Path
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from leaudit.dsl.schema import RulesFile
log = logging.getLogger(__name__)
_DEFAULT_RULES_DIR = Path(__file__).resolve().parents[2] / "rules"
class RulesLoader:
"""Load and cache leaudit RulesFile from YAML files."""
def __init__(self, rules_dir: str | Path | None = None) -> None:
self._rules_dir = Path(rules_dir) if rules_dir else _DEFAULT_RULES_DIR
self._cache: dict[str, RulesFile] = {}
def load(self, rules_path: str) -> RulesFile:
"""Load a RulesFile by relative path under rules_dir, or absolute path."""
from leaudit.dsl.loader import load_rules_file
if rules_path in self._cache:
return self._cache[rules_path]
p = Path(rules_path)
if not p.is_absolute():
p = self._rules_dir / p
log.info("Loading RulesFile: %s", p)
rules_file = load_rules_file(p)
self._cache[rules_path] = rules_file
return rules_file
def load_from_yaml_text(self, yaml_text: str, cache_key: str | None = None) -> RulesFile:
"""Parse a RulesFile from raw YAML string."""
from leaudit.dsl.loader import parse_rules_yaml_text
if cache_key and cache_key in self._cache:
return self._cache[cache_key]
rules_file = parse_rules_yaml_text(yaml_text)
if cache_key:
self._cache[cache_key] = rules_file
return rules_file
def clear_cache(self) -> None:
self._cache.clear()
@@ -0,0 +1,364 @@
"""Storage adapter — write leaudit results into docauditai's PostgreSQL via PostgREST.
Converts leaudit's OcrResult, ExtractionBundle, EvaluationResult
into docauditai's table format and writes via PostgRESTClient.
Uses the new `leaudit_evaluation_results` table for per-rule results.
"""
from __future__ import annotations
import logging
import re
from typing import Any
from leaudit.dsl.schema import RulesFile
from leaudit.engine.models import EvaluationResult, RuleResult
from leaudit.extraction.bundle import ExtractionBundle
from leaudit.extraction.models import FieldValue
from leaudit.ocr.models import OcrResult
log = logging.getLogger(__name__)
def _get_postgrest_client():
"""Lazy import to avoid circular dependency at module load."""
from core.postgrest.client import get_postgrest_client
return get_postgrest_client()
class StorageAdapter:
"""Write leaudit pipeline results to docauditai's database."""
# ---- Document status ----
async def update_document_status(self, document_id: int, status: str) -> None:
"""Update the document's processing status."""
client = _get_postgrest_client()
await client.update(
table="documents",
filters={"id": f"eq.{document_id}"},
data={"status": status},
)
log.debug("[%d] Status updated: %s", document_id, status)
# ---- Document number (案件编号) ----
async def update_document_number(self, document_id: int, document_number: str) -> None:
"""Update the document's case number (document_number field)."""
client = _get_postgrest_client()
await client.update(
table="documents",
filters={"id": f"eq.{document_id}"},
data={"document_number": document_number},
)
log.info("[%d] document_number updated: %s", document_id, document_number)
# ---- OCR result ----
async def save_ocr_result(self, document_id: int, ocr_result: OcrResult) -> None:
"""Save OCR result to documents.ocr_result and raw_full_text_original."""
client = _get_postgrest_client()
ocr_dict = _ocr_to_dict(ocr_result)
full_text = ocr_result.full_text or "\n".join(p.text for p in ocr_result.pages)
await client.update(
table="documents",
filters={"id": f"eq.{document_id}"},
data={
"ocr_result": ocr_dict,
"raw_full_text_original": full_text,
},
)
log.info("[%d] OCR result saved (%d pages)", document_id, len(ocr_result.pages))
# ---- Extraction result ----
async def save_extraction_result(
self, document_id: int, bundle: ExtractionBundle,
) -> None:
"""Save extraction result to documents.extracted_results."""
client = _get_postgrest_client()
extracted = _bundle_to_extracted(bundle)
await client.update(
table="documents",
filters={"id": f"eq.{document_id}"},
data={"extracted_results": extracted},
)
log.info(
"[%d] Extraction result saved (%d fields)",
document_id,
len(bundle.fields),
)
# ---- Evaluation results ----
async def save_evaluation_results(
self,
document_id: int,
rules_file: RulesFile,
evaluation: EvaluationResult,
bundle: ExtractionBundle,
) -> None:
"""Save evaluation results to leaudit_evaluation_results table.
One row per rule. Deletes existing results for the document first,
then inserts fresh rows.
"""
client = _get_postgrest_client()
# Delete existing results for this document
await client.delete(
table="leaudit_evaluation_results",
filters={"document_id": f"eq.{document_id}"},
)
# Build rule_id → rule metadata lookup
rule_meta = {}
for rule in rules_file.flat_rules:
rule_meta[rule.rule_id] = rule
# Insert one row per rule result
for rule_result in evaluation.rules:
rule = rule_meta.get(rule_result.rule_id)
row = _rule_result_to_row(document_id, rule_result, rule, bundle)
await client.insert(table="leaudit_evaluation_results", data=row)
log.info(
"[%d] Evaluation results saved: %d passed, %d failed, %d skipped",
document_id,
evaluation.passed_count,
evaluation.failed_count,
evaluation.skipped_count,
)
# ---- Serialization helpers ----
def _ocr_to_dict(ocr: OcrResult) -> dict[str, Any]:
"""Convert OcrResult to a JSON-safe dict for PostgREST storage."""
result: dict[str, Any] = {
"numPages": len(ocr.pages),
"full_text": ocr.full_text,
"pages": [],
}
for page in ocr.pages:
page_dict: dict[str, Any] = {
"page_num": page.page_num,
"text": page.text,
"page_box": page.page_box,
}
if page.chunks:
page_dict["chunks"] = [
(
{"bbox": c["bbox"], "content": c["content"], "label": c.get("label")}
if isinstance(c, dict) and "bbox" in c and "content" in c
else {
"bbox": c.bbox if hasattr(c, "bbox") else None,
"content": c.content if hasattr(c, "content") else str(c),
"label": c.label if hasattr(c, "label") else None,
}
)
for c in page.chunks
]
if page.bboxes:
page_dict["bboxes"] = page.bboxes
result["pages"].append(page_dict)
if ocr.visual_manifest:
result["visual_manifest"] = ocr.visual_manifest.model_dump(mode="json")
if ocr.images:
result["images"] = ocr.images
return result
def _bundle_to_extracted(bundle: ExtractionBundle) -> dict[str, Any]:
"""Convert ExtractionBundle to docauditai's extracted_results format."""
fields: dict[str, Any] = {}
for name, fv in bundle.fields.items():
if isinstance(fv, FieldValue):
field_data = {
"value": fv.value,
"confidence": float(fv.confidence) if fv.confidence else 0.0,
}
if fv.position is not None:
field_data["position"] = fv.position.model_dump(mode="json")
fields[name] = field_data
else:
fields[name] = {"value": fv}
multi_entity: dict[str, Any] = {}
for name, rows in bundle.multi_entity.items():
multi_entity[name] = [
{
k: (v.value if isinstance(v, FieldValue) else v)
for k, v in row.items()
}
if isinstance(row, dict) else {"value": row}
for row in rows
]
return {
"fields": fields,
"multi_entity": multi_entity,
"derived": dict(bundle.derived) if bundle.derived else {},
"is_case_file": bundle.is_case_file,
}
def _extract_relevant_fields(rule: Any, bundle: ExtractionBundle) -> dict[str, Any]:
"""Extract field values referenced by a rule's stages."""
relevant: dict[str, Any] = {}
if not rule or not hasattr(rule, "stages") or not rule.stages:
return relevant
for stage in rule.stages:
stage_data = stage.model_dump(exclude_none=True) if hasattr(stage, "model_dump") else {}
extra = stage.extra if hasattr(stage, "extra") else {}
field_names: list[str] = []
for key in ("field", "field1", "field2", "fields"):
val = getattr(stage, key, None) or extra.get(key) or stage_data.get(key)
if isinstance(val, list):
field_names.extend(f for f in val if isinstance(f, str))
elif isinstance(val, str):
field_names.append(val)
pairs = stage_data.get("pairs")
if isinstance(pairs, list):
for pair in pairs:
if not isinstance(pair, dict):
continue
for key in ("source", "target", "a", "b"):
ref = pair.get(key)
if isinstance(ref, str):
field_names.append(ref)
prompt = stage_data.get("prompt")
if isinstance(prompt, str):
for m in re.finditer(r"\{\{\s*([^{}]+?)\s*\}\}", prompt):
field_names.append(m.group(1).strip())
for f in field_names:
if f in relevant:
continue
if f not in bundle.fields:
continue
fv = bundle.fields[f]
relevant[f] = fv.value if isinstance(fv, FieldValue) else fv
return relevant
def _extract_relevant_field_positions(
rule: Any,
bundle: ExtractionBundle,
) -> dict[str, Any]:
"""Extract position data for fields referenced by a rule's stages."""
positions: dict[str, Any] = {}
if not rule or not hasattr(rule, "stages") or not rule.stages:
return positions
for stage in rule.stages:
stage_data = stage.model_dump(exclude_none=True) if hasattr(stage, "model_dump") else {}
extra = stage.extra if hasattr(stage, "extra") else {}
field_names: list[str] = []
for key in ("field", "field1", "field2", "fields"):
val = getattr(stage, key, None) or extra.get(key) or stage_data.get(key)
if isinstance(val, list):
field_names.extend(f for f in val if isinstance(f, str))
elif isinstance(val, str):
field_names.append(val)
pairs = stage_data.get("pairs")
if isinstance(pairs, list):
for pair in pairs:
if not isinstance(pair, dict):
continue
for key in ("source", "target", "a", "b"):
ref = pair.get(key)
if isinstance(ref, str):
field_names.append(ref)
prompt = stage_data.get("prompt")
if isinstance(prompt, str):
for m in re.finditer(r"\{\{\s*([^{}]+?)\s*\}\}", prompt):
field_names.append(m.group(1).strip())
for f in field_names:
if f in positions:
continue
fv = bundle.fields.get(f)
if fv is not None and isinstance(fv, FieldValue) and fv.position is not None:
positions[f] = fv.position.model_dump(mode="json")
return positions
def _rule_result_to_row(
document_id: int,
rule_result: RuleResult,
rule: Any | None,
bundle: ExtractionBundle,
) -> dict[str, Any]:
"""Convert a RuleResult to a leaudit_evaluation_results row."""
passed = rule_result.passed
# Resolve messages: rule_result → rule definition
pass_msg = ""
fail_msg = ""
if rule_result.messages:
pass_msg = rule_result.messages.get("pass", "")
fail_msg = rule_result.messages.get("fail", "")
elif rule:
from leaudit.dsl.schema import Rule as DslRule
if isinstance(rule, DslRule) and rule.messages:
pass_msg = rule.messages.get("pass", "")
fail_msg = rule.messages.get("fail", "")
# Extract relevant fields
relevant_fields = _extract_relevant_fields(rule, bundle)
# Remediation (if present)
remediation = None
if rule_result.remediation:
remediation = rule_result.remediation.model_dump(mode="json")
# Rule metadata (references_laws, etc.)
rule_meta_data: dict[str, Any] = {}
if rule:
from leaudit.dsl.schema import Rule as DslRule
if isinstance(rule, DslRule):
if rule.references_laws:
rule_meta_data["references_laws"] = rule.references_laws
if rule.desc:
rule_meta_data["desc"] = rule.desc
if rule.group:
rule_meta_data["group"] = rule.group
return {
"document_id": document_id,
"rule_id": rule_result.rule_id,
"rule_name": rule_result.name,
"risk": rule_result.risk or "medium",
"score": rule_result.score,
"passed": passed,
"status": rule_result.status,
"skip_reason": rule_result.skip_reason,
"confidence": rule_result.confidence,
"pass_message": pass_msg,
"fail_message": fail_msg,
"stages": [s.model_dump(mode="json") for s in (rule_result.stages or [])],
"extracted_fields": relevant_fields,
"field_positions": _extract_relevant_field_positions(rule, bundle),
"remediation": remediation,
"rule_meta": rule_meta_data,
}
@@ -0,0 +1,201 @@
"""Celery task for leaudit pipeline processing.
Activated when PIPELINE_MODE=leaudit in env.{port} config.
Replaces the legacy OCR → extraction → evaluation pipeline with
leaudit's YAML-rules-driven approach.
"""
from __future__ import annotations
import asyncio
import os
import tempfile
import time
from typing import Any, Dict, Optional
from core.celery_app_limited import celery_app
from core.postgrest.client import get_postgrest_client
from core.logger import log
from leaudit_bridge import create_pipeline, RulesLoader
@celery_app.task(bind=True, name="leaudit.process_document")
def leaudit_process_document(
self,
document_id: int,
file_content: bytes,
filename: str,
upload_info: Optional[Dict[str, Any]] = None,
source_port: Optional[int] = None,
rules_path: Optional[str] = None,
):
"""Process a document using leaudit's full pipeline.
Steps: OCR → Extraction → Evaluation → Store in docauditai DB.
"""
task_id = self.request.id
log.task.info(f"[任务ID: {task_id}] leaudit管线开始处理: {filename}")
if source_port:
from core.utils.instance_context import set_instance_environment
instance_name = set_instance_environment(source_port)
log.task.info(
f"[任务ID: {task_id}] 实例环境: {instance_name} (端口: {source_port})"
)
if upload_info is None:
upload_info = {}
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
rules_path_resolved = rules_path or _resolve_rules_path(document_id, loop)
# For types with a known mapping (e.g. 行政处罚), pre-load rules_file.
# For types that need content classification (e.g. 行政许可 sub-types),
# rules_path will be None → adapter classifies after OCR → pipeline
# loads rules from ocr_result.rules_file_path.
rules_file = None
if rules_path_resolved:
loader = RulesLoader()
rules_file = loader.load(rules_path_resolved)
log.task.info(
f"[任务ID: {task_id}] RulesFile pre-loaded: {rules_path_resolved} "
f"({len(rules_file.flat_rules)} rules, {len(rules_file.flat_extract)} fields)"
)
else:
log.task.info(
f"[任务ID: {task_id}] No fixed rules_path — "
"will classify from document content after OCR"
)
suffix = _get_suffix(filename)
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp:
temp.write(file_content)
temp_path = temp.name
pipeline = create_pipeline(rules_path=rules_path_resolved)
t0 = time.time()
result = loop.run_until_complete(
pipeline.run(
document_id=document_id,
file_path=temp_path,
rules_file=rules_file,
source_port=source_port or int(os.getenv("APP_PORT", "8000")),
)
)
elapsed = round(time.time() - t0, 2)
try:
os.remove(temp_path)
except OSError:
pass
log.task.info(
f"[任务ID: {task_id}] leaudit管线完成: phase={result.detected_phase}, "
f"timing={result.timing}, 总耗时={elapsed:.1f}s"
)
return {
"status": "success",
"document_id": document_id,
"phase": result.detected_phase,
"timing": result.timing,
"errors": result.errors,
}
except Exception as e:
log.task.error(f"[任务ID: {task_id}] leaudit管线失败: {e}", exc_info=True)
try:
loop.run_until_complete(_update_status_safe(document_id, "Failed"))
except Exception:
pass
raise
finally:
loop.close()
# type_id → rules directory mapping (only fixed-mapping types)
# 行政许可 (type_id=2) has 9 sub-types, NOT mapped here —
# must come from document metadata (rules_file_path) or content classification.
_TYPE_ID_RULES_MAP: dict[int, str] = {
3: "行政处罚",
}
def _resolve_rules_path(document_id: int, loop: asyncio.AbstractEventLoop) -> str | None:
"""Resolve rules_path: config override → document metadata → type_id mapping."""
from core.config import LEAUDIT_CONFIG
# 1. Config override (when explicitly set)
config_path = LEAUDIT_CONFIG.get("RULES_PATH", "")
if config_path:
return config_path
try:
client = get_postgrest_client()
doc = loop.run_until_complete(
client.select(
table="documents",
filters={"id": f"eq.{document_id}"},
single=True,
)
)
if not doc:
return None
# 2. Document-level override
rfp = doc.get("rules_file_path")
if rfp:
return rfp
# 3. type_id mapping
type_id = doc.get("type_id")
if type_id and type_id in _TYPE_ID_RULES_MAP:
return f"{_TYPE_ID_RULES_MAP[type_id]}/rules.yaml"
except Exception as e:
log.task.warning(f"Failed to resolve rules_path from document: {e}")
return None
async def _update_status_safe(document_id: int, status: str) -> None:
"""Safely update document status, ignoring errors."""
try:
client = get_postgrest_client()
await client.update(
table="documents",
filters={"id": f"eq.{document_id}"},
data={"status": status},
)
except Exception:
pass
def _get_suffix(filename: str) -> str:
"""Extract file suffix from filename."""
_, ext = os.path.splitext(filename)
return ext if ext else ".pdf"
def dispatch_leaudit_task(
document_id: int,
file_content: bytes,
filename: str,
upload_info: Optional[Dict[str, Any]] = None,
source_port: Optional[int] = None,
rules_path: Optional[str] = None,
):
"""Dispatch a leaudit processing task."""
return leaudit_process_document.apply_async(
args=[document_id, file_content, filename],
kwargs={
"upload_info": upload_info,
"source_port": source_port or int(os.getenv("APP_PORT", "8000")),
"rules_path": rules_path,
},
)
@@ -0,0 +1,11 @@
"""LeAudit 模型导出。"""
from fastapi_modules.fastapi_leaudit.models.leauditDocument import LeauditDocument
from fastapi_modules.fastapi_leaudit.models.leauditDocumentFile import LeauditDocumentFile
from fastapi_modules.fastapi_leaudit.models.leauditAuditRun import LeauditAuditRun
__all__ = [
"LeauditDocument",
"LeauditDocumentFile",
"LeauditAuditRun",
]
@@ -0,0 +1,77 @@
"""JWT Token 模型 —— jwt_tokens 表。
记录每次签发的 Token 生命周期:签发、刷新、撤销。
"""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import BigInteger, Boolean, DateTime, String, select, update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, mapped_column
from fastapi_common.fastapi_common_web.models import BaseModel
class JwtToken(BaseModel):
"""JWT Token 记录表。"""
__tablename__ = "jwt_tokens"
Id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
userId: Mapped[int] = mapped_column(BigInteger, comment="用户ID")
tokenJti: Mapped[str] = mapped_column(String(128), comment="Token JTI")
tokenHash: Mapped[str] = mapped_column(String(128), comment="Access Token SHA256")
refreshTokenHash: Mapped[str | None] = mapped_column(String(128), comment="Refresh Token SHA256")
tokenType: Mapped[str] = mapped_column(String(32), default="ACCESS", comment="ACCESS/REFRESH")
deviceId: Mapped[str | None] = mapped_column(String(128))
deviceName: Mapped[str | None] = mapped_column(String(256))
userAgent: Mapped[str | None] = mapped_column(String(512))
ipAddress: Mapped[str | None] = mapped_column(String(64))
issuedAt: Mapped[datetime] = mapped_column(DateTime(timezone=True))
expiresAt: Mapped[datetime] = mapped_column(DateTime(timezone=True))
refreshExpiresAt: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
lastUsedAt: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
isRevoked: Mapped[bool] = mapped_column(Boolean, default=False)
revokedAt: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
revokeReason: Mapped[str | None] = mapped_column(String(256))
@classmethod
async def get_by_jti(cls, session: AsyncSession, jti: str) -> "JwtToken | None":
"""按 JTI 查询 Token 记录。"""
return await session.scalar(select(cls).where(cls.tokenJti == jti))
@classmethod
async def revoke_by_jti(cls, session: AsyncSession, jti: str, reason: str = "") -> None:
"""撤销指定 JTI 的 Token。"""
await session.execute(
update(cls)
.where(cls.tokenJti == jti)
.values(isRevoked=True, revokedAt=datetime.now(), revokeReason=reason)
)
@classmethod
async def revoke_all_user_tokens(cls, session: AsyncSession, userId: int, reason: str = "") -> list[str]:
"""撤销用户的所有活跃 Token,返回被撤销的 JTI 列表。"""
result = await session.execute(
select(cls.tokenJti).where(cls.userId == userId, cls.isRevoked == False)
)
jtis = [row[0] for row in result.fetchall()]
await session.execute(
update(cls)
.where(cls.userId == userId, cls.isRevoked == False)
.values(isRevoked=True, revokedAt=datetime.now(), revokeReason=reason)
)
return jtis
@classmethod
async def cleanup_expired(cls, session: AsyncSession, before: datetime) -> int:
"""清理过期的 Token 记录,返回删除数。"""
result = await session.execute(
select(cls).where(cls.expiresAt < before)
)
rows = result.scalars().all()
for row in rows:
await session.delete(row)
return len(rows)
@@ -0,0 +1,60 @@
"""LeAudit AuditRun 模型 —— leaudit_audit_runs 表。"""
from __future__ import annotations
from datetime import datetime
from sqlalchemy import BigInteger, Boolean, DateTime, Integer, Numeric, String
from sqlalchemy.orm import Mapped, mapped_column
from fastapi_common.fastapi_common_web.models import BaseModel
class LeauditAuditRun(BaseModel):
"""评查运行主表。"""
__tablename__ = "leaudit_audit_runs"
Id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
documentId: Mapped[int] = mapped_column(BigInteger, comment="关联 leaudit_documents.id")
documentFileId: Mapped[int | None] = mapped_column(BigInteger, comment="输入文件ID")
runNo: Mapped[int] = mapped_column(Integer, comment="同一文档第几次执行")
triggerSource: Mapped[str] = mapped_column(String(64), comment="upload/manual/retry/migration/batch")
triggerUserId: Mapped[int | None] = mapped_column(BigInteger, comment="触发人")
taskId: Mapped[str | None] = mapped_column(String(128), comment="Celery 任务 ID")
# 状态
status: Mapped[str] = mapped_column(String(64), default="pending", comment="pending/processing/completed/failed/cancelled")
phase: Mapped[str | None] = mapped_column(String(32), comment="draft/executed")
# 规则溯源
ruleSetId: Mapped[int] = mapped_column(BigInteger, comment="关联 leaudit_rule_sets.id")
ruleVersionId: Mapped[int] = mapped_column(BigInteger, comment="关联 leaudit_rule_versions.id")
ruleTypeId: Mapped[str | None] = mapped_column(String(256), comment="LeAudit metadata.type_id")
ruleSourceOssUrl: Mapped[str | None] = mapped_column(String(2048), comment="规则 YAML OSS 地址")
ruleSourceSha256: Mapped[str | None] = mapped_column(String(64), comment="规则文件 SHA256")
ruleLocalCachePath: Mapped[str | None] = mapped_column(String(1024), comment="本地缓存路径")
# 模型快照
engineVersion: Mapped[str | None] = mapped_column(String(64))
llmProvider: Mapped[str | None] = mapped_column(String(64))
llmModel: Mapped[str | None] = mapped_column(String(128))
vlmProvider: Mapped[str | None] = mapped_column(String(64))
vlmModel: Mapped[str | None] = mapped_column(String(128))
ocrProvider: Mapped[str | None] = mapped_column(String(64))
ocrModel: Mapped[str | None] = mapped_column(String(128))
# Rescue
rescueMode: Mapped[str | None] = mapped_column(String(32), comment="off/tier1/auto")
rescueApplied: Mapped[bool] = mapped_column(Boolean, default=False, comment="是否执行 rescue")
# 结果汇总
totalScore: Mapped[float | None] = mapped_column(Numeric(10, 2))
passedCount: Mapped[int | None] = mapped_column(Integer)
failedCount: Mapped[int | None] = mapped_column(Integer)
skippedCount: Mapped[int | None] = mapped_column(Integer)
resultStatus: Mapped[str | None] = mapped_column(String(32), comment="pass/fail/partial/error")
# 时间
startedAt: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
finishedAt: Mapped[datetime | None] = mapped_column(DateTime(timezone=True))
@@ -0,0 +1,44 @@
"""LeAudit 域文档镜像模型 —— leaudit_documents 表。
通过 biz_document_id 关联业务 documents 表,不复制业务字段。
"""
from __future__ import annotations
from sqlalchemy import BigInteger, String, ForeignKey
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Mapped, mapped_column
from fastapi_common.fastapi_common_web.models import BaseModel
class LeauditDocument(BaseModel):
"""LeAudit 文档镜像表。"""
__tablename__ = "leaudit_documents"
Id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
bizDocumentId: Mapped[int] = mapped_column(BigInteger, unique=True, comment="关联业务 documents.id")
typeId: Mapped[int | None] = mapped_column(BigInteger, comment="文档类型ID")
processingStatus: Mapped[str | None] = mapped_column(String(64), default="waiting", comment="waiting/processing/completed/failed")
currentRunId: Mapped[int | None] = mapped_column(BigInteger, comment="最新有效 run id")
@classmethod
async def get_by_biz_id(cls, session: AsyncSession, bizDocumentId: int) -> "LeauditDocument | None":
"""按业务文档 ID 查询。"""
from sqlalchemy import select
return await session.scalar(select(cls).where(cls.bizDocumentId == bizDocumentId))
@classmethod
async def upsert_by_biz_id(cls, session: AsyncSession, bizDocumentId: int, **fields) -> "LeauditDocument":
"""按业务文档 ID 创建或更新。"""
from sqlalchemy import select
doc = await session.scalar(select(cls).where(cls.bizDocumentId == bizDocumentId))
if doc is None:
doc = cls(bizDocumentId=bizDocumentId, **fields)
session.add(doc)
else:
for k, v in fields.items():
setattr(doc, k, v)
await session.flush()
return doc
@@ -0,0 +1,28 @@
"""LeAudit 文档文件模型 —— leaudit_document_files 表。"""
from __future__ import annotations
from sqlalchemy import BigInteger, Boolean, String
from sqlalchemy.orm import Mapped, mapped_column
from fastapi_common.fastapi_common_web.models import BaseModel
class LeauditDocumentFile(BaseModel):
"""文档文件表。"""
__tablename__ = "leaudit_document_files"
Id: Mapped[int] = mapped_column(BigInteger, primary_key=True, autoincrement=True)
documentId: Mapped[int] = mapped_column(BigInteger, comment="关联 leaudit_documents.id")
fileRole: Mapped[str] = mapped_column(String(64), comment="original/converted_pdf/merged_pdf/temp_input")
fileName: Mapped[str] = mapped_column(String(512), comment="文件名")
fileExt: Mapped[str | None] = mapped_column(String(32), comment="扩展名")
mimeType: Mapped[str | None] = mapped_column(String(128), comment="MIME")
fileSize: Mapped[int | None] = mapped_column(BigInteger, comment="文件大小")
sha256: Mapped[str | None] = mapped_column(String(64), comment="SHA256")
localPath: Mapped[str | None] = mapped_column(String(1024), comment="本地路径")
ossUrl: Mapped[str | None] = mapped_column(String(2048), comment="OSS 地址")
storageProvider: Mapped[str | None] = mapped_column(String(32), comment="oss/minio/local")
isActive: Mapped[bool] = mapped_column(Boolean, default=True, comment="当前生效文件")
createdBy: Mapped[int | None] = mapped_column(BigInteger, comment="上传人")
@@ -0,0 +1,8 @@
"""LeAudit 服务层导出。"""
from fastapi_modules.fastapi_leaudit.services.auditService import IAuditService
from fastapi_modules.fastapi_leaudit.services.authService import IAuthService
from fastapi_modules.fastapi_leaudit.services.permissionService import IPermissionService
from fastapi_modules.fastapi_leaudit.services.ruleService import IRuleService
__all__ = ["IAuditService", "IAuthService", "IPermissionService", "IRuleService"]
@@ -0,0 +1,24 @@
"""评查服务接口。"""
from abc import ABC, abstractmethod
from fastapi_modules.fastapi_leaudit.domian.vo.auditVo import AuditRunVO, AuditResultVO
class IAuditService(ABC):
"""评查服务接口。"""
@abstractmethod
async def Run(self, DocumentId: int) -> AuditRunVO:
"""触发文档评查。"""
...
@abstractmethod
async def GetRunStatus(self, RunId: int) -> AuditRunVO:
"""查询评查运行状态。"""
...
@abstractmethod
async def GetResult(self, RunId: int) -> AuditResultVO:
"""获取评查结果。"""
...
@@ -0,0 +1,22 @@
"""认证服务接口。"""
from abc import ABC, abstractmethod
from fastapi_modules.fastapi_leaudit.domian.vo.auth.loginTokenVo import LoginTokenVO
class IAuthService(ABC):
"""认证服务接口。"""
@abstractmethod
async def PasswordLogin(self, Sub: str, Password: str) -> LoginTokenVO:
"""账密登录。"""
...
@abstractmethod
async def OAuthLogin(self, Sub: str, Username: str | None, Nickname: str | None,
Email: str | None, PhoneNumber: str | None,
OuId: str | None, OuName: str | None,
IsLeader: bool | None, Area: str | None, ExpiresIn: int) -> LoginTokenVO:
"""OAuth 登录。"""
...
@@ -0,0 +1,65 @@
"""评查服务实现。
编排 LeAudit 引擎执行链路:
文档 → OCR → Extract → Evaluate → Rescue → Persist
"""
from fastapi_common.fastapi_common_logger import logger
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
from fastapi_modules.fastapi_leaudit.domian.vo.auditVo import AuditRunVO, AuditResultVO
from fastapi_modules.fastapi_leaudit.models import LeauditAuditRun
from fastapi_modules.fastapi_leaudit.services import IAuditService
class AuditServiceImpl(IAuditService):
"""评查服务实现。"""
async def Run(self, DocumentId: int, RuleType: str | None = None, Force: bool = False) -> AuditRunVO:
"""触发文档评查。
实际执行流程由 Celery 任务异步处理。
"""
async with GetAsyncSession() as session:
# TODO: 从 bridge 层获取 pipeline,提交 Celery 任务
logger.info(f"触发评查: documentId={DocumentId}, ruleType={RuleType}")
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "Celery 任务集成待实现")
async def GetRunStatus(self, RunId: int) -> AuditRunVO:
"""查询评查运行状态。"""
async with GetAsyncSession() as session:
run = await session.get(LeauditAuditRun, RunId)
if not run:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "评查运行记录不存在")
return AuditRunVO(
runId=run.Id,
documentId=run.documentId,
runNo=run.runNo,
status=run.status,
phase=run.phase,
totalScore=float(run.totalScore) if run.totalScore else None,
passedCount=run.passedCount,
failedCount=run.failedCount,
startedAt=run.startedAt,
finishedAt=run.finishedAt,
)
async def GetResult(self, RunId: int) -> AuditResultVO:
"""获取评查结果。"""
async with GetAsyncSession() as session:
run = await session.get(LeauditAuditRun, RunId)
if not run:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "评查运行记录不存在")
# TODO: 从 leaudit_rule_results 表查询规则级结果
return AuditResultVO(
runId=run.Id,
totalScore=float(run.totalScore) if run.totalScore else None,
passedCount=run.passedCount or 0,
failedCount=run.failedCount or 0,
skippedCount=run.skippedCount or 0,
phase=run.phase,
rescueApplied=run.rescueApplied or False,
rules=[],
)
@@ -0,0 +1,162 @@
"""认证服务实现。
从旧项目 app/routes/auth.py 和 app/auth/auth.py 迁移,业务逻辑完全不变。
仅重组为 Controller → Service(interface+impl) → Model 结构。
"""
from fastapi_common.fastapi_common_logger import logger
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from fastapi_common.fastapi_common_security.jwtService import JwtService
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
from fastapi_modules.fastapi_leaudit.domian.vo.auth.loginTokenVo import LoginTokenVO
from fastapi_modules.fastapi_leaudit.services.authService import IAuthService
class AuthServiceImpl(IAuthService):
"""认证服务实现。"""
async def PasswordLogin(self, Sub: str, Password: str) -> LoginTokenVO:
"""账密登录。
校验 sso_users 表:sub + password + status=0 + deleted_at IS NULL。
安全:统一错误提示"账号或密码错误",防止用户枚举。
"""
async with GetAsyncSession() as session:
from sqlalchemy import select, text
result = await session.execute(
text("SELECT id, sub, username, nick_name, phone_number, email, "
"ou_id, ou_name, is_leader, password, status, deleted_at, "
"try_count, try_login_time, area, tenant_name, dep_name, dep_short_name "
"FROM sso_users WHERE sub = :sub"),
{"sub": Sub},
)
row = result.fetchone()
if not row:
logger.warning(f"登录失败: 用户不存在 - sub={Sub}")
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号或密码错误")
user = dict(row._mapping)
if user.get("deleted_at") is not None:
logger.warning(f"登录失败: 账号已删除 - sub={Sub}")
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号或密码错误")
if user.get("status") != 0:
logger.warning(f"登录失败: 账号已禁用 - sub={Sub}, status={user.get('status')}")
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号或密码错误")
if user.get("password") != Password:
logger.warning(f"登录失败: 密码错误 - sub={Sub}")
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号或密码错误")
return await self._buildLoginResponse(user, session)
async def OAuthLogin(self, Sub: str, Username: str | None, Nickname: str | None,
Email: str | None, PhoneNumber: str | None,
OuId: str | None, OuName: str | None,
IsLeader: bool | None, Area: str | None, ExpiresIn: int) -> LoginTokenVO:
"""OAuth 登录。验证 sub 是否存在,不存在则自动创建用户。"""
async with GetAsyncSession() as session:
from sqlalchemy import select, text
from datetime import datetime, timezone
result = await session.execute(
text("SELECT id, sub, username, nick_name, phone_number, email, "
"ou_id, ou_name, is_leader, status, deleted_at, "
"area, tenant_name, dep_name, dep_short_name "
"FROM sso_users WHERE sub = :sub"),
{"sub": Sub},
)
row = result.fetchone()
if row:
user = dict(row._mapping)
if user.get("deleted_at") is not None or user.get("status") != 0:
raise LeauditException(StatusCodeEnum.HTTP_401_UNAUTHORIZED, "账号已被禁用或删除")
# 更新最后登录信息
await session.execute(
text("UPDATE sso_users SET username = :username, nick_name = :nick, "
"email = :email, phone_number = :phone, ou_id = :ou_id, "
"ou_name = :ou_name, is_leader = :is_leader, area = :area, "
"updated_at = :now WHERE id = :id"),
{"username": Username, "nick": Nickname, "email": Email,
"phone": PhoneNumber, "ou_id": OuId, "ou_name": OuName,
"is_leader": IsLeader, "area": Area,
"now": datetime.now(timezone.utc), "id": user["id"]},
)
else:
# 自动创建用户
await session.execute(
text("INSERT INTO sso_users (sub, username, nick_name, email, "
"phone_number, ou_id, ou_name, is_leader, area, status) "
"VALUES (:sub, :username, :nick, :email, :phone, :ou_id, "
":ou_name, :is_leader, :area, 0)"),
{"sub": Sub, "username": Username, "nick": Nickname, "email": Email,
"phone": PhoneNumber, "ou_id": OuId, "ou_name": OuName,
"is_leader": IsLeader, "area": Area},
)
await session.commit()
result = await session.execute(
text("SELECT id, sub, username, nick_name, phone_number, email, "
"ou_id, ou_name, is_leader, area, tenant_name, dep_name, dep_short_name "
"FROM sso_users WHERE sub = :sub"),
{"sub": Sub},
)
row = result.fetchone()
user = dict(row._mapping) if row else {}
if not user:
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "用户创建失败")
return await self._buildLoginResponse(user, session)
async def _buildLoginResponse(self, user: dict, session) -> LoginTokenVO:
"""组装登录响应:查询角色 → 签发 JWT → 返回 LoginTokenVO。"""
from sqlalchemy import text
# 查询用户角色
roleResult = await session.execute(
text("SELECT r.role_key FROM user_role ur "
"JOIN roles r ON ur.role_id = r.id "
"WHERE ur.user_id = :uid LIMIT 1"),
{"uid": user["id"]},
)
roleRow = roleResult.fetchone()
userRole = roleRow[0] if roleRow else "common"
# 签发 JWT
expiresIn = 3600 # 默认 1 小时
tokens = JwtService.generate(
userId=user["id"],
username=user.get("username") or user.get("sub", ""),
nickName=user.get("nick_name") or "",
ouId=user.get("ou_id") or "",
ouName=user.get("ou_name") or "",
area=user.get("area"),
userRole=userRole,
)
return LoginTokenVO(
access_token=tokens["access_token"],
token_type="Bearer",
expires_in=expiresIn,
issued_time=tokens.get("issued_time", ""),
user_info={
"user_id": user["id"],
"sub": user.get("sub"),
"username": user.get("username"),
"nick_name": user.get("nick_name"),
"email": user.get("email"),
"phone_number": user.get("phone_number"),
"ou_id": user.get("ou_id"),
"ou_name": user.get("ou_name"),
"is_leader": user.get("is_leader"),
"area": user.get("area"),
"role": userRole,
},
)
@@ -0,0 +1,141 @@
"""权限服务实现。
从旧项目 app/rbac/permission_checker_v2.py 迁移,业务逻辑完全不变。
- 数据库驱动 GRANT/DENY 机制(DENY 优先级更高)
- 支持通配符:document:*:*、*:*:* 等
- 实时查询(无 Redis 缓存)
仅改造为 SQLAlchemy 会话 + 项目统一配置。
"""
from fastapi_common.fastapi_common_logger import logger
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from fastapi_modules.fastapi_leaudit.services.permissionService import IPermissionService
class PermissionServiceImpl(IPermissionService):
"""权限检查服务实现。"""
async def CheckPermission(self, UserId: int, PermissionKey: str) -> bool:
"""检查用户是否拥有指定权限。
GRANT/DENY 优先级:
1. 精确 DENY → 拒绝
2. 通配符 DENY → 拒绝
3. 精确 GRANT → 通过
4. 通配符 GRANT → 通过
5. 无匹配 → 拒绝
"""
try:
grants, denies = await self._getUserPermissions(UserId)
# DENY 优先
if PermissionKey in denies:
logger.debug(f"[DENY] 精确拒绝: user={UserId}, perm={PermissionKey}")
return False
if self._matchWildcard(PermissionKey, denies):
logger.debug(f"[DENY] 通配符拒绝: user={UserId}, perm={PermissionKey}")
return False
# GRANT
if PermissionKey in grants:
logger.debug(f"[GRANT] 精确授权: user={UserId}, perm={PermissionKey}")
return True
if self._matchWildcard(PermissionKey, grants):
logger.debug(f"[GRANT] 通配符授权: user={UserId}, perm={PermissionKey}")
return True
logger.debug(f"[DENY] 无匹配权限: user={UserId}, perm={PermissionKey}")
return False
except Exception as e:
logger.error(f"权限检查异常: user={UserId}, perm={PermissionKey}, error={e}")
return False # 安全优先:异常时拒绝
async def HasAnyPermission(self, UserId: int, PermissionKeys: list[str]) -> bool:
"""OR 逻辑:任一权限通过即返回 True。"""
for key in PermissionKeys:
if await self.CheckPermission(UserId, key):
return True
return False
async def HasAllPermissions(self, UserId: int, PermissionKeys: list[str]) -> bool:
"""AND 逻辑:全部权限通过才返回 True。"""
for key in PermissionKeys:
if not await self.CheckPermission(UserId, key):
return False
return True
# ------------------------------------------------------------------
# 内部方法
# ------------------------------------------------------------------
async def _getUserPermissions(self, UserId: int) -> tuple[set[str], set[str]]:
"""从数据库查询用户的 GRANT 和 DENY 权限集合。"""
grants: set[str] = set()
denies: set[str] = set()
async with GetAsyncSession() as session:
from sqlalchemy import text
result = await session.execute(
text(
"SELECT p.permission_key, rp.grant_type "
"FROM sso_users u "
"JOIN user_role ur ON u.id = ur.user_id "
"JOIN roles r ON ur.role_id = r.id "
"JOIN role_permissions rp ON r.id = rp.role_id "
"JOIN permissions p ON rp.permission_id = p.id "
"WHERE u.id = :uid"
),
{"uid": UserId},
)
rows = result.fetchall()
if not rows:
logger.warning(f"用户无角色或权限: user_id={UserId}")
return grants, denies
for row in rows:
permKey = row[0]
grantType = row[1]
if grantType == "GRANT":
grants.add(permKey)
elif grantType == "DENY":
denies.add(permKey)
logger.debug(f"用户权限: user={UserId}, grants={len(grants)}, denies={len(denies)}")
return grants, denies
@staticmethod
def _matchWildcard(PermissionKey: str, PermissionSet: set[str]) -> bool:
"""检查 PermissionKey 是否匹配集合中的任一通配符模式。"""
for pattern in PermissionSet:
if PermissionServiceImpl._wildcardMatch(PermissionKey, pattern):
return True
return False
@staticmethod
def _wildcardMatch(PermissionKey: str, Pattern: str) -> bool:
"""通配符匹配。
_wildcardMatch("document:read:all", "document:*:*") → True
_wildcardMatch("document:read:all", "document:read:*") → True
_wildcardMatch("document:read:all", "evaluation:*:*") → False
_wildcardMatch("document:read:all", "*:*:*") → True
"""
keyParts = PermissionKey.split(":")
patternParts = Pattern.split(":")
if len(keyParts) != len(patternParts):
return False
for keyPart, patternPart in zip(keyParts, patternParts):
if patternPart == "*":
continue
if keyPart != patternPart:
return False
return True
@@ -0,0 +1,27 @@
"""权限服务接口。"""
from abc import ABC, abstractmethod
class IPermissionService(ABC):
"""权限检查服务接口。
权限格式:module:resource:action(如 document:list:read
支持通配符:document:*:*、*:*:* 等
支持 GRANT/DENY 机制(DENY 优先级更高)
"""
@abstractmethod
async def CheckPermission(self, UserId: int, PermissionKey: str) -> bool:
"""检查用户是否拥有指定权限。"""
...
@abstractmethod
async def HasAnyPermission(self, UserId: int, PermissionKeys: list[str]) -> bool:
"""检查用户是否拥有任意一个权限(OR 逻辑)。"""
...
@abstractmethod
async def HasAllPermissions(self, UserId: int, PermissionKeys: list[str]) -> bool:
"""检查用户是否拥有所有权限(AND 逻辑)。"""
...
@@ -0,0 +1,24 @@
"""规则服务接口。"""
from abc import ABC, abstractmethod
from fastapi_modules.fastapi_leaudit.domian.vo.ruleVo import RuleSetVO, RuleVersionVO
class IRuleService(ABC):
"""规则服务接口。"""
@abstractmethod
async def ListSets(self) -> list[RuleSetVO]:
"""列出所有规则集。"""
...
@abstractmethod
async def GetVersions(self, RuleType: str) -> list[RuleVersionVO]:
"""获取规则集的所有版本。"""
...
@abstractmethod
async def Publish(self, RuleType: str, VersionId: int) -> RuleVersionVO:
"""发布指定版本。"""
...