fix: preserve review field page positions in platform
This commit is contained in:
@@ -107,6 +107,7 @@ class NativeRunner:
|
||||
await self.storage.save_extraction_result(
|
||||
document_id,
|
||||
ctx.extraction,
|
||||
ocr_result=ctx.normalized_doc,
|
||||
run_id=run_id,
|
||||
)
|
||||
if ctx.evaluation is not None and ctx.rules_file is not None and ctx.extraction is not None:
|
||||
@@ -115,6 +116,7 @@ class NativeRunner:
|
||||
ctx.rules_file,
|
||||
ctx.evaluation,
|
||||
ctx.extraction,
|
||||
ocr_result=ctx.normalized_doc,
|
||||
run_id=run_id,
|
||||
rule_version_id=result.metadata.rule_version_id,
|
||||
)
|
||||
|
||||
@@ -99,11 +99,13 @@ class StorageAdapter:
|
||||
self,
|
||||
document_id: int,
|
||||
bundle: ExtractionBundle,
|
||||
ocr_result: OcrResult | None = None,
|
||||
*,
|
||||
run_id: int | None = None,
|
||||
) -> None:
|
||||
"""Save extraction result to leaudit_field_results table."""
|
||||
extracted = _bundle_to_extracted(bundle)
|
||||
inferred_positions = _build_inferred_field_positions(bundle, ocr_result)
|
||||
extracted = _bundle_to_extracted(bundle, inferred_positions=inferred_positions)
|
||||
resolved_run_id = await self._ensure_run_id(document_id, run_id)
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
@@ -111,7 +113,8 @@ class StorageAdapter:
|
||||
field_data = extracted.get("fields", {}).get(name, {})
|
||||
raw_value = fv.raw_value if isinstance(fv, FieldValue) else None
|
||||
meta_json = {
|
||||
"position": _field_value_position_payload(fv),
|
||||
"position": (_field_value_position_payload(fv) if isinstance(fv, FieldValue) else None)
|
||||
or inferred_positions.get(name),
|
||||
"reasons": list(fv.reasons or []),
|
||||
"type_name": fv.type_name,
|
||||
} if isinstance(fv, FieldValue) else None
|
||||
@@ -153,6 +156,7 @@ class StorageAdapter:
|
||||
rules_file: RulesFile,
|
||||
evaluation: EvaluationResult,
|
||||
bundle: ExtractionBundle,
|
||||
ocr_result: OcrResult | None = None,
|
||||
*,
|
||||
run_id: int | None = None,
|
||||
rule_version_id: int | None = None,
|
||||
@@ -163,6 +167,7 @@ class StorageAdapter:
|
||||
then inserts fresh rows.
|
||||
"""
|
||||
resolved_run_id = await self._ensure_run_id(document_id, run_id)
|
||||
inferred_positions = _build_inferred_field_positions(bundle, ocr_result)
|
||||
async with GetAsyncSession() as session:
|
||||
# Delete existing results for this document+run
|
||||
await session.execute(
|
||||
@@ -178,7 +183,14 @@ class StorageAdapter:
|
||||
# Insert one row per rule result
|
||||
for rule_result in evaluation.rules:
|
||||
rule = rule_meta.get(rule_result.rule_id)
|
||||
row = _rule_result_to_row(document_id, resolved_run_id, rule_result, rule, bundle)
|
||||
row = _rule_result_to_row(
|
||||
document_id,
|
||||
resolved_run_id,
|
||||
rule_result,
|
||||
rule,
|
||||
bundle,
|
||||
inferred_positions=inferred_positions,
|
||||
)
|
||||
if rule_version_id is not None:
|
||||
row["rule_version_id"] = rule_version_id
|
||||
json_columns = {"stages", "extracted_fields", "field_positions", "remediation", "rule_meta"}
|
||||
@@ -550,8 +562,13 @@ def _ocr_to_dict(ocr: OcrResult) -> dict[str, Any]:
|
||||
return result
|
||||
|
||||
|
||||
def _bundle_to_extracted(bundle: ExtractionBundle) -> dict[str, Any]:
|
||||
def _bundle_to_extracted(
|
||||
bundle: ExtractionBundle,
|
||||
*,
|
||||
inferred_positions: dict[str, dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Convert ExtractionBundle to docauditai's extracted_results format."""
|
||||
inferred_positions = inferred_positions or {}
|
||||
fields: dict[str, Any] = {}
|
||||
for name, fv in bundle.fields.items():
|
||||
if isinstance(fv, FieldValue):
|
||||
@@ -559,7 +576,7 @@ def _bundle_to_extracted(bundle: ExtractionBundle) -> dict[str, Any]:
|
||||
"value": fv.value,
|
||||
"confidence": float(fv.confidence) if fv.confidence else 0.0,
|
||||
}
|
||||
position_payload = _field_value_position_payload(fv)
|
||||
position_payload = _field_value_position_payload(fv) or inferred_positions.get(name)
|
||||
if position_payload is not None:
|
||||
field_data["position"] = position_payload
|
||||
fields[name] = field_data
|
||||
@@ -633,8 +650,11 @@ def _extract_relevant_fields(rule: Any, bundle: ExtractionBundle) -> dict[str, A
|
||||
def _extract_relevant_field_positions(
|
||||
rule: Any,
|
||||
bundle: ExtractionBundle,
|
||||
*,
|
||||
inferred_positions: dict[str, dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Extract position data for fields referenced by a rule's stages."""
|
||||
inferred_positions = inferred_positions or {}
|
||||
positions: dict[str, Any] = {}
|
||||
if not rule or not hasattr(rule, "stages") or not rule.stages:
|
||||
return positions
|
||||
@@ -671,7 +691,7 @@ def _extract_relevant_field_positions(
|
||||
continue
|
||||
fv = bundle.fields.get(f)
|
||||
if fv is not None and isinstance(fv, FieldValue):
|
||||
position_payload = _field_value_position_payload(fv)
|
||||
position_payload = _field_value_position_payload(fv) or inferred_positions.get(f)
|
||||
if position_payload is not None:
|
||||
positions[f] = position_payload
|
||||
return positions
|
||||
@@ -703,12 +723,183 @@ def _field_value_position_payload(fv: FieldValue) -> dict[str, Any] | None:
|
||||
return payload or None
|
||||
|
||||
|
||||
_POSITION_PUNCT_TABLE = str.maketrans({
|
||||
",": ",", "。": ".", ";": ";", ":": ":",
|
||||
"!": "!", "?": "?", "(": "(", ")": ")",
|
||||
"【": "[", "】": "]", "「": '"', "」": '"',
|
||||
"‘": "'", "’": "'", "“": '"', "”": '"',
|
||||
})
|
||||
|
||||
|
||||
def _build_inferred_field_positions(
|
||||
bundle: ExtractionBundle,
|
||||
ocr_result: OcrResult | None,
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""平台侧兜底补齐字段页码,避免 bridge 落库后丢失定位页。"""
|
||||
if ocr_result is None or not ocr_result.pages:
|
||||
return {}
|
||||
|
||||
inferred: dict[str, dict[str, Any]] = {}
|
||||
page_texts = [
|
||||
(int(page.page_num) + 1, str(page.text or ""))
|
||||
for page in ocr_result.pages
|
||||
if str(page.text or "").strip()
|
||||
]
|
||||
if not page_texts:
|
||||
return {}
|
||||
|
||||
for field_name, field_value in bundle.fields.items():
|
||||
if not isinstance(field_value, FieldValue):
|
||||
continue
|
||||
if _field_value_position_payload(field_value) is not None:
|
||||
continue
|
||||
|
||||
value_text = _stringify_position_value(field_value.value)
|
||||
if not value_text:
|
||||
continue
|
||||
|
||||
chunk_match = _infer_position_from_chunks(value_text, ocr_result)
|
||||
if chunk_match is not None:
|
||||
inferred[field_name] = chunk_match
|
||||
continue
|
||||
|
||||
page_num = _infer_page_num_from_page_texts(value_text, page_texts)
|
||||
if page_num is not None:
|
||||
inferred[field_name] = {
|
||||
"pageNum": page_num,
|
||||
"matchMethod": "platform_bridge_page_fallback",
|
||||
}
|
||||
|
||||
return inferred
|
||||
|
||||
|
||||
def _infer_position_from_chunks(
|
||||
value_text: str,
|
||||
ocr_result: OcrResult,
|
||||
) -> dict[str, Any] | None:
|
||||
normalized_value = _normalize_position_text(value_text)
|
||||
if not normalized_value:
|
||||
return None
|
||||
|
||||
for page in ocr_result.pages:
|
||||
for chunk in page.chunks or []:
|
||||
content = ""
|
||||
bbox = None
|
||||
if isinstance(chunk, dict):
|
||||
content = str(chunk.get("content") or "")
|
||||
bbox = chunk.get("bbox")
|
||||
else:
|
||||
content = str(getattr(chunk, "content", "") or "")
|
||||
bbox = getattr(chunk, "bbox", None)
|
||||
|
||||
if not content:
|
||||
continue
|
||||
if normalized_value not in _normalize_position_text(content):
|
||||
continue
|
||||
|
||||
payload: dict[str, Any] = {
|
||||
"pageNum": int(page.page_num) + 1,
|
||||
"matchMethod": "platform_bridge_chunk_fallback",
|
||||
}
|
||||
if bbox:
|
||||
payload["bbox"] = bbox
|
||||
return payload
|
||||
return None
|
||||
|
||||
|
||||
def _infer_page_num_from_page_texts(
|
||||
value_text: str,
|
||||
page_texts: list[tuple[int, str]],
|
||||
) -> int | None:
|
||||
normalized_value = _normalize_position_text(value_text)
|
||||
if not normalized_value:
|
||||
return None
|
||||
|
||||
best_page: int | None = None
|
||||
best_score = 0.0
|
||||
min_length_for_fuzzy = 8
|
||||
|
||||
for page_num, page_text in page_texts:
|
||||
normalized_page = _normalize_position_text(page_text)
|
||||
if not normalized_page:
|
||||
continue
|
||||
if normalized_value in normalized_page:
|
||||
return page_num
|
||||
|
||||
if len(normalized_value) < min_length_for_fuzzy:
|
||||
continue
|
||||
|
||||
score = _partial_similarity(normalized_value, normalized_page)
|
||||
if score > best_score:
|
||||
best_score = score
|
||||
best_page = page_num
|
||||
|
||||
if best_score >= 0.92:
|
||||
return best_page
|
||||
return None
|
||||
|
||||
|
||||
def _normalize_position_text(text: str) -> str:
|
||||
if not text:
|
||||
return ""
|
||||
normalized = re.sub(r"<[^>]+>", "", str(text))
|
||||
normalized = re.sub(r"\s+", "", normalized)
|
||||
return normalized.translate(_POSITION_PUNCT_TABLE)
|
||||
|
||||
|
||||
def _partial_similarity(needle: str, haystack: str) -> float:
|
||||
if not needle or not haystack:
|
||||
return 0.0
|
||||
if len(needle) > len(haystack):
|
||||
needle, haystack = haystack, needle
|
||||
|
||||
window = len(needle)
|
||||
if window <= 0:
|
||||
return 0.0
|
||||
if needle == haystack:
|
||||
return 1.0
|
||||
|
||||
best = 0.0
|
||||
step = max(1, window // 6)
|
||||
stop = max(len(haystack) - window + 1, 1)
|
||||
for start in range(0, stop, step):
|
||||
chunk = haystack[start : start + window]
|
||||
if not chunk:
|
||||
continue
|
||||
matches = sum(1 for left, right in zip(needle, chunk) if left == right)
|
||||
best = max(best, matches / window)
|
||||
if best >= 0.999:
|
||||
return 1.0
|
||||
return best
|
||||
|
||||
|
||||
def _stringify_position_value(raw_value: Any) -> str:
|
||||
if raw_value is None:
|
||||
return ""
|
||||
if isinstance(raw_value, str):
|
||||
return raw_value.strip()
|
||||
if isinstance(raw_value, (int, float, bool)):
|
||||
return str(raw_value)
|
||||
if isinstance(raw_value, dict):
|
||||
for key in ("value", "text", "value_text"):
|
||||
value = raw_value.get(key)
|
||||
if isinstance(value, str) and value.strip():
|
||||
return value.strip()
|
||||
return ""
|
||||
if isinstance(raw_value, list):
|
||||
parts = [_stringify_position_value(item) for item in raw_value]
|
||||
return " ".join(part for part in parts if part).strip()
|
||||
return str(raw_value).strip()
|
||||
|
||||
|
||||
def _rule_result_to_row(
|
||||
document_id: int,
|
||||
run_id: int | None,
|
||||
rule_result: RuleResult,
|
||||
rule: Any | None,
|
||||
bundle: ExtractionBundle,
|
||||
*,
|
||||
inferred_positions: dict[str, dict[str, Any]] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Convert a RuleResult to a leaudit_rule_results row."""
|
||||
passed = rule_result.passed
|
||||
@@ -760,7 +951,11 @@ def _rule_result_to_row(
|
||||
"fail_message": fail_msg,
|
||||
"stages": [s.model_dump(mode="json") for s in (rule_result.stages or [])],
|
||||
"extracted_fields": relevant_fields,
|
||||
"field_positions": _extract_relevant_field_positions(rule, bundle),
|
||||
"field_positions": _extract_relevant_field_positions(
|
||||
rule,
|
||||
bundle,
|
||||
inferred_positions=inferred_positions,
|
||||
),
|
||||
"remediation": remediation,
|
||||
"rule_meta": rule_meta_data,
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user