"""RAG 检索器。""" from __future__ import annotations import re from dataclasses import dataclass, field from typing import Any, Awaitable, Callable import httpx from sqlalchemy import text from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG, build_openai_embeddings_url EmbedTexts = Callable[[list[str], str], Awaitable[list[list[float]]] | list[list[float]]] @dataclass class RagRetrieveResult: """RAG 检索结果。""" rag_context: str = "" rag_resources: list[dict[str, Any]] = field(default_factory=list) chunks: list[dict[str, Any]] = field(default_factory=list) dataset_name: str = "" def as_dict(self) -> dict[str, Any]: """转换为规则引擎可注入的字典。""" return { "rag_context": self.rag_context, "rag_resources": self.rag_resources, } class RagRetriever: """RAG 检索器。""" def __init__( self, *, chroma_client: Any | None = None, embed_texts: EmbedTexts | None = None, hydrate_documents: bool = True, ) -> None: """初始化 RAG 检索器。""" self._chroma_client = chroma_client self._embed_texts_override = embed_texts self._hydrate_documents_enabled = hydrate_documents async def retrieve( self, query: str, collection_name: str | None = None, dataset_id: int | None = None, top_k: int = 5, source_names: list[str] | None = None, ) -> RagRetrieveResult: """根据问题检索 RAG 上下文。""" query_text = (query or "").strip() if not query_text: return RagRetrieveResult() top_k = max(1, int(top_k or 5)) dataset = await self._load_dataset(dataset_id) if dataset_id else None resolved_collection = (collection_name or (dataset or {}).get("collection_name") or "").strip() if not resolved_collection: return RagRetrieveResult(dataset_name=str((dataset or {}).get("name") or "")) retrieval_model = (dataset or {}).get("retrieval_model") or {} if dataset and not collection_name: top_k = max(1, int(retrieval_model.get("top_k") or top_k)) score_threshold = self._resolve_score_threshold(retrieval_model) dataset_name = str((dataset or {}).get("name") or "") embedding_model = str((dataset or {}).get("embedding_model") or "") chunks: list[dict[str, Any]] = [] try: chunks = await self._vector_retrieve( query=query_text, collection_name=resolved_collection, dataset_name=dataset_name, embedding_model=embedding_model, top_k=top_k, score_threshold=score_threshold, source_names=source_names, ) except Exception: chunks = [] if not chunks: try: chunks = await self._keyword_retrieve_context( dataset_id=dataset_id, collection_name=resolved_collection, dataset_name=dataset_name, query=query_text, top_k=top_k, score_threshold=score_threshold, source_names=source_names, ) except Exception: chunks = [] if dataset_id and self._hydrate_documents_enabled: chunks = await self._hydrate_document_hits(dataset_id, chunks) chunks = chunks[:top_k] return RagRetrieveResult( rag_context=self._build_context(chunks), rag_resources=self.build_sources(chunks, dataset_name), chunks=chunks, dataset_name=dataset_name, ) async def _load_dataset(self, dataset_id: int | None) -> dict[str, Any] | None: if not dataset_id: return None async with GetAsyncSession() as session: row = ( await session.execute( text( """ SELECT id, name, collection_name, retrieval_model, embedding_model FROM rag_dataset WHERE id = :dataset_id AND deleted_at IS NULL LIMIT 1 """ ), {"dataset_id": dataset_id}, ) ).mappings().first() return dict(row) if row else None def _resolve_score_threshold(self, retrieval_model: dict[str, Any]) -> float | None: if not retrieval_model.get("score_threshold_enabled"): return None try: return float(retrieval_model.get("score_threshold")) except (TypeError, ValueError): return None async def _vector_retrieve( self, *, query: str, collection_name: str, dataset_name: str, embedding_model: str, top_k: int, score_threshold: float | None, source_names: list[str] | None, ) -> list[dict[str, Any]]: query_embedding = await self._embed_texts([query], embedding_model) collection = self._get_chroma().get_or_create_collection(collection_name) result = collection.query( query_embeddings=query_embedding, n_results=max(top_k, 1), include=["documents", "metadatas", "distances"], ) ids = (result.get("ids") or [[]])[0] if result.get("ids") else [] docs = (result.get("documents") or [[]])[0] metas = (result.get("metadatas") or [[]])[0] distances = (result.get("distances") or [[]])[0] allowed_sources = self._normalize_source_filter(source_names) chunks: list[dict[str, Any]] = [] for idx, doc in enumerate(docs): meta = metas[idx] if idx < len(metas) and isinstance(metas[idx], dict) else {} document_name = str(meta.get("document_name") or meta.get("source") or "") if allowed_sources and document_name not in allowed_sources: continue dist = float(distances[idx]) if idx < len(distances) and distances[idx] is not None else 1.0 score = 1.0 / (1.0 + max(dist, 0.0)) if score_threshold is not None and score < score_threshold: continue chunks.append( { "id": str(ids[idx] if idx < len(ids) else meta.get("id") or idx), "text": doc or "", "source": meta.get("source") or meta.get("document_name") or dataset_name, "score": score, "chunk_index": int(meta.get("chunk_index") or idx), "document_name": document_name, "document_id": meta.get("document_id"), "page": meta.get("page"), "source_scope": meta.get("source_scope"), "attachment_id": meta.get("attachment_id"), "conversation_id": meta.get("conversation_id"), "tenant_code": meta.get("tenant_code"), "user_id": meta.get("user_id"), } ) return chunks async def _embed_texts(self, texts: list[str], model_name: str = "") -> list[list[float]]: if self._embed_texts_override is not None: result = self._embed_texts_override(texts, model_name) if hasattr(result, "__await__"): return await result # type: ignore[misc] return result embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or build_openai_embeddings_url(RAG_CONFIG["LLM_BASE_URL"]) embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"] embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4" batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10)) if not embed_url or not embed_key: raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务") embeddings: list[list[float]] = [] async with httpx.AsyncClient(timeout=120.0) as client: for start in range(0, len(texts), batch_size): batch_texts = texts[start:start + batch_size] try: response = await client.post( embed_url, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {embed_key}", }, json={"model": embed_model, "input": batch_texts}, ) response.raise_for_status() except httpx.HTTPStatusError as exc: error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}" raise LeauditException( StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, f"向量化服务调用失败: {error_message[:300]}", ) from exc payload = response.json() rows = payload.get("data") or [] batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")] if len(batch_embeddings) != len(batch_texts): raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常") embeddings.extend(batch_embeddings) if len(embeddings) != len(texts): raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常") return embeddings async def _keyword_retrieve_context( self, *, dataset_id: int | None, collection_name: str, dataset_name: str, query: str, top_k: int, score_threshold: float | None, source_names: list[str] | None, ) -> list[dict[str, Any]]: collection = self._get_chroma().get_or_create_collection(collection_name) raw = collection.get(include=["documents", "metadatas"]) ids = raw.get("ids") or [] docs = raw.get("documents") or [] metas = raw.get("metadatas") or [] allowed_sources = self._normalize_source_filter(source_names) terms = self._build_keyword_terms(query) if not terms: return [] scored_chunks: list[dict[str, Any]] = [] for idx, chunk_id in enumerate(ids): doc = docs[idx] if idx < len(docs) else "" meta = metas[idx] if idx < len(metas) and isinstance(metas[idx], dict) else {} document_name = str(meta.get("document_name") or meta.get("source") or "") if allowed_sources and document_name not in allowed_sources: continue score = self._score_keyword_chunk( query=query, terms=terms, content=doc or "", document_name=document_name, ) if score <= 0: continue if score_threshold is not None and score < score_threshold: continue scored_chunks.append( { "id": str(chunk_id), "text": doc or "", "source": meta.get("source") or meta.get("document_name") or dataset_name, "score": score, "chunk_index": int(meta.get("chunk_index") or idx), "document_name": document_name, "document_id": meta.get("document_id"), "page": meta.get("page"), "source_scope": meta.get("source_scope"), "attachment_id": meta.get("attachment_id"), "conversation_id": meta.get("conversation_id"), "tenant_code": meta.get("tenant_code"), "user_id": meta.get("user_id"), } ) scored_chunks.sort(key=lambda item: (-float(item.get("score") or 0.0), int(item.get("chunk_index") or 0))) return scored_chunks[: max(top_k * 3, top_k)] def _build_keyword_terms(self, query: str) -> list[str]: normalized = self._normalize_keyword_query(query) spans = [item.strip() for item in re.findall(r"[\u4e00-\u9fffA-Za-z0-9]+", normalized) if item.strip()] if not spans: return [] stop_terms = { "什么", "请问", "一下", "有关", "关于", "如何", "哪些", "怎么", "是否", "规定", "办法", "条例", "法律", } terms: list[str] = [] for span in spans: if span in stop_terms: continue terms.append(span) if re.fullmatch(r"[\u4e00-\u9fff]+", span): for size in (2, 3, 4): if len(span) > size: for start in range(0, len(span) - size + 1): token = span[start:start + size] if token not in stop_terms: terms.append(token) unique_terms: list[str] = [] seen: set[str] = set() for term in sorted(terms, key=len, reverse=True): if term and term not in seen: unique_terms.append(term) seen.add(term) return unique_terms[:20] def _normalize_keyword_query(self, query: str) -> str: normalized = (query or "").strip().lower() patterns = [ "是什么", "什么是", "有哪些", "有什么", "是什么?", "是什么?", "请问", "介绍一下", "解释一下", "帮我分析", "帮我看看", ] for pattern in patterns: normalized = normalized.replace(pattern, " ") return re.sub(r"\s+", " ", normalized).strip() def _score_keyword_chunk(self, *, query: str, terms: list[str], content: str, document_name: str) -> float: haystack = f"{document_name}\n{content}".lower() if not haystack: return 0.0 exact_query = self._normalize_keyword_query(query) if exact_query and exact_query in haystack: return 0.98 matched_weight = 0.0 total_weight = 0.0 name_bonus = 0.0 for term in terms: weight = float(max(len(term), 1) ** 2) total_weight += weight if term.lower() in haystack: matched_weight += weight if term.lower() in document_name.lower(): name_bonus += min(0.15, 0.03 * len(term)) if total_weight <= 0: return 0.0 score = (matched_weight / total_weight) + name_bonus return round(min(score, 0.99), 6) def _build_context(self, chunks: list[dict[str, Any]]) -> str: lines: list[str] = [] for index, chunk in enumerate(chunks, start=1): document_name = chunk.get("document_name") or chunk.get("source") or "未知来源" text_value = str(chunk.get("text") or "").strip() if not text_value: continue lines.append(f"[{index}] 来源:{document_name}\n{text_value}") return "\n\n".join(lines) def build_sources(self, context_chunks: list[dict[str, Any]], dataset_name: str = "") -> list[dict[str, Any]]: return [ { "position": index + 1, "dataset_id": str(chunk.get("dataset_id") or ""), "dataset_name": chunk.get("dataset_name") or dataset_name, "document_id": str(chunk.get("document_id") or ""), "document_name": chunk.get("document_name") or chunk.get("source", ""), "data_source_type": chunk.get("data_source_type") or chunk.get("source_scope") or "upload_file", "segment_id": chunk.get("id", ""), "retriever_from": "rag", "score": round(float(chunk.get("score") or 0.0), 4), "hit_count": chunk.get("hit_count", 0), "word_count": len(chunk.get("text", "")), "segment_position": index + 1, "index_node_hash": "", "content": chunk.get("text", "")[:500], "page": None, "source_scope": chunk.get("source_scope") or "", "attachment_id": chunk.get("attachment_id") or "", } for index, chunk in enumerate(context_chunks) ] async def _hydrate_document_hits(self, dataset_id: int, chunks: list[dict[str, Any]]) -> list[dict[str, Any]]: source_names = sorted( { str(chunk.get("document_name") or chunk.get("source") or "").strip() for chunk in chunks if str(chunk.get("document_name") or chunk.get("source") or "").strip() } ) if not source_names: return chunks async with GetAsyncSession() as session: rows = ( await session.execute( text( """ SELECT id, original_name, enabled, hit_count FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL AND original_name = ANY(:source_names) """ ), { "dataset_id": dataset_id, "source_names": source_names, }, ) ).mappings().all() document_map = {str(row["original_name"]): row for row in rows} visible_chunks: list[dict[str, Any]] = [] hit_document_ids: list[int] = [] for chunk in chunks: source_name = str(chunk.get("document_name") or chunk.get("source") or "").strip() document = document_map.get(source_name) if document and not bool(document.get("enabled")): continue if document: chunk["document_id"] = document["id"] chunk["dataset_id"] = dataset_id chunk["document_name"] = document["original_name"] chunk["hit_count"] = document.get("hit_count") or 0 hit_document_ids.append(int(document["id"])) visible_chunks.append(chunk) if hit_document_ids: async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ UPDATE rag_document SET hit_count = hit_count + 1, updated_at = NOW() WHERE id = ANY(:document_ids) """ ), {"document_ids": sorted(set(hit_document_ids))}, ) return visible_chunks def _normalize_source_filter(self, source_names: list[str] | None) -> set[str]: return {str(name).strip() for name in (source_names or []) if str(name).strip()} def _get_chroma(self) -> Any: return self._chroma_client or get_chroma()