feat: complete M1-M3 infrastructure — OSS client, native execution chain, rule lifecycle API, system docs

- M1: unified OSS client (upload/download/presign) + path utils + config
- M2: rule service with validate/create/publish/rollback + binding CRUD endpoints
- M3: native AuditCtx runner, file/rule resolvers, storage adapter with full persistence
- docs: SYSTEM_OVERVIEW.md as comprehensive architecture reference
- fix: double finalize — terminal state now written once by finalize_run
This commit is contained in:
wren
2026-04-28 11:49:55 +08:00
parent be9fc4856b
commit 246c0e5ded
26 changed files with 1771 additions and 188 deletions
@@ -0,0 +1,10 @@
"""规则发布 DTO。"""
from pydantic import BaseModel, Field
class RulePublishDTO(BaseModel):
"""规则版本发布/回滚请求。"""
versionId: int = Field(..., description="规则版本ID")
operatorUserId: int | None = Field(None, description="操作用户ID")
@@ -0,0 +1,9 @@
"""规则校验 DTO。"""
from pydantic import BaseModel, Field
class RuleValidateDTO(BaseModel):
"""规则 YAML 校验请求。"""
yamlText: str = Field(..., description="规则 YAML 正文")
@@ -0,0 +1,11 @@
"""规则版本创建 DTO。"""
from pydantic import BaseModel, Field
class RuleVersionCreateDTO(BaseModel):
"""创建规则版本请求。"""
yamlText: str = Field(..., description="规则 YAML 正文")
changeNote: str | None = Field(None, description="版本变更说明")
editorUserId: int | None = Field(None, description="编辑者用户ID")
@@ -1,44 +1,30 @@
"""leaudit bridge — use leaudit's full pipeline with docauditai's database storage.
"""LeAudit Bridge 模块。
Directly calls leaudit's OCR → extraction → evaluation pipeline
and persists results into docauditai's PostgreSQL via PostgREST.
Configuration switch (in env.{port}):
PIPELINE_MODE=leaudit → use leaudit pipeline
对平台暴露统一桥接入口,内部逐步从旧的手写 pipeline
迁移到原生 ``AuditCtx`` + ``AuditService`` 路线。
"""
from leaudit_bridge.client_factory import (
from fastapi_modules.fastapi_leaudit.leaudit_bridge.client_factory import (
create_ocr_client,
create_llm_client,
create_vlm_client,
)
from leaudit_bridge.ocr_bridge import BridgeOCRClient
from leaudit_bridge.pipeline import LauditPipeline, PipelineResult
from leaudit_bridge.rules_loader import RulesLoader
from leaudit_bridge.storage_adapter import StorageAdapter
from fastapi_modules.fastapi_leaudit.leaudit_bridge.ocr_bridge import BridgeOCRClient
from fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline import LauditPipeline, PipelineResult
from fastapi_modules.fastapi_leaudit.leaudit_bridge.rules_loader import RulesLoader
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
def is_leaudit_mode() -> bool:
"""Check if the system is configured to use the leaudit pipeline."""
from core.config import PIPELINE_MODE
return PIPELINE_MODE == "leaudit"
"""新平台始终使用 leaudit pipeline"""
return True
def create_pipeline(rules_path: str | None = None) -> LauditPipeline:
"""Create a fully configured LauditPipeline from current config.
"""创建旧版兼容 LauditPipeline。
Wraps the raw OCR client with DocNormalizationAdapter so that a single
``.ocr()`` call produces a fully enriched OcrResult with:
- Document classification (type_id + rules_file_path)
- Dossier segmentation (sub-document page mapping)
- Seal/signature enrichment (text, seal_id, party_id)
- Normalized markdown (seal blocks + page separators)
Args:
rules_path: If provided, forces the adapter to use this rules file
for classification and segmentation. When None, the adapter
uses the RulesFileRegistry to classify from document content,
enabling auto-detection of sub-types (e.g. 行政许可 variants).
当前仍保留该入口兼容旧调用方,后续正式执行链应逐步切到
``NativeRunner``。
"""
from pathlib import Path
from leaudit.doc_normalization.adapter import DocNormalizationAdapter
@@ -51,7 +37,7 @@ def create_pipeline(rules_path: str | None = None) -> LauditPipeline:
# Build registry from rules/ directory for content-based classification
registry = None
if rules_path is None:
rules_dir = Path(__file__).resolve().parents[1] / "rules"
rules_dir = Path(__file__).resolve().parents[3] / "rules"
if rules_dir.is_dir():
registry = RulesFileRegistry.from_directory(rules_dir)
@@ -0,0 +1,75 @@
"""Build native leaudit ``AuditCtx`` instances from platform-side inputs."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from leaudit.config.audit_config import AuditConfig
from leaudit.services.audit_ctx import AuditCtx
from leaudit.services.audit_services import AuditServices
@dataclass(frozen=True)
class NativeAuditMetadata:
"""Platform-side metadata kept outside the native ``AuditCtx`` model."""
run_id: int
document_id: int
document_file_id: int | None = None
rule_set_id: int | None = None
rule_version_id: int | None = None
trigger_user_id: int | None = None
extras: dict[str, Any] = field(default_factory=dict)
@dataclass(frozen=True)
class NativeAuditBuildInput:
"""Everything the bridge knows before constructing a native ``AuditCtx``."""
metadata: NativeAuditMetadata
file_path: str
services: AuditServices
rules_file: Any | None = None
page_range: tuple[int, ...] | None = None
rule_source_path: str | None = None
force_rules_path: str | None = None
config_overrides: dict[str, Any] = field(default_factory=dict)
class AuditCtxBuilder:
"""Translate platform-side run inputs into leaudit's native ``AuditCtx``."""
def build(self, payload: NativeAuditBuildInput) -> AuditCtx:
"""Create a native ``AuditCtx`` ready for ``AuditService.audit``."""
config = self.build_config(
force_rules_path=payload.force_rules_path or payload.rule_source_path,
overrides=payload.config_overrides,
)
return AuditCtx(
document_id=str(payload.metadata.document_id),
rules_file=payload.rules_file,
services=payload.services,
file_path=payload.file_path,
page_range=payload.page_range,
config=config,
)
def build_config(
self,
*,
force_rules_path: str | None = None,
overrides: dict[str, Any] | None = None,
) -> AuditConfig:
"""Build native ``AuditConfig`` from platform-side overrides."""
raw = dict(overrides or {})
if force_rules_path and "force_rules_path" not in raw:
raw["force_rules_path"] = force_rules_path
return AuditConfig(**raw)
__all__ = [
"AuditCtxBuilder",
"NativeAuditBuildInput",
"NativeAuditMetadata",
]
@@ -0,0 +1,132 @@
"""Factory helpers for leaudit's native service-layer orchestration.
This module is the bridge-side assembly point for native leaudit services:
- ``AuditServices``
- ``DocNormalizationService``
- ``ExtractionService``
- ``EvaluationService``
- ``RescueService``
- ``AuditService``
The platform should not construct these objects in controllers/services
directly. Keep all leaudit-native wiring inside ``leaudit_bridge/``.
"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from leaudit.services.audit_service import AuditService
from leaudit.services.audit_services import AuditServices
from leaudit.services.doc_normalization_service import DocNormalizationService
from leaudit.services.evaluation_service import EvaluationService
from leaudit.services.extraction_service import ExtractionService
from leaudit.services.rescue_service import RescueService
from fastapi_modules.fastapi_leaudit.leaudit_bridge.client_factory import (
create_llm_client,
create_ocr_client,
create_vlm_client,
)
from fastapi_modules.fastapi_leaudit.leaudit_bridge.ocr_bridge import BridgeOCRClient
@dataclass(frozen=True)
class NativeServiceBundle:
"""Fully assembled native leaudit service bundle."""
audit_service: AuditService
audit_services: AuditServices
normalization_service: DocNormalizationService
extraction_service: ExtractionService
evaluation_service: EvaluationService
rescue_service: RescueService | None
class AuditServiceFactory:
"""Build native leaudit services for one platform-side run."""
def create_bundle(self, rules_path: str | None = None) -> NativeServiceBundle:
"""Create a fully wired native leaudit service bundle.
``rules_path`` is only used to force the normalization adapter's
classification path when the caller wants a fixed rules file.
"""
normalization_service, audit_services = self._create_normalization_services(
rules_path=rules_path
)
extraction_service = ExtractionService(
session=None,
llm_client=audit_services.llm_client,
)
evaluation_service = EvaluationService(session=None)
rescue_service = RescueService(
session=None,
llm_client=audit_services.llm_client,
vlm_client=audit_services.vlm_client,
extraction_service=extraction_service,
)
audit_service = AuditService(
document_service=None,
normalization_service=normalization_service,
extraction_service=extraction_service,
evaluation_service=evaluation_service,
rescue_service=rescue_service,
services=AuditServices(
llm_client=audit_services.llm_client,
vlm_client=audit_services.vlm_client,
ocr_client=audit_services.ocr_client,
normalization=normalization_service,
extraction=extraction_service,
evaluation=evaluation_service,
),
)
return NativeServiceBundle(
audit_service=audit_service,
audit_services=audit_service.services,
normalization_service=normalization_service,
extraction_service=extraction_service,
evaluation_service=evaluation_service,
rescue_service=rescue_service,
)
def _create_normalization_services(
self, rules_path: str | None = None
) -> tuple[DocNormalizationService, AuditServices]:
"""Create normalization service plus low-level shared clients."""
from leaudit.doc_normalization.adapter import DocNormalizationAdapter
from leaudit.doc_normalization.doc_classifier import RulesFileRegistry
raw_ocr = create_ocr_client()
llm_client = create_llm_client()
vlm_client = create_vlm_client()
registry = None
if rules_path is None:
rules_dir = Path(__file__).resolve().parents[3] / "rules"
if rules_dir.is_dir():
registry = RulesFileRegistry.from_directory(rules_dir)
adapter = DocNormalizationAdapter(
ocr_client=raw_ocr,
registry=registry,
llm_client=llm_client,
vlm_client=vlm_client,
force_rules_path=rules_path,
)
ocr_client = BridgeOCRClient(adapter, vlm_client=vlm_client)
normalization_service = DocNormalizationService(ocr_client)
audit_services = AuditServices(
llm_client=llm_client,
vlm_client=vlm_client,
ocr_client=raw_ocr,
normalization=normalization_service,
)
return normalization_service, audit_services
__all__ = ["AuditServiceFactory", "NativeServiceBundle"]
@@ -5,13 +5,15 @@ from __future__ import annotations
import logging
from typing import TYPE_CHECKING
from core.config import (
OCR_CONFIG,
DEFAULT_BASE_URL,
DEFAULT_LLM_MODEL,
DEFAULT_API_KEY,
DEFAULT_VLM_BASE_URL,
DEFAULT_VLM_MODEL,
from fastapi_admin.config import (
OCR_BASE_URL,
OCR_TIMEOUT,
LLM_BASE_URL,
LLM_MODEL,
LLM_API_KEY,
VLM_BASE_URL,
VLM_MODEL,
VLM_API_KEY,
)
if TYPE_CHECKING:
@@ -29,8 +31,8 @@ def create_ocr_client() -> BaseOCRClient:
base_url = os.getenv("LEAUDIT_OCR_URL", "").rstrip("/")
if not base_url:
base_url = OCR_CONFIG["API_URL"].rsplit("/api/v1/ocr", 1)[0]
timeout = float(OCR_CONFIG["TIMEOUT"])
base_url = OCR_BASE_URL.rstrip("/")
timeout = float(OCR_TIMEOUT)
client = ChandraOCRClient(
base_url=base_url,
@@ -45,9 +47,9 @@ def create_llm_client() -> BaseLLMClient:
"""Create a leaudit OpenAICompatibleClient from docauditai's LLM config."""
from leaudit.llm.openai_client import OpenAICompatibleClient
base_url = DEFAULT_BASE_URL
model = DEFAULT_LLM_MODEL
api_key = DEFAULT_API_KEY or "no-key"
base_url = LLM_BASE_URL
model = LLM_MODEL
api_key = LLM_API_KEY or "no-key"
client = OpenAICompatibleClient(
api_key=api_key,
@@ -63,9 +65,9 @@ def create_vlm_client() -> BaseVLMClient | None:
"""Create a leaudit QwenVLMClient from docauditai's VLM config."""
from leaudit.llm.qwen_vlm_client import QwenVLMClient
base_url = DEFAULT_VLM_BASE_URL
model = DEFAULT_VLM_MODEL
api_key = DEFAULT_API_KEY or "no-key"
base_url = VLM_BASE_URL
model = VLM_MODEL
api_key = VLM_API_KEY or LLM_API_KEY or "no-key"
if not base_url or not model:
log.info("leaudit VLM client skipped: no VLM config")
@@ -0,0 +1,63 @@
"""文档文件来源解析器。"""
from __future__ import annotations
from dataclasses import dataclass
from pathlib import Path
from fastapi_common.fastapi_common_logger import logger
from fastapi_common.fastapi_common_storage.oss_client import OssClient
from fastapi_modules.fastapi_leaudit.models.leauditDocumentFile import LeauditDocumentFile
@dataclass(frozen=True)
class FileSourcePayload:
"""可供执行链消费的文件载荷。"""
fileName: str
fileContent: bytes
sourceType: str
sourcePath: str | None = None
class FileSourceResolver:
"""解析文档文件来源。"""
def __init__(self, Oss: OssClient | None = None) -> None:
self.Oss = Oss or OssClient()
async def ResolvePayload(self, DocumentFile: LeauditDocumentFile) -> FileSourcePayload:
"""解析文档文件,返回任务入口可直接消费的字节载荷。"""
if DocumentFile.localPath:
LocalPath = Path(DocumentFile.localPath)
if LocalPath.is_file():
return FileSourcePayload(
fileName=DocumentFile.fileName,
fileContent=LocalPath.read_bytes(),
sourceType="local",
sourcePath=str(LocalPath),
)
if DocumentFile.ossUrl:
return await self._DownloadFromUrl(
FileName=DocumentFile.fileName,
Url=DocumentFile.ossUrl,
)
raise ValueError("当前文档文件既无可用 localPath,也无可用 ossUrl")
async def _DownloadFromUrl(self, FileName: str, Url: str) -> FileSourcePayload:
"""从 OSS 或 URL 下载文件内容。"""
try:
Content = self.Oss.DownloadBytes(Url)
except Exception as Error:
logger.error(f"下载 OSS 文件失败: url={Url}, error={Error}")
raise
return FileSourcePayload(
fileName=FileName,
fileContent=Content,
sourceType="oss",
sourcePath=Url,
)
@@ -0,0 +1,185 @@
"""Native AuditCtx runner for the platform bridge.
This module is the target execution path after deprecating the old
platform-side hand-written pipeline orchestration.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
from leaudit.services.audit_ctx import AuditCtx
from fastapi_modules.fastapi_leaudit.leaudit_bridge.auditCtxBuilder import (
AuditCtxBuilder,
NativeAuditBuildInput,
NativeAuditMetadata,
)
from fastapi_modules.fastapi_leaudit.leaudit_bridge.auditServiceFactory import (
AuditServiceFactory,
NativeServiceBundle,
)
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
@dataclass(frozen=True)
class NativeRunRequest:
"""Platform-side request payload for one native leaudit run."""
metadata: NativeAuditMetadata
local_file_path: str
rules_file: Any | None = None
rule_source_path: str | None = None
rules_path_override: str | None = None
page_range: tuple[int, ...] | None = None
config_overrides: dict[str, Any] = field(default_factory=dict)
@dataclass(frozen=True)
class NativeRunResult:
"""Result of one native ``AuditService.audit(ctx)`` execution."""
ctx: AuditCtx
service_bundle: NativeServiceBundle
metadata: NativeAuditMetadata
class NativeRunner:
"""Bridge-side runner that delegates orchestration to native leaudit."""
def __init__(
self,
*,
service_factory: AuditServiceFactory | None = None,
ctx_builder: AuditCtxBuilder | None = None,
storage_adapter: StorageAdapter | None = None,
) -> None:
self.service_factory = service_factory or AuditServiceFactory()
self.ctx_builder = ctx_builder or AuditCtxBuilder()
self.storage = storage_adapter or StorageAdapter()
async def run(self, request: NativeRunRequest) -> NativeRunResult:
"""Execute one native leaudit run and return the populated ctx.
Persistence is intentionally not mixed into the orchestration step.
The caller can choose when to persist the final ctx to platform tables.
"""
bundle = self.service_factory.create_bundle(
rules_path=request.rules_path_override or request.rule_source_path
)
ctx = self.ctx_builder.build(
NativeAuditBuildInput(
metadata=request.metadata,
file_path=request.local_file_path,
services=bundle.audit_services,
rules_file=request.rules_file,
page_range=request.page_range,
rule_source_path=request.rule_source_path,
force_rules_path=request.rules_path_override,
config_overrides=request.config_overrides,
)
)
ctx = await bundle.audit_service.audit(ctx)
return NativeRunResult(
ctx=ctx,
service_bundle=bundle,
metadata=request.metadata,
)
async def persist_result(self, result: NativeRunResult) -> None:
"""Persist a native run into platform-owned ``leaudit_*`` tables.
"""
document_id = result.metadata.document_id
run_id = result.metadata.run_id
ctx = result.ctx
extraction_errors = list(ctx.extraction_errors)
if not extraction_errors and ctx.extraction is not None:
extraction_errors = list(ctx.extraction.all_errors)
if ctx.normalized_doc is not None:
await self.storage.save_ocr_result(
document_id,
ctx.normalized_doc,
run_id=run_id,
)
if ctx.extraction is not None:
await self.storage.save_extraction_result(
document_id,
ctx.extraction,
run_id=run_id,
)
if ctx.evaluation is not None and ctx.rules_file is not None and ctx.extraction is not None:
await self.storage.save_evaluation_results(
document_id,
ctx.rules_file,
ctx.evaluation,
ctx.extraction,
run_id=run_id,
rule_version_id=result.metadata.rule_version_id,
)
if extraction_errors:
await self.storage.save_run_errors(
document_id,
run_id=run_id,
stage=ctx.phase or "extract",
messages=extraction_errors,
level="warning",
error_code="EXTRACTION_WARNING",
)
if ctx.fallback_tasks:
await self.storage.save_rescue_outcomes(
document_id,
run_id=run_id,
tasks=ctx.fallback_tasks,
)
await self.storage.save_run_metrics(
document_id,
run_id=run_id,
timing=dict(ctx.timing),
page_count=len(ctx.normalized_doc.pages) if ctx.normalized_doc is not None else None,
sub_document_count=len(ctx.extraction.sub_documents) if ctx.extraction is not None and getattr(ctx.extraction, "sub_documents", None) else 0,
field_count=len(ctx.extraction.fields) if ctx.extraction is not None else 0,
rule_count=len(ctx.evaluation.rules) if ctx.evaluation is not None else (len(ctx.rules_file.flat_rules) if ctx.rules_file is not None else 0),
rescue_rule_count=len(ctx.fallback_tasks),
artifact_count=self._estimate_artifact_count(ctx),
)
result_status = "review" if any(task.requires_human_review for task in ctx.fallback_tasks) else self._resolve_result_status(ctx)
await self.storage.finalize_run(
document_id,
run_id=run_id,
result_status=result_status,
rescue_applied=bool(ctx.fallback_tasks),
phase=ctx.phase,
finished=True,
)
def _estimate_artifact_count(self, ctx: AuditCtx) -> int:
"""粗略估算当前运行已经产出的平台产物数。"""
count = 0
if ctx.normalized_doc is not None:
count += 1
if ctx.extraction is not None:
count += 1
if ctx.evaluation is not None:
count += 1
if ctx.fallback_tasks:
count += len(ctx.fallback_tasks)
return count
def _resolve_result_status(self, ctx: AuditCtx) -> str:
"""按原生 AuditCtx 结果推导运行状态。"""
if ctx.evaluation is None:
return "error"
if ctx.evaluation.errors:
return "error"
if ctx.evaluation.failed_count == 0 and ctx.evaluation.skipped_count == 0:
return "pass"
if ctx.evaluation.failed_count > 0:
return "fail"
return "partial"
__all__ = ["NativeRunRequest", "NativeRunResult", "NativeRunner"]
@@ -23,7 +23,7 @@ from leaudit.llm.base import BaseLLMClient
from leaudit.ocr.base import BaseOCRClient
from leaudit.ocr.models import OcrResult
from leaudit_bridge.storage_adapter import StorageAdapter
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
log = logging.getLogger(__name__)
@@ -229,7 +229,7 @@ class LauditPipeline:
self, document_id: int, ocr_result: OcrResult,
) -> None:
"""Extract case number from OCR and write to database."""
from leaudit_bridge.case_number_extractor import (
from fastapi_modules.fastapi_leaudit.leaudit_bridge.case_number_extractor import (
extract_case_number_with_llm,
)
@@ -0,0 +1,86 @@
"""规则 YAML 校验器。"""
from __future__ import annotations
import importlib
import sys
from dataclasses import dataclass
import yaml
from pydantic import ValidationError
from leaudit.dsl.loader import parse_rules_yaml_text
from leaudit.dsl.validator import DSLValidationError, validate as validate_rules
@dataclass(frozen=True)
class RuleValidationPayload:
"""规则校验结果。"""
valid: bool
ruleType: str | None = None
ruleName: str | None = None
versionNo: str | None = None
ruleCount: int = 0
extractCount: int = 0
errors: list[str] | None = None
class RuleValidator:
"""负责规则 YAML 的语法与 DSL 语义校验。"""
_CHECK_MODULES = (
"leaudit.engine.checks.required",
"leaudit.engine.checks.compare",
"leaudit.engine.checks.format_check",
"leaudit.engine.checks.text",
"leaudit.engine.checks.multi_entity",
"leaudit.engine.checks.visual",
"leaudit.engine.checks.external",
"leaudit.engine.checks.assert_check",
"leaudit.engine.checks.code_check",
"leaudit.engine.checks.ai_check",
)
def ValidateYaml(self, YamlText: str) -> RuleValidationPayload:
"""校验 YAML 并返回摘要结果。"""
try:
RulesFile = parse_rules_yaml_text(YamlText)
self._EnsureChecksImported()
validate_rules(RulesFile, registered_primitives=None)
except yaml.YAMLError as Error:
return RuleValidationPayload(valid=False, errors=[f"YAML 语法错误: {Error}"])
except ValidationError as Error:
return RuleValidationPayload(valid=False, errors=[f"Schema 校验失败: {Error}"])
except DSLValidationError as Error:
return RuleValidationPayload(valid=False, errors=[f"DSL 校验失败: {Error}"])
except Exception as Error:
return RuleValidationPayload(valid=False, errors=[f"规则校验失败: {Error}"])
return RuleValidationPayload(
valid=True,
ruleType=RulesFile.metadata.type_id,
ruleName=RulesFile.metadata.name,
versionNo=RulesFile.metadata.version,
ruleCount=len(RulesFile.flat_rules),
extractCount=len(RulesFile.flat_extract),
errors=[],
)
def ParseValidated(self, YamlText: str):
"""解析并返回已通过完整校验的 RulesFile。"""
Validation = self.ValidateYaml(YamlText)
if not Validation.valid:
raise ValueError("; ".join(Validation.errors or ["规则校验失败"]))
return parse_rules_yaml_text(YamlText)
def _EnsureChecksImported(self) -> None:
"""确保所有检查器模块已注册。"""
for ModuleName in self._CHECK_MODULES:
try:
if ModuleName in sys.modules:
importlib.reload(sys.modules[ModuleName])
else:
importlib.import_module(ModuleName)
except Exception:
continue
@@ -0,0 +1,125 @@
"""规则版本来源解析器。"""
from __future__ import annotations
import hashlib
from dataclasses import dataclass
from pathlib import Path
from fastapi_common.fastapi_common_logger import logger
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from fastapi_common.fastapi_common_storage.oss_client import OssClient
from sqlalchemy import text
@dataclass(frozen=True)
class RuleVersionPayload:
"""规则文件解析结果。"""
localPath: str
sourceType: str
sourcePath: str
ruleVersionId: int | None = None
ruleTypeId: str | None = None
fileSha256: str | None = None
class RuleVersionResolver:
"""按运行记录解析规则 YAML 文件来源。"""
def __init__(self, Oss: OssClient | None = None) -> None:
self.Oss = Oss or OssClient()
async def ResolveForRun(self, RunId: int) -> RuleVersionPayload | None:
"""根据运行记录解析规则文件来源。"""
RunInfo = await self._LoadRunInfo(RunId)
if not RunInfo:
return None
LocalCachePath = RunInfo["rule_local_cache_path"]
if LocalCachePath:
CachePath = Path(LocalCachePath)
if CachePath.is_file():
return RuleVersionPayload(
localPath=str(CachePath),
sourceType="local_cache",
sourcePath=str(CachePath),
ruleVersionId=RunInfo["rule_version_id"],
ruleTypeId=RunInfo["rule_type_id"],
fileSha256=RunInfo["rule_source_sha256"],
)
SourceUrl = RunInfo["rule_source_oss_url"]
if not SourceUrl:
return None
return await self._DownloadFromUrl(
Url=SourceUrl,
RuleVersionId=RunInfo["rule_version_id"],
RuleTypeId=RunInfo["rule_type_id"],
ExpectedSha256=RunInfo["rule_source_sha256"],
)
async def _LoadRunInfo(self, RunId: int) -> dict[str, object] | None:
"""读取运行记录中的规则来源信息。"""
async with GetAsyncSession() as Session:
Result = await Session.execute(
text(
"""
SELECT
rule_version_id,
rule_type_id,
rule_source_oss_url,
rule_source_sha256,
rule_local_cache_path
FROM leaudit_audit_runs
WHERE id = :run_id
LIMIT 1
"""
),
{"run_id": RunId},
)
Row = Result.mappings().first()
return dict(Row) if Row else None
async def _DownloadFromUrl(
self,
*,
Url: str,
RuleVersionId: int | None,
RuleTypeId: str | None,
ExpectedSha256: str | None,
) -> RuleVersionPayload:
"""从 OSS 下载规则 YAML 到本地临时文件。"""
try:
Content = self.Oss.DownloadBytes(Url)
except Exception as Error:
logger.error(f"下载规则 YAML 失败: url={Url}, error={Error}")
raise
ActualSha256 = hashlib.sha256(Content).hexdigest()
if ExpectedSha256 and ActualSha256.lower() != ExpectedSha256.lower():
raise ValueError(
"规则 YAML SHA256 校验失败: "
f"expected={ExpectedSha256}, actual={ActualSha256}"
)
FilePrefix = "leaudit-rule-"
if RuleTypeId:
SafeTypeId = RuleTypeId.replace("/", "_").replace(".", "_")
FilePrefix = f"{FilePrefix}{SafeTypeId}-"
LocalPath = self.Oss.WriteTempBytes(
Content=Content,
Suffix=".yaml",
Prefix=FilePrefix,
)
return RuleVersionPayload(
localPath=LocalPath,
sourceType="oss",
sourcePath=Url,
ruleVersionId=RuleVersionId,
ruleTypeId=RuleTypeId,
fileSha256=ActualSha256,
)
@@ -1,28 +1,38 @@
"""Celery task for leaudit pipeline processing.
Activated when PIPELINE_MODE=leaudit in env.{port} config.
Replaces the legacy OCR → extraction → evaluation pipeline with
leaudit's YAML-rules-driven approach.
"""
"""LeAudit 任务入口。"""
from __future__ import annotations
import asyncio
import os
from pathlib import Path
import tempfile
import time
from typing import Any, Dict, Optional
from core.celery_app_limited import celery_app
from core.postgrest.client import get_postgrest_client
from core.logger import log
from fastapi_common.fastapi_common_logger import logger
from leaudit_bridge import create_pipeline, RulesLoader
from fastapi_admin.config import LEAUDIT_RULES_DIR
from fastapi_modules.fastapi_leaudit.leaudit_bridge.nativeRunner import (
NativeRunRequest,
NativeRunner,
)
from fastapi_modules.fastapi_leaudit.leaudit_bridge.auditCtxBuilder import (
NativeAuditMetadata,
)
from fastapi_modules.fastapi_leaudit.leaudit_bridge.ruleVersionResolver import (
RuleVersionResolver,
)
from fastapi_modules.fastapi_leaudit.leaudit_bridge.rules_loader import RulesLoader
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
# Celery 集成待 P2 阶段实现,当前使用同步占位
# from core.celery_app_limited import celery_app
log = logger
@celery_app.task(bind=True, name="leaudit.process_document")
# P2: Celery 集成后启用 @celery_app.task 装饰器
def leaudit_process_document(
self,
document_id: int,
file_content: bytes,
filename: str,
@@ -30,43 +40,41 @@ def leaudit_process_document(
source_port: Optional[int] = None,
rules_path: Optional[str] = None,
):
"""Process a document using leaudit's full pipeline.
Steps: OCR → Extraction → Evaluation → Store in docauditai DB.
"""
task_id = self.request.id
log.task.info(f"[任务ID: {task_id}] leaudit管线开始处理: {filename}")
"""处理单个文档的 LeAudit 任务。"""
task_id = os.urandom(8).hex()
log.info(f"[任务ID: {task_id}] leaudit管线开始处理: {filename}")
# 新平台:region 通过参数传递,不再依赖 os.environ 切换
if source_port:
from core.utils.instance_context import set_instance_environment
instance_name = set_instance_environment(source_port)
log.task.info(
f"[任务ID: {task_id}] 实例环境: {instance_name} (端口: {source_port})"
)
log.info(f"[任务ID: {task_id}] 来源端口: {source_port}")
if upload_info is None:
upload_info = {}
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
temp_paths: list[str] = []
storage = StorageAdapter()
try:
rules_path_resolved = rules_path or _resolve_rules_path(document_id, loop)
run_id = _resolve_run_id(document_id, upload_info, loop)
rules_resolution = _resolve_rules_runtime(document_id, run_id, rules_path, loop)
loop.run_until_complete(_update_run_status_safe(run_id, "running"))
rules_path_resolved = rules_resolution["rules_path"]
# For types with a known mapping (e.g. 行政处罚), pre-load rules_file.
# For types that need content classification (e.g. 行政许可 sub-types),
# rules_path will be None → adapter classifies after OCR → pipeline
# loads rules from ocr_result.rules_file_path.
rules_file = None
if rules_path_resolved:
loader = RulesLoader()
rules_file = loader.load(rules_path_resolved)
log.task.info(
temp_rule_path = rules_resolution.get("temp_rule_path")
if isinstance(temp_rule_path, str):
temp_paths.append(temp_rule_path)
log.info(
f"[任务ID: {task_id}] RulesFile pre-loaded: {rules_path_resolved} "
f"({len(rules_file.flat_rules)} rules, {len(rules_file.flat_extract)} fields)"
)
else:
log.task.info(
log.info(
f"[任务ID: {task_id}] No fixed rules_path — "
"will classify from document content after OCR"
)
@@ -75,47 +83,79 @@ def leaudit_process_document(
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp:
temp.write(file_content)
temp_path = temp.name
temp_paths.append(temp_path)
pipeline = create_pipeline(rules_path=rules_path_resolved)
runner = NativeRunner()
t0 = time.time()
result = loop.run_until_complete(
pipeline.run(
document_id=document_id,
file_path=temp_path,
rules_file=rules_file,
source_port=source_port or int(os.getenv("APP_PORT", "8000")),
native_result = loop.run_until_complete(
runner.run(
NativeRunRequest(
metadata=NativeAuditMetadata(
run_id=run_id,
document_id=document_id,
rule_version_id=_optional_int(upload_info, "rule_version_id", "ruleVersionId"),
extras={"taskId": task_id},
),
local_file_path=temp_path,
rules_file=rules_file,
rule_source_path=rules_path_resolved,
)
)
)
loop.run_until_complete(runner.persist_result(native_result))
elapsed = round(time.time() - t0, 2)
try:
os.remove(temp_path)
except OSError:
pass
log.task.info(
f"[任务ID: {task_id}] leaudit管线完成: phase={result.detected_phase}, "
f"timing={result.timing}, 总耗时={elapsed:.1f}s"
ctx = native_result.ctx
loop.run_until_complete(_update_run_phase_safe(run_id, ctx.phase))
loop.run_until_complete(_update_run_status_safe(run_id, "completed"))
loop.run_until_complete(_update_status_safe(document_id, "completed"))
log.info(
f"[任务ID: {task_id}] leaudit管线完成: phase={ctx.phase}, "
f"timing={dict(ctx.timing)}, 总耗时={elapsed:.1f}s"
)
return {
"status": "success",
"document_id": document_id,
"phase": result.detected_phase,
"timing": result.timing,
"errors": result.errors,
"run_id": run_id,
"phase": ctx.phase,
"timing": dict(ctx.timing),
"errors": list(ctx.extraction.all_errors) if ctx.extraction is not None else list(ctx.extraction_errors),
}
except Exception as e:
log.task.error(f"[任务ID: {task_id}] leaudit管线失败: {e}", exc_info=True)
log.error(f"[任务ID: {task_id}] leaudit管线失败: {e}", exc_info=True)
try:
loop.run_until_complete(_update_status_safe(document_id, "Failed"))
loop.run_until_complete(_update_status_safe(document_id, "failed"))
if 'run_id' in locals():
failed_phase = "persist"
if "native_result" in locals():
failed_phase = native_result.ctx.phase or failed_phase
loop.run_until_complete(
storage.fail_run(
document_id,
run_id=run_id,
phase=failed_phase,
message=str(e),
detail_json={
"taskId": task_id,
"filename": filename,
"errorType": type(e).__name__,
},
)
)
except Exception:
pass
raise
finally:
for temp_path in temp_paths:
try:
if Path(temp_path).exists():
os.remove(temp_path)
except OSError:
pass
loop.close()
@@ -129,49 +169,158 @@ _TYPE_ID_RULES_MAP: dict[int, str] = {
def _resolve_rules_path(document_id: int, loop: asyncio.AbstractEventLoop) -> str | None:
"""Resolve rules_path: config override → document metadata → type_id mapping."""
from core.config import LEAUDIT_CONFIG
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from sqlalchemy import text as sa_text
# 1. Config override (when explicitly set)
config_path = LEAUDIT_CONFIG.get("RULES_PATH", "")
if config_path:
return config_path
# 1. Config override (when explicitly set in app.toml)
if LEAUDIT_RULES_DIR:
return LEAUDIT_RULES_DIR
try:
client = get_postgrest_client()
doc = loop.run_until_complete(
client.select(
table="documents",
filters={"id": f"eq.{document_id}"},
single=True,
)
)
if not doc:
return None
async def _fetch():
async with GetAsyncSession() as session:
result = await session.execute(
sa_text("SELECT type_id FROM leaudit_documents WHERE id = :did"),
{"did": document_id},
)
row = result.fetchone()
if row and row[0] and row[0] in _TYPE_ID_RULES_MAP:
return f"{_TYPE_ID_RULES_MAP[row[0]]}/rules.yaml"
return None
# 2. Document-level override
rfp = doc.get("rules_file_path")
if rfp:
return rfp
# 3. type_id mapping
type_id = doc.get("type_id")
if type_id and type_id in _TYPE_ID_RULES_MAP:
return f"{_TYPE_ID_RULES_MAP[type_id]}/rules.yaml"
return loop.run_until_complete(_fetch())
except Exception as e:
log.task.warning(f"Failed to resolve rules_path from document: {e}")
log.warning(f"Failed to resolve rules_path from document: {e}")
return None
async def _update_status_safe(document_id: int, status: str) -> None:
"""Safely update document status, ignoring errors."""
def _resolve_rules_runtime(
document_id: int,
run_id: int,
explicit_rules_path: str | None,
loop: asyncio.AbstractEventLoop,
) -> dict[str, str | None]:
"""解析本次执行使用的规则来源。"""
if explicit_rules_path:
return {
"rules_path": explicit_rules_path,
"temp_rule_path": None,
"source_type": "explicit",
"source_path": explicit_rules_path,
}
resolver = RuleVersionResolver()
try:
client = get_postgrest_client()
await client.update(
table="documents",
filters={"id": f"eq.{document_id}"},
data={"status": status},
)
payload = loop.run_until_complete(resolver.ResolveForRun(run_id))
if payload:
log.info(
f"run_id={run_id} 规则来源已解析: sourceType={payload.sourceType}, "
f"sourcePath={payload.sourcePath}, localPath={payload.localPath}"
)
return {
"rules_path": payload.localPath,
"temp_rule_path": payload.localPath if payload.sourceType == "oss" else None,
"source_type": payload.sourceType,
"source_path": payload.sourcePath,
}
except Exception as e:
log.warning(f"Failed to resolve rule version from run: run_id={run_id}, error={e}")
fallback_rules_path = _resolve_rules_path(document_id, loop)
return {
"rules_path": fallback_rules_path,
"temp_rule_path": None,
"source_type": "legacy_fallback",
"source_path": fallback_rules_path,
}
def _resolve_run_id(
document_id: int,
upload_info: Dict[str, Any] | None,
loop: asyncio.AbstractEventLoop,
) -> int:
"""解析本次任务对应的运行 ID。"""
if upload_info:
for key in ("run_id", "runId"):
value = upload_info.get(key)
if isinstance(value, int):
return value
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from sqlalchemy import text as sa_text
async def _fetch() -> int:
async with GetAsyncSession() as session:
result = await session.execute(
sa_text("SELECT id FROM leaudit_audit_runs WHERE document_id = :did ORDER BY id DESC LIMIT 1"),
{"did": document_id},
)
row = result.fetchone()
if not row:
raise ValueError(f"未找到 document_id={document_id} 对应的 run 记录")
return int(row[0])
return loop.run_until_complete(_fetch())
def _optional_int(payload: Dict[str, Any] | None, *keys: str) -> int | None:
"""从字典中按顺序取可用整数。"""
if not payload:
return None
for key in keys:
value = payload.get(key)
if isinstance(value, int):
return value
return None
async def _update_status_safe(document_id: int, status: str) -> None:
"""Safely update document status via SQLAlchemy, ignoring errors."""
try:
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from sqlalchemy import text as sa_text
async with GetAsyncSession() as session:
await session.execute(
sa_text("UPDATE leaudit_documents SET processing_status = :s, update_time = now() WHERE id = :did"),
{"s": status, "did": document_id},
)
await session.commit()
except Exception:
pass
async def _update_run_status_safe(run_id: int, status: str) -> None:
"""安全更新运行状态。"""
try:
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from sqlalchemy import text as sa_text
async with GetAsyncSession() as session:
await session.execute(
sa_text("UPDATE leaudit_audit_runs SET status = :s, update_time = now() WHERE id = :rid"),
{"s": status, "rid": run_id},
)
await session.commit()
except Exception:
pass
async def _update_run_phase_safe(run_id: int, phase: str | None) -> None:
"""安全更新运行阶段。"""
try:
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from sqlalchemy import text as sa_text
async with GetAsyncSession() as session:
await session.execute(
sa_text("UPDATE leaudit_audit_runs SET phase = :p, update_time = now() WHERE id = :rid"),
{"p": phase, "rid": run_id},
)
await session.commit()
except Exception:
pass
@@ -190,12 +339,16 @@ def dispatch_leaudit_task(
source_port: Optional[int] = None,
rules_path: Optional[str] = None,
):
"""Dispatch a leaudit processing task."""
return leaudit_process_document.apply_async(
args=[document_id, file_content, filename],
kwargs={
"upload_info": upload_info,
"source_port": source_port or int(os.getenv("APP_PORT", "8000")),
"rules_path": rules_path,
},
"""Dispatch a leaudit processing task.
P2: Celery 集成后改用 leaudit_process_document.apply_async(...)
当前阶段直接同步调用。
"""
return leaudit_process_document(
document_id=document_id,
file_content=file_content,
filename=filename,
upload_info=upload_info,
source_port=source_port or int(os.getenv("APP_PORT", "8000")),
rules_path=rules_path,
)
@@ -2,7 +2,8 @@
from fastapi_modules.fastapi_leaudit.services.auditService import IAuditService
from fastapi_modules.fastapi_leaudit.services.authService import IAuthService
from fastapi_modules.fastapi_leaudit.services.ossService import IOssService
from fastapi_modules.fastapi_leaudit.services.permissionService import IPermissionService
from fastapi_modules.fastapi_leaudit.services.ruleService import IRuleService
__all__ = ["IAuditService", "IAuthService", "IPermissionService", "IRuleService"]
__all__ = ["IAuditService", "IAuthService", "IOssService", "IPermissionService", "IRuleService"]
@@ -9,7 +9,7 @@ class IAuditService(ABC):
"""评查服务接口。"""
@abstractmethod
async def Run(self, DocumentId: int) -> AuditRunVO:
async def Run(self, DocumentId: int, RuleType: str | None = None, Force: bool = False) -> AuditRunVO:
"""触发文档评查。"""
...
@@ -4,13 +4,22 @@
文档 → OCR → Extract → Evaluate → Rescue → Persist
"""
from datetime import datetime
from fastapi_common.fastapi_common_logger import logger
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
from sqlalchemy import select, text
from fastapi_modules.fastapi_leaudit.domian.vo.auditVo import AuditRunVO, AuditResultVO
from fastapi_modules.fastapi_leaudit.models import LeauditAuditRun
from fastapi_modules.fastapi_leaudit.leaudit_bridge.fileSourceResolver import FileSourceResolver
from fastapi_modules.fastapi_leaudit.leaudit_bridge.tasks import dispatch_leaudit_task
from fastapi_modules.fastapi_leaudit.models import (
LeauditAuditRun,
LeauditDocument,
LeauditDocumentFile,
)
from fastapi_modules.fastapi_leaudit.services import IAuditService
@@ -20,12 +29,116 @@ class AuditServiceImpl(IAuditService):
async def Run(self, DocumentId: int, RuleType: str | None = None, Force: bool = False) -> AuditRunVO:
"""触发文档评查。
实际执行流程由 Celery 任务异步处理
当前阶段同步触发 bridge 执行链,后续再切换为 Celery 异步分发
"""
async with GetAsyncSession() as session:
# TODO: 从 bridge 层获取 pipeline,提交 Celery 任务
logger.info(f"触发评查: documentId={DocumentId}, ruleType={RuleType}")
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "Celery 任务集成待实现")
document = await session.get(LeauditDocument, DocumentId)
if not document:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "评查文档不存在")
fileResult = await session.execute(
select(LeauditDocumentFile)
.where(
LeauditDocumentFile.documentId == DocumentId,
LeauditDocumentFile.isActive.is_(True),
)
.order_by(LeauditDocumentFile.Id.desc())
.limit(1)
)
documentFile = fileResult.scalar_one_or_none()
if not documentFile:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前文档没有可执行文件版本")
runNoResult = await session.execute(
select(LeauditAuditRun.runNo)
.where(LeauditAuditRun.documentId == DocumentId)
.order_by(LeauditAuditRun.runNo.desc())
.limit(1)
)
latestRunNo = runNoResult.scalar_one_or_none() or 0
bindingResult = await session.execute(
text(
"""
SELECT
rs.id AS rule_set_id,
rs.current_version_id AS rule_version_id,
rv.oss_url AS rule_source_oss_url,
rv.file_sha256 AS rule_source_sha256,
rv.metadata_type_id AS rule_type_id
FROM leaudit_rule_type_bindings b
JOIN leaudit_rule_sets rs ON rs.id = b.rule_set_id
LEFT JOIN leaudit_rule_versions rv ON rv.id = rs.current_version_id
WHERE b.doc_type_id = :doc_type_id
AND b.is_active = true
ORDER BY b.priority DESC, b.id DESC
LIMIT 1
"""
),
{"doc_type_id": document.typeId},
)
binding = bindingResult.mappings().first()
if not binding or not binding["rule_set_id"] or not binding["rule_version_id"]:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前文档类型未绑定可用规则版本")
run = LeauditAuditRun(
documentId=DocumentId,
documentFileId=documentFile.Id,
runNo=int(latestRunNo) + 1,
triggerSource="manual" if not Force else "retry",
status="pending",
ruleSetId=int(binding["rule_set_id"]),
ruleVersionId=int(binding["rule_version_id"]),
ruleTypeId=binding["rule_type_id"],
ruleSourceOssUrl=binding["rule_source_oss_url"],
ruleSourceSha256=binding["rule_source_sha256"],
startedAt=datetime.now(),
)
session.add(run)
await session.flush()
document.currentRunId = run.Id
document.processingStatus = "running"
await session.commit()
await session.refresh(run)
try:
Resolver = FileSourceResolver()
Payload = await Resolver.ResolvePayload(documentFile)
except Exception as Error:
raise LeauditException(
StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR,
f"读取评查文件失败: {Error}",
) from Error
dispatch_leaudit_task(
document_id=DocumentId,
file_content=Payload.fileContent,
filename=Payload.fileName,
upload_info={
"run_id": run.Id,
"rule_version_id": run.ruleVersionId,
"rule_source_oss_url": run.ruleSourceOssUrl,
"source_type": Payload.sourceType,
"source_path": Payload.sourcePath,
},
rules_path=RuleType,
)
await session.refresh(run)
return AuditRunVO(
runId=run.Id,
documentId=run.documentId,
runNo=run.runNo,
status=run.status,
phase=run.phase,
totalScore=float(run.totalScore) if run.totalScore else None,
passedCount=run.passedCount,
failedCount=run.failedCount,
startedAt=run.startedAt,
finishedAt=run.finishedAt,
)
async def GetRunStatus(self, RunId: int) -> AuditRunVO:
"""查询评查运行状态。"""
@@ -52,7 +165,33 @@ class AuditServiceImpl(IAuditService):
run = await session.get(LeauditAuditRun, RunId)
if not run:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "评查运行记录不存在")
# TODO: 从 leaudit_rule_results 表查询规则级结果
result = await session.execute(
text(
"""
SELECT
rule_id,
rule_name,
risk,
score,
passed,
status,
skip_reason,
confidence,
pass_message,
fail_message,
remediation,
extracted_fields,
field_positions,
rescue_applied,
rescue_passed
FROM leaudit_rule_results
WHERE run_id = :run_id
ORDER BY id ASC
"""
),
{"run_id": RunId},
)
rules = [dict(row) for row in result.mappings().all()]
return AuditResultVO(
runId=run.Id,
totalScore=float(run.totalScore) if run.totalScore else None,
@@ -61,5 +200,5 @@ class AuditServiceImpl(IAuditService):
skippedCount=run.skippedCount or 0,
phase=run.phase,
rescueApplied=run.rescueApplied or False,
rules=[],
rules=rules,
)
@@ -0,0 +1,69 @@
"""OSS 服务实现。"""
from fastapi_common.fastapi_common_storage.oss_client import OssClient
from fastapi_modules.fastapi_leaudit.services.ossService import IOssService
class OssServiceImpl(IOssService):
"""OSS 服务实现。"""
def __init__(self, Client: OssClient | None = None) -> None:
self.Client = Client or OssClient()
async def DownloadBytes(self, Source: str, Bucket: str | None = None) -> bytes:
"""下载对象内容。"""
return self.Client.DownloadBytes(Source=Source, Bucket=Bucket)
async def DownloadToTempFile(
self,
Source: str,
Suffix: str = "",
Prefix: str = "oss-",
Bucket: str | None = None,
) -> str:
"""下载对象到本地临时文件。"""
return self.Client.DownloadToTempFile(
Source=Source,
Suffix=Suffix,
Prefix=Prefix,
Bucket=Bucket,
)
async def UploadBytes(
self,
ObjectKey: str,
Content: bytes,
ContentType: str = "application/octet-stream",
Bucket: str | None = None,
) -> str:
"""上传二进制内容。"""
return self.Client.UploadBytes(
ObjectKey=ObjectKey,
Content=Content,
ContentType=ContentType,
Bucket=Bucket,
)
async def UploadText(
self,
ObjectKey: str,
Content: str,
ContentType: str = "text/plain; charset=utf-8",
Bucket: str | None = None,
) -> str:
"""上传文本内容。"""
return self.Client.UploadText(
ObjectKey=ObjectKey,
Content=Content,
ContentType=ContentType,
Bucket=Bucket,
)
async def ObjectExists(self, Source: str, Bucket: str | None = None) -> bool:
"""判断对象是否存在。"""
return self.Client.ObjectExists(Source=Source, Bucket=Bucket)
async def PresignGetUrl(self, Source: str, Bucket: str | None = None) -> str:
"""生成对象下载签名 URL。"""
return self.Client.PresignGetUrl(Source=Source, Bucket=Bucket)
@@ -0,0 +1,55 @@
"""OSS 服务接口。"""
from abc import ABC, abstractmethod
class IOssService(ABC):
"""OSS 服务接口。"""
@abstractmethod
async def DownloadBytes(self, Source: str, Bucket: str | None = None) -> bytes:
"""下载对象内容。"""
...
@abstractmethod
async def DownloadToTempFile(
self,
Source: str,
Suffix: str = "",
Prefix: str = "oss-",
Bucket: str | None = None,
) -> str:
"""下载对象到本地临时文件。"""
...
@abstractmethod
async def UploadBytes(
self,
ObjectKey: str,
Content: bytes,
ContentType: str = "application/octet-stream",
Bucket: str | None = None,
) -> str:
"""上传二进制内容。"""
...
@abstractmethod
async def UploadText(
self,
ObjectKey: str,
Content: str,
ContentType: str = "text/plain; charset=utf-8",
Bucket: str | None = None,
) -> str:
"""上传文本内容。"""
...
@abstractmethod
async def ObjectExists(self, Source: str, Bucket: str | None = None) -> bool:
"""判断对象是否存在。"""
...
@abstractmethod
async def PresignGetUrl(self, Source: str, Bucket: str | None = None) -> str:
"""生成对象下载签名 URL。"""
...