"""评查服务实现。 编排 LeAudit 引擎执行链路: 文档 → OCR → Extract → Evaluate → Rescue → Persist """ from datetime import datetime from typing import Any from fastapi_common.fastapi_common_logger import logger from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException from sqlalchemy import select, text from fastapi_modules.fastapi_leaudit.domian.vo.auditVo import ( AuditArtifactVO, AuditFieldResultVO, AuditMetricsVO, AuditRescueOutcomeVO, AuditResultVO, AuditRunErrorVO, AuditRunVO, ) from fastapi_modules.fastapi_leaudit.leaudit_bridge.tasks import ( dispatch_leaudit_task, resolve_worker_queue, ) from fastapi_modules.fastapi_leaudit.models import ( LeauditAuditRun, LeauditDocument, LeauditDocumentFile, ) from fastapi_modules.fastapi_leaudit.services import IAuditService def _normalize_speed(speed: str | None) -> str: """Normalize front-end speed selection to urgent/normal.""" normalized = (speed or "").strip().lower() if normalized in {"urgent", "high", "fast", "emergency", "紧急"}: return "urgent" return "normal" def _candidate_binding_tenant_codes(tenant_code: str | None) -> list[str]: """Return binding resolution order for one document tenant. PUBLIC is the platform template source; PROVINCIAL remains only as legacy fallback. """ normalized = str(tenant_code or "").strip().upper() candidates: list[str] = [] if normalized and normalized not in {"PROVINCIAL", "PUBLIC"}: candidates.append(normalized) candidates.append("PUBLIC") if normalized != "PUBLIC": candidates.append("PROVINCIAL") return list(dict.fromkeys(candidates)) def _pick_effective_binding(bindings: list[dict], tenant_code: str | None) -> dict | None: """Pick the effective binding by tenant inheritance order. Legacy rows without ``tenant_code`` are treated as the loosest fallback. """ if not bindings: return None binding_map: dict[str, dict] = {} empty_binding: dict | None = None for binding in bindings: normalized = str(binding.get("tenant_code") or "").strip().upper() if not normalized: if empty_binding is None: empty_binding = binding continue binding_map.setdefault(normalized, binding) for candidate in _candidate_binding_tenant_codes(tenant_code): matched = binding_map.get(candidate) if matched is not None: return matched return empty_binding class AuditServiceImpl(IAuditService): """评查服务实现。""" def __init__(self) -> None: self._column_exists_cache: dict[str, bool] = {} async def _column_exists(self, session, table_name: str, column_name: str) -> bool: cache_key = f"{table_name}.{column_name}" cached = self._column_exists_cache.get(cache_key) if cached is not None: return cached exists = bool( ( await session.execute( text( """ SELECT EXISTS ( SELECT 1 FROM information_schema.columns WHERE table_schema = current_schema() AND table_name = :table_name AND column_name = :column_name ) """ ), {"table_name": table_name, "column_name": column_name}, ) ).scalar_one() ) self._column_exists_cache[cache_key] = exists return exists async def _resolve_rule_binding_from_group( self, session, group_id: int | None, tenant_code: str | None = None, ) -> dict | None: """按二级分组解析正式规则绑定。""" if not group_id: return None binding_tenant_expr = ( "COALESCE(NULLIF(BTRIM(rgb.tenant_code), ''), 'PROVINCIAL')" if await self._column_exists(session, "leaudit_rule_group_bindings", "tenant_code") else "'PROVINCIAL'" ) binding_scope_expr = ( "COALESCE(NULLIF(BTRIM(rgb.scope_type), ''), 'PROVINCIAL')" if await self._column_exists(session, "leaudit_rule_group_bindings", "scope_type") else "'PROVINCIAL'" ) result = await session.execute( text( f""" SELECT rgb.id AS binding_id, {binding_tenant_expr} AS tenant_code, {binding_scope_expr} AS scope_type, rs.id AS rule_set_id, COALESCE(rs.current_version_id, fallback_rv.id) AS rule_version_id, COALESCE(current_rv.oss_url, fallback_rv.oss_url) AS rule_source_oss_url, COALESCE(current_rv.file_sha256, fallback_rv.file_sha256) AS rule_source_sha256, COALESCE(current_rv.metadata_type_id, fallback_rv.metadata_type_id) AS rule_type_id FROM leaudit_rule_group_bindings rgb JOIN leaudit_rule_sets rs ON rs.id = rgb.rule_set_id LEFT JOIN leaudit_rule_versions current_rv ON current_rv.id = rs.current_version_id LEFT JOIN LATERAL ( SELECT rv.id, rv.oss_url, rv.file_sha256, rv.metadata_type_id FROM leaudit_rule_versions rv WHERE rv.rule_set_id = rs.id AND rv.status IN ('published', 'rollback') ORDER BY rv.version_seq DESC, rv.id DESC LIMIT 1 ) fallback_rv ON TRUE WHERE rgb.group_id = :group_id AND rgb.is_active = TRUE AND rgb.deleted_at IS NULL ORDER BY rgb.priority DESC, rgb.id ASC """ ), {"group_id": int(group_id)}, ) return _pick_effective_binding(list(result.mappings().all()), tenant_code) async def _resolve_unique_group_binding_by_doc_type( self, session, doc_type_id: int | None, tenant_code: str | None = None, ) -> dict | None: """当文档尚未落 group_id 时,按文档类型唯一子组兜底解析正式绑定。""" if not doc_type_id: return None group_row = ( await session.execute( text( """ SELECT CASE WHEN COUNT(*) = 1 THEN MIN(id) END AS group_id FROM leaudit_evaluation_point_groups WHERE document_type_id = :doc_type_id AND deleted_at IS NULL AND is_enabled = TRUE AND COALESCE(pid, 0) <> 0 """ ), {"doc_type_id": int(doc_type_id)}, ) ).mappings().first() resolved_group_id = int(group_row["group_id"]) if group_row and group_row.get("group_id") is not None else None return await self._resolve_rule_binding_from_group(session, resolved_group_id, tenant_code) async def _persist_run_tenant_snapshot( self, session, run_id: int, *, tenant_code: str | None, scope_type: str | None, group_id: int | None, rule_binding_id: int | None, ) -> None: updates: list[str] = [] params: dict[str, Any] = {"run_id": run_id} optional_values = { "tenant_code": tenant_code, "scope_type_snapshot": scope_type, "group_id_snapshot": group_id, "rule_binding_id_snapshot": rule_binding_id, } for column_name, value in optional_values.items(): if await self._column_exists(session, "leaudit_audit_runs", column_name): updates.append(f"{column_name} = :{column_name}") params[column_name] = value if not updates: return await session.execute( text( f""" UPDATE leaudit_audit_runs SET {", ".join(updates)} WHERE id = :run_id """ ), params, ) async def Run( self, DocumentId: int, RuleType: str | None = None, Force: bool = False, Speed: str = "normal", TriggerUserId: int | None = None, ) -> AuditRunVO: """触发文档评查。 当前阶段只负责创建 run 并投递 worker,不在 HTTP 请求内同步执行。 """ async with GetAsyncSession() as session: logger.info(f"触发评查: documentId={DocumentId}, ruleType={RuleType}, triggerUserId={TriggerUserId}") normalizedSpeed = _normalize_speed(Speed) await session.execute( text( """ ALTER TABLE leaudit_documents ADD COLUMN IF NOT EXISTS group_id BIGINT NULL REFERENCES leaudit_evaluation_point_groups(id) """ ) ) document = await session.get(LeauditDocument, DocumentId) if not document: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "评查文档不存在") if not Force: activeRunResult = await session.execute( select(LeauditAuditRun) .where( LeauditAuditRun.documentId == DocumentId, LeauditAuditRun.status.in_(("queued", "running", "retrying")), ) .order_by(LeauditAuditRun.Id.desc()) .limit(1) ) activeRun = activeRunResult.scalar_one_or_none() if activeRun: return AuditRunVO( runId=activeRun.Id, documentId=activeRun.documentId, runNo=activeRun.runNo, documentFileId=activeRun.documentFileId, status=activeRun.status, phase=activeRun.phase, resultStatus=activeRun.resultStatus, ruleSetId=activeRun.ruleSetId, ruleVersionId=activeRun.ruleVersionId, ruleTypeId=activeRun.ruleTypeId, rescueApplied=activeRun.rescueApplied or False, totalScore=float(activeRun.totalScore) if activeRun.totalScore else None, passedCount=activeRun.passedCount, failedCount=activeRun.failedCount, skippedCount=activeRun.skippedCount, startedAt=activeRun.startedAt, finishedAt=activeRun.finishedAt, ) fileResult = await session.execute( select(LeauditDocumentFile) .where( LeauditDocumentFile.documentId == DocumentId, LeauditDocumentFile.isActive.is_(True), LeauditDocumentFile.fileRole == "primary", ) .order_by(LeauditDocumentFile.Id.desc()) .limit(1) ) documentFile = fileResult.scalar_one_or_none() if not documentFile: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前文档没有可执行文件版本") runNoResult = await session.execute( select(LeauditAuditRun.runNo) .where(LeauditAuditRun.documentId == DocumentId) .order_by(LeauditAuditRun.runNo.desc()) .limit(1) ) latestRunNo = runNoResult.scalar_one_or_none() or 0 document_tenant_code = str(getattr(document, "tenantCode", None) or "").strip() or None binding = await self._resolve_rule_binding_from_group( session, getattr(document, "groupId", None), document_tenant_code, ) if binding is None: binding = await self._resolve_unique_group_binding_by_doc_type( session, getattr(document, "typeId", None), document_tenant_code, ) if binding and getattr(document, "groupId", None) is None: logger.info("文档未显式记录 group_id,已按文档类型唯一子组解析正式规则绑定") if not binding or not binding["rule_set_id"] or not binding["rule_version_id"]: if getattr(document, "groupId", None): raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前子类型未绑定可执行规则集,请先检查二级分组规则配置") if binding is None or not binding["rule_set_id"] or not binding["rule_version_id"]: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前文档未绑定可执行规则集,请先检查二级分组规则配置") triggerSource = f"{'retry' if Force else 'upload'}:{normalizedSpeed}" run = LeauditAuditRun( documentId=DocumentId, documentFileId=documentFile.Id, runNo=int(latestRunNo) + 1, triggerSource=triggerSource, triggerUserId=TriggerUserId, status="queued", phase="dispatch", ruleSetId=int(binding["rule_set_id"]), ruleVersionId=int(binding["rule_version_id"]), ruleTypeId=binding["rule_type_id"], ruleSourceOssUrl=binding["rule_source_oss_url"], ruleSourceSha256=binding["rule_source_sha256"], ) session.add(run) await session.flush() await self._persist_run_tenant_snapshot( session, run.Id, tenant_code=document_tenant_code, scope_type=str(binding.get("scope_type") or "").strip() or None, group_id=getattr(document, "groupId", None), rule_binding_id=int(binding["binding_id"]) if binding.get("binding_id") is not None else None, ) document.currentRunId = run.Id document.processingStatus = "queued" await session.commit() await session.refresh(run) try: taskId = dispatch_leaudit_task( run_id=run.Id, queue_name=resolve_worker_queue(triggerSource), rules_path=RuleType, ) except Exception as Error: run.status = "failed" run.phase = "dispatch" run.finishedAt = datetime.now() document.processingStatus = "failed" await session.commit() raise LeauditException( StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, f"投递评查任务失败: {Error}", ) from Error run.taskId = taskId await session.commit() await session.refresh(run) return AuditRunVO( runId=run.Id, documentId=run.documentId, runNo=run.runNo, documentFileId=run.documentFileId, status=run.status, phase=run.phase, resultStatus=run.resultStatus, ruleSetId=run.ruleSetId, ruleVersionId=run.ruleVersionId, ruleTypeId=run.ruleTypeId, rescueApplied=run.rescueApplied or False, totalScore=float(run.totalScore) if run.totalScore else None, passedCount=run.passedCount, failedCount=run.failedCount, skippedCount=run.skippedCount, startedAt=run.startedAt, finishedAt=run.finishedAt, ) async def GetRunStatus(self, RunId: int) -> AuditRunVO: """查询评查运行状态。""" async with GetAsyncSession() as session: run = await session.get(LeauditAuditRun, RunId) if not run: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "评查运行记录不存在") return AuditRunVO( runId=run.Id, documentId=run.documentId, runNo=run.runNo, documentFileId=run.documentFileId, status=run.status, phase=run.phase, resultStatus=run.resultStatus, ruleSetId=run.ruleSetId, ruleVersionId=run.ruleVersionId, ruleTypeId=run.ruleTypeId, rescueApplied=run.rescueApplied or False, totalScore=float(run.totalScore) if run.totalScore else None, passedCount=run.passedCount, failedCount=run.failedCount, skippedCount=run.skippedCount, startedAt=run.startedAt, finishedAt=run.finishedAt, ) async def GetResult(self, RunId: int) -> AuditResultVO: """获取评查结果。""" async with GetAsyncSession() as session: run = await session.get(LeauditAuditRun, RunId) if not run: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "评查运行记录不存在") ruleResult = await session.execute( text( """ SELECT rule_id, rule_name, risk, score, passed, status, skip_reason, confidence, pass_message, fail_message, remediation, extracted_fields, field_positions, rescue_applied, rescue_passed FROM leaudit_rule_results WHERE run_id = :run_id ORDER BY id ASC """ ), {"run_id": RunId}, ) fieldResult = await session.execute( text( """ SELECT field_name, field_type, value_text, confidence, grounding_method, fallback_value, raw_value_json, meta_json FROM leaudit_field_results WHERE run_id = :run_id ORDER BY id ASC """ ), {"run_id": RunId}, ) errorResult = await session.execute( text( """ SELECT stage, level, error_code, message, detail_json, created_at FROM leaudit_run_errors WHERE run_id = :run_id ORDER BY id ASC """ ), {"run_id": RunId}, ) rescueResult = await session.execute( text( """ SELECT rule_id, status, diagnosis, diagnosis_confidence, final_status, failure_reason, llm_calls, vlm_calls, duration_ms, requires_human_review, payload FROM leaudit_rescue_outcomes WHERE run_id = :run_id ORDER BY id ASC """ ), {"run_id": RunId}, ) metricResult = await session.execute( text( """ SELECT ocr_seconds, normalize_seconds, extract_seconds, evaluate_seconds, rescue_seconds, total_seconds, page_count, sub_document_count, field_count, rule_count, llm_call_count, vlm_call_count, rescue_rule_count, artifact_count FROM leaudit_run_metrics WHERE run_id = :run_id ORDER BY id DESC LIMIT 1 """ ), {"run_id": RunId}, ) artifactResult = await session.execute( text( """ SELECT artifact_type, artifact_role, file_name, file_ext, mime_type, file_size, oss_url, is_persisted FROM leaudit_artifacts WHERE run_id = :run_id ORDER BY id ASC """ ), {"run_id": RunId}, ) rules = [dict(row) for row in ruleResult.mappings().all()] fields = [ AuditFieldResultVO( fieldName=row["field_name"], fieldType=row["field_type"], valueText=row["value_text"], confidence=float(row["confidence"]) if row["confidence"] is not None else None, groundingMethod=row["grounding_method"], fallbackValue=row["fallback_value"], rawValueJson=row["raw_value_json"], metaJson=row["meta_json"], ) for row in fieldResult.mappings().all() ] errors = [ AuditRunErrorVO( stage=row["stage"], level=row["level"], errorCode=row["error_code"], message=row["message"], detailJson=row["detail_json"], createdAt=row["created_at"], ) for row in errorResult.mappings().all() ] rescueOutcomes = [ AuditRescueOutcomeVO( ruleId=row["rule_id"], status=row["status"], diagnosis=row["diagnosis"], diagnosisConfidence=float(row["diagnosis_confidence"]) if row["diagnosis_confidence"] is not None else None, finalStatus=row["final_status"], failureReason=row["failure_reason"], llmCalls=row["llm_calls"], vlmCalls=row["vlm_calls"], durationMs=row["duration_ms"], requiresHumanReview=bool(row["requires_human_review"]), payload=row["payload"], ) for row in rescueResult.mappings().all() ] metricRow = metricResult.mappings().first() metrics = ( AuditMetricsVO( ocrSeconds=float(metricRow["ocr_seconds"]) if metricRow["ocr_seconds"] is not None else None, normalizeSeconds=float(metricRow["normalize_seconds"]) if metricRow["normalize_seconds"] is not None else None, extractSeconds=float(metricRow["extract_seconds"]) if metricRow["extract_seconds"] is not None else None, evaluateSeconds=float(metricRow["evaluate_seconds"]) if metricRow["evaluate_seconds"] is not None else None, rescueSeconds=float(metricRow["rescue_seconds"]) if metricRow["rescue_seconds"] is not None else None, totalSeconds=float(metricRow["total_seconds"]) if metricRow["total_seconds"] is not None else None, pageCount=metricRow["page_count"], subDocumentCount=metricRow["sub_document_count"], fieldCount=metricRow["field_count"], ruleCount=metricRow["rule_count"], llmCallCount=metricRow["llm_call_count"], vlmCallCount=metricRow["vlm_call_count"], rescueRuleCount=metricRow["rescue_rule_count"], artifactCount=metricRow["artifact_count"], ) if metricRow else None ) artifacts = [ AuditArtifactVO( artifactType=row["artifact_type"], artifactRole=row["artifact_role"], fileName=row["file_name"], fileExt=row["file_ext"], mimeType=row["mime_type"], fileSize=row["file_size"], ossUrl=row["oss_url"], isPersisted=row["is_persisted"], ) for row in artifactResult.mappings().all() ] return AuditResultVO( runId=run.Id, documentId=run.documentId, documentFileId=run.documentFileId, status=run.status, totalScore=float(run.totalScore) if run.totalScore else None, passedCount=run.passedCount or 0, failedCount=run.failedCount or 0, skippedCount=run.skippedCount or 0, phase=run.phase, resultStatus=run.resultStatus, rescueApplied=run.rescueApplied or False, ruleSetId=run.ruleSetId, ruleVersionId=run.ruleVersionId, startedAt=run.startedAt, finishedAt=run.finishedAt, rules=rules, fields=fields, errors=errors, rescueOutcomes=rescueOutcomes, metrics=metrics, artifacts=artifacts, )