3 Commits

6 changed files with 795 additions and 316 deletions
@@ -24,6 +24,7 @@ from leaudit.ocr.base import BaseOCRClient
from leaudit.ocr.models import OcrResult from leaudit.ocr.models import OcrResult
from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter
from fastapi_modules.fastapi_leaudit.rag_engine.retriever import RagRetriever
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@@ -89,10 +90,12 @@ class LauditPipeline:
ocr_client: BaseOCRClient, ocr_client: BaseOCRClient,
llm_client: BaseLLMClient | None = None, llm_client: BaseLLMClient | None = None,
storage_adapter: StorageAdapter | None = None, storage_adapter: StorageAdapter | None = None,
rag_retriever: RagRetriever | None = None,
) -> None: ) -> None:
self.ocr_client = ocr_client self.ocr_client = ocr_client
self.llm_client = llm_client self.llm_client = llm_client
self.storage = storage_adapter or StorageAdapter() self.storage = storage_adapter or StorageAdapter()
self.rag_retriever = rag_retriever or RagRetriever()
async def run( async def run(
self, self,
@@ -219,6 +222,7 @@ class LauditPipeline:
visual_manifest=visual_manifest, visual_manifest=visual_manifest,
phase=detected_phase, phase=detected_phase,
external_mocks=external_mocks, external_mocks=external_mocks,
retriever=self.rag_retriever,
) )
timing["evaluation"] = round(time.time() - t0, 2) timing["evaluation"] = round(time.time() - t0, 2)
log.info( log.info(
@@ -0,0 +1,470 @@
from __future__ import annotations
import re
from dataclasses import dataclass, field
from typing import Any, Awaitable, Callable
import httpx
from sqlalchemy import text
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG, build_openai_embeddings_url
EmbedTexts = Callable[[list[str], str], Awaitable[list[list[float]]] | list[list[float]]]
@dataclass
class RagRetrieveResult:
rag_context: str = ""
rag_resources: list[dict[str, Any]] = field(default_factory=list)
chunks: list[dict[str, Any]] = field(default_factory=list)
dataset_name: str = ""
def as_dict(self) -> dict[str, Any]:
return {
"rag_context": self.rag_context,
"rag_resources": self.rag_resources,
}
class RagRetriever:
def __init__(
self,
*,
chroma_client: Any | None = None,
embed_texts: EmbedTexts | None = None,
hydrate_documents: bool = True,
) -> None:
self._chroma_client = chroma_client
self._embed_texts_override = embed_texts
self._hydrate_documents_enabled = hydrate_documents
async def retrieve(
self,
query: str,
collection_name: str | None = None,
dataset_id: int | None = None,
top_k: int = 5,
source_names: list[str] | None = None,
) -> RagRetrieveResult:
query_text = (query or "").strip()
if not query_text:
return RagRetrieveResult()
top_k = max(1, int(top_k or 5))
dataset = await self._load_dataset(dataset_id) if dataset_id else None
resolved_collection = (collection_name or (dataset or {}).get("collection_name") or "").strip()
if not resolved_collection:
return RagRetrieveResult(dataset_name=str((dataset or {}).get("name") or ""))
retrieval_model = (dataset or {}).get("retrieval_model") or {}
if dataset and not collection_name:
top_k = max(1, int(retrieval_model.get("top_k") or top_k))
score_threshold = self._resolve_score_threshold(retrieval_model)
dataset_name = str((dataset or {}).get("name") or "")
embedding_model = str((dataset or {}).get("embedding_model") or "")
chunks: list[dict[str, Any]] = []
try:
chunks = await self._vector_retrieve(
query=query_text,
collection_name=resolved_collection,
dataset_name=dataset_name,
embedding_model=embedding_model,
top_k=top_k,
score_threshold=score_threshold,
source_names=source_names,
)
except Exception:
chunks = []
if not chunks:
try:
chunks = await self._keyword_retrieve_context(
dataset_id=dataset_id,
collection_name=resolved_collection,
dataset_name=dataset_name,
query=query_text,
top_k=top_k,
score_threshold=score_threshold,
source_names=source_names,
)
except Exception:
chunks = []
if dataset_id and self._hydrate_documents_enabled:
chunks = await self._hydrate_document_hits(dataset_id, chunks)
chunks = chunks[:top_k]
return RagRetrieveResult(
rag_context=self._build_context(chunks),
rag_resources=self.build_sources(chunks, dataset_name),
chunks=chunks,
dataset_name=dataset_name,
)
async def _load_dataset(self, dataset_id: int | None) -> dict[str, Any] | None:
if not dataset_id:
return None
async with GetAsyncSession() as session:
row = (
await session.execute(
text(
"""
SELECT id, name, collection_name, retrieval_model, embedding_model
FROM rag_dataset
WHERE id = :dataset_id AND deleted_at IS NULL
LIMIT 1
"""
),
{"dataset_id": dataset_id},
)
).mappings().first()
return dict(row) if row else None
def _resolve_score_threshold(self, retrieval_model: dict[str, Any]) -> float | None:
if not retrieval_model.get("score_threshold_enabled"):
return None
try:
return float(retrieval_model.get("score_threshold"))
except (TypeError, ValueError):
return None
async def _vector_retrieve(
self,
*,
query: str,
collection_name: str,
dataset_name: str,
embedding_model: str,
top_k: int,
score_threshold: float | None,
source_names: list[str] | None,
) -> list[dict[str, Any]]:
query_embedding = await self._embed_texts([query], embedding_model)
collection = self._get_chroma().get_or_create_collection(collection_name)
result = collection.query(
query_embeddings=query_embedding,
n_results=max(top_k, 1),
include=["documents", "metadatas", "distances"],
)
ids = (result.get("ids") or [[]])[0] if result.get("ids") else []
docs = (result.get("documents") or [[]])[0]
metas = (result.get("metadatas") or [[]])[0]
distances = (result.get("distances") or [[]])[0]
allowed_sources = self._normalize_source_filter(source_names)
chunks: list[dict[str, Any]] = []
for idx, doc in enumerate(docs):
meta = metas[idx] if idx < len(metas) and isinstance(metas[idx], dict) else {}
document_name = str(meta.get("document_name") or meta.get("source") or "")
if allowed_sources and document_name not in allowed_sources:
continue
dist = float(distances[idx]) if idx < len(distances) and distances[idx] is not None else 1.0
score = 1.0 / (1.0 + max(dist, 0.0))
if score_threshold is not None and score < score_threshold:
continue
chunks.append(
{
"id": str(ids[idx] if idx < len(ids) else meta.get("id") or idx),
"text": doc or "",
"source": meta.get("source") or meta.get("document_name") or dataset_name,
"score": score,
"chunk_index": int(meta.get("chunk_index") or idx),
"document_name": document_name,
"document_id": meta.get("document_id"),
"page": meta.get("page"),
}
)
return chunks
async def _embed_texts(self, texts: list[str], model_name: str = "") -> list[list[float]]:
if self._embed_texts_override is not None:
result = self._embed_texts_override(texts, model_name)
if hasattr(result, "__await__"):
return await result # type: ignore[misc]
return result
embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or build_openai_embeddings_url(RAG_CONFIG["LLM_BASE_URL"])
embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"]
embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4"
batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10))
if not embed_url or not embed_key:
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务")
embeddings: list[list[float]] = []
async with httpx.AsyncClient(timeout=120.0) as client:
for start in range(0, len(texts), batch_size):
batch_texts = texts[start:start + batch_size]
try:
response = await client.post(
embed_url,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {embed_key}",
},
json={"model": embed_model, "input": batch_texts},
)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}"
raise LeauditException(
StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR,
f"向量化服务调用失败: {error_message[:300]}",
) from exc
payload = response.json()
rows = payload.get("data") or []
batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")]
if len(batch_embeddings) != len(batch_texts):
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
embeddings.extend(batch_embeddings)
if len(embeddings) != len(texts):
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
return embeddings
async def _keyword_retrieve_context(
self,
*,
dataset_id: int | None,
collection_name: str,
dataset_name: str,
query: str,
top_k: int,
score_threshold: float | None,
source_names: list[str] | None,
) -> list[dict[str, Any]]:
collection = self._get_chroma().get_or_create_collection(collection_name)
raw = collection.get(include=["documents", "metadatas"])
ids = raw.get("ids") or []
docs = raw.get("documents") or []
metas = raw.get("metadatas") or []
allowed_sources = self._normalize_source_filter(source_names)
terms = self._build_keyword_terms(query)
if not terms:
return []
scored_chunks: list[dict[str, Any]] = []
for idx, chunk_id in enumerate(ids):
doc = docs[idx] if idx < len(docs) else ""
meta = metas[idx] if idx < len(metas) and isinstance(metas[idx], dict) else {}
document_name = str(meta.get("document_name") or meta.get("source") or "")
if allowed_sources and document_name not in allowed_sources:
continue
score = self._score_keyword_chunk(
query=query,
terms=terms,
content=doc or "",
document_name=document_name,
)
if score <= 0:
continue
if score_threshold is not None and score < score_threshold:
continue
scored_chunks.append(
{
"id": str(chunk_id),
"text": doc or "",
"source": meta.get("source") or meta.get("document_name") or dataset_name,
"score": score,
"chunk_index": int(meta.get("chunk_index") or idx),
"document_name": document_name,
"document_id": meta.get("document_id"),
"page": meta.get("page"),
}
)
scored_chunks.sort(key=lambda item: (-float(item.get("score") or 0.0), int(item.get("chunk_index") or 0)))
return scored_chunks[: max(top_k * 3, top_k)]
def _build_keyword_terms(self, query: str) -> list[str]:
normalized = self._normalize_keyword_query(query)
spans = [item.strip() for item in re.findall(r"[\u4e00-\u9fffA-Za-z0-9]+", normalized) if item.strip()]
if not spans:
return []
stop_terms = {
"什么",
"请问",
"一下",
"有关",
"关于",
"如何",
"哪些",
"怎么",
"是否",
"规定",
"办法",
"条例",
"法律",
}
terms: list[str] = []
for span in spans:
if span in stop_terms:
continue
terms.append(span)
if re.fullmatch(r"[\u4e00-\u9fff]+", span):
for size in (2, 3, 4):
if len(span) > size:
for start in range(0, len(span) - size + 1):
token = span[start:start + size]
if token not in stop_terms:
terms.append(token)
unique_terms: list[str] = []
seen: set[str] = set()
for term in sorted(terms, key=len, reverse=True):
if term and term not in seen:
unique_terms.append(term)
seen.add(term)
return unique_terms[:20]
def _normalize_keyword_query(self, query: str) -> str:
normalized = (query or "").strip().lower()
patterns = [
"是什么",
"什么是",
"有哪些",
"有什么",
"是什么?",
"是什么?",
"请问",
"介绍一下",
"解释一下",
"帮我分析",
"帮我看看",
]
for pattern in patterns:
normalized = normalized.replace(pattern, " ")
return re.sub(r"\s+", " ", normalized).strip()
def _score_keyword_chunk(self, *, query: str, terms: list[str], content: str, document_name: str) -> float:
haystack = f"{document_name}\n{content}".lower()
if not haystack:
return 0.0
exact_query = self._normalize_keyword_query(query)
if exact_query and exact_query in haystack:
return 0.98
matched_weight = 0.0
total_weight = 0.0
name_bonus = 0.0
for term in terms:
weight = float(max(len(term), 1) ** 2)
total_weight += weight
if term.lower() in haystack:
matched_weight += weight
if term.lower() in document_name.lower():
name_bonus += min(0.15, 0.03 * len(term))
if total_weight <= 0:
return 0.0
score = (matched_weight / total_weight) + name_bonus
return round(min(score, 0.99), 6)
def _build_context(self, chunks: list[dict[str, Any]]) -> str:
lines: list[str] = []
for index, chunk in enumerate(chunks, start=1):
document_name = chunk.get("document_name") or chunk.get("source") or "未知来源"
text_value = str(chunk.get("text") or "").strip()
if not text_value:
continue
lines.append(f"[{index}] 来源:{document_name}\n{text_value}")
return "\n\n".join(lines)
def build_sources(self, context_chunks: list[dict[str, Any]], dataset_name: str = "") -> list[dict[str, Any]]:
return [
{
"position": index + 1,
"dataset_id": str(chunk.get("dataset_id") or ""),
"dataset_name": dataset_name,
"document_id": str(chunk.get("document_id") or ""),
"document_name": chunk.get("document_name") or chunk.get("source", ""),
"data_source_type": "upload_file",
"segment_id": chunk.get("id", ""),
"retriever_from": "rag",
"score": round(float(chunk.get("score") or 0.0), 4),
"hit_count": chunk.get("hit_count", 0),
"word_count": len(chunk.get("text", "")),
"segment_position": index + 1,
"index_node_hash": "",
"content": chunk.get("text", "")[:500],
"page": None,
}
for index, chunk in enumerate(context_chunks)
]
async def _hydrate_document_hits(self, dataset_id: int, chunks: list[dict[str, Any]]) -> list[dict[str, Any]]:
source_names = sorted(
{
str(chunk.get("document_name") or chunk.get("source") or "").strip()
for chunk in chunks
if str(chunk.get("document_name") or chunk.get("source") or "").strip()
}
)
if not source_names:
return chunks
async with GetAsyncSession() as session:
rows = (
await session.execute(
text(
"""
SELECT id, original_name, enabled, hit_count
FROM rag_document
WHERE dataset_id = :dataset_id
AND deleted_at IS NULL
AND original_name = ANY(:source_names)
"""
),
{
"dataset_id": dataset_id,
"source_names": source_names,
},
)
).mappings().all()
document_map = {str(row["original_name"]): row for row in rows}
visible_chunks: list[dict[str, Any]] = []
hit_document_ids: list[int] = []
for chunk in chunks:
source_name = str(chunk.get("document_name") or chunk.get("source") or "").strip()
document = document_map.get(source_name)
if document and not bool(document.get("enabled")):
continue
if document:
chunk["document_id"] = document["id"]
chunk["dataset_id"] = dataset_id
chunk["document_name"] = document["original_name"]
chunk["hit_count"] = document.get("hit_count") or 0
hit_document_ids.append(int(document["id"]))
visible_chunks.append(chunk)
if hit_document_ids:
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
UPDATE rag_document
SET hit_count = hit_count + 1,
updated_at = NOW()
WHERE id = ANY(:document_ids)
"""
),
{"document_ids": sorted(set(hit_document_ids))},
)
return visible_chunks
def _normalize_source_filter(self, source_names: list[str] | None) -> set[str]:
return {str(name).strip() for name in (source_names or []) if str(name).strip()}
def _get_chroma(self) -> Any:
return self._chroma_client or get_chroma()
@@ -33,11 +33,10 @@ from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import (
from fastapi_modules.fastapi_leaudit.rag_engine.config import ( from fastapi_modules.fastapi_leaudit.rag_engine.config import (
RAG_CONFIG, RAG_CONFIG,
build_openai_chat_completions_url, build_openai_chat_completions_url,
build_openai_embeddings_url,
) )
from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma
from fastapi_modules.fastapi_leaudit.rag_engine.generator import generate_stream from fastapi_modules.fastapi_leaudit.rag_engine.generator import generate_stream
from fastapi_modules.fastapi_leaudit.rag_engine.question_chains import generate_followups from fastapi_modules.fastapi_leaudit.rag_engine.question_chains import generate_followups
from fastapi_modules.fastapi_leaudit.rag_engine.retriever import RagRetriever
from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService
@@ -54,6 +53,9 @@ class RagChatServiceImpl(IRagChatService):
_task_locks: dict[str, asyncio.Lock] = {} _task_locks: dict[str, asyncio.Lock] = {}
_title_tasks: dict[str, asyncio.Task] = {} _title_tasks: dict[str, asyncio.Task] = {}
def __init__(self, retriever: RagRetriever | None = None) -> None:
self.retriever = retriever or RagRetriever()
async def GetApps(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagChatAppListVO: async def GetApps(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagChatAppListVO:
apps = await self._load_apps(UserArea, UserRole, only_default=False) apps = await self._load_apps(UserArea, UserRole, only_default=False)
return RagChatAppListVO(data=apps, total=len(apps)) return RagChatAppListVO(data=apps, total=len(apps))
@@ -592,121 +594,11 @@ class RagChatServiceImpl(IRagChatService):
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权访问该会话") raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权访问该会话")
async def _retrieve_context(self, dataset_id: int | None, query: str) -> tuple[list[dict], str]: async def _retrieve_context(self, dataset_id: int | None, query: str) -> tuple[list[dict], str]:
if not dataset_id: result = await self.retriever.retrieve(query=query, dataset_id=dataset_id)
return [], "" return result.chunks, result.dataset_name
async with GetAsyncSession() as session:
dataset = (
await session.execute(
text(
"""
SELECT id, name, collection_name, retrieval_model, embedding_model
FROM rag_dataset
WHERE id = :dataset_id AND deleted_at IS NULL
LIMIT 1
"""
),
{"dataset_id": dataset_id},
)
).mappings().first()
if not dataset:
return [], ""
retrieval_model = dataset.get("retrieval_model") or {}
top_k = int(retrieval_model.get("top_k") or 5)
score_threshold = None
if retrieval_model.get("score_threshold_enabled"):
try:
score_threshold = float(retrieval_model.get("score_threshold"))
except (TypeError, ValueError):
score_threshold = None
try:
query_embedding = await self._embed_texts([query], dataset.get("embedding_model") or "")
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
result = collection.query(
query_embeddings=query_embedding,
n_results=max(top_k, 1),
include=["documents", "metadatas", "distances"],
)
ids = (result.get("ids") or [[]])[0] if result.get("ids") else []
docs = (result.get("documents") or [[]])[0]
metas = (result.get("metadatas") or [[]])[0]
distances = (result.get("distances") or [[]])[0]
chunks: list[dict] = []
for idx, doc in enumerate(docs):
meta = metas[idx] if idx < len(metas) else {}
dist = float(distances[idx]) if idx < len(distances) and distances[idx] is not None else 1.0
score = 1.0 / (1.0 + max(dist, 0.0))
if score_threshold is not None and score < score_threshold:
continue
chunks.append(
{
"id": str(ids[idx] if idx < len(ids) else meta.get("id") or idx),
"text": doc,
"source": meta.get("source") or meta.get("document_name") or dataset.get("name") or "",
"score": score,
"chunk_index": int(meta.get("chunk_index") or idx),
"document_name": meta.get("document_name") or meta.get("source") or "",
"document_id": meta.get("document_id"),
"page": meta.get("page"),
}
)
chunks = await self._hydrate_document_hits(dataset_id, chunks)
if chunks:
return chunks[:top_k], dataset.get("name") or ""
except Exception:
pass
try:
chunks = await self._keyword_retrieve_context(
dataset_id=dataset_id,
collection_name=str(dataset["collection_name"]),
dataset_name=str(dataset.get("name") or ""),
query=query,
top_k=top_k,
score_threshold=score_threshold,
)
return chunks[:top_k], dataset.get("name") or ""
except Exception:
return [], dataset.get("name") or ""
async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]: async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]:
embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or build_openai_embeddings_url(RAG_CONFIG["LLM_BASE_URL"]) return await self.retriever._embed_texts(texts, model_name)
embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"]
embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4"
batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10))
if not embed_url or not embed_key:
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务")
embeddings: list[list[float]] = []
async with httpx.AsyncClient(timeout=120.0) as client:
for start in range(0, len(texts), batch_size):
batch_texts = texts[start:start + batch_size]
try:
response = await client.post(
embed_url,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {embed_key}",
},
json={"model": embed_model, "input": batch_texts},
)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}"
raise LeauditException(
StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR,
f"向量化服务调用失败: {error_message[:300]}",
) from exc
payload = response.json()
rows = payload.get("data") or []
batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")]
if len(batch_embeddings) != len(batch_texts):
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
embeddings.extend(batch_embeddings)
if len(embeddings) != len(texts):
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
return embeddings
async def _start_message_task( async def _start_message_task(
self, self,
@@ -1219,220 +1111,42 @@ class RagChatServiceImpl(IRagChatService):
top_k: int, top_k: int,
score_threshold: float | None, score_threshold: float | None,
) -> list[dict]: ) -> list[dict]:
collection = get_chroma().get_or_create_collection(collection_name) chunks = await self.retriever._keyword_retrieve_context(
raw = collection.get(include=["documents", "metadatas"]) dataset_id=dataset_id,
ids = raw.get("ids") or [] collection_name=collection_name,
docs = raw.get("documents") or [] dataset_name=dataset_name,
metas = raw.get("metadatas") or [] query=query,
top_k=top_k,
terms = self._build_keyword_terms(query) score_threshold=score_threshold,
if not terms: source_names=None,
return [] )
return chunks[:top_k]
scored_chunks: list[dict] = []
for idx, chunk_id in enumerate(ids):
doc = docs[idx] if idx < len(docs) else ""
meta = metas[idx] if idx < len(metas) and isinstance(metas[idx], dict) else {}
score = self._score_keyword_chunk(
query=query,
terms=terms,
content=doc or "",
document_name=str(meta.get("document_name") or meta.get("source") or ""),
)
if score <= 0:
continue
if score_threshold is not None and score < score_threshold:
continue
scored_chunks.append(
{
"id": str(chunk_id),
"text": doc or "",
"source": meta.get("source") or meta.get("document_name") or dataset_name,
"score": score,
"chunk_index": int(meta.get("chunk_index") or idx),
"document_name": meta.get("document_name") or meta.get("source") or "",
"document_id": meta.get("document_id"),
"page": meta.get("page"),
}
)
scored_chunks.sort(key=lambda item: (-float(item.get("score") or 0.0), int(item.get("chunk_index") or 0)))
hydrated = await self._hydrate_document_hits(dataset_id, scored_chunks[: max(top_k * 3, top_k)])
return hydrated[:top_k]
def _build_keyword_terms(self, query: str) -> list[str]: def _build_keyword_terms(self, query: str) -> list[str]:
normalized = self._normalize_keyword_query(query) return self.retriever._build_keyword_terms(query)
spans = [item.strip() for item in re.findall(r"[\u4e00-\u9fffA-Za-z0-9]+", normalized) if item.strip()]
if not spans:
return []
stop_terms = {
"什么",
"请问",
"一下",
"有关",
"关于",
"如何",
"哪些",
"怎么",
"是否",
"规定",
"办法",
"条例",
"法律",
}
terms: list[str] = []
for span in spans:
if span in stop_terms:
continue
terms.append(span)
if re.fullmatch(r"[\u4e00-\u9fff]+", span):
for size in (2, 3, 4):
if len(span) > size:
for start in range(0, len(span) - size + 1):
token = span[start:start + size]
if token not in stop_terms:
terms.append(token)
unique_terms: list[str] = []
seen: set[str] = set()
for term in sorted(terms, key=len, reverse=True):
if term and term not in seen:
unique_terms.append(term)
seen.add(term)
return unique_terms[:20]
def _normalize_keyword_query(self, query: str) -> str: def _normalize_keyword_query(self, query: str) -> str:
normalized = (query or "").strip().lower() return self.retriever._normalize_keyword_query(query)
patterns = [
"是什么",
"什么是",
"有哪些",
"有什么",
"是什么?",
"是什么?",
"请问",
"介绍一下",
"解释一下",
"帮我分析",
"帮我看看",
]
for pattern in patterns:
normalized = normalized.replace(pattern, " ")
return re.sub(r"\s+", " ", normalized).strip()
def _score_keyword_chunk(self, *, query: str, terms: list[str], content: str, document_name: str) -> float: def _score_keyword_chunk(self, *, query: str, terms: list[str], content: str, document_name: str) -> float:
haystack = f"{document_name}\n{content}".lower() return self.retriever._score_keyword_chunk(
if not haystack: query=query,
return 0.0 terms=terms,
content=content,
exact_query = self._normalize_keyword_query(query) document_name=document_name,
if exact_query and exact_query in haystack: )
return 0.98
matched_weight = 0.0
total_weight = 0.0
name_bonus = 0.0
for term in terms:
weight = float(max(len(term), 1) ** 2)
total_weight += weight
if term.lower() in haystack:
matched_weight += weight
if term.lower() in document_name.lower():
name_bonus += min(0.15, 0.03 * len(term))
if total_weight <= 0:
return 0.0
score = (matched_weight / total_weight) + name_bonus
return round(min(score, 0.99), 6)
def _format_sse(self, payload: dict) -> bytes: def _format_sse(self, payload: dict) -> bytes:
return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n".encode("utf-8") return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n".encode("utf-8")
def _build_sources(self, context_chunks: list[dict], dataset_name: str) -> list[dict]: def _build_sources(self, context_chunks: list[dict], dataset_name: str) -> list[dict]:
return [ build_sources = getattr(self.retriever, "build_sources", None)
{ if callable(build_sources):
"position": index + 1, return build_sources(context_chunks, dataset_name)
"dataset_id": str(chunk.get("dataset_id") or ""), return RagRetriever(hydrate_documents=False).build_sources(context_chunks, dataset_name)
"dataset_name": dataset_name,
"document_id": str(chunk.get("document_id") or ""),
"document_name": chunk.get("document_name") or chunk.get("source", ""),
"data_source_type": "upload_file",
"segment_id": chunk.get("id", ""),
"retriever_from": "rag",
"score": round(chunk.get("score", 0.0), 4),
"hit_count": chunk.get("hit_count", 0),
"word_count": len(chunk.get("text", "")),
"segment_position": index + 1,
"index_node_hash": "",
"content": chunk.get("text", "")[:500],
"page": None,
}
for index, chunk in enumerate(context_chunks)
]
async def _hydrate_document_hits(self, dataset_id: int, chunks: list[dict]) -> list[dict]: async def _hydrate_document_hits(self, dataset_id: int, chunks: list[dict]) -> list[dict]:
source_names = sorted( return await self.retriever._hydrate_document_hits(dataset_id, chunks)
{
str(chunk.get("document_name") or chunk.get("source") or "").strip()
for chunk in chunks
if str(chunk.get("document_name") or chunk.get("source") or "").strip()
}
)
if not source_names:
return chunks
async with GetAsyncSession() as session:
rows = (
await session.execute(
text(
"""
SELECT id, original_name, enabled, hit_count
FROM rag_document
WHERE dataset_id = :dataset_id
AND deleted_at IS NULL
AND original_name = ANY(:source_names)
"""
),
{
"dataset_id": dataset_id,
"source_names": source_names,
},
)
).mappings().all()
document_map = {str(row["original_name"]): row for row in rows}
visible_chunks: list[dict] = []
hit_document_ids: list[int] = []
for chunk in chunks:
source_name = str(chunk.get("document_name") or chunk.get("source") or "").strip()
document = document_map.get(source_name)
if document and not bool(document.get("enabled")):
continue
if document:
chunk["document_id"] = document["id"]
chunk["dataset_id"] = dataset_id
chunk["document_name"] = document["original_name"]
chunk["hit_count"] = document.get("hit_count") or 0
hit_document_ids.append(int(document["id"]))
visible_chunks.append(chunk)
if hit_document_ids:
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
UPDATE rag_document
SET hit_count = hit_count + 1,
updated_at = NOW()
WHERE id = ANY(:document_ids)
"""
),
{"document_ids": sorted(set(hit_document_ids))},
)
return visible_chunks
def _parse_sse_event(self, chunk: str) -> dict | None: def _parse_sse_event(self, chunk: str) -> dict | None:
data_lines: list[str] = [] data_lines: list[str] = []
+61
View File
@@ -1392,6 +1392,67 @@ rules:
references_laws: references_laws:
- 《中华人民共和国行政处罚法》第五十九条 - 《中华人民共和国行政处罚法》第五十九条
type: deterministic type: deterministic
- rule_id: JZ-JD-005
name: 案由及裁量标准适用准确性
desc: 结合处罚决定书认定依据、处罚依据、罚款项目和罚款金额,检索案由与裁量标准,判断处罚种类和罚款幅度是否适用准确。
risk: medium
score: 10
scope:
- 处罚决定书
rag:
collection: general_legal_kb
top_k: 5
source_names:
- 广东省烟草专卖行政处罚裁量执行标准-rag.md
- 案由_行政处罚与反走私管理治理办法.md
query_template: |
认定依据:{{处罚决定书.认定依据}}
处罚依据:{{处罚决定书.处罚依据}}
罚款项目:{{处罚决定书.罚款项目}}
罚款基数:{{处罚决定书.罚款基数}}
罚款比例:{{处罚决定书.罚款比例}}
罚款总额:{{处罚决定书.罚款总额}}
问题:检索对应案由、裁量档次、处罚种类和罚款幅度
inject_as: rag_context
resources_as: rag_resources
stages:
- id: '1'
check: required
fields:
- 处罚决定书.认定依据
- 处罚决定书.处罚依据
- 处罚决定书.罚款项目
- 处罚决定书.罚款基数
- 处罚决定书.罚款比例
- 处罚决定书.罚款总额
- id: '2'
check: ai
prompt: |
请结合检索到的法律知识和卷宗处罚决定书字段,判断案由、裁量档次、处罚种类和罚款幅度是否适用准确。
【检索依据】
{{rag_context}}
【处罚决定书字段】
认定依据:{{处罚决定书.认定依据}}
处罚依据:{{处罚决定书.处罚依据}}
罚款项目:{{处罚决定书.罚款项目}}
罚款基数:{{处罚决定书.罚款基数}}
罚款比例:{{处罚决定书.罚款比例}}
罚款总额:{{处罚决定书.罚款总额}}
【判断要求】
1. 判断违法事实对应案由是否准确;
2. 判断处罚依据是否能支撑对应处罚种类;
3. 判断罚款基数、比例、总额是否落在裁量标准允许幅度内;
4. 若检索依据不足以确认,应给出 warn,不要编造依据。
logic: 1 AND 2
messages:
pass: 案由、裁量档次、处罚种类和罚款幅度适用准确。
fail: 案由、裁量档次、处罚种类或罚款幅度可能适用不准确,请核对。
references_laws:
- 《中华人民共和国行政处罚法》第五十九条
type: ai_rule
- group: JZG-SD - group: JZG-SD
rules: rules:
- rule_id: JZ-SD-001 - rule_id: JZ-SD-001
+128
View File
@@ -0,0 +1,128 @@
from __future__ import annotations
import tempfile
import sys
import types
import unittest
from pathlib import Path
from unittest.mock import patch
from leaudit.dsl.schema import Metadata, Rule, RuleAuthoringGroup, RulesFile, Stage
from leaudit.engine.models import EvaluationResult
from leaudit.extraction.bundle import bundle_from_single
from leaudit.extraction.models import ExtractionResult, FieldValue
from leaudit.ocr.models import OcrResult, Page
from fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline import LauditPipeline
class FakeOcrClient:
async def ocr(self, file_path):
return OcrResult(
pages=[Page(page_num=1, text="处罚决定书")],
full_text="处罚决定书",
)
class FakeStorage:
async def update_document_status(self, document_id, status):
return None
async def save_ocr_result(self, document_id, ocr_result):
return None
async def save_extraction_result(self, document_id, extraction_bundle):
return None
async def save_evaluation_results(self, document_id, rules_file, evaluation_result, extraction_bundle):
return None
class FakeRetriever:
pass
def _rules_file() -> RulesFile:
return RulesFile(
metadata=Metadata(
type_id="test.rag_bridge",
name="RAG bridge test",
version="1.0",
last_updated="2026-05-21",
),
rules=[
RuleAuthoringGroup(
group="测试",
rules=[
Rule(
rule_id="R-RAG",
name="RAG 规则",
risk="medium",
score=1,
stages=[Stage(id="1", check="required", field="处罚决定书.处罚依据")],
logic="1",
messages={"pass": "ok", "fail": "missing"},
)
],
)
],
)
class LeauditRagBridgeTest(unittest.IsolatedAsyncioTestCase):
async def test_pipeline_passes_injected_rag_retriever_to_evaluation(self):
captured = {}
retriever = FakeRetriever()
field_value = FieldValue(value="依据", confidence=0.9)
object.__setattr__(field_value, "position", None)
extraction = bundle_from_single(
ExtractionResult(
fields={"处罚决定书.处罚依据": field_value},
source_text="处罚决定书",
)
)
async def fake_dispatch_extract(*args, **kwargs):
return extraction
async def fake_determine_phase(*args, **kwargs):
return "executed"
async def fake_evaluate_extraction(*args, **kwargs):
captured["retriever"] = kwargs.get("retriever")
return EvaluationResult()
class TestPipeline(LauditPipeline):
async def _extract_and_save_case_number(self, document_id, ocr_result):
return None
with tempfile.NamedTemporaryFile() as tmp:
pipeline = TestPipeline(
ocr_client=FakeOcrClient(),
storage_adapter=FakeStorage(),
rag_retriever=retriever,
)
with patch(
"fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.dispatch_extract",
fake_dispatch_extract,
), patch(
"fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.determine_phase",
fake_determine_phase,
), patch(
"fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.evaluate_extraction",
fake_evaluate_extraction,
), patch.dict(
sys.modules,
{
"leaudit.extraction.coordinate_resolver": types.SimpleNamespace(
resolve_bundle_positions=lambda *args, **kwargs: None
)
},
):
await pipeline.run(1, Path(tmp.name), _rules_file())
self.assertIs(captured["retriever"], retriever)
if __name__ == "__main__":
unittest.main()
+102
View File
@@ -0,0 +1,102 @@
from __future__ import annotations
import unittest
from fastapi_modules.fastapi_leaudit.rag_engine.retriever import RagRetriever
class FakeCollection:
def query(self, **kwargs):
return {
"ids": [["seg-1", "seg-2"]],
"documents": [["烟草处罚裁量标准内容", "其他来源内容"]],
"metadatas": [[
{
"document_name": "广东省烟草专卖行政处罚裁量执行标准-rag.md",
"chunk_index": 0,
},
{
"document_name": "其他.md",
"chunk_index": 1,
},
]],
"distances": [[0.0, 0.2]],
}
class FallbackCollection:
def query(self, **kwargs):
raise RuntimeError("vector unavailable")
def get(self, **kwargs):
return {
"ids": ["seg-fallback"],
"documents": ["未在当地烟草专卖批发企业进货,对应裁量档次内容"],
"metadatas": [
{
"document_name": "案由_行政处罚与反走私管理治理办法.md",
"chunk_index": 3,
}
],
}
class FakeChroma:
def __init__(self):
self.collection_name = None
def get_or_create_collection(self, name):
self.collection_name = name
return FakeCollection()
class FallbackChroma:
def get_or_create_collection(self, name):
return FallbackCollection()
class RagRetrieverTest(unittest.IsolatedAsyncioTestCase):
async def test_retrieve_from_collection_filters_sources_and_builds_resources(self):
chroma = FakeChroma()
retriever = RagRetriever(
chroma_client=chroma,
embed_texts=lambda texts, model_name="": [[0.1, 0.2] for _ in texts],
hydrate_documents=False,
)
result = await retriever.retrieve(
query="罚款幅度",
collection_name="general_legal_kb",
top_k=5,
source_names=["广东省烟草专卖行政处罚裁量执行标准-rag.md"],
)
self.assertEqual(chroma.collection_name, "general_legal_kb")
self.assertIn("烟草处罚裁量标准内容", result.rag_context)
self.assertNotIn("其他来源内容", result.rag_context)
self.assertEqual(len(result.rag_resources), 1)
self.assertEqual(
result.rag_resources[0]["document_name"],
"广东省烟草专卖行政处罚裁量执行标准-rag.md",
)
async def test_retrieve_uses_keyword_fallback_when_vector_search_fails(self):
retriever = RagRetriever(
chroma_client=FallbackChroma(),
embed_texts=lambda texts, model_name="": [[0.1, 0.2] for _ in texts],
hydrate_documents=False,
)
result = await retriever.retrieve(
query="未在当地烟草专卖批发企业进货",
collection_name="general_legal_kb",
source_names=["案由_行政处罚与反走私管理治理办法.md"],
)
self.assertIn("对应裁量档次内容", result.rag_context)
self.assertEqual(len(result.chunks), 1)
self.assertEqual(result.rag_resources[0]["segment_id"], "seg-fallback")
if __name__ == "__main__":
unittest.main()