feat(rag): add shared retriever for audit pipeline

This commit is contained in:
2026-05-21 15:54:47 +08:00
parent 4847bccdec
commit 6ce1a290ab
6 changed files with 795 additions and 316 deletions
@@ -33,11 +33,10 @@ from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import (
from fastapi_modules.fastapi_leaudit.rag_engine.config import (
RAG_CONFIG,
build_openai_chat_completions_url,
build_openai_embeddings_url,
)
from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma
from fastapi_modules.fastapi_leaudit.rag_engine.generator import generate_stream
from fastapi_modules.fastapi_leaudit.rag_engine.question_chains import generate_followups
from fastapi_modules.fastapi_leaudit.rag_engine.retriever import RagRetriever
from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService
@@ -54,6 +53,9 @@ class RagChatServiceImpl(IRagChatService):
_task_locks: dict[str, asyncio.Lock] = {}
_title_tasks: dict[str, asyncio.Task] = {}
def __init__(self, retriever: RagRetriever | None = None) -> None:
self.retriever = retriever or RagRetriever()
async def GetApps(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagChatAppListVO:
apps = await self._load_apps(UserArea, UserRole, only_default=False)
return RagChatAppListVO(data=apps, total=len(apps))
@@ -592,121 +594,11 @@ class RagChatServiceImpl(IRagChatService):
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权访问该会话")
async def _retrieve_context(self, dataset_id: int | None, query: str) -> tuple[list[dict], str]:
if not dataset_id:
return [], ""
async with GetAsyncSession() as session:
dataset = (
await session.execute(
text(
"""
SELECT id, name, collection_name, retrieval_model, embedding_model
FROM rag_dataset
WHERE id = :dataset_id AND deleted_at IS NULL
LIMIT 1
"""
),
{"dataset_id": dataset_id},
)
).mappings().first()
if not dataset:
return [], ""
retrieval_model = dataset.get("retrieval_model") or {}
top_k = int(retrieval_model.get("top_k") or 5)
score_threshold = None
if retrieval_model.get("score_threshold_enabled"):
try:
score_threshold = float(retrieval_model.get("score_threshold"))
except (TypeError, ValueError):
score_threshold = None
try:
query_embedding = await self._embed_texts([query], dataset.get("embedding_model") or "")
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
result = collection.query(
query_embeddings=query_embedding,
n_results=max(top_k, 1),
include=["documents", "metadatas", "distances"],
)
ids = (result.get("ids") or [[]])[0] if result.get("ids") else []
docs = (result.get("documents") or [[]])[0]
metas = (result.get("metadatas") or [[]])[0]
distances = (result.get("distances") or [[]])[0]
chunks: list[dict] = []
for idx, doc in enumerate(docs):
meta = metas[idx] if idx < len(metas) else {}
dist = float(distances[idx]) if idx < len(distances) and distances[idx] is not None else 1.0
score = 1.0 / (1.0 + max(dist, 0.0))
if score_threshold is not None and score < score_threshold:
continue
chunks.append(
{
"id": str(ids[idx] if idx < len(ids) else meta.get("id") or idx),
"text": doc,
"source": meta.get("source") or meta.get("document_name") or dataset.get("name") or "",
"score": score,
"chunk_index": int(meta.get("chunk_index") or idx),
"document_name": meta.get("document_name") or meta.get("source") or "",
"document_id": meta.get("document_id"),
"page": meta.get("page"),
}
)
chunks = await self._hydrate_document_hits(dataset_id, chunks)
if chunks:
return chunks[:top_k], dataset.get("name") or ""
except Exception:
pass
try:
chunks = await self._keyword_retrieve_context(
dataset_id=dataset_id,
collection_name=str(dataset["collection_name"]),
dataset_name=str(dataset.get("name") or ""),
query=query,
top_k=top_k,
score_threshold=score_threshold,
)
return chunks[:top_k], dataset.get("name") or ""
except Exception:
return [], dataset.get("name") or ""
result = await self.retriever.retrieve(query=query, dataset_id=dataset_id)
return result.chunks, result.dataset_name
async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]:
embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or build_openai_embeddings_url(RAG_CONFIG["LLM_BASE_URL"])
embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"]
embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4"
batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10))
if not embed_url or not embed_key:
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务")
embeddings: list[list[float]] = []
async with httpx.AsyncClient(timeout=120.0) as client:
for start in range(0, len(texts), batch_size):
batch_texts = texts[start:start + batch_size]
try:
response = await client.post(
embed_url,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {embed_key}",
},
json={"model": embed_model, "input": batch_texts},
)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}"
raise LeauditException(
StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR,
f"向量化服务调用失败: {error_message[:300]}",
) from exc
payload = response.json()
rows = payload.get("data") or []
batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")]
if len(batch_embeddings) != len(batch_texts):
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
embeddings.extend(batch_embeddings)
if len(embeddings) != len(texts):
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
return embeddings
return await self.retriever._embed_texts(texts, model_name)
async def _start_message_task(
self,
@@ -1219,220 +1111,42 @@ class RagChatServiceImpl(IRagChatService):
top_k: int,
score_threshold: float | None,
) -> list[dict]:
collection = get_chroma().get_or_create_collection(collection_name)
raw = collection.get(include=["documents", "metadatas"])
ids = raw.get("ids") or []
docs = raw.get("documents") or []
metas = raw.get("metadatas") or []
terms = self._build_keyword_terms(query)
if not terms:
return []
scored_chunks: list[dict] = []
for idx, chunk_id in enumerate(ids):
doc = docs[idx] if idx < len(docs) else ""
meta = metas[idx] if idx < len(metas) and isinstance(metas[idx], dict) else {}
score = self._score_keyword_chunk(
query=query,
terms=terms,
content=doc or "",
document_name=str(meta.get("document_name") or meta.get("source") or ""),
)
if score <= 0:
continue
if score_threshold is not None and score < score_threshold:
continue
scored_chunks.append(
{
"id": str(chunk_id),
"text": doc or "",
"source": meta.get("source") or meta.get("document_name") or dataset_name,
"score": score,
"chunk_index": int(meta.get("chunk_index") or idx),
"document_name": meta.get("document_name") or meta.get("source") or "",
"document_id": meta.get("document_id"),
"page": meta.get("page"),
}
)
scored_chunks.sort(key=lambda item: (-float(item.get("score") or 0.0), int(item.get("chunk_index") or 0)))
hydrated = await self._hydrate_document_hits(dataset_id, scored_chunks[: max(top_k * 3, top_k)])
return hydrated[:top_k]
chunks = await self.retriever._keyword_retrieve_context(
dataset_id=dataset_id,
collection_name=collection_name,
dataset_name=dataset_name,
query=query,
top_k=top_k,
score_threshold=score_threshold,
source_names=None,
)
return chunks[:top_k]
def _build_keyword_terms(self, query: str) -> list[str]:
normalized = self._normalize_keyword_query(query)
spans = [item.strip() for item in re.findall(r"[\u4e00-\u9fffA-Za-z0-9]+", normalized) if item.strip()]
if not spans:
return []
stop_terms = {
"什么",
"请问",
"一下",
"有关",
"关于",
"如何",
"哪些",
"怎么",
"是否",
"规定",
"办法",
"条例",
"法律",
}
terms: list[str] = []
for span in spans:
if span in stop_terms:
continue
terms.append(span)
if re.fullmatch(r"[\u4e00-\u9fff]+", span):
for size in (2, 3, 4):
if len(span) > size:
for start in range(0, len(span) - size + 1):
token = span[start:start + size]
if token not in stop_terms:
terms.append(token)
unique_terms: list[str] = []
seen: set[str] = set()
for term in sorted(terms, key=len, reverse=True):
if term and term not in seen:
unique_terms.append(term)
seen.add(term)
return unique_terms[:20]
return self.retriever._build_keyword_terms(query)
def _normalize_keyword_query(self, query: str) -> str:
normalized = (query or "").strip().lower()
patterns = [
"是什么",
"什么是",
"有哪些",
"有什么",
"是什么?",
"是什么?",
"请问",
"介绍一下",
"解释一下",
"帮我分析",
"帮我看看",
]
for pattern in patterns:
normalized = normalized.replace(pattern, " ")
return re.sub(r"\s+", " ", normalized).strip()
return self.retriever._normalize_keyword_query(query)
def _score_keyword_chunk(self, *, query: str, terms: list[str], content: str, document_name: str) -> float:
haystack = f"{document_name}\n{content}".lower()
if not haystack:
return 0.0
exact_query = self._normalize_keyword_query(query)
if exact_query and exact_query in haystack:
return 0.98
matched_weight = 0.0
total_weight = 0.0
name_bonus = 0.0
for term in terms:
weight = float(max(len(term), 1) ** 2)
total_weight += weight
if term.lower() in haystack:
matched_weight += weight
if term.lower() in document_name.lower():
name_bonus += min(0.15, 0.03 * len(term))
if total_weight <= 0:
return 0.0
score = (matched_weight / total_weight) + name_bonus
return round(min(score, 0.99), 6)
return self.retriever._score_keyword_chunk(
query=query,
terms=terms,
content=content,
document_name=document_name,
)
def _format_sse(self, payload: dict) -> bytes:
return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n".encode("utf-8")
def _build_sources(self, context_chunks: list[dict], dataset_name: str) -> list[dict]:
return [
{
"position": index + 1,
"dataset_id": str(chunk.get("dataset_id") or ""),
"dataset_name": dataset_name,
"document_id": str(chunk.get("document_id") or ""),
"document_name": chunk.get("document_name") or chunk.get("source", ""),
"data_source_type": "upload_file",
"segment_id": chunk.get("id", ""),
"retriever_from": "rag",
"score": round(chunk.get("score", 0.0), 4),
"hit_count": chunk.get("hit_count", 0),
"word_count": len(chunk.get("text", "")),
"segment_position": index + 1,
"index_node_hash": "",
"content": chunk.get("text", "")[:500],
"page": None,
}
for index, chunk in enumerate(context_chunks)
]
build_sources = getattr(self.retriever, "build_sources", None)
if callable(build_sources):
return build_sources(context_chunks, dataset_name)
return RagRetriever(hydrate_documents=False).build_sources(context_chunks, dataset_name)
async def _hydrate_document_hits(self, dataset_id: int, chunks: list[dict]) -> list[dict]:
source_names = sorted(
{
str(chunk.get("document_name") or chunk.get("source") or "").strip()
for chunk in chunks
if str(chunk.get("document_name") or chunk.get("source") or "").strip()
}
)
if not source_names:
return chunks
async with GetAsyncSession() as session:
rows = (
await session.execute(
text(
"""
SELECT id, original_name, enabled, hit_count
FROM rag_document
WHERE dataset_id = :dataset_id
AND deleted_at IS NULL
AND original_name = ANY(:source_names)
"""
),
{
"dataset_id": dataset_id,
"source_names": source_names,
},
)
).mappings().all()
document_map = {str(row["original_name"]): row for row in rows}
visible_chunks: list[dict] = []
hit_document_ids: list[int] = []
for chunk in chunks:
source_name = str(chunk.get("document_name") or chunk.get("source") or "").strip()
document = document_map.get(source_name)
if document and not bool(document.get("enabled")):
continue
if document:
chunk["document_id"] = document["id"]
chunk["dataset_id"] = dataset_id
chunk["document_name"] = document["original_name"]
chunk["hit_count"] = document.get("hit_count") or 0
hit_document_ids.append(int(document["id"]))
visible_chunks.append(chunk)
if hit_document_ids:
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
UPDATE rag_document
SET hit_count = hit_count + 1,
updated_at = NOW()
WHERE id = ANY(:document_ids)
"""
),
{"document_ids": sorted(set(hit_document_ids))},
)
return visible_chunks
return await self.retriever._hydrate_document_hits(dataset_id, chunks)
def _parse_sse_event(self, chunk: str) -> dict | None:
data_lines: list[str] = []