142 lines
3.9 KiB
Python
142 lines
3.9 KiB
Python
"""规则文件的 Pydantic schema。"""
|
|
|
|
from __future__ import annotations
|
|
from typing import Any, Literal
|
|
from pydantic import BaseModel, Field, model_validator
|
|
|
|
|
|
CheckType = Literal[
|
|
"required", "font", "style_match", "line_spacing",
|
|
"attachment_marker_style",
|
|
"regex_require", "regex_forbid",
|
|
"confused_pair", "forbid_phrase", "forbid_chars",
|
|
"punctuation", "wenzhong_whitelist",
|
|
"hierarchy", "cross_role", "ai",
|
|
]
|
|
|
|
|
|
class AppliesTo(BaseModel):
|
|
role: str | None = None
|
|
roles: list[str] | None = None
|
|
paragraph_index: int | None = None
|
|
|
|
|
|
class RuleStage(BaseModel):
|
|
id: str | None = None
|
|
check: CheckType
|
|
when: str | None = None
|
|
field: str | None = None
|
|
expect: dict[str, Any] | None = None
|
|
pattern: str | None = None
|
|
chars: list[str] | None = None
|
|
pairs: list[dict[str, Any]] | None = None
|
|
phrases: list[str] | None = None
|
|
rules: list[dict[str, Any]] | None = None
|
|
expected_order: list[dict[str, Any]] | None = None
|
|
forbid_patterns: list[str] | None = None
|
|
prompt: str | None = None
|
|
format: str | None = None
|
|
|
|
model_config = {"extra": "allow"}
|
|
|
|
|
|
class Messages(BaseModel):
|
|
pass_msg: str = Field(alias="pass", default="")
|
|
fail: str = ""
|
|
|
|
model_config = {"populate_by_name": True}
|
|
|
|
|
|
class Rule(BaseModel):
|
|
rule_id: str
|
|
name: str
|
|
severity: Literal["error", "warning", "info"] = "warning"
|
|
category: str
|
|
score: int | None = None
|
|
# 二选一:target 通道(推荐,绑定语义实体)或 applies_to 通道(旧,按 role 选段)
|
|
applies_to: AppliesTo | None = None
|
|
target: str | None = None
|
|
on_missing: Literal["pass", "warn", "fail", "skip"] = "skip"
|
|
activate_if: str | None = None
|
|
stages: list[RuleStage]
|
|
messages: Messages
|
|
|
|
@model_validator(mode="after")
|
|
def _check_at_least_one_target(self) -> "Rule":
|
|
if self.applies_to is None and self.target is None:
|
|
raise ValueError(
|
|
f"Rule {self.rule_id}: 必须声明 target 或 applies_to 之一"
|
|
)
|
|
return self
|
|
|
|
|
|
class RuleGroup(BaseModel):
|
|
group: str
|
|
rules: list[Rule]
|
|
|
|
|
|
# 8 个内置语义实体名(与 entity_builder.BUILTIN_ENTITY_NAMES 同步)
|
|
_BUILTIN_ENTITY_NAMES = frozenset({
|
|
"title", "doc_number", "recipient", "date",
|
|
"signature", "attachments", "wenzhong", "issuer",
|
|
})
|
|
|
|
|
|
class EntitySpec(BaseModel):
|
|
"""声明用户自定义语义实体(builtin 实体由代码自动产出,无需在 yaml 出现)。"""
|
|
|
|
name: str
|
|
type: Literal["string", "number", "list"] = "string"
|
|
description: str = ""
|
|
|
|
@model_validator(mode="after")
|
|
def _no_clash_with_builtin(self) -> "EntitySpec":
|
|
if self.name in _BUILTIN_ENTITY_NAMES:
|
|
raise ValueError(
|
|
f"entity '{self.name}' 与内置实体重名,"
|
|
f"去掉该声明即可(内置实体自动产出)"
|
|
)
|
|
return self
|
|
|
|
|
|
class ExtractSpec(BaseModel):
|
|
entities: list[EntitySpec] = Field(default_factory=list)
|
|
|
|
|
|
class RuleSetMetadata(BaseModel):
|
|
type_id: str
|
|
name: str
|
|
version: str
|
|
source: str | None = None
|
|
description: str | None = None
|
|
|
|
|
|
class RuleSet(BaseModel):
|
|
metadata: RuleSetMetadata
|
|
extract: ExtractSpec = Field(default_factory=ExtractSpec)
|
|
rules: list[RuleGroup]
|
|
|
|
@model_validator(mode="after")
|
|
def _check_unique_ids(self) -> "RuleSet":
|
|
seen: set[str] = set()
|
|
for g in self.rules:
|
|
for r in g.rules:
|
|
if r.rule_id in seen:
|
|
raise ValueError(f"duplicate rule_id: {r.rule_id}")
|
|
seen.add(r.rule_id)
|
|
return self
|
|
|
|
def all_rules(self) -> list[Rule]:
|
|
return [r for g in self.rules for r in g.rules]
|
|
|
|
|
|
class FontCheck(BaseModel):
|
|
eastasia: str | None = None
|
|
ascii: str | None = None
|
|
size_pt: float | None = None
|
|
|
|
|
|
class RegexForbidCheck(BaseModel):
|
|
pattern: str
|
|
message: str | None = None
|