Files
leaudit-platform-backend/fastapi_modules/fastapi_leaudit/govdoc_engine/engine/runner.py
T

243 lines
9.2 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""规则评估引擎:跑一条规则的多 stage。"""
from __future__ import annotations
import asyncio
import uuid
from dataclasses import dataclass, field
from fastapi_modules.fastapi_leaudit.govdoc_engine.models import Document, Finding, Location
from fastapi_modules.fastapi_leaudit.govdoc_engine.dsl.schema import Rule
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import get_check # noqa: F401 (确保注册)
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks.base import CheckContext, CheckResult, CheckHit
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks.ai_check import AiCheck
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.selector import select_paragraphs
from fastapi_modules.fastapi_leaudit.govdoc_engine.parser.entities import SemanticEntity
from fastapi_modules.fastapi_leaudit.govdoc_engine.llm.client import LlmClient
# 触发所有 check 类的 @register
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import required as _r # noqa: F401
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import font as _f # noqa: F401
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import regex_check as _rc # noqa: F401
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import confused_pair as _cp # noqa: F401
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import forbid as _fb # noqa: F401
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import wenzhong as _wz # noqa: F401
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import hierarchy as _h # noqa: F401
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import punctuation as _p # noqa: F401
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import cross_role as _cr # noqa: F401
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import ai_check as _ai # noqa: F401
@dataclass
class RuleOutcome:
"""单条规则的执行结果(含 skipped 状态)。"""
rule: Rule
findings: list[Finding] = field(default_factory=list)
skipped: bool = False
skip_reason: str = ""
class RuleRunner:
def __init__(self, llm_client: LlmClient | None = None):
self.llm = llm_client
# -- 上下文装配 -----------------------------------------------------
def _resolve_target(
self,
rule: Rule,
doc: Document,
entities: dict[str, SemanticEntity | None],
) -> tuple[list, SemanticEntity | None, RuleOutcome | None]:
"""根据 rule.target 或 rule.applies_to 选段落。
返回 (paragraphs, target_entity, early_outcome)
若 early_outcome 非 None,调用方应直接返回(命中 on_missing 提前结束)。
"""
if rule.target:
target_entity = entities.get(rule.target)
if target_entity is None:
return [], None, self._handle_missing(rule)
paragraphs = [
doc.paragraphs[i]
for i in target_entity.paragraph_indices
if 0 <= i < len(doc.paragraphs)
]
return paragraphs, target_entity, None
# applies_to 通道(多段扫描)
return select_paragraphs(doc, rule.applies_to), None, None
def _handle_missing(self, rule: Rule) -> RuleOutcome:
mode = rule.on_missing
if mode == "pass":
return RuleOutcome(rule=rule)
reason = f"目标实体「{rule.target}」未识别到"
if mode == "skip":
return RuleOutcome(rule=rule, skipped=True, skip_reason=reason)
severity = "error" if mode == "fail" else "warning"
finding = Finding(
finding_id=f"F-{uuid.uuid4().hex[:8]}",
rule_id=rule.rule_id,
rule_name=rule.name,
severity=severity,
category=rule.category,
location=Location(paragraph_index=-1),
message=reason,
suggestion=rule.messages.fail or "",
evidence="", confidence=0.9,
)
return RuleOutcome(rule=rule, findings=[finding])
@staticmethod
def _merge_skip(outcome: RuleOutcome, result: CheckResult) -> None:
if not outcome.skip_reason:
outcome.skip_reason = result.skip_reason or "stage skipped"
outcome.skipped = True
# -- 同步路径 -------------------------------------------------------
def run_rule(
self,
rule: Rule,
doc: Document,
entities: dict[str, SemanticEntity | None] | None = None,
) -> RuleOutcome:
entities = entities or {}
paragraphs, target, early = self._resolve_target(rule, doc, entities)
if early is not None:
return early
outcome = RuleOutcome(rule=rule)
for stage in rule.stages:
if stage.check == "ai":
check = AiCheck(llm_client=self.llm)
else:
check_cls = get_check(stage.check)
check = check_cls()
ctx = CheckContext(
document=doc,
paragraphs=paragraphs,
stage=stage,
entities=entities,
target=target,
rule_id=rule.rule_id,
)
result: CheckResult = check.run(ctx)
if result.skipped:
self._merge_skip(outcome, result)
continue
if not result.passed:
outcome.findings = [self._hit_to_finding(rule, h) for h in result.hits]
outcome.skipped = False
outcome.skip_reason = ""
return outcome
return outcome
def run_all(
self,
rules: list[Rule],
doc: Document,
entities: dict[str, SemanticEntity | None] | None = None,
) -> list[Finding]:
flat, _ = self.evaluate(rules, doc, entities)
return flat
def evaluate(
self,
rules: list[Rule],
doc: Document,
entities: dict[str, SemanticEntity | None] | None = None,
) -> tuple[list[Finding], list[RuleOutcome]]:
flat: list[Finding] = []
outcomes: list[RuleOutcome] = []
for r in rules:
o = self.run_rule(r, doc, entities)
flat.extend(o.findings)
outcomes.append(o)
return flat, outcomes
# -- 异步路径 -------------------------------------------------------
async def run_rule_async(
self,
rule: Rule,
doc: Document,
entities: dict[str, SemanticEntity | None] | None = None,
) -> RuleOutcome:
entities = entities or {}
paragraphs, target, early = self._resolve_target(rule, doc, entities)
if early is not None:
return early
outcome = RuleOutcome(rule=rule)
for stage in rule.stages:
ctx = CheckContext(
document=doc,
paragraphs=paragraphs,
stage=stage,
entities=entities,
target=target,
rule_id=rule.rule_id,
)
if stage.check == "ai":
result = await AiCheck(llm_client=self.llm).run_async(ctx)
else:
check_cls = get_check(stage.check)
result = check_cls().run(ctx)
if result.skipped:
self._merge_skip(outcome, result)
continue
if not result.passed:
outcome.findings = [self._hit_to_finding(rule, h) for h in result.hits]
outcome.skipped = False
outcome.skip_reason = ""
return outcome
return outcome
async def run_all_async(
self,
rules: list[Rule],
doc: Document,
entities: dict[str, SemanticEntity | None] | None = None,
) -> list[Finding]:
flat, _ = await self.evaluate_async(rules, doc, entities)
return flat
async def evaluate_async(
self,
rules: list[Rule],
doc: Document,
entities: dict[str, SemanticEntity | None] | None = None,
) -> tuple[list[Finding], list[RuleOutcome]]:
outcomes_list = await asyncio.gather(
*(self.run_rule_async(r, doc, entities) for r in rules)
)
flat: list[Finding] = []
outcomes: list[RuleOutcome] = []
for o in outcomes_list:
flat.extend(o.findings)
outcomes.append(o)
return flat, outcomes
def _hit_to_finding(self, rule: Rule, hit: CheckHit) -> Finding:
para = hit.paragraph
loc = Location(
paragraph_index=para.index if para else -1,
role=para.role if para else None,
char_start=hit.char_start,
char_end=hit.char_end,
context=para.text if para else "",
)
msg = hit.message or rule.messages.fail
return Finding(
finding_id=f"F-{uuid.uuid4().hex[:8]}",
rule_id=rule.rule_id,
rule_name=rule.name,
severity=rule.severity,
category=rule.category,
location=loc,
actual=hit.actual or {},
expected=hit.expected or {},
message=msg,
suggestion=rule.messages.fail or "",
evidence=rule.messages.fail or "",
confidence=hit.confidence,
)