feat: add rag retriever bridge
This commit is contained in:
@@ -24,6 +24,7 @@ from leaudit.ocr.base import BaseOCRClient
|
||||
from leaudit.ocr.models import OcrResult
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.retriever import RagRetriever
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
@@ -89,10 +90,12 @@ class LauditPipeline:
|
||||
ocr_client: BaseOCRClient,
|
||||
llm_client: BaseLLMClient | None = None,
|
||||
storage_adapter: StorageAdapter | None = None,
|
||||
rag_retriever: RagRetriever | None = None,
|
||||
) -> None:
|
||||
self.ocr_client = ocr_client
|
||||
self.llm_client = llm_client
|
||||
self.storage = storage_adapter or StorageAdapter()
|
||||
self.rag_retriever = rag_retriever or RagRetriever()
|
||||
|
||||
async def run(
|
||||
self,
|
||||
@@ -219,6 +222,7 @@ class LauditPipeline:
|
||||
visual_manifest=visual_manifest,
|
||||
phase=detected_phase,
|
||||
external_mocks=external_mocks,
|
||||
retriever=self.rag_retriever,
|
||||
)
|
||||
timing["evaluation"] = round(time.time() - t0, 2)
|
||||
log.info(
|
||||
|
||||
@@ -0,0 +1,479 @@
|
||||
"""RAG 检索器。"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, Awaitable, Callable
|
||||
|
||||
import httpx
|
||||
from sqlalchemy import text
|
||||
|
||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
|
||||
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG, build_openai_embeddings_url
|
||||
|
||||
|
||||
EmbedTexts = Callable[[list[str], str], Awaitable[list[list[float]]] | list[list[float]]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class RagRetrieveResult:
|
||||
"""RAG 检索结果。"""
|
||||
|
||||
rag_context: str = ""
|
||||
rag_resources: list[dict[str, Any]] = field(default_factory=list)
|
||||
chunks: list[dict[str, Any]] = field(default_factory=list)
|
||||
dataset_name: str = ""
|
||||
|
||||
def as_dict(self) -> dict[str, Any]:
|
||||
"""转换为规则引擎可注入的字典。"""
|
||||
return {
|
||||
"rag_context": self.rag_context,
|
||||
"rag_resources": self.rag_resources,
|
||||
}
|
||||
|
||||
|
||||
class RagRetriever:
|
||||
"""RAG 检索器。"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
chroma_client: Any | None = None,
|
||||
embed_texts: EmbedTexts | None = None,
|
||||
hydrate_documents: bool = True,
|
||||
) -> None:
|
||||
"""初始化 RAG 检索器。"""
|
||||
self._chroma_client = chroma_client
|
||||
self._embed_texts_override = embed_texts
|
||||
self._hydrate_documents_enabled = hydrate_documents
|
||||
|
||||
async def retrieve(
|
||||
self,
|
||||
query: str,
|
||||
collection_name: str | None = None,
|
||||
dataset_id: int | None = None,
|
||||
top_k: int = 5,
|
||||
source_names: list[str] | None = None,
|
||||
) -> RagRetrieveResult:
|
||||
"""根据问题检索 RAG 上下文。"""
|
||||
query_text = (query or "").strip()
|
||||
if not query_text:
|
||||
return RagRetrieveResult()
|
||||
|
||||
top_k = max(1, int(top_k or 5))
|
||||
dataset = await self._load_dataset(dataset_id) if dataset_id else None
|
||||
resolved_collection = (collection_name or (dataset or {}).get("collection_name") or "").strip()
|
||||
if not resolved_collection:
|
||||
return RagRetrieveResult(dataset_name=str((dataset or {}).get("name") or ""))
|
||||
|
||||
retrieval_model = (dataset or {}).get("retrieval_model") or {}
|
||||
if dataset and not collection_name:
|
||||
top_k = max(1, int(retrieval_model.get("top_k") or top_k))
|
||||
score_threshold = self._resolve_score_threshold(retrieval_model)
|
||||
dataset_name = str((dataset or {}).get("name") or "")
|
||||
embedding_model = str((dataset or {}).get("embedding_model") or "")
|
||||
|
||||
chunks: list[dict[str, Any]] = []
|
||||
try:
|
||||
chunks = await self._vector_retrieve(
|
||||
query=query_text,
|
||||
collection_name=resolved_collection,
|
||||
dataset_name=dataset_name,
|
||||
embedding_model=embedding_model,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
source_names=source_names,
|
||||
)
|
||||
except Exception:
|
||||
chunks = []
|
||||
|
||||
if not chunks:
|
||||
try:
|
||||
chunks = await self._keyword_retrieve_context(
|
||||
dataset_id=dataset_id,
|
||||
collection_name=resolved_collection,
|
||||
dataset_name=dataset_name,
|
||||
query=query_text,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
source_names=source_names,
|
||||
)
|
||||
except Exception:
|
||||
chunks = []
|
||||
|
||||
if dataset_id and self._hydrate_documents_enabled:
|
||||
chunks = await self._hydrate_document_hits(dataset_id, chunks)
|
||||
|
||||
chunks = chunks[:top_k]
|
||||
return RagRetrieveResult(
|
||||
rag_context=self._build_context(chunks),
|
||||
rag_resources=self.build_sources(chunks, dataset_name),
|
||||
chunks=chunks,
|
||||
dataset_name=dataset_name,
|
||||
)
|
||||
|
||||
async def _load_dataset(self, dataset_id: int | None) -> dict[str, Any] | None:
|
||||
if not dataset_id:
|
||||
return None
|
||||
async with GetAsyncSession() as session:
|
||||
row = (
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT id, name, collection_name, retrieval_model, embedding_model
|
||||
FROM rag_dataset
|
||||
WHERE id = :dataset_id AND deleted_at IS NULL
|
||||
LIMIT 1
|
||||
"""
|
||||
),
|
||||
{"dataset_id": dataset_id},
|
||||
)
|
||||
).mappings().first()
|
||||
return dict(row) if row else None
|
||||
|
||||
def _resolve_score_threshold(self, retrieval_model: dict[str, Any]) -> float | None:
|
||||
if not retrieval_model.get("score_threshold_enabled"):
|
||||
return None
|
||||
try:
|
||||
return float(retrieval_model.get("score_threshold"))
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
async def _vector_retrieve(
|
||||
self,
|
||||
*,
|
||||
query: str,
|
||||
collection_name: str,
|
||||
dataset_name: str,
|
||||
embedding_model: str,
|
||||
top_k: int,
|
||||
score_threshold: float | None,
|
||||
source_names: list[str] | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
query_embedding = await self._embed_texts([query], embedding_model)
|
||||
collection = self._get_chroma().get_or_create_collection(collection_name)
|
||||
result = collection.query(
|
||||
query_embeddings=query_embedding,
|
||||
n_results=max(top_k, 1),
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
ids = (result.get("ids") or [[]])[0] if result.get("ids") else []
|
||||
docs = (result.get("documents") or [[]])[0]
|
||||
metas = (result.get("metadatas") or [[]])[0]
|
||||
distances = (result.get("distances") or [[]])[0]
|
||||
allowed_sources = self._normalize_source_filter(source_names)
|
||||
chunks: list[dict[str, Any]] = []
|
||||
for idx, doc in enumerate(docs):
|
||||
meta = metas[idx] if idx < len(metas) and isinstance(metas[idx], dict) else {}
|
||||
document_name = str(meta.get("document_name") or meta.get("source") or "")
|
||||
if allowed_sources and document_name not in allowed_sources:
|
||||
continue
|
||||
dist = float(distances[idx]) if idx < len(distances) and distances[idx] is not None else 1.0
|
||||
score = 1.0 / (1.0 + max(dist, 0.0))
|
||||
if score_threshold is not None and score < score_threshold:
|
||||
continue
|
||||
chunks.append(
|
||||
{
|
||||
"id": str(ids[idx] if idx < len(ids) else meta.get("id") or idx),
|
||||
"text": doc or "",
|
||||
"source": meta.get("source") or meta.get("document_name") or dataset_name,
|
||||
"score": score,
|
||||
"chunk_index": int(meta.get("chunk_index") or idx),
|
||||
"document_name": document_name,
|
||||
"document_id": meta.get("document_id"),
|
||||
"page": meta.get("page"),
|
||||
}
|
||||
)
|
||||
return chunks
|
||||
|
||||
async def _embed_texts(self, texts: list[str], model_name: str = "") -> list[list[float]]:
|
||||
if self._embed_texts_override is not None:
|
||||
result = self._embed_texts_override(texts, model_name)
|
||||
if hasattr(result, "__await__"):
|
||||
return await result # type: ignore[misc]
|
||||
return result
|
||||
|
||||
embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or build_openai_embeddings_url(RAG_CONFIG["LLM_BASE_URL"])
|
||||
embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"]
|
||||
embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4"
|
||||
batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10))
|
||||
if not embed_url or not embed_key:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务")
|
||||
|
||||
embeddings: list[list[float]] = []
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
for start in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[start:start + batch_size]
|
||||
try:
|
||||
response = await client.post(
|
||||
embed_url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {embed_key}",
|
||||
},
|
||||
json={"model": embed_model, "input": batch_texts},
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}"
|
||||
raise LeauditException(
|
||||
StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
f"向量化服务调用失败: {error_message[:300]}",
|
||||
) from exc
|
||||
|
||||
payload = response.json()
|
||||
rows = payload.get("data") or []
|
||||
batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")]
|
||||
if len(batch_embeddings) != len(batch_texts):
|
||||
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
|
||||
embeddings.extend(batch_embeddings)
|
||||
|
||||
if len(embeddings) != len(texts):
|
||||
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
|
||||
return embeddings
|
||||
|
||||
async def _keyword_retrieve_context(
|
||||
self,
|
||||
*,
|
||||
dataset_id: int | None,
|
||||
collection_name: str,
|
||||
dataset_name: str,
|
||||
query: str,
|
||||
top_k: int,
|
||||
score_threshold: float | None,
|
||||
source_names: list[str] | None,
|
||||
) -> list[dict[str, Any]]:
|
||||
collection = self._get_chroma().get_or_create_collection(collection_name)
|
||||
raw = collection.get(include=["documents", "metadatas"])
|
||||
ids = raw.get("ids") or []
|
||||
docs = raw.get("documents") or []
|
||||
metas = raw.get("metadatas") or []
|
||||
allowed_sources = self._normalize_source_filter(source_names)
|
||||
|
||||
terms = self._build_keyword_terms(query)
|
||||
if not terms:
|
||||
return []
|
||||
|
||||
scored_chunks: list[dict[str, Any]] = []
|
||||
for idx, chunk_id in enumerate(ids):
|
||||
doc = docs[idx] if idx < len(docs) else ""
|
||||
meta = metas[idx] if idx < len(metas) and isinstance(metas[idx], dict) else {}
|
||||
document_name = str(meta.get("document_name") or meta.get("source") or "")
|
||||
if allowed_sources and document_name not in allowed_sources:
|
||||
continue
|
||||
score = self._score_keyword_chunk(
|
||||
query=query,
|
||||
terms=terms,
|
||||
content=doc or "",
|
||||
document_name=document_name,
|
||||
)
|
||||
if score <= 0:
|
||||
continue
|
||||
if score_threshold is not None and score < score_threshold:
|
||||
continue
|
||||
scored_chunks.append(
|
||||
{
|
||||
"id": str(chunk_id),
|
||||
"text": doc or "",
|
||||
"source": meta.get("source") or meta.get("document_name") or dataset_name,
|
||||
"score": score,
|
||||
"chunk_index": int(meta.get("chunk_index") or idx),
|
||||
"document_name": document_name,
|
||||
"document_id": meta.get("document_id"),
|
||||
"page": meta.get("page"),
|
||||
}
|
||||
)
|
||||
|
||||
scored_chunks.sort(key=lambda item: (-float(item.get("score") or 0.0), int(item.get("chunk_index") or 0)))
|
||||
return scored_chunks[: max(top_k * 3, top_k)]
|
||||
|
||||
def _build_keyword_terms(self, query: str) -> list[str]:
|
||||
normalized = self._normalize_keyword_query(query)
|
||||
spans = [item.strip() for item in re.findall(r"[\u4e00-\u9fffA-Za-z0-9]+", normalized) if item.strip()]
|
||||
if not spans:
|
||||
return []
|
||||
|
||||
stop_terms = {
|
||||
"什么",
|
||||
"请问",
|
||||
"一下",
|
||||
"有关",
|
||||
"关于",
|
||||
"如何",
|
||||
"哪些",
|
||||
"怎么",
|
||||
"是否",
|
||||
"规定",
|
||||
"办法",
|
||||
"条例",
|
||||
"法律",
|
||||
}
|
||||
terms: list[str] = []
|
||||
for span in spans:
|
||||
if span in stop_terms:
|
||||
continue
|
||||
terms.append(span)
|
||||
if re.fullmatch(r"[\u4e00-\u9fff]+", span):
|
||||
for size in (2, 3, 4):
|
||||
if len(span) > size:
|
||||
for start in range(0, len(span) - size + 1):
|
||||
token = span[start:start + size]
|
||||
if token not in stop_terms:
|
||||
terms.append(token)
|
||||
|
||||
unique_terms: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for term in sorted(terms, key=len, reverse=True):
|
||||
if term and term not in seen:
|
||||
unique_terms.append(term)
|
||||
seen.add(term)
|
||||
return unique_terms[:20]
|
||||
|
||||
def _normalize_keyword_query(self, query: str) -> str:
|
||||
normalized = (query or "").strip().lower()
|
||||
patterns = [
|
||||
"是什么",
|
||||
"什么是",
|
||||
"有哪些",
|
||||
"有什么",
|
||||
"是什么?",
|
||||
"是什么?",
|
||||
"请问",
|
||||
"介绍一下",
|
||||
"解释一下",
|
||||
"帮我分析",
|
||||
"帮我看看",
|
||||
]
|
||||
for pattern in patterns:
|
||||
normalized = normalized.replace(pattern, " ")
|
||||
return re.sub(r"\s+", " ", normalized).strip()
|
||||
|
||||
def _score_keyword_chunk(self, *, query: str, terms: list[str], content: str, document_name: str) -> float:
|
||||
haystack = f"{document_name}\n{content}".lower()
|
||||
if not haystack:
|
||||
return 0.0
|
||||
|
||||
exact_query = self._normalize_keyword_query(query)
|
||||
if exact_query and exact_query in haystack:
|
||||
return 0.98
|
||||
|
||||
matched_weight = 0.0
|
||||
total_weight = 0.0
|
||||
name_bonus = 0.0
|
||||
for term in terms:
|
||||
weight = float(max(len(term), 1) ** 2)
|
||||
total_weight += weight
|
||||
if term.lower() in haystack:
|
||||
matched_weight += weight
|
||||
if term.lower() in document_name.lower():
|
||||
name_bonus += min(0.15, 0.03 * len(term))
|
||||
|
||||
if total_weight <= 0:
|
||||
return 0.0
|
||||
score = (matched_weight / total_weight) + name_bonus
|
||||
return round(min(score, 0.99), 6)
|
||||
|
||||
def _build_context(self, chunks: list[dict[str, Any]]) -> str:
|
||||
lines: list[str] = []
|
||||
for index, chunk in enumerate(chunks, start=1):
|
||||
document_name = chunk.get("document_name") or chunk.get("source") or "未知来源"
|
||||
text_value = str(chunk.get("text") or "").strip()
|
||||
if not text_value:
|
||||
continue
|
||||
lines.append(f"[{index}] 来源:{document_name}\n{text_value}")
|
||||
return "\n\n".join(lines)
|
||||
|
||||
def build_sources(self, context_chunks: list[dict[str, Any]], dataset_name: str = "") -> list[dict[str, Any]]:
|
||||
return [
|
||||
{
|
||||
"position": index + 1,
|
||||
"dataset_id": str(chunk.get("dataset_id") or ""),
|
||||
"dataset_name": dataset_name,
|
||||
"document_id": str(chunk.get("document_id") or ""),
|
||||
"document_name": chunk.get("document_name") or chunk.get("source", ""),
|
||||
"data_source_type": "upload_file",
|
||||
"segment_id": chunk.get("id", ""),
|
||||
"retriever_from": "rag",
|
||||
"score": round(float(chunk.get("score") or 0.0), 4),
|
||||
"hit_count": chunk.get("hit_count", 0),
|
||||
"word_count": len(chunk.get("text", "")),
|
||||
"segment_position": index + 1,
|
||||
"index_node_hash": "",
|
||||
"content": chunk.get("text", "")[:500],
|
||||
"page": None,
|
||||
}
|
||||
for index, chunk in enumerate(context_chunks)
|
||||
]
|
||||
|
||||
async def _hydrate_document_hits(self, dataset_id: int, chunks: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
source_names = sorted(
|
||||
{
|
||||
str(chunk.get("document_name") or chunk.get("source") or "").strip()
|
||||
for chunk in chunks
|
||||
if str(chunk.get("document_name") or chunk.get("source") or "").strip()
|
||||
}
|
||||
)
|
||||
if not source_names:
|
||||
return chunks
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
rows = (
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT id, original_name, enabled, hit_count
|
||||
FROM rag_document
|
||||
WHERE dataset_id = :dataset_id
|
||||
AND deleted_at IS NULL
|
||||
AND original_name = ANY(:source_names)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"dataset_id": dataset_id,
|
||||
"source_names": source_names,
|
||||
},
|
||||
)
|
||||
).mappings().all()
|
||||
|
||||
document_map = {str(row["original_name"]): row for row in rows}
|
||||
visible_chunks: list[dict[str, Any]] = []
|
||||
hit_document_ids: list[int] = []
|
||||
for chunk in chunks:
|
||||
source_name = str(chunk.get("document_name") or chunk.get("source") or "").strip()
|
||||
document = document_map.get(source_name)
|
||||
if document and not bool(document.get("enabled")):
|
||||
continue
|
||||
if document:
|
||||
chunk["document_id"] = document["id"]
|
||||
chunk["dataset_id"] = dataset_id
|
||||
chunk["document_name"] = document["original_name"]
|
||||
chunk["hit_count"] = document.get("hit_count") or 0
|
||||
hit_document_ids.append(int(document["id"]))
|
||||
visible_chunks.append(chunk)
|
||||
|
||||
if hit_document_ids:
|
||||
async with GetAsyncSession() as session:
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE rag_document
|
||||
SET hit_count = hit_count + 1,
|
||||
updated_at = NOW()
|
||||
WHERE id = ANY(:document_ids)
|
||||
"""
|
||||
),
|
||||
{"document_ids": sorted(set(hit_document_ids))},
|
||||
)
|
||||
|
||||
return visible_chunks
|
||||
|
||||
def _normalize_source_filter(self, source_names: list[str] | None) -> set[str]:
|
||||
return {str(name).strip() for name in (source_names or []) if str(name).strip()}
|
||||
|
||||
def _get_chroma(self) -> Any:
|
||||
return self._chroma_client or get_chroma()
|
||||
@@ -2838,7 +2838,7 @@ class DocumentServiceImpl(IDocumentService):
|
||||
warning_text = None
|
||||
if issue_count > 0 and summary_status:
|
||||
pages = await self._loadPageQualityIssuePages(Session, int(row["run_id"]))
|
||||
warning_text = self._buildPageQualityWarningText(pages, summary_status)
|
||||
warning_text = await self._buildPageQualityWarningText(Session, int(row["run_id"]), summary_status)
|
||||
result[document_id] = {
|
||||
"pageQualityRunId": int(row["run_id"]),
|
||||
"pageQualityRunStatus": str(row["run_status"] or "") or None,
|
||||
@@ -2871,15 +2871,39 @@ class DocumentServiceImpl(IDocumentService):
|
||||
).mappings().all()
|
||||
return [int(row["page_num"]) for row in rows if row["page_num"] is not None]
|
||||
|
||||
def _buildPageQualityWarningText(self, Pages: list[int], SummaryStatus: str | None) -> str | None:
|
||||
async def _buildPageQualityWarningText(self, Session, RunId: int, SummaryStatus: str | None) -> str | None:
|
||||
"""组装页级模糊预警文案。"""
|
||||
if not Pages or not SummaryStatus:
|
||||
if not RunId or not SummaryStatus:
|
||||
return None
|
||||
pages_text = "、".join(f"第{page}页" for page in Pages[:10])
|
||||
suffix = "建议重拍" if SummaryStatus == "reject" else "疑似模糊"
|
||||
if len(Pages) > 10:
|
||||
if not await self._tableExists(Session, "leaudit_page_quality_results"):
|
||||
return None
|
||||
rows = (
|
||||
await Session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT quality_status, ARRAY_AGG(DISTINCT page_num ORDER BY page_num ASC) AS pages
|
||||
FROM leaudit_page_quality_results
|
||||
WHERE run_id = :run_id
|
||||
AND quality_status IN ('review', 'reject')
|
||||
GROUP BY quality_status
|
||||
"""
|
||||
),
|
||||
{"run_id": RunId},
|
||||
)
|
||||
).mappings().all()
|
||||
parts: list[str] = []
|
||||
status_order = {"reject": 0, "review": 1}
|
||||
status_suffix = {"reject": "建议重拍", "review": "疑似模糊"}
|
||||
for row in sorted(rows, key=lambda item: status_order.get(str(item["quality_status"] or ""), 9)):
|
||||
status = str(row["quality_status"] or "")
|
||||
pages = [int(page) for page in row["pages"] or [] if page is not None]
|
||||
if not pages or status not in status_suffix:
|
||||
continue
|
||||
pages_text = "、".join(f"第{page}页" for page in pages[:10])
|
||||
if len(pages) > 10:
|
||||
pages_text = f"{pages_text}等"
|
||||
return f"{pages_text}{suffix}"
|
||||
parts.append(f"{pages_text}{status_suffix[status]}")
|
||||
return ";".join(parts) if parts else None
|
||||
|
||||
async def _getDocumentDetail(
|
||||
self,
|
||||
|
||||
@@ -35,11 +35,10 @@ from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import (
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.config import (
|
||||
RAG_CONFIG,
|
||||
build_openai_chat_completions_url,
|
||||
build_openai_embeddings_url,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.generator import generate_stream
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.question_chains import generate_followups
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.retriever import RagRetriever
|
||||
from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.tenantResolver import TenantResolver
|
||||
|
||||
@@ -61,8 +60,9 @@ class RagChatServiceImpl(IRagChatService):
|
||||
_chat_schema_checked = False
|
||||
_chat_schema_lock = asyncio.Lock()
|
||||
|
||||
def __init__(self) -> None:
|
||||
def __init__(self, retriever: RagRetriever | None = None) -> None:
|
||||
self.TenantResolver = TenantResolver()
|
||||
self.retriever = retriever or RagRetriever()
|
||||
|
||||
_APP_TENANT_NAME_SQL = (
|
||||
"CASE "
|
||||
@@ -939,121 +939,11 @@ class RagChatServiceImpl(IRagChatService):
|
||||
return await self._app_visible(app_row, tenant_context=tenant_context, user_role=user_role)
|
||||
|
||||
async def _retrieve_context(self, dataset_id: int | None, query: str) -> tuple[list[dict], str]:
|
||||
if not dataset_id:
|
||||
return [], ""
|
||||
async with GetAsyncSession() as session:
|
||||
dataset = (
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT id, name, collection_name, retrieval_model, embedding_model
|
||||
FROM rag_dataset
|
||||
WHERE id = :dataset_id AND deleted_at IS NULL
|
||||
LIMIT 1
|
||||
"""
|
||||
),
|
||||
{"dataset_id": dataset_id},
|
||||
)
|
||||
).mappings().first()
|
||||
if not dataset:
|
||||
return [], ""
|
||||
retrieval_model = dataset.get("retrieval_model") or {}
|
||||
top_k = int(retrieval_model.get("top_k") or 5)
|
||||
score_threshold = None
|
||||
if retrieval_model.get("score_threshold_enabled"):
|
||||
try:
|
||||
score_threshold = float(retrieval_model.get("score_threshold"))
|
||||
except (TypeError, ValueError):
|
||||
score_threshold = None
|
||||
try:
|
||||
query_embedding = await self._embed_texts([query], dataset.get("embedding_model") or "")
|
||||
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
||||
result = collection.query(
|
||||
query_embeddings=query_embedding,
|
||||
n_results=max(top_k, 1),
|
||||
include=["documents", "metadatas", "distances"],
|
||||
)
|
||||
ids = (result.get("ids") or [[]])[0] if result.get("ids") else []
|
||||
docs = (result.get("documents") or [[]])[0]
|
||||
metas = (result.get("metadatas") or [[]])[0]
|
||||
distances = (result.get("distances") or [[]])[0]
|
||||
chunks: list[dict] = []
|
||||
for idx, doc in enumerate(docs):
|
||||
meta = metas[idx] if idx < len(metas) else {}
|
||||
dist = float(distances[idx]) if idx < len(distances) and distances[idx] is not None else 1.0
|
||||
score = 1.0 / (1.0 + max(dist, 0.0))
|
||||
if score_threshold is not None and score < score_threshold:
|
||||
continue
|
||||
chunks.append(
|
||||
{
|
||||
"id": str(ids[idx] if idx < len(ids) else meta.get("id") or idx),
|
||||
"text": doc,
|
||||
"source": meta.get("source") or meta.get("document_name") or dataset.get("name") or "",
|
||||
"score": score,
|
||||
"chunk_index": int(meta.get("chunk_index") or idx),
|
||||
"document_name": meta.get("document_name") or meta.get("source") or "",
|
||||
"document_id": meta.get("document_id"),
|
||||
"page": meta.get("page"),
|
||||
}
|
||||
)
|
||||
chunks = await self._hydrate_document_hits(dataset_id, chunks)
|
||||
if chunks:
|
||||
return chunks[:top_k], dataset.get("name") or ""
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
try:
|
||||
chunks = await self._keyword_retrieve_context(
|
||||
dataset_id=dataset_id,
|
||||
collection_name=str(dataset["collection_name"]),
|
||||
dataset_name=str(dataset.get("name") or ""),
|
||||
query=query,
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
return chunks[:top_k], dataset.get("name") or ""
|
||||
except Exception:
|
||||
return [], dataset.get("name") or ""
|
||||
result = await self.retriever.retrieve(query=query, dataset_id=dataset_id)
|
||||
return result.chunks, result.dataset_name
|
||||
|
||||
async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]:
|
||||
embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or build_openai_embeddings_url(RAG_CONFIG["LLM_BASE_URL"])
|
||||
embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"]
|
||||
embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4"
|
||||
batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10))
|
||||
if not embed_url or not embed_key:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务")
|
||||
|
||||
embeddings: list[list[float]] = []
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
for start in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[start:start + batch_size]
|
||||
try:
|
||||
response = await client.post(
|
||||
embed_url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {embed_key}",
|
||||
},
|
||||
json={"model": embed_model, "input": batch_texts},
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}"
|
||||
raise LeauditException(
|
||||
StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
f"向量化服务调用失败: {error_message[:300]}",
|
||||
) from exc
|
||||
|
||||
payload = response.json()
|
||||
rows = payload.get("data") or []
|
||||
batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")]
|
||||
if len(batch_embeddings) != len(batch_texts):
|
||||
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
|
||||
embeddings.extend(batch_embeddings)
|
||||
|
||||
if len(embeddings) != len(texts):
|
||||
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
|
||||
return embeddings
|
||||
return await self.retriever._embed_texts(texts, model_name)
|
||||
|
||||
async def _start_message_task(
|
||||
self,
|
||||
@@ -1569,220 +1459,42 @@ class RagChatServiceImpl(IRagChatService):
|
||||
top_k: int,
|
||||
score_threshold: float | None,
|
||||
) -> list[dict]:
|
||||
collection = get_chroma().get_or_create_collection(collection_name)
|
||||
raw = collection.get(include=["documents", "metadatas"])
|
||||
ids = raw.get("ids") or []
|
||||
docs = raw.get("documents") or []
|
||||
metas = raw.get("metadatas") or []
|
||||
|
||||
terms = self._build_keyword_terms(query)
|
||||
if not terms:
|
||||
return []
|
||||
|
||||
scored_chunks: list[dict] = []
|
||||
for idx, chunk_id in enumerate(ids):
|
||||
doc = docs[idx] if idx < len(docs) else ""
|
||||
meta = metas[idx] if idx < len(metas) and isinstance(metas[idx], dict) else {}
|
||||
score = self._score_keyword_chunk(
|
||||
chunks = await self.retriever._keyword_retrieve_context(
|
||||
dataset_id=dataset_id,
|
||||
collection_name=collection_name,
|
||||
dataset_name=dataset_name,
|
||||
query=query,
|
||||
terms=terms,
|
||||
content=doc or "",
|
||||
document_name=str(meta.get("document_name") or meta.get("source") or ""),
|
||||
top_k=top_k,
|
||||
score_threshold=score_threshold,
|
||||
source_names=None,
|
||||
)
|
||||
if score <= 0:
|
||||
continue
|
||||
if score_threshold is not None and score < score_threshold:
|
||||
continue
|
||||
scored_chunks.append(
|
||||
{
|
||||
"id": str(chunk_id),
|
||||
"text": doc or "",
|
||||
"source": meta.get("source") or meta.get("document_name") or dataset_name,
|
||||
"score": score,
|
||||
"chunk_index": int(meta.get("chunk_index") or idx),
|
||||
"document_name": meta.get("document_name") or meta.get("source") or "",
|
||||
"document_id": meta.get("document_id"),
|
||||
"page": meta.get("page"),
|
||||
}
|
||||
)
|
||||
|
||||
scored_chunks.sort(key=lambda item: (-float(item.get("score") or 0.0), int(item.get("chunk_index") or 0)))
|
||||
hydrated = await self._hydrate_document_hits(dataset_id, scored_chunks[: max(top_k * 3, top_k)])
|
||||
return hydrated[:top_k]
|
||||
return chunks[:top_k]
|
||||
|
||||
def _build_keyword_terms(self, query: str) -> list[str]:
|
||||
normalized = self._normalize_keyword_query(query)
|
||||
spans = [item.strip() for item in re.findall(r"[\u4e00-\u9fffA-Za-z0-9]+", normalized) if item.strip()]
|
||||
if not spans:
|
||||
return []
|
||||
|
||||
stop_terms = {
|
||||
"什么",
|
||||
"请问",
|
||||
"一下",
|
||||
"有关",
|
||||
"关于",
|
||||
"如何",
|
||||
"哪些",
|
||||
"怎么",
|
||||
"是否",
|
||||
"规定",
|
||||
"办法",
|
||||
"条例",
|
||||
"法律",
|
||||
}
|
||||
terms: list[str] = []
|
||||
for span in spans:
|
||||
if span in stop_terms:
|
||||
continue
|
||||
terms.append(span)
|
||||
if re.fullmatch(r"[\u4e00-\u9fff]+", span):
|
||||
for size in (2, 3, 4):
|
||||
if len(span) > size:
|
||||
for start in range(0, len(span) - size + 1):
|
||||
token = span[start:start + size]
|
||||
if token not in stop_terms:
|
||||
terms.append(token)
|
||||
|
||||
unique_terms: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for term in sorted(terms, key=len, reverse=True):
|
||||
if term and term not in seen:
|
||||
unique_terms.append(term)
|
||||
seen.add(term)
|
||||
return unique_terms[:20]
|
||||
return self.retriever._build_keyword_terms(query)
|
||||
|
||||
def _normalize_keyword_query(self, query: str) -> str:
|
||||
normalized = (query or "").strip().lower()
|
||||
patterns = [
|
||||
"是什么",
|
||||
"什么是",
|
||||
"有哪些",
|
||||
"有什么",
|
||||
"是什么?",
|
||||
"是什么?",
|
||||
"请问",
|
||||
"介绍一下",
|
||||
"解释一下",
|
||||
"帮我分析",
|
||||
"帮我看看",
|
||||
]
|
||||
for pattern in patterns:
|
||||
normalized = normalized.replace(pattern, " ")
|
||||
return re.sub(r"\s+", " ", normalized).strip()
|
||||
return self.retriever._normalize_keyword_query(query)
|
||||
|
||||
def _score_keyword_chunk(self, *, query: str, terms: list[str], content: str, document_name: str) -> float:
|
||||
haystack = f"{document_name}\n{content}".lower()
|
||||
if not haystack:
|
||||
return 0.0
|
||||
|
||||
exact_query = self._normalize_keyword_query(query)
|
||||
if exact_query and exact_query in haystack:
|
||||
return 0.98
|
||||
|
||||
matched_weight = 0.0
|
||||
total_weight = 0.0
|
||||
name_bonus = 0.0
|
||||
for term in terms:
|
||||
weight = float(max(len(term), 1) ** 2)
|
||||
total_weight += weight
|
||||
if term.lower() in haystack:
|
||||
matched_weight += weight
|
||||
if term.lower() in document_name.lower():
|
||||
name_bonus += min(0.15, 0.03 * len(term))
|
||||
|
||||
if total_weight <= 0:
|
||||
return 0.0
|
||||
score = (matched_weight / total_weight) + name_bonus
|
||||
return round(min(score, 0.99), 6)
|
||||
return self.retriever._score_keyword_chunk(
|
||||
query=query,
|
||||
terms=terms,
|
||||
content=content,
|
||||
document_name=document_name,
|
||||
)
|
||||
|
||||
def _format_sse(self, payload: dict) -> bytes:
|
||||
return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n".encode("utf-8")
|
||||
|
||||
def _build_sources(self, context_chunks: list[dict], dataset_name: str) -> list[dict]:
|
||||
return [
|
||||
{
|
||||
"position": index + 1,
|
||||
"dataset_id": str(chunk.get("dataset_id") or ""),
|
||||
"dataset_name": dataset_name,
|
||||
"document_id": str(chunk.get("document_id") or ""),
|
||||
"document_name": chunk.get("document_name") or chunk.get("source", ""),
|
||||
"data_source_type": "upload_file",
|
||||
"segment_id": chunk.get("id", ""),
|
||||
"retriever_from": "rag",
|
||||
"score": round(chunk.get("score", 0.0), 4),
|
||||
"hit_count": chunk.get("hit_count", 0),
|
||||
"word_count": len(chunk.get("text", "")),
|
||||
"segment_position": index + 1,
|
||||
"index_node_hash": "",
|
||||
"content": chunk.get("text", "")[:500],
|
||||
"page": None,
|
||||
}
|
||||
for index, chunk in enumerate(context_chunks)
|
||||
]
|
||||
build_sources = getattr(self.retriever, "build_sources", None)
|
||||
if callable(build_sources):
|
||||
return build_sources(context_chunks, dataset_name)
|
||||
return RagRetriever(hydrate_documents=False).build_sources(context_chunks, dataset_name)
|
||||
|
||||
async def _hydrate_document_hits(self, dataset_id: int, chunks: list[dict]) -> list[dict]:
|
||||
source_names = sorted(
|
||||
{
|
||||
str(chunk.get("document_name") or chunk.get("source") or "").strip()
|
||||
for chunk in chunks
|
||||
if str(chunk.get("document_name") or chunk.get("source") or "").strip()
|
||||
}
|
||||
)
|
||||
if not source_names:
|
||||
return chunks
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
rows = (
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT id, original_name, enabled, hit_count
|
||||
FROM rag_document
|
||||
WHERE dataset_id = :dataset_id
|
||||
AND deleted_at IS NULL
|
||||
AND original_name = ANY(:source_names)
|
||||
"""
|
||||
),
|
||||
{
|
||||
"dataset_id": dataset_id,
|
||||
"source_names": source_names,
|
||||
},
|
||||
)
|
||||
).mappings().all()
|
||||
|
||||
document_map = {str(row["original_name"]): row for row in rows}
|
||||
visible_chunks: list[dict] = []
|
||||
hit_document_ids: list[int] = []
|
||||
for chunk in chunks:
|
||||
source_name = str(chunk.get("document_name") or chunk.get("source") or "").strip()
|
||||
document = document_map.get(source_name)
|
||||
if document and not bool(document.get("enabled")):
|
||||
continue
|
||||
if document:
|
||||
chunk["document_id"] = document["id"]
|
||||
chunk["dataset_id"] = dataset_id
|
||||
chunk["document_name"] = document["original_name"]
|
||||
chunk["hit_count"] = document.get("hit_count") or 0
|
||||
hit_document_ids.append(int(document["id"]))
|
||||
visible_chunks.append(chunk)
|
||||
|
||||
if hit_document_ids:
|
||||
async with GetAsyncSession() as session:
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE rag_document
|
||||
SET hit_count = hit_count + 1,
|
||||
updated_at = NOW()
|
||||
WHERE id = ANY(:document_ids)
|
||||
"""
|
||||
),
|
||||
{"document_ids": sorted(set(hit_document_ids))},
|
||||
)
|
||||
|
||||
return visible_chunks
|
||||
return await self.retriever._hydrate_document_hits(dataset_id, chunks)
|
||||
|
||||
def _parse_sse_event(self, chunk: str) -> dict | None:
|
||||
data_lines: list[str] = []
|
||||
|
||||
+1
-1
Submodule legal-platform-frontend updated: 28f1054238...1abbbe6b4e
@@ -1392,6 +1392,67 @@ rules:
|
||||
references_laws:
|
||||
- 《中华人民共和国行政处罚法》第五十九条
|
||||
type: deterministic
|
||||
- rule_id: JZ-JD-005
|
||||
name: 案由及裁量标准适用准确性
|
||||
desc: 结合处罚决定书认定依据、处罚依据、罚款项目和罚款金额,检索案由与裁量标准,判断处罚种类和罚款幅度是否适用准确。
|
||||
risk: medium
|
||||
score: 10
|
||||
scope:
|
||||
- 处罚决定书
|
||||
rag:
|
||||
collection: general_legal_kb
|
||||
top_k: 5
|
||||
source_names:
|
||||
- 广东省烟草专卖行政处罚裁量执行标准-rag.md
|
||||
- 案由_行政处罚与反走私管理治理办法.md
|
||||
query_template: |
|
||||
认定依据:{{处罚决定书.认定依据}}
|
||||
处罚依据:{{处罚决定书.处罚依据}}
|
||||
罚款项目:{{处罚决定书.罚款项目}}
|
||||
罚款基数:{{处罚决定书.罚款基数}}
|
||||
罚款比例:{{处罚决定书.罚款比例}}
|
||||
罚款总额:{{处罚决定书.罚款总额}}
|
||||
问题:检索对应案由、裁量档次、处罚种类和罚款幅度
|
||||
inject_as: rag_context
|
||||
resources_as: rag_resources
|
||||
stages:
|
||||
- id: '1'
|
||||
check: required
|
||||
fields:
|
||||
- 处罚决定书.认定依据
|
||||
- 处罚决定书.处罚依据
|
||||
- 处罚决定书.罚款项目
|
||||
- 处罚决定书.罚款基数
|
||||
- 处罚决定书.罚款比例
|
||||
- 处罚决定书.罚款总额
|
||||
- id: '2'
|
||||
check: ai
|
||||
prompt: |
|
||||
请结合检索到的法律知识和卷宗处罚决定书字段,判断案由、裁量档次、处罚种类和罚款幅度是否适用准确。
|
||||
|
||||
【检索依据】
|
||||
{{rag_context}}
|
||||
|
||||
【处罚决定书字段】
|
||||
认定依据:{{处罚决定书.认定依据}}
|
||||
处罚依据:{{处罚决定书.处罚依据}}
|
||||
罚款项目:{{处罚决定书.罚款项目}}
|
||||
罚款基数:{{处罚决定书.罚款基数}}
|
||||
罚款比例:{{处罚决定书.罚款比例}}
|
||||
罚款总额:{{处罚决定书.罚款总额}}
|
||||
|
||||
【判断要求】
|
||||
1. 判断违法事实对应案由是否准确;
|
||||
2. 判断处罚依据是否能支撑对应处罚种类;
|
||||
3. 判断罚款基数、比例、总额是否落在裁量标准允许幅度内;
|
||||
4. 若检索依据不足以确认,应给出 warn,不要编造依据。
|
||||
logic: 1 AND 2
|
||||
messages:
|
||||
pass: 案由、裁量档次、处罚种类和罚款幅度适用准确。
|
||||
fail: 案由、裁量档次、处罚种类或罚款幅度可能适用不准确,请核对。
|
||||
references_laws:
|
||||
- 《中华人民共和国行政处罚法》第五十九条
|
||||
type: ai_rule
|
||||
- group: JZG-SD
|
||||
rules:
|
||||
- rule_id: JZ-SD-001
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import tempfile
|
||||
import sys
|
||||
import types
|
||||
import unittest
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from leaudit.dsl.schema import Metadata, Rule, RuleAuthoringGroup, RulesFile, Stage
|
||||
from leaudit.engine.models import EvaluationResult
|
||||
from leaudit.extraction.bundle import bundle_from_single
|
||||
from leaudit.extraction.models import ExtractionResult, FieldValue
|
||||
from leaudit.ocr.models import OcrResult, Page
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline import LauditPipeline
|
||||
|
||||
|
||||
class FakeOcrClient:
|
||||
async def ocr(self, file_path):
|
||||
return OcrResult(
|
||||
pages=[Page(page_num=1, text="处罚决定书")],
|
||||
full_text="处罚决定书",
|
||||
)
|
||||
|
||||
|
||||
class FakeStorage:
|
||||
async def update_document_status(self, document_id, status):
|
||||
return None
|
||||
|
||||
async def save_ocr_result(self, document_id, ocr_result):
|
||||
return None
|
||||
|
||||
async def save_extraction_result(self, document_id, extraction_bundle):
|
||||
return None
|
||||
|
||||
async def save_evaluation_results(self, document_id, rules_file, evaluation_result, extraction_bundle):
|
||||
return None
|
||||
|
||||
|
||||
class FakeRetriever:
|
||||
pass
|
||||
|
||||
|
||||
def _rules_file() -> RulesFile:
|
||||
return RulesFile(
|
||||
metadata=Metadata(
|
||||
type_id="test.rag_bridge",
|
||||
name="RAG bridge test",
|
||||
version="1.0",
|
||||
last_updated="2026-05-21",
|
||||
),
|
||||
rules=[
|
||||
RuleAuthoringGroup(
|
||||
group="测试",
|
||||
rules=[
|
||||
Rule(
|
||||
rule_id="R-RAG",
|
||||
name="RAG 规则",
|
||||
risk="medium",
|
||||
score=1,
|
||||
stages=[Stage(id="1", check="required", field="处罚决定书.处罚依据")],
|
||||
logic="1",
|
||||
messages={"pass": "ok", "fail": "missing"},
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
class LeauditRagBridgeTest(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_pipeline_passes_injected_rag_retriever_to_evaluation(self):
|
||||
captured = {}
|
||||
retriever = FakeRetriever()
|
||||
field_value = FieldValue(value="依据", confidence=0.9)
|
||||
object.__setattr__(field_value, "position", None)
|
||||
extraction = bundle_from_single(
|
||||
ExtractionResult(
|
||||
fields={"处罚决定书.处罚依据": field_value},
|
||||
source_text="处罚决定书",
|
||||
)
|
||||
)
|
||||
|
||||
async def fake_dispatch_extract(*args, **kwargs):
|
||||
return extraction
|
||||
|
||||
async def fake_determine_phase(*args, **kwargs):
|
||||
return "executed"
|
||||
|
||||
async def fake_evaluate_extraction(*args, **kwargs):
|
||||
captured["retriever"] = kwargs.get("retriever")
|
||||
return EvaluationResult()
|
||||
|
||||
class TestPipeline(LauditPipeline):
|
||||
async def _extract_and_save_case_number(self, document_id, ocr_result):
|
||||
return None
|
||||
|
||||
with tempfile.NamedTemporaryFile() as tmp:
|
||||
pipeline = TestPipeline(
|
||||
ocr_client=FakeOcrClient(),
|
||||
storage_adapter=FakeStorage(),
|
||||
rag_retriever=retriever,
|
||||
)
|
||||
with patch(
|
||||
"fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.dispatch_extract",
|
||||
fake_dispatch_extract,
|
||||
), patch(
|
||||
"fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.determine_phase",
|
||||
fake_determine_phase,
|
||||
), patch(
|
||||
"fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.evaluate_extraction",
|
||||
fake_evaluate_extraction,
|
||||
), patch.dict(
|
||||
sys.modules,
|
||||
{
|
||||
"leaudit.extraction.coordinate_resolver": types.SimpleNamespace(
|
||||
resolve_bundle_positions=lambda *args, **kwargs: None
|
||||
)
|
||||
},
|
||||
):
|
||||
await pipeline.run(1, Path(tmp.name), _rules_file())
|
||||
|
||||
self.assertIs(captured["retriever"], retriever)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@@ -0,0 +1,102 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import unittest
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.retriever import RagRetriever
|
||||
|
||||
|
||||
class FakeCollection:
|
||||
def query(self, **kwargs):
|
||||
return {
|
||||
"ids": [["seg-1", "seg-2"]],
|
||||
"documents": [["烟草处罚裁量标准内容", "其他来源内容"]],
|
||||
"metadatas": [[
|
||||
{
|
||||
"document_name": "广东省烟草专卖行政处罚裁量执行标准-rag.md",
|
||||
"chunk_index": 0,
|
||||
},
|
||||
{
|
||||
"document_name": "其他.md",
|
||||
"chunk_index": 1,
|
||||
},
|
||||
]],
|
||||
"distances": [[0.0, 0.2]],
|
||||
}
|
||||
|
||||
|
||||
class FallbackCollection:
|
||||
def query(self, **kwargs):
|
||||
raise RuntimeError("vector unavailable")
|
||||
|
||||
def get(self, **kwargs):
|
||||
return {
|
||||
"ids": ["seg-fallback"],
|
||||
"documents": ["未在当地烟草专卖批发企业进货,对应裁量档次内容"],
|
||||
"metadatas": [
|
||||
{
|
||||
"document_name": "案由_行政处罚与反走私管理治理办法.md",
|
||||
"chunk_index": 3,
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
class FakeChroma:
|
||||
def __init__(self):
|
||||
self.collection_name = None
|
||||
|
||||
def get_or_create_collection(self, name):
|
||||
self.collection_name = name
|
||||
return FakeCollection()
|
||||
|
||||
|
||||
class FallbackChroma:
|
||||
def get_or_create_collection(self, name):
|
||||
return FallbackCollection()
|
||||
|
||||
|
||||
class RagRetrieverTest(unittest.IsolatedAsyncioTestCase):
|
||||
async def test_retrieve_from_collection_filters_sources_and_builds_resources(self):
|
||||
chroma = FakeChroma()
|
||||
retriever = RagRetriever(
|
||||
chroma_client=chroma,
|
||||
embed_texts=lambda texts, model_name="": [[0.1, 0.2] for _ in texts],
|
||||
hydrate_documents=False,
|
||||
)
|
||||
|
||||
result = await retriever.retrieve(
|
||||
query="罚款幅度",
|
||||
collection_name="general_legal_kb",
|
||||
top_k=5,
|
||||
source_names=["广东省烟草专卖行政处罚裁量执行标准-rag.md"],
|
||||
)
|
||||
|
||||
self.assertEqual(chroma.collection_name, "general_legal_kb")
|
||||
self.assertIn("烟草处罚裁量标准内容", result.rag_context)
|
||||
self.assertNotIn("其他来源内容", result.rag_context)
|
||||
self.assertEqual(len(result.rag_resources), 1)
|
||||
self.assertEqual(
|
||||
result.rag_resources[0]["document_name"],
|
||||
"广东省烟草专卖行政处罚裁量执行标准-rag.md",
|
||||
)
|
||||
|
||||
async def test_retrieve_uses_keyword_fallback_when_vector_search_fails(self):
|
||||
retriever = RagRetriever(
|
||||
chroma_client=FallbackChroma(),
|
||||
embed_texts=lambda texts, model_name="": [[0.1, 0.2] for _ in texts],
|
||||
hydrate_documents=False,
|
||||
)
|
||||
|
||||
result = await retriever.retrieve(
|
||||
query="未在当地烟草专卖批发企业进货",
|
||||
collection_name="general_legal_kb",
|
||||
source_names=["案由_行政处罚与反走私管理治理办法.md"],
|
||||
)
|
||||
|
||||
self.assertIn("对应裁量档次内容", result.rag_context)
|
||||
self.assertEqual(len(result.chunks), 1)
|
||||
self.assertEqual(result.rag_resources[0]["segment_id"], "seg-fallback")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
Reference in New Issue
Block a user