Files

674 lines
27 KiB
Python

"""评查服务实现。
编排 LeAudit 引擎执行链路:
文档 → OCR → Extract → Evaluate → Rescue → Persist
"""
from datetime import datetime
from typing import Any
from fastapi_common.fastapi_common_logger import logger
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
from sqlalchemy import select, text
from fastapi_modules.fastapi_leaudit.domian.vo.auditVo import (
AuditArtifactVO,
AuditFieldResultVO,
AuditMetricsVO,
AuditRescueOutcomeVO,
AuditResultVO,
AuditRunErrorVO,
AuditRunVO,
)
from fastapi_modules.fastapi_leaudit.leaudit_bridge.tasks import (
dispatch_leaudit_task,
resolve_worker_queue,
)
from fastapi_modules.fastapi_leaudit.models import (
LeauditAuditRun,
LeauditDocument,
LeauditDocumentFile,
)
from fastapi_modules.fastapi_leaudit.services import IAuditService
def _normalize_speed(speed: str | None) -> str:
"""Normalize front-end speed selection to urgent/normal."""
normalized = (speed or "").strip().lower()
if normalized in {"urgent", "high", "fast", "emergency", "紧急"}:
return "urgent"
return "normal"
def _candidate_binding_tenant_codes(tenant_code: str | None) -> list[str]:
"""Return binding resolution order for one document tenant.
PUBLIC is the platform template source; PROVINCIAL remains only as legacy fallback.
"""
normalized = str(tenant_code or "").strip().upper()
candidates: list[str] = []
if normalized and normalized not in {"PROVINCIAL", "PUBLIC"}:
candidates.append(normalized)
candidates.append("PUBLIC")
if normalized != "PUBLIC":
candidates.append("PROVINCIAL")
return list(dict.fromkeys(candidates))
def _pick_effective_binding(bindings: list[dict], tenant_code: str | None) -> dict | None:
"""Pick the effective binding by tenant inheritance order.
Legacy rows without ``tenant_code`` are treated as the loosest fallback.
"""
if not bindings:
return None
binding_map: dict[str, dict] = {}
empty_binding: dict | None = None
for binding in bindings:
normalized = str(binding.get("tenant_code") or "").strip().upper()
if not normalized:
if empty_binding is None:
empty_binding = binding
continue
binding_map.setdefault(normalized, binding)
for candidate in _candidate_binding_tenant_codes(tenant_code):
matched = binding_map.get(candidate)
if matched is not None:
return matched
return empty_binding
class AuditServiceImpl(IAuditService):
"""评查服务实现。"""
def __init__(self) -> None:
self._column_exists_cache: dict[str, bool] = {}
async def _column_exists(self, session, table_name: str, column_name: str) -> bool:
cache_key = f"{table_name}.{column_name}"
cached = self._column_exists_cache.get(cache_key)
if cached is not None:
return cached
exists = bool(
(
await session.execute(
text(
"""
SELECT EXISTS (
SELECT 1
FROM information_schema.columns
WHERE table_schema = current_schema()
AND table_name = :table_name
AND column_name = :column_name
)
"""
),
{"table_name": table_name, "column_name": column_name},
)
).scalar_one()
)
self._column_exists_cache[cache_key] = exists
return exists
async def _resolve_rule_binding_from_group(
self,
session,
group_id: int | None,
tenant_code: str | None = None,
) -> dict | None:
"""按二级分组解析正式规则绑定。"""
if not group_id:
return None
binding_tenant_expr = (
"COALESCE(NULLIF(BTRIM(rgb.tenant_code), ''), 'PROVINCIAL')"
if await self._column_exists(session, "leaudit_rule_group_bindings", "tenant_code")
else "'PROVINCIAL'"
)
binding_scope_expr = (
"COALESCE(NULLIF(BTRIM(rgb.scope_type), ''), 'PROVINCIAL')"
if await self._column_exists(session, "leaudit_rule_group_bindings", "scope_type")
else "'PROVINCIAL'"
)
result = await session.execute(
text(
f"""
SELECT
rgb.id AS binding_id,
{binding_tenant_expr} AS tenant_code,
{binding_scope_expr} AS scope_type,
rs.id AS rule_set_id,
COALESCE(rs.current_version_id, fallback_rv.id) AS rule_version_id,
COALESCE(current_rv.oss_url, fallback_rv.oss_url) AS rule_source_oss_url,
COALESCE(current_rv.file_sha256, fallback_rv.file_sha256) AS rule_source_sha256,
COALESCE(current_rv.metadata_type_id, fallback_rv.metadata_type_id) AS rule_type_id
FROM leaudit_rule_group_bindings rgb
JOIN leaudit_rule_sets rs ON rs.id = rgb.rule_set_id
LEFT JOIN leaudit_rule_versions current_rv ON current_rv.id = rs.current_version_id
LEFT JOIN LATERAL (
SELECT
rv.id,
rv.oss_url,
rv.file_sha256,
rv.metadata_type_id
FROM leaudit_rule_versions rv
WHERE rv.rule_set_id = rs.id
AND rv.status IN ('published', 'rollback')
ORDER BY rv.version_seq DESC, rv.id DESC
LIMIT 1
) fallback_rv ON TRUE
WHERE rgb.group_id = :group_id
AND rgb.is_active = TRUE
AND rgb.deleted_at IS NULL
ORDER BY rgb.priority DESC, rgb.id ASC
"""
),
{"group_id": int(group_id)},
)
return _pick_effective_binding(list(result.mappings().all()), tenant_code)
async def _resolve_unique_group_binding_by_doc_type(
self,
session,
doc_type_id: int | None,
tenant_code: str | None = None,
) -> dict | None:
"""当文档尚未落 group_id 时,按文档类型唯一子组兜底解析正式绑定。"""
if not doc_type_id:
return None
group_row = (
await session.execute(
text(
"""
SELECT CASE WHEN COUNT(*) = 1 THEN MIN(id) END AS group_id
FROM leaudit_evaluation_point_groups
WHERE document_type_id = :doc_type_id
AND deleted_at IS NULL
AND is_enabled = TRUE
AND COALESCE(pid, 0) <> 0
"""
),
{"doc_type_id": int(doc_type_id)},
)
).mappings().first()
resolved_group_id = int(group_row["group_id"]) if group_row and group_row.get("group_id") is not None else None
return await self._resolve_rule_binding_from_group(session, resolved_group_id, tenant_code)
async def _persist_run_tenant_snapshot(
self,
session,
run_id: int,
*,
tenant_code: str | None,
scope_type: str | None,
group_id: int | None,
rule_binding_id: int | None,
) -> None:
updates: list[str] = []
params: dict[str, Any] = {"run_id": run_id}
optional_values = {
"tenant_code": tenant_code,
"scope_type_snapshot": scope_type,
"group_id_snapshot": group_id,
"rule_binding_id_snapshot": rule_binding_id,
}
for column_name, value in optional_values.items():
if await self._column_exists(session, "leaudit_audit_runs", column_name):
updates.append(f"{column_name} = :{column_name}")
params[column_name] = value
if not updates:
return
await session.execute(
text(
f"""
UPDATE leaudit_audit_runs
SET {", ".join(updates)}
WHERE id = :run_id
"""
),
params,
)
async def Run(
self,
DocumentId: int,
RuleType: str | None = None,
Force: bool = False,
Speed: str = "normal",
TriggerUserId: int | None = None,
) -> AuditRunVO:
"""触发文档评查。
当前阶段只负责创建 run 并投递 worker,不在 HTTP 请求内同步执行。
"""
async with GetAsyncSession() as session:
logger.info(f"触发评查: documentId={DocumentId}, ruleType={RuleType}, triggerUserId={TriggerUserId}")
normalizedSpeed = _normalize_speed(Speed)
await session.execute(
text(
"""
ALTER TABLE leaudit_documents
ADD COLUMN IF NOT EXISTS group_id BIGINT NULL
REFERENCES leaudit_evaluation_point_groups(id)
"""
)
)
document = await session.get(LeauditDocument, DocumentId)
if not document:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "评查文档不存在")
if not Force:
activeRunResult = await session.execute(
select(LeauditAuditRun)
.where(
LeauditAuditRun.documentId == DocumentId,
LeauditAuditRun.status.in_(("queued", "running", "retrying")),
)
.order_by(LeauditAuditRun.Id.desc())
.limit(1)
)
activeRun = activeRunResult.scalar_one_or_none()
if activeRun:
return AuditRunVO(
runId=activeRun.Id,
documentId=activeRun.documentId,
runNo=activeRun.runNo,
documentFileId=activeRun.documentFileId,
status=activeRun.status,
phase=activeRun.phase,
resultStatus=activeRun.resultStatus,
ruleSetId=activeRun.ruleSetId,
ruleVersionId=activeRun.ruleVersionId,
ruleTypeId=activeRun.ruleTypeId,
rescueApplied=activeRun.rescueApplied or False,
totalScore=float(activeRun.totalScore) if activeRun.totalScore else None,
passedCount=activeRun.passedCount,
failedCount=activeRun.failedCount,
skippedCount=activeRun.skippedCount,
startedAt=activeRun.startedAt,
finishedAt=activeRun.finishedAt,
)
fileResult = await session.execute(
select(LeauditDocumentFile)
.where(
LeauditDocumentFile.documentId == DocumentId,
LeauditDocumentFile.isActive.is_(True),
LeauditDocumentFile.fileRole == "primary",
)
.order_by(LeauditDocumentFile.Id.desc())
.limit(1)
)
documentFile = fileResult.scalar_one_or_none()
if not documentFile:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前文档没有可执行文件版本")
runNoResult = await session.execute(
select(LeauditAuditRun.runNo)
.where(LeauditAuditRun.documentId == DocumentId)
.order_by(LeauditAuditRun.runNo.desc())
.limit(1)
)
latestRunNo = runNoResult.scalar_one_or_none() or 0
document_tenant_code = str(getattr(document, "tenantCode", None) or "").strip() or None
binding = await self._resolve_rule_binding_from_group(
session,
getattr(document, "groupId", None),
document_tenant_code,
)
if binding is None:
binding = await self._resolve_unique_group_binding_by_doc_type(
session,
getattr(document, "typeId", None),
document_tenant_code,
)
if binding and getattr(document, "groupId", None) is None:
logger.info("文档未显式记录 group_id,已按文档类型唯一子组解析正式规则绑定")
if not binding or not binding["rule_set_id"] or not binding["rule_version_id"]:
if getattr(document, "groupId", None):
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前子类型未绑定可执行规则集,请先检查二级分组规则配置")
if binding is None or not binding["rule_set_id"] or not binding["rule_version_id"]:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前文档未绑定可执行规则集,请先检查二级分组规则配置")
triggerSource = f"{'retry' if Force else 'upload'}:{normalizedSpeed}"
run = LeauditAuditRun(
documentId=DocumentId,
documentFileId=documentFile.Id,
runNo=int(latestRunNo) + 1,
triggerSource=triggerSource,
triggerUserId=TriggerUserId,
status="queued",
phase="dispatch",
ruleSetId=int(binding["rule_set_id"]),
ruleVersionId=int(binding["rule_version_id"]),
ruleTypeId=binding["rule_type_id"],
ruleSourceOssUrl=binding["rule_source_oss_url"],
ruleSourceSha256=binding["rule_source_sha256"],
)
session.add(run)
await session.flush()
await self._persist_run_tenant_snapshot(
session,
run.Id,
tenant_code=document_tenant_code,
scope_type=str(binding.get("scope_type") or "").strip() or None,
group_id=getattr(document, "groupId", None),
rule_binding_id=int(binding["binding_id"]) if binding.get("binding_id") is not None else None,
)
document.currentRunId = run.Id
document.processingStatus = "queued"
await session.commit()
await session.refresh(run)
try:
taskId = dispatch_leaudit_task(
run_id=run.Id,
queue_name=resolve_worker_queue(triggerSource),
rules_path=RuleType,
)
except Exception as Error:
run.status = "failed"
run.phase = "dispatch"
run.finishedAt = datetime.now()
document.processingStatus = "failed"
await session.commit()
raise LeauditException(
StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR,
f"投递评查任务失败: {Error}",
) from Error
run.taskId = taskId
await session.commit()
await session.refresh(run)
return AuditRunVO(
runId=run.Id,
documentId=run.documentId,
runNo=run.runNo,
documentFileId=run.documentFileId,
status=run.status,
phase=run.phase,
resultStatus=run.resultStatus,
ruleSetId=run.ruleSetId,
ruleVersionId=run.ruleVersionId,
ruleTypeId=run.ruleTypeId,
rescueApplied=run.rescueApplied or False,
totalScore=float(run.totalScore) if run.totalScore else None,
passedCount=run.passedCount,
failedCount=run.failedCount,
skippedCount=run.skippedCount,
startedAt=run.startedAt,
finishedAt=run.finishedAt,
)
async def GetRunStatus(self, RunId: int) -> AuditRunVO:
"""查询评查运行状态。"""
async with GetAsyncSession() as session:
run = await session.get(LeauditAuditRun, RunId)
if not run:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "评查运行记录不存在")
return AuditRunVO(
runId=run.Id,
documentId=run.documentId,
runNo=run.runNo,
documentFileId=run.documentFileId,
status=run.status,
phase=run.phase,
resultStatus=run.resultStatus,
ruleSetId=run.ruleSetId,
ruleVersionId=run.ruleVersionId,
ruleTypeId=run.ruleTypeId,
rescueApplied=run.rescueApplied or False,
totalScore=float(run.totalScore) if run.totalScore else None,
passedCount=run.passedCount,
failedCount=run.failedCount,
skippedCount=run.skippedCount,
startedAt=run.startedAt,
finishedAt=run.finishedAt,
)
async def GetResult(self, RunId: int) -> AuditResultVO:
"""获取评查结果。"""
async with GetAsyncSession() as session:
run = await session.get(LeauditAuditRun, RunId)
if not run:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "评查运行记录不存在")
ruleResult = await session.execute(
text(
"""
SELECT
rule_id,
rule_name,
risk,
score,
passed,
status,
skip_reason,
confidence,
pass_message,
fail_message,
remediation,
extracted_fields,
field_positions,
rescue_applied,
rescue_passed
FROM leaudit_rule_results
WHERE run_id = :run_id
ORDER BY id ASC
"""
),
{"run_id": RunId},
)
fieldResult = await session.execute(
text(
"""
SELECT
field_name,
field_type,
value_text,
confidence,
grounding_method,
fallback_value,
raw_value_json,
meta_json
FROM leaudit_field_results
WHERE run_id = :run_id
ORDER BY id ASC
"""
),
{"run_id": RunId},
)
errorResult = await session.execute(
text(
"""
SELECT
stage,
level,
error_code,
message,
detail_json,
created_at
FROM leaudit_run_errors
WHERE run_id = :run_id
ORDER BY id ASC
"""
),
{"run_id": RunId},
)
rescueResult = await session.execute(
text(
"""
SELECT
rule_id,
status,
diagnosis,
diagnosis_confidence,
final_status,
failure_reason,
llm_calls,
vlm_calls,
duration_ms,
requires_human_review,
payload
FROM leaudit_rescue_outcomes
WHERE run_id = :run_id
ORDER BY id ASC
"""
),
{"run_id": RunId},
)
metricResult = await session.execute(
text(
"""
SELECT
ocr_seconds,
normalize_seconds,
extract_seconds,
evaluate_seconds,
rescue_seconds,
total_seconds,
page_count,
sub_document_count,
field_count,
rule_count,
llm_call_count,
vlm_call_count,
rescue_rule_count,
artifact_count
FROM leaudit_run_metrics
WHERE run_id = :run_id
ORDER BY id DESC
LIMIT 1
"""
),
{"run_id": RunId},
)
artifactResult = await session.execute(
text(
"""
SELECT
artifact_type,
artifact_role,
file_name,
file_ext,
mime_type,
file_size,
oss_url,
is_persisted
FROM leaudit_artifacts
WHERE run_id = :run_id
ORDER BY id ASC
"""
),
{"run_id": RunId},
)
rules = [dict(row) for row in ruleResult.mappings().all()]
fields = [
AuditFieldResultVO(
fieldName=row["field_name"],
fieldType=row["field_type"],
valueText=row["value_text"],
confidence=float(row["confidence"]) if row["confidence"] is not None else None,
groundingMethod=row["grounding_method"],
fallbackValue=row["fallback_value"],
rawValueJson=row["raw_value_json"],
metaJson=row["meta_json"],
)
for row in fieldResult.mappings().all()
]
errors = [
AuditRunErrorVO(
stage=row["stage"],
level=row["level"],
errorCode=row["error_code"],
message=row["message"],
detailJson=row["detail_json"],
createdAt=row["created_at"],
)
for row in errorResult.mappings().all()
]
rescueOutcomes = [
AuditRescueOutcomeVO(
ruleId=row["rule_id"],
status=row["status"],
diagnosis=row["diagnosis"],
diagnosisConfidence=float(row["diagnosis_confidence"]) if row["diagnosis_confidence"] is not None else None,
finalStatus=row["final_status"],
failureReason=row["failure_reason"],
llmCalls=row["llm_calls"],
vlmCalls=row["vlm_calls"],
durationMs=row["duration_ms"],
requiresHumanReview=bool(row["requires_human_review"]),
payload=row["payload"],
)
for row in rescueResult.mappings().all()
]
metricRow = metricResult.mappings().first()
metrics = (
AuditMetricsVO(
ocrSeconds=float(metricRow["ocr_seconds"]) if metricRow["ocr_seconds"] is not None else None,
normalizeSeconds=float(metricRow["normalize_seconds"]) if metricRow["normalize_seconds"] is not None else None,
extractSeconds=float(metricRow["extract_seconds"]) if metricRow["extract_seconds"] is not None else None,
evaluateSeconds=float(metricRow["evaluate_seconds"]) if metricRow["evaluate_seconds"] is not None else None,
rescueSeconds=float(metricRow["rescue_seconds"]) if metricRow["rescue_seconds"] is not None else None,
totalSeconds=float(metricRow["total_seconds"]) if metricRow["total_seconds"] is not None else None,
pageCount=metricRow["page_count"],
subDocumentCount=metricRow["sub_document_count"],
fieldCount=metricRow["field_count"],
ruleCount=metricRow["rule_count"],
llmCallCount=metricRow["llm_call_count"],
vlmCallCount=metricRow["vlm_call_count"],
rescueRuleCount=metricRow["rescue_rule_count"],
artifactCount=metricRow["artifact_count"],
)
if metricRow
else None
)
artifacts = [
AuditArtifactVO(
artifactType=row["artifact_type"],
artifactRole=row["artifact_role"],
fileName=row["file_name"],
fileExt=row["file_ext"],
mimeType=row["mime_type"],
fileSize=row["file_size"],
ossUrl=row["oss_url"],
isPersisted=row["is_persisted"],
)
for row in artifactResult.mappings().all()
]
return AuditResultVO(
runId=run.Id,
documentId=run.documentId,
documentFileId=run.documentFileId,
status=run.status,
totalScore=float(run.totalScore) if run.totalScore else None,
passedCount=run.passedCount or 0,
failedCount=run.failedCount or 0,
skippedCount=run.skippedCount or 0,
phase=run.phase,
resultStatus=run.resultStatus,
rescueApplied=run.rescueApplied or False,
ruleSetId=run.ruleSetId,
ruleVersionId=run.ruleVersionId,
startedAt=run.startedAt,
finishedAt=run.finishedAt,
rules=rules,
fields=fields,
errors=errors,
rescueOutcomes=rescueOutcomes,
metrics=metrics,
artifacts=artifacts,
)