492 lines
20 KiB
Python
492 lines
20 KiB
Python
"""RAG 检索器。"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import re
|
|
from dataclasses import dataclass, field
|
|
from typing import Any, Awaitable, Callable
|
|
|
|
import httpx
|
|
from sqlalchemy import text
|
|
|
|
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
|
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
|
|
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
|
|
from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma
|
|
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG, build_openai_embeddings_url
|
|
|
|
|
|
EmbedTexts = Callable[[list[str], str], Awaitable[list[list[float]]] | list[list[float]]]
|
|
|
|
|
|
@dataclass
|
|
class RagRetrieveResult:
|
|
"""RAG 检索结果。"""
|
|
|
|
rag_context: str = ""
|
|
rag_resources: list[dict[str, Any]] = field(default_factory=list)
|
|
chunks: list[dict[str, Any]] = field(default_factory=list)
|
|
dataset_name: str = ""
|
|
|
|
def as_dict(self) -> dict[str, Any]:
|
|
"""转换为规则引擎可注入的字典。"""
|
|
return {
|
|
"rag_context": self.rag_context,
|
|
"rag_resources": self.rag_resources,
|
|
}
|
|
|
|
|
|
class RagRetriever:
|
|
"""RAG 检索器。"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
chroma_client: Any | None = None,
|
|
embed_texts: EmbedTexts | None = None,
|
|
hydrate_documents: bool = True,
|
|
) -> None:
|
|
"""初始化 RAG 检索器。"""
|
|
self._chroma_client = chroma_client
|
|
self._embed_texts_override = embed_texts
|
|
self._hydrate_documents_enabled = hydrate_documents
|
|
|
|
async def retrieve(
|
|
self,
|
|
query: str,
|
|
collection_name: str | None = None,
|
|
dataset_id: int | None = None,
|
|
top_k: int = 5,
|
|
source_names: list[str] | None = None,
|
|
) -> RagRetrieveResult:
|
|
"""根据问题检索 RAG 上下文。"""
|
|
query_text = (query or "").strip()
|
|
if not query_text:
|
|
return RagRetrieveResult()
|
|
|
|
top_k = max(1, int(top_k or 5))
|
|
dataset = await self._load_dataset(dataset_id) if dataset_id else None
|
|
resolved_collection = (collection_name or (dataset or {}).get("collection_name") or "").strip()
|
|
if not resolved_collection:
|
|
return RagRetrieveResult(dataset_name=str((dataset or {}).get("name") or ""))
|
|
|
|
retrieval_model = (dataset or {}).get("retrieval_model") or {}
|
|
if dataset and not collection_name:
|
|
top_k = max(1, int(retrieval_model.get("top_k") or top_k))
|
|
score_threshold = self._resolve_score_threshold(retrieval_model)
|
|
dataset_name = str((dataset or {}).get("name") or "")
|
|
embedding_model = str((dataset or {}).get("embedding_model") or "")
|
|
|
|
chunks: list[dict[str, Any]] = []
|
|
try:
|
|
chunks = await self._vector_retrieve(
|
|
query=query_text,
|
|
collection_name=resolved_collection,
|
|
dataset_name=dataset_name,
|
|
embedding_model=embedding_model,
|
|
top_k=top_k,
|
|
score_threshold=score_threshold,
|
|
source_names=source_names,
|
|
)
|
|
except Exception:
|
|
chunks = []
|
|
|
|
if not chunks:
|
|
try:
|
|
chunks = await self._keyword_retrieve_context(
|
|
dataset_id=dataset_id,
|
|
collection_name=resolved_collection,
|
|
dataset_name=dataset_name,
|
|
query=query_text,
|
|
top_k=top_k,
|
|
score_threshold=score_threshold,
|
|
source_names=source_names,
|
|
)
|
|
except Exception:
|
|
chunks = []
|
|
|
|
if dataset_id and self._hydrate_documents_enabled:
|
|
chunks = await self._hydrate_document_hits(dataset_id, chunks)
|
|
|
|
chunks = chunks[:top_k]
|
|
return RagRetrieveResult(
|
|
rag_context=self._build_context(chunks),
|
|
rag_resources=self.build_sources(chunks, dataset_name),
|
|
chunks=chunks,
|
|
dataset_name=dataset_name,
|
|
)
|
|
|
|
async def _load_dataset(self, dataset_id: int | None) -> dict[str, Any] | None:
|
|
if not dataset_id:
|
|
return None
|
|
async with GetAsyncSession() as session:
|
|
row = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
SELECT id, name, collection_name, retrieval_model, embedding_model
|
|
FROM rag_dataset
|
|
WHERE id = :dataset_id AND deleted_at IS NULL
|
|
LIMIT 1
|
|
"""
|
|
),
|
|
{"dataset_id": dataset_id},
|
|
)
|
|
).mappings().first()
|
|
return dict(row) if row else None
|
|
|
|
def _resolve_score_threshold(self, retrieval_model: dict[str, Any]) -> float | None:
|
|
if not retrieval_model.get("score_threshold_enabled"):
|
|
return None
|
|
try:
|
|
return float(retrieval_model.get("score_threshold"))
|
|
except (TypeError, ValueError):
|
|
return None
|
|
|
|
async def _vector_retrieve(
|
|
self,
|
|
*,
|
|
query: str,
|
|
collection_name: str,
|
|
dataset_name: str,
|
|
embedding_model: str,
|
|
top_k: int,
|
|
score_threshold: float | None,
|
|
source_names: list[str] | None,
|
|
) -> list[dict[str, Any]]:
|
|
query_embedding = await self._embed_texts([query], embedding_model)
|
|
collection = self._get_chroma().get_or_create_collection(collection_name)
|
|
result = collection.query(
|
|
query_embeddings=query_embedding,
|
|
n_results=max(top_k, 1),
|
|
include=["documents", "metadatas", "distances"],
|
|
)
|
|
ids = (result.get("ids") or [[]])[0] if result.get("ids") else []
|
|
docs = (result.get("documents") or [[]])[0]
|
|
metas = (result.get("metadatas") or [[]])[0]
|
|
distances = (result.get("distances") or [[]])[0]
|
|
allowed_sources = self._normalize_source_filter(source_names)
|
|
chunks: list[dict[str, Any]] = []
|
|
for idx, doc in enumerate(docs):
|
|
meta = metas[idx] if idx < len(metas) and isinstance(metas[idx], dict) else {}
|
|
document_name = str(meta.get("document_name") or meta.get("source") or "")
|
|
if allowed_sources and document_name not in allowed_sources:
|
|
continue
|
|
dist = float(distances[idx]) if idx < len(distances) and distances[idx] is not None else 1.0
|
|
score = 1.0 / (1.0 + max(dist, 0.0))
|
|
if score_threshold is not None and score < score_threshold:
|
|
continue
|
|
chunks.append(
|
|
{
|
|
"id": str(ids[idx] if idx < len(ids) else meta.get("id") or idx),
|
|
"text": doc or "",
|
|
"source": meta.get("source") or meta.get("document_name") or dataset_name,
|
|
"score": score,
|
|
"chunk_index": int(meta.get("chunk_index") or idx),
|
|
"document_name": document_name,
|
|
"document_id": meta.get("document_id"),
|
|
"page": meta.get("page"),
|
|
"source_scope": meta.get("source_scope"),
|
|
"attachment_id": meta.get("attachment_id"),
|
|
"conversation_id": meta.get("conversation_id"),
|
|
"tenant_code": meta.get("tenant_code"),
|
|
"user_id": meta.get("user_id"),
|
|
}
|
|
)
|
|
return chunks
|
|
|
|
async def _embed_texts(self, texts: list[str], model_name: str = "") -> list[list[float]]:
|
|
if self._embed_texts_override is not None:
|
|
result = self._embed_texts_override(texts, model_name)
|
|
if hasattr(result, "__await__"):
|
|
return await result # type: ignore[misc]
|
|
return result
|
|
|
|
embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or build_openai_embeddings_url(RAG_CONFIG["LLM_BASE_URL"])
|
|
embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"]
|
|
embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4"
|
|
batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10))
|
|
if not embed_url or not embed_key:
|
|
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务")
|
|
|
|
embeddings: list[list[float]] = []
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
for start in range(0, len(texts), batch_size):
|
|
batch_texts = texts[start:start + batch_size]
|
|
try:
|
|
response = await client.post(
|
|
embed_url,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {embed_key}",
|
|
},
|
|
json={"model": embed_model, "input": batch_texts},
|
|
)
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}"
|
|
raise LeauditException(
|
|
StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
f"向量化服务调用失败: {error_message[:300]}",
|
|
) from exc
|
|
|
|
payload = response.json()
|
|
rows = payload.get("data") or []
|
|
batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")]
|
|
if len(batch_embeddings) != len(batch_texts):
|
|
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
|
|
embeddings.extend(batch_embeddings)
|
|
|
|
if len(embeddings) != len(texts):
|
|
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
|
|
return embeddings
|
|
|
|
async def _keyword_retrieve_context(
|
|
self,
|
|
*,
|
|
dataset_id: int | None,
|
|
collection_name: str,
|
|
dataset_name: str,
|
|
query: str,
|
|
top_k: int,
|
|
score_threshold: float | None,
|
|
source_names: list[str] | None,
|
|
) -> list[dict[str, Any]]:
|
|
collection = self._get_chroma().get_or_create_collection(collection_name)
|
|
raw = collection.get(include=["documents", "metadatas"])
|
|
ids = raw.get("ids") or []
|
|
docs = raw.get("documents") or []
|
|
metas = raw.get("metadatas") or []
|
|
allowed_sources = self._normalize_source_filter(source_names)
|
|
|
|
terms = self._build_keyword_terms(query)
|
|
if not terms:
|
|
return []
|
|
|
|
scored_chunks: list[dict[str, Any]] = []
|
|
for idx, chunk_id in enumerate(ids):
|
|
doc = docs[idx] if idx < len(docs) else ""
|
|
meta = metas[idx] if idx < len(metas) and isinstance(metas[idx], dict) else {}
|
|
document_name = str(meta.get("document_name") or meta.get("source") or "")
|
|
if allowed_sources and document_name not in allowed_sources:
|
|
continue
|
|
score = self._score_keyword_chunk(
|
|
query=query,
|
|
terms=terms,
|
|
content=doc or "",
|
|
document_name=document_name,
|
|
)
|
|
if score <= 0:
|
|
continue
|
|
if score_threshold is not None and score < score_threshold:
|
|
continue
|
|
scored_chunks.append(
|
|
{
|
|
"id": str(chunk_id),
|
|
"text": doc or "",
|
|
"source": meta.get("source") or meta.get("document_name") or dataset_name,
|
|
"score": score,
|
|
"chunk_index": int(meta.get("chunk_index") or idx),
|
|
"document_name": document_name,
|
|
"document_id": meta.get("document_id"),
|
|
"page": meta.get("page"),
|
|
"source_scope": meta.get("source_scope"),
|
|
"attachment_id": meta.get("attachment_id"),
|
|
"conversation_id": meta.get("conversation_id"),
|
|
"tenant_code": meta.get("tenant_code"),
|
|
"user_id": meta.get("user_id"),
|
|
}
|
|
)
|
|
|
|
scored_chunks.sort(key=lambda item: (-float(item.get("score") or 0.0), int(item.get("chunk_index") or 0)))
|
|
return scored_chunks[: max(top_k * 3, top_k)]
|
|
|
|
def _build_keyword_terms(self, query: str) -> list[str]:
|
|
normalized = self._normalize_keyword_query(query)
|
|
spans = [item.strip() for item in re.findall(r"[\u4e00-\u9fffA-Za-z0-9]+", normalized) if item.strip()]
|
|
if not spans:
|
|
return []
|
|
|
|
stop_terms = {
|
|
"什么",
|
|
"请问",
|
|
"一下",
|
|
"有关",
|
|
"关于",
|
|
"如何",
|
|
"哪些",
|
|
"怎么",
|
|
"是否",
|
|
"规定",
|
|
"办法",
|
|
"条例",
|
|
"法律",
|
|
}
|
|
terms: list[str] = []
|
|
for span in spans:
|
|
if span in stop_terms:
|
|
continue
|
|
terms.append(span)
|
|
if re.fullmatch(r"[\u4e00-\u9fff]+", span):
|
|
for size in (2, 3, 4):
|
|
if len(span) > size:
|
|
for start in range(0, len(span) - size + 1):
|
|
token = span[start:start + size]
|
|
if token not in stop_terms:
|
|
terms.append(token)
|
|
|
|
unique_terms: list[str] = []
|
|
seen: set[str] = set()
|
|
for term in sorted(terms, key=len, reverse=True):
|
|
if term and term not in seen:
|
|
unique_terms.append(term)
|
|
seen.add(term)
|
|
return unique_terms[:20]
|
|
|
|
def _normalize_keyword_query(self, query: str) -> str:
|
|
normalized = (query or "").strip().lower()
|
|
patterns = [
|
|
"是什么",
|
|
"什么是",
|
|
"有哪些",
|
|
"有什么",
|
|
"是什么?",
|
|
"是什么?",
|
|
"请问",
|
|
"介绍一下",
|
|
"解释一下",
|
|
"帮我分析",
|
|
"帮我看看",
|
|
]
|
|
for pattern in patterns:
|
|
normalized = normalized.replace(pattern, " ")
|
|
return re.sub(r"\s+", " ", normalized).strip()
|
|
|
|
def _score_keyword_chunk(self, *, query: str, terms: list[str], content: str, document_name: str) -> float:
|
|
haystack = f"{document_name}\n{content}".lower()
|
|
if not haystack:
|
|
return 0.0
|
|
|
|
exact_query = self._normalize_keyword_query(query)
|
|
if exact_query and exact_query in haystack:
|
|
return 0.98
|
|
|
|
matched_weight = 0.0
|
|
total_weight = 0.0
|
|
name_bonus = 0.0
|
|
for term in terms:
|
|
weight = float(max(len(term), 1) ** 2)
|
|
total_weight += weight
|
|
if term.lower() in haystack:
|
|
matched_weight += weight
|
|
if term.lower() in document_name.lower():
|
|
name_bonus += min(0.15, 0.03 * len(term))
|
|
|
|
if total_weight <= 0:
|
|
return 0.0
|
|
score = (matched_weight / total_weight) + name_bonus
|
|
return round(min(score, 0.99), 6)
|
|
|
|
def _build_context(self, chunks: list[dict[str, Any]]) -> str:
|
|
lines: list[str] = []
|
|
for index, chunk in enumerate(chunks, start=1):
|
|
document_name = chunk.get("document_name") or chunk.get("source") or "未知来源"
|
|
text_value = str(chunk.get("text") or "").strip()
|
|
if not text_value:
|
|
continue
|
|
lines.append(f"[{index}] 来源:{document_name}\n{text_value}")
|
|
return "\n\n".join(lines)
|
|
|
|
def build_sources(self, context_chunks: list[dict[str, Any]], dataset_name: str = "") -> list[dict[str, Any]]:
|
|
return [
|
|
{
|
|
"position": index + 1,
|
|
"dataset_id": str(chunk.get("dataset_id") or ""),
|
|
"dataset_name": chunk.get("dataset_name") or dataset_name,
|
|
"document_id": str(chunk.get("document_id") or ""),
|
|
"document_name": chunk.get("document_name") or chunk.get("source", ""),
|
|
"data_source_type": chunk.get("data_source_type") or chunk.get("source_scope") or "upload_file",
|
|
"segment_id": chunk.get("id", ""),
|
|
"retriever_from": "rag",
|
|
"score": round(float(chunk.get("score") or 0.0), 4),
|
|
"hit_count": chunk.get("hit_count", 0),
|
|
"word_count": len(chunk.get("text", "")),
|
|
"segment_position": index + 1,
|
|
"index_node_hash": "",
|
|
"content": chunk.get("text", "")[:500],
|
|
"page": None,
|
|
"source_scope": chunk.get("source_scope") or "",
|
|
"attachment_id": chunk.get("attachment_id") or "",
|
|
}
|
|
for index, chunk in enumerate(context_chunks)
|
|
]
|
|
|
|
async def _hydrate_document_hits(self, dataset_id: int, chunks: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
|
source_names = sorted(
|
|
{
|
|
str(chunk.get("document_name") or chunk.get("source") or "").strip()
|
|
for chunk in chunks
|
|
if str(chunk.get("document_name") or chunk.get("source") or "").strip()
|
|
}
|
|
)
|
|
if not source_names:
|
|
return chunks
|
|
|
|
async with GetAsyncSession() as session:
|
|
rows = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
SELECT id, original_name, enabled, hit_count
|
|
FROM rag_document
|
|
WHERE dataset_id = :dataset_id
|
|
AND deleted_at IS NULL
|
|
AND original_name = ANY(:source_names)
|
|
"""
|
|
),
|
|
{
|
|
"dataset_id": dataset_id,
|
|
"source_names": source_names,
|
|
},
|
|
)
|
|
).mappings().all()
|
|
|
|
document_map = {str(row["original_name"]): row for row in rows}
|
|
visible_chunks: list[dict[str, Any]] = []
|
|
hit_document_ids: list[int] = []
|
|
for chunk in chunks:
|
|
source_name = str(chunk.get("document_name") or chunk.get("source") or "").strip()
|
|
document = document_map.get(source_name)
|
|
if document and not bool(document.get("enabled")):
|
|
continue
|
|
if document:
|
|
chunk["document_id"] = document["id"]
|
|
chunk["dataset_id"] = dataset_id
|
|
chunk["document_name"] = document["original_name"]
|
|
chunk["hit_count"] = document.get("hit_count") or 0
|
|
hit_document_ids.append(int(document["id"]))
|
|
visible_chunks.append(chunk)
|
|
|
|
if hit_document_ids:
|
|
async with GetAsyncSession() as session:
|
|
async with session.begin():
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
UPDATE rag_document
|
|
SET hit_count = hit_count + 1,
|
|
updated_at = NOW()
|
|
WHERE id = ANY(:document_ids)
|
|
"""
|
|
),
|
|
{"document_ids": sorted(set(hit_document_ids))},
|
|
)
|
|
|
|
return visible_chunks
|
|
|
|
def _normalize_source_filter(self, source_names: list[str] | None) -> set[str]:
|
|
return {str(name).strip() for name in (source_names or []) if str(name).strip()}
|
|
|
|
def _get_chroma(self) -> Any:
|
|
return self._chroma_client or get_chroma()
|