243 lines
9.2 KiB
Python
243 lines
9.2 KiB
Python
"""规则评估引擎:跑一条规则的多 stage。"""
|
||
|
||
from __future__ import annotations
|
||
import asyncio
|
||
import uuid
|
||
from dataclasses import dataclass, field
|
||
from fastapi_modules.fastapi_leaudit.govdoc_engine.models import Document, Finding, Location
|
||
from fastapi_modules.fastapi_leaudit.govdoc_engine.dsl.schema import Rule
|
||
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import get_check # noqa: F401 (确保注册)
|
||
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks.base import CheckContext, CheckResult, CheckHit
|
||
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks.ai_check import AiCheck
|
||
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.selector import select_paragraphs
|
||
from fastapi_modules.fastapi_leaudit.govdoc_engine.parser.entities import SemanticEntity
|
||
from fastapi_modules.fastapi_leaudit.govdoc_engine.llm.client import LlmClient
|
||
|
||
# 触发所有 check 类的 @register
|
||
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import required as _r # noqa: F401
|
||
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import font as _f # noqa: F401
|
||
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import regex_check as _rc # noqa: F401
|
||
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import confused_pair as _cp # noqa: F401
|
||
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import forbid as _fb # noqa: F401
|
||
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import wenzhong as _wz # noqa: F401
|
||
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import hierarchy as _h # noqa: F401
|
||
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import punctuation as _p # noqa: F401
|
||
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import cross_role as _cr # noqa: F401
|
||
from fastapi_modules.fastapi_leaudit.govdoc_engine.engine.checks import ai_check as _ai # noqa: F401
|
||
|
||
|
||
@dataclass
|
||
class RuleOutcome:
|
||
"""单条规则的执行结果(含 skipped 状态)。"""
|
||
|
||
rule: Rule
|
||
findings: list[Finding] = field(default_factory=list)
|
||
skipped: bool = False
|
||
skip_reason: str = ""
|
||
|
||
|
||
class RuleRunner:
|
||
def __init__(self, llm_client: LlmClient | None = None):
|
||
self.llm = llm_client
|
||
|
||
# -- 上下文装配 -----------------------------------------------------
|
||
def _resolve_target(
|
||
self,
|
||
rule: Rule,
|
||
doc: Document,
|
||
entities: dict[str, SemanticEntity | None],
|
||
) -> tuple[list, SemanticEntity | None, RuleOutcome | None]:
|
||
"""根据 rule.target 或 rule.applies_to 选段落。
|
||
|
||
返回 (paragraphs, target_entity, early_outcome);
|
||
若 early_outcome 非 None,调用方应直接返回(命中 on_missing 提前结束)。
|
||
"""
|
||
if rule.target:
|
||
target_entity = entities.get(rule.target)
|
||
if target_entity is None:
|
||
return [], None, self._handle_missing(rule)
|
||
paragraphs = [
|
||
doc.paragraphs[i]
|
||
for i in target_entity.paragraph_indices
|
||
if 0 <= i < len(doc.paragraphs)
|
||
]
|
||
return paragraphs, target_entity, None
|
||
# applies_to 通道(多段扫描)
|
||
return select_paragraphs(doc, rule.applies_to), None, None
|
||
|
||
def _handle_missing(self, rule: Rule) -> RuleOutcome:
|
||
mode = rule.on_missing
|
||
if mode == "pass":
|
||
return RuleOutcome(rule=rule)
|
||
reason = f"目标实体「{rule.target}」未识别到"
|
||
if mode == "skip":
|
||
return RuleOutcome(rule=rule, skipped=True, skip_reason=reason)
|
||
severity = "error" if mode == "fail" else "warning"
|
||
finding = Finding(
|
||
finding_id=f"F-{uuid.uuid4().hex[:8]}",
|
||
rule_id=rule.rule_id,
|
||
rule_name=rule.name,
|
||
severity=severity,
|
||
category=rule.category,
|
||
location=Location(paragraph_index=-1),
|
||
message=reason,
|
||
suggestion=rule.messages.fail or "",
|
||
evidence="", confidence=0.9,
|
||
)
|
||
return RuleOutcome(rule=rule, findings=[finding])
|
||
|
||
@staticmethod
|
||
def _merge_skip(outcome: RuleOutcome, result: CheckResult) -> None:
|
||
if not outcome.skip_reason:
|
||
outcome.skip_reason = result.skip_reason or "stage skipped"
|
||
outcome.skipped = True
|
||
|
||
# -- 同步路径 -------------------------------------------------------
|
||
def run_rule(
|
||
self,
|
||
rule: Rule,
|
||
doc: Document,
|
||
entities: dict[str, SemanticEntity | None] | None = None,
|
||
) -> RuleOutcome:
|
||
entities = entities or {}
|
||
paragraphs, target, early = self._resolve_target(rule, doc, entities)
|
||
if early is not None:
|
||
return early
|
||
|
||
outcome = RuleOutcome(rule=rule)
|
||
for stage in rule.stages:
|
||
if stage.check == "ai":
|
||
check = AiCheck(llm_client=self.llm)
|
||
else:
|
||
check_cls = get_check(stage.check)
|
||
check = check_cls()
|
||
|
||
ctx = CheckContext(
|
||
document=doc,
|
||
paragraphs=paragraphs,
|
||
stage=stage,
|
||
entities=entities,
|
||
target=target,
|
||
rule_id=rule.rule_id,
|
||
)
|
||
result: CheckResult = check.run(ctx)
|
||
if result.skipped:
|
||
self._merge_skip(outcome, result)
|
||
continue
|
||
if not result.passed:
|
||
outcome.findings = [self._hit_to_finding(rule, h) for h in result.hits]
|
||
outcome.skipped = False
|
||
outcome.skip_reason = ""
|
||
return outcome
|
||
return outcome
|
||
|
||
def run_all(
|
||
self,
|
||
rules: list[Rule],
|
||
doc: Document,
|
||
entities: dict[str, SemanticEntity | None] | None = None,
|
||
) -> list[Finding]:
|
||
flat, _ = self.evaluate(rules, doc, entities)
|
||
return flat
|
||
|
||
def evaluate(
|
||
self,
|
||
rules: list[Rule],
|
||
doc: Document,
|
||
entities: dict[str, SemanticEntity | None] | None = None,
|
||
) -> tuple[list[Finding], list[RuleOutcome]]:
|
||
flat: list[Finding] = []
|
||
outcomes: list[RuleOutcome] = []
|
||
for r in rules:
|
||
o = self.run_rule(r, doc, entities)
|
||
flat.extend(o.findings)
|
||
outcomes.append(o)
|
||
return flat, outcomes
|
||
|
||
# -- 异步路径 -------------------------------------------------------
|
||
async def run_rule_async(
|
||
self,
|
||
rule: Rule,
|
||
doc: Document,
|
||
entities: dict[str, SemanticEntity | None] | None = None,
|
||
) -> RuleOutcome:
|
||
entities = entities or {}
|
||
paragraphs, target, early = self._resolve_target(rule, doc, entities)
|
||
if early is not None:
|
||
return early
|
||
|
||
outcome = RuleOutcome(rule=rule)
|
||
for stage in rule.stages:
|
||
ctx = CheckContext(
|
||
document=doc,
|
||
paragraphs=paragraphs,
|
||
stage=stage,
|
||
entities=entities,
|
||
target=target,
|
||
rule_id=rule.rule_id,
|
||
)
|
||
if stage.check == "ai":
|
||
result = await AiCheck(llm_client=self.llm).run_async(ctx)
|
||
else:
|
||
check_cls = get_check(stage.check)
|
||
result = check_cls().run(ctx)
|
||
if result.skipped:
|
||
self._merge_skip(outcome, result)
|
||
continue
|
||
if not result.passed:
|
||
outcome.findings = [self._hit_to_finding(rule, h) for h in result.hits]
|
||
outcome.skipped = False
|
||
outcome.skip_reason = ""
|
||
return outcome
|
||
return outcome
|
||
|
||
async def run_all_async(
|
||
self,
|
||
rules: list[Rule],
|
||
doc: Document,
|
||
entities: dict[str, SemanticEntity | None] | None = None,
|
||
) -> list[Finding]:
|
||
flat, _ = await self.evaluate_async(rules, doc, entities)
|
||
return flat
|
||
|
||
async def evaluate_async(
|
||
self,
|
||
rules: list[Rule],
|
||
doc: Document,
|
||
entities: dict[str, SemanticEntity | None] | None = None,
|
||
) -> tuple[list[Finding], list[RuleOutcome]]:
|
||
outcomes_list = await asyncio.gather(
|
||
*(self.run_rule_async(r, doc, entities) for r in rules)
|
||
)
|
||
flat: list[Finding] = []
|
||
outcomes: list[RuleOutcome] = []
|
||
for o in outcomes_list:
|
||
flat.extend(o.findings)
|
||
outcomes.append(o)
|
||
return flat, outcomes
|
||
|
||
def _hit_to_finding(self, rule: Rule, hit: CheckHit) -> Finding:
|
||
para = hit.paragraph
|
||
loc = Location(
|
||
paragraph_index=para.index if para else -1,
|
||
role=para.role if para else None,
|
||
char_start=hit.char_start,
|
||
char_end=hit.char_end,
|
||
context=para.text if para else "",
|
||
)
|
||
msg = hit.message or rule.messages.fail
|
||
return Finding(
|
||
finding_id=f"F-{uuid.uuid4().hex[:8]}",
|
||
rule_id=rule.rule_id,
|
||
rule_name=rule.name,
|
||
severity=rule.severity,
|
||
category=rule.category,
|
||
location=loc,
|
||
actual=hit.actual or {},
|
||
expected=hit.expected or {},
|
||
message=msg,
|
||
suggestion=rule.messages.fail or "",
|
||
evidence=rule.messages.fail or "",
|
||
confidence=hit.confidence,
|
||
)
|