134 lines
4.4 KiB
Python
134 lines
4.4 KiB
Python
"""Govdoc Bridge — 输入文件解析器。
|
||
|
||
从 leaudit_document_files 中定位输入文件,从 OSS 下载到本地临时路径。
|
||
"""
|
||
|
||
from __future__ import annotations
|
||
|
||
import hashlib
|
||
import os
|
||
import tempfile
|
||
from dataclasses import dataclass
|
||
from pathlib import Path
|
||
|
||
from fastapi_common.fastapi_common_logger import logger
|
||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||
from fastapi_common.fastapi_common_storage.oss_client import OssClient
|
||
from sqlalchemy import select
|
||
|
||
from fastapi_modules.fastapi_leaudit.models.leauditDocumentFile import LeauditDocumentFile
|
||
|
||
log = logger
|
||
|
||
|
||
@dataclass(frozen=True)
|
||
class InputPayload:
|
||
"""Govdoc 引擎执行所需的输入载荷。"""
|
||
|
||
fileName: str
|
||
fileExt: str
|
||
localPath: str
|
||
sha256: str | None = None
|
||
fileSize: int | None = None
|
||
documentFileId: int | None = None
|
||
tempDir: str | None = None # 需调用方在任务结束时清理
|
||
|
||
|
||
class InputResolver:
|
||
"""解析 govdoc 引擎输入文件。
|
||
|
||
从 leaudit_document_files 中定位输入文件 (file_role='original'),
|
||
优先使用本地缓存路径,否则从 OSS 下载到临时目录。
|
||
"""
|
||
|
||
def __init__(self, Oss: OssClient | None = None) -> None:
|
||
self.Oss = Oss or OssClient()
|
||
|
||
async def ResolveForDocument(self, documentId: int) -> InputPayload:
|
||
"""为指定文档解析输入文件载荷。
|
||
|
||
查找该文档最近一次激活的 original 文件记录。
|
||
"""
|
||
async with GetAsyncSession() as session:
|
||
result = await session.execute(
|
||
select(LeauditDocumentFile)
|
||
.where(
|
||
LeauditDocumentFile.documentId == documentId,
|
||
LeauditDocumentFile.fileRole == "original",
|
||
LeauditDocumentFile.isActive.is_(True),
|
||
)
|
||
.order_by(LeauditDocumentFile.Id.desc())
|
||
.limit(1)
|
||
)
|
||
fileRow = result.scalar_one_or_none()
|
||
|
||
if fileRow is None:
|
||
raise ValueError(f"未找到文档 {documentId} 的原始文件记录")
|
||
|
||
return await self.ResolveFromRow(fileRow)
|
||
|
||
async def ResolveFromRow(self, FileRow: LeauditDocumentFile) -> InputPayload:
|
||
"""从文件记录解析输入载荷。"""
|
||
# 优先本地路径
|
||
if FileRow.localPath:
|
||
LocalPath = Path(FileRow.localPath)
|
||
if LocalPath.is_file():
|
||
return InputPayload(
|
||
fileName=FileRow.fileName,
|
||
fileExt=FileRow.fileExt or _ext_from_name(FileRow.fileName),
|
||
localPath=str(LocalPath),
|
||
sha256=FileRow.sha256,
|
||
fileSize=FileRow.fileSize,
|
||
documentFileId=FileRow.Id,
|
||
)
|
||
|
||
# 否则从 OSS 下载
|
||
if FileRow.ossUrl:
|
||
return await self._DownloadFromOss(FileRow)
|
||
|
||
raise ValueError(
|
||
f"文件 {FileRow.Id} ({FileRow.fileName}) 既无可用 localPath 也无 ossUrl"
|
||
)
|
||
|
||
async def _DownloadFromOss(self, FileRow: LeauditDocumentFile) -> InputPayload:
|
||
"""从 OSS 下载文件到临时目录。"""
|
||
try:
|
||
content = self.Oss.DownloadBytes(FileRow.ossUrl)
|
||
except Exception as e:
|
||
log.error(f"从 OSS 下载文件失败: url={FileRow.ossUrl}, error={e}")
|
||
raise
|
||
|
||
tempDir = tempfile.mkdtemp(prefix="govdoc_input_")
|
||
ext = FileRow.fileExt or _ext_from_name(FileRow.fileName)
|
||
safeName = f"input_{FileRow.Id}{ext}"
|
||
localPath = os.path.join(tempDir, safeName)
|
||
|
||
with open(localPath, "wb") as f:
|
||
f.write(content)
|
||
|
||
computedSha = hashlib.sha256(content).hexdigest()
|
||
if FileRow.sha256 and computedSha != FileRow.sha256:
|
||
log.warning(
|
||
f"文件 SHA256 不匹配: expected={FileRow.sha256}, computed={computedSha}"
|
||
)
|
||
|
||
log.info(
|
||
f"从 OSS 下载文件: {FileRow.fileName} → {localPath} ({len(content)} bytes)"
|
||
)
|
||
|
||
return InputPayload(
|
||
fileName=FileRow.fileName,
|
||
fileExt=ext,
|
||
localPath=localPath,
|
||
sha256=computedSha,
|
||
fileSize=len(content),
|
||
documentFileId=FileRow.Id,
|
||
tempDir=tempDir,
|
||
)
|
||
|
||
|
||
def _ext_from_name(fileName: str) -> str:
|
||
"""从文件名提取扩展名。"""
|
||
_, ext = os.path.splitext(fileName)
|
||
return ext if ext else ".docx"
|