466 lines
19 KiB
Python
466 lines
19 KiB
Python
"""评查服务实现。
|
|
|
|
编排 LeAudit 引擎执行链路:
|
|
文档 → OCR → Extract → Evaluate → Rescue → Persist
|
|
"""
|
|
|
|
from datetime import datetime
|
|
|
|
from fastapi_common.fastapi_common_logger import logger
|
|
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
|
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
|
|
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
|
|
from sqlalchemy import select, text
|
|
|
|
from fastapi_modules.fastapi_leaudit.domian.vo.auditVo import (
|
|
AuditArtifactVO,
|
|
AuditFieldResultVO,
|
|
AuditMetricsVO,
|
|
AuditRescueOutcomeVO,
|
|
AuditResultVO,
|
|
AuditRunErrorVO,
|
|
AuditRunVO,
|
|
)
|
|
from fastapi_modules.fastapi_leaudit.leaudit_bridge.tasks import (
|
|
dispatch_leaudit_task,
|
|
resolve_worker_queue,
|
|
)
|
|
from fastapi_modules.fastapi_leaudit.models import (
|
|
LeauditAuditRun,
|
|
LeauditDocument,
|
|
LeauditDocumentFile,
|
|
)
|
|
from fastapi_modules.fastapi_leaudit.services import IAuditService
|
|
|
|
|
|
def _normalize_speed(speed: str | None) -> str:
|
|
"""Normalize front-end speed selection to urgent/normal."""
|
|
normalized = (speed or "").strip().lower()
|
|
if normalized in {"urgent", "high", "fast", "emergency", "紧急"}:
|
|
return "urgent"
|
|
return "normal"
|
|
|
|
|
|
class AuditServiceImpl(IAuditService):
|
|
"""评查服务实现。"""
|
|
|
|
async def Run(
|
|
self,
|
|
DocumentId: int,
|
|
RuleType: str | None = None,
|
|
Force: bool = False,
|
|
Speed: str = "normal",
|
|
) -> AuditRunVO:
|
|
"""触发文档评查。
|
|
|
|
当前阶段只负责创建 run 并投递 worker,不在 HTTP 请求内同步执行。
|
|
"""
|
|
async with GetAsyncSession() as session:
|
|
logger.info(f"触发评查: documentId={DocumentId}, ruleType={RuleType}")
|
|
normalizedSpeed = _normalize_speed(Speed)
|
|
document = await session.get(LeauditDocument, DocumentId)
|
|
if not document:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "评查文档不存在")
|
|
|
|
if not Force:
|
|
activeRunResult = await session.execute(
|
|
select(LeauditAuditRun)
|
|
.where(
|
|
LeauditAuditRun.documentId == DocumentId,
|
|
LeauditAuditRun.status.in_(("queued", "running", "retrying")),
|
|
)
|
|
.order_by(LeauditAuditRun.Id.desc())
|
|
.limit(1)
|
|
)
|
|
activeRun = activeRunResult.scalar_one_or_none()
|
|
if activeRun:
|
|
return AuditRunVO(
|
|
runId=activeRun.Id,
|
|
documentId=activeRun.documentId,
|
|
runNo=activeRun.runNo,
|
|
documentFileId=activeRun.documentFileId,
|
|
status=activeRun.status,
|
|
phase=activeRun.phase,
|
|
resultStatus=activeRun.resultStatus,
|
|
ruleSetId=activeRun.ruleSetId,
|
|
ruleVersionId=activeRun.ruleVersionId,
|
|
ruleTypeId=activeRun.ruleTypeId,
|
|
rescueApplied=activeRun.rescueApplied or False,
|
|
totalScore=float(activeRun.totalScore) if activeRun.totalScore else None,
|
|
passedCount=activeRun.passedCount,
|
|
failedCount=activeRun.failedCount,
|
|
skippedCount=activeRun.skippedCount,
|
|
startedAt=activeRun.startedAt,
|
|
finishedAt=activeRun.finishedAt,
|
|
)
|
|
|
|
fileResult = await session.execute(
|
|
select(LeauditDocumentFile)
|
|
.where(
|
|
LeauditDocumentFile.documentId == DocumentId,
|
|
LeauditDocumentFile.isActive.is_(True),
|
|
)
|
|
.order_by(LeauditDocumentFile.Id.desc())
|
|
.limit(1)
|
|
)
|
|
documentFile = fileResult.scalar_one_or_none()
|
|
if not documentFile:
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前文档没有可执行文件版本")
|
|
|
|
runNoResult = await session.execute(
|
|
select(LeauditAuditRun.runNo)
|
|
.where(LeauditAuditRun.documentId == DocumentId)
|
|
.order_by(LeauditAuditRun.runNo.desc())
|
|
.limit(1)
|
|
)
|
|
latestRunNo = runNoResult.scalar_one_or_none() or 0
|
|
|
|
bindingResult = await session.execute(
|
|
text(
|
|
"""
|
|
SELECT
|
|
rs.id AS rule_set_id,
|
|
rs.current_version_id AS rule_version_id,
|
|
rv.oss_url AS rule_source_oss_url,
|
|
rv.file_sha256 AS rule_source_sha256,
|
|
rv.metadata_type_id AS rule_type_id
|
|
FROM leaudit_rule_type_bindings b
|
|
JOIN leaudit_rule_sets rs ON rs.id = b.rule_set_id
|
|
LEFT JOIN leaudit_rule_versions rv ON rv.id = rs.current_version_id
|
|
WHERE b.doc_type_id = :doc_type_id
|
|
AND b.is_active = true
|
|
AND b.region = :region
|
|
ORDER BY b.priority DESC, b.id DESC
|
|
LIMIT 1
|
|
"""
|
|
),
|
|
{"doc_type_id": document.typeId, "region": document.region},
|
|
)
|
|
binding = bindingResult.mappings().first()
|
|
if not binding or not binding["rule_set_id"] or not binding["rule_version_id"]:
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前文档类型未绑定可用规则版本")
|
|
|
|
triggerSource = f"{'retry' if Force else 'upload'}:{normalizedSpeed}"
|
|
|
|
run = LeauditAuditRun(
|
|
documentId=DocumentId,
|
|
documentFileId=documentFile.Id,
|
|
runNo=int(latestRunNo) + 1,
|
|
triggerSource=triggerSource,
|
|
status="queued",
|
|
phase="dispatch",
|
|
ruleSetId=int(binding["rule_set_id"]),
|
|
ruleVersionId=int(binding["rule_version_id"]),
|
|
ruleTypeId=binding["rule_type_id"],
|
|
ruleSourceOssUrl=binding["rule_source_oss_url"],
|
|
ruleSourceSha256=binding["rule_source_sha256"],
|
|
)
|
|
session.add(run)
|
|
await session.flush()
|
|
|
|
document.currentRunId = run.Id
|
|
document.processingStatus = "queued"
|
|
await session.commit()
|
|
await session.refresh(run)
|
|
|
|
try:
|
|
taskId = dispatch_leaudit_task(
|
|
run_id=run.Id,
|
|
queue_name=resolve_worker_queue(triggerSource),
|
|
rules_path=RuleType,
|
|
)
|
|
except Exception as Error:
|
|
run.status = "failed"
|
|
run.phase = "dispatch"
|
|
run.finishedAt = datetime.now()
|
|
document.processingStatus = "failed"
|
|
await session.commit()
|
|
raise LeauditException(
|
|
StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
f"投递评查任务失败: {Error}",
|
|
) from Error
|
|
|
|
run.taskId = taskId
|
|
await session.commit()
|
|
|
|
await session.refresh(run)
|
|
return AuditRunVO(
|
|
runId=run.Id,
|
|
documentId=run.documentId,
|
|
runNo=run.runNo,
|
|
documentFileId=run.documentFileId,
|
|
status=run.status,
|
|
phase=run.phase,
|
|
resultStatus=run.resultStatus,
|
|
ruleSetId=run.ruleSetId,
|
|
ruleVersionId=run.ruleVersionId,
|
|
ruleTypeId=run.ruleTypeId,
|
|
rescueApplied=run.rescueApplied or False,
|
|
totalScore=float(run.totalScore) if run.totalScore else None,
|
|
passedCount=run.passedCount,
|
|
failedCount=run.failedCount,
|
|
skippedCount=run.skippedCount,
|
|
startedAt=run.startedAt,
|
|
finishedAt=run.finishedAt,
|
|
)
|
|
|
|
async def GetRunStatus(self, RunId: int) -> AuditRunVO:
|
|
"""查询评查运行状态。"""
|
|
async with GetAsyncSession() as session:
|
|
run = await session.get(LeauditAuditRun, RunId)
|
|
if not run:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "评查运行记录不存在")
|
|
return AuditRunVO(
|
|
runId=run.Id,
|
|
documentId=run.documentId,
|
|
runNo=run.runNo,
|
|
documentFileId=run.documentFileId,
|
|
status=run.status,
|
|
phase=run.phase,
|
|
resultStatus=run.resultStatus,
|
|
ruleSetId=run.ruleSetId,
|
|
ruleVersionId=run.ruleVersionId,
|
|
ruleTypeId=run.ruleTypeId,
|
|
rescueApplied=run.rescueApplied or False,
|
|
totalScore=float(run.totalScore) if run.totalScore else None,
|
|
passedCount=run.passedCount,
|
|
failedCount=run.failedCount,
|
|
skippedCount=run.skippedCount,
|
|
startedAt=run.startedAt,
|
|
finishedAt=run.finishedAt,
|
|
)
|
|
|
|
async def GetResult(self, RunId: int) -> AuditResultVO:
|
|
"""获取评查结果。"""
|
|
async with GetAsyncSession() as session:
|
|
run = await session.get(LeauditAuditRun, RunId)
|
|
if not run:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "评查运行记录不存在")
|
|
ruleResult = await session.execute(
|
|
text(
|
|
"""
|
|
SELECT
|
|
rule_id,
|
|
rule_name,
|
|
risk,
|
|
score,
|
|
passed,
|
|
status,
|
|
skip_reason,
|
|
confidence,
|
|
pass_message,
|
|
fail_message,
|
|
remediation,
|
|
extracted_fields,
|
|
field_positions,
|
|
rescue_applied,
|
|
rescue_passed
|
|
FROM leaudit_rule_results
|
|
WHERE run_id = :run_id
|
|
ORDER BY id ASC
|
|
"""
|
|
),
|
|
{"run_id": RunId},
|
|
)
|
|
fieldResult = await session.execute(
|
|
text(
|
|
"""
|
|
SELECT
|
|
field_name,
|
|
field_type,
|
|
value_text,
|
|
confidence,
|
|
grounding_method,
|
|
fallback_value,
|
|
raw_value_json,
|
|
meta_json
|
|
FROM leaudit_field_results
|
|
WHERE run_id = :run_id
|
|
ORDER BY id ASC
|
|
"""
|
|
),
|
|
{"run_id": RunId},
|
|
)
|
|
errorResult = await session.execute(
|
|
text(
|
|
"""
|
|
SELECT
|
|
stage,
|
|
level,
|
|
error_code,
|
|
message,
|
|
detail_json,
|
|
created_at
|
|
FROM leaudit_run_errors
|
|
WHERE run_id = :run_id
|
|
ORDER BY id ASC
|
|
"""
|
|
),
|
|
{"run_id": RunId},
|
|
)
|
|
rescueResult = await session.execute(
|
|
text(
|
|
"""
|
|
SELECT
|
|
rule_id,
|
|
status,
|
|
diagnosis,
|
|
diagnosis_confidence,
|
|
final_status,
|
|
failure_reason,
|
|
llm_calls,
|
|
vlm_calls,
|
|
duration_ms,
|
|
requires_human_review,
|
|
payload
|
|
FROM leaudit_rescue_outcomes
|
|
WHERE run_id = :run_id
|
|
ORDER BY id ASC
|
|
"""
|
|
),
|
|
{"run_id": RunId},
|
|
)
|
|
metricResult = await session.execute(
|
|
text(
|
|
"""
|
|
SELECT
|
|
ocr_seconds,
|
|
normalize_seconds,
|
|
extract_seconds,
|
|
evaluate_seconds,
|
|
rescue_seconds,
|
|
total_seconds,
|
|
page_count,
|
|
sub_document_count,
|
|
field_count,
|
|
rule_count,
|
|
llm_call_count,
|
|
vlm_call_count,
|
|
rescue_rule_count,
|
|
artifact_count
|
|
FROM leaudit_run_metrics
|
|
WHERE run_id = :run_id
|
|
ORDER BY id DESC
|
|
LIMIT 1
|
|
"""
|
|
),
|
|
{"run_id": RunId},
|
|
)
|
|
artifactResult = await session.execute(
|
|
text(
|
|
"""
|
|
SELECT
|
|
artifact_type,
|
|
artifact_role,
|
|
file_name,
|
|
file_ext,
|
|
mime_type,
|
|
file_size,
|
|
oss_url,
|
|
is_persisted
|
|
FROM leaudit_artifacts
|
|
WHERE run_id = :run_id
|
|
ORDER BY id ASC
|
|
"""
|
|
),
|
|
{"run_id": RunId},
|
|
)
|
|
rules = [dict(row) for row in ruleResult.mappings().all()]
|
|
fields = [
|
|
AuditFieldResultVO(
|
|
fieldName=row["field_name"],
|
|
fieldType=row["field_type"],
|
|
valueText=row["value_text"],
|
|
confidence=float(row["confidence"]) if row["confidence"] is not None else None,
|
|
groundingMethod=row["grounding_method"],
|
|
fallbackValue=row["fallback_value"],
|
|
rawValueJson=row["raw_value_json"],
|
|
metaJson=row["meta_json"],
|
|
)
|
|
for row in fieldResult.mappings().all()
|
|
]
|
|
errors = [
|
|
AuditRunErrorVO(
|
|
stage=row["stage"],
|
|
level=row["level"],
|
|
errorCode=row["error_code"],
|
|
message=row["message"],
|
|
detailJson=row["detail_json"],
|
|
createdAt=row["created_at"],
|
|
)
|
|
for row in errorResult.mappings().all()
|
|
]
|
|
rescueOutcomes = [
|
|
AuditRescueOutcomeVO(
|
|
ruleId=row["rule_id"],
|
|
status=row["status"],
|
|
diagnosis=row["diagnosis"],
|
|
diagnosisConfidence=float(row["diagnosis_confidence"]) if row["diagnosis_confidence"] is not None else None,
|
|
finalStatus=row["final_status"],
|
|
failureReason=row["failure_reason"],
|
|
llmCalls=row["llm_calls"],
|
|
vlmCalls=row["vlm_calls"],
|
|
durationMs=row["duration_ms"],
|
|
requiresHumanReview=bool(row["requires_human_review"]),
|
|
payload=row["payload"],
|
|
)
|
|
for row in rescueResult.mappings().all()
|
|
]
|
|
metricRow = metricResult.mappings().first()
|
|
metrics = (
|
|
AuditMetricsVO(
|
|
ocrSeconds=float(metricRow["ocr_seconds"]) if metricRow["ocr_seconds"] is not None else None,
|
|
normalizeSeconds=float(metricRow["normalize_seconds"]) if metricRow["normalize_seconds"] is not None else None,
|
|
extractSeconds=float(metricRow["extract_seconds"]) if metricRow["extract_seconds"] is not None else None,
|
|
evaluateSeconds=float(metricRow["evaluate_seconds"]) if metricRow["evaluate_seconds"] is not None else None,
|
|
rescueSeconds=float(metricRow["rescue_seconds"]) if metricRow["rescue_seconds"] is not None else None,
|
|
totalSeconds=float(metricRow["total_seconds"]) if metricRow["total_seconds"] is not None else None,
|
|
pageCount=metricRow["page_count"],
|
|
subDocumentCount=metricRow["sub_document_count"],
|
|
fieldCount=metricRow["field_count"],
|
|
ruleCount=metricRow["rule_count"],
|
|
llmCallCount=metricRow["llm_call_count"],
|
|
vlmCallCount=metricRow["vlm_call_count"],
|
|
rescueRuleCount=metricRow["rescue_rule_count"],
|
|
artifactCount=metricRow["artifact_count"],
|
|
)
|
|
if metricRow
|
|
else None
|
|
)
|
|
artifacts = [
|
|
AuditArtifactVO(
|
|
artifactType=row["artifact_type"],
|
|
artifactRole=row["artifact_role"],
|
|
fileName=row["file_name"],
|
|
fileExt=row["file_ext"],
|
|
mimeType=row["mime_type"],
|
|
fileSize=row["file_size"],
|
|
ossUrl=row["oss_url"],
|
|
isPersisted=row["is_persisted"],
|
|
)
|
|
for row in artifactResult.mappings().all()
|
|
]
|
|
return AuditResultVO(
|
|
runId=run.Id,
|
|
documentId=run.documentId,
|
|
documentFileId=run.documentFileId,
|
|
status=run.status,
|
|
totalScore=float(run.totalScore) if run.totalScore else None,
|
|
passedCount=run.passedCount or 0,
|
|
failedCount=run.failedCount or 0,
|
|
skippedCount=run.skippedCount or 0,
|
|
phase=run.phase,
|
|
resultStatus=run.resultStatus,
|
|
rescueApplied=run.rescueApplied or False,
|
|
ruleSetId=run.ruleSetId,
|
|
ruleVersionId=run.ruleVersionId,
|
|
startedAt=run.startedAt,
|
|
finishedAt=run.finishedAt,
|
|
rules=rules,
|
|
fields=fields,
|
|
errors=errors,
|
|
rescueOutcomes=rescueOutcomes,
|
|
metrics=metrics,
|
|
artifacts=artifacts,
|
|
)
|