diff --git a/fastapi_modules/fastapi_leaudit/domian/vo/ragChatVo.py b/fastapi_modules/fastapi_leaudit/domian/vo/ragChatVo.py index ddf8022..e020a16 100644 --- a/fastapi_modules/fastapi_leaudit/domian/vo/ragChatVo.py +++ b/fastapi_modules/fastapi_leaudit/domian/vo/ragChatVo.py @@ -34,6 +34,7 @@ class RagMessageItemVO(BaseModel): answer: str = Field(...) feedback: dict | None = Field(None) retrieverResources: list[dict] | None = Field(None) + suggestedQuestions: list[str] = Field(default_factory=list) createdAt: int = Field(0) diff --git a/fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py b/fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py index f97faf1..33cf0bb 100644 --- a/fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py +++ b/fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py @@ -4,6 +4,7 @@ import json import uuid from typing import AsyncGenerator +import httpx from sqlalchemy import text from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession @@ -25,6 +26,7 @@ from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import ( RagMessagePageVO, RagOperationResultVO, ) +from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG from fastapi_modules.fastapi_leaudit.rag_engine.generator import generate_stream from fastapi_modules.fastapi_leaudit.rag_engine.question_chains import generate_followups from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService @@ -194,7 +196,7 @@ class RagChatServiceImpl(IRagChatService): await session.execute( text( """ - SELECT message_id, role, content, sources, feedback, created_at + SELECT message_id, role, content, sources, metadata, feedback, created_at FROM rag_message WHERE conversation_id = :conversation_id ORDER BY created_at ASC @@ -216,6 +218,11 @@ class RagChatServiceImpl(IRagChatService): row = items[idx] if row["role"] == "user": answer = items[idx + 1] if idx + 1 < len(items) and items[idx + 1]["role"] == "assistant" else None + answer_sources = self._parse_json_field(answer.get("sources")) if answer else [] + answer_metadata = self._parse_json_field(answer.get("metadata")) if answer else {} + suggested_questions = answer_metadata.get("suggested_questions") if isinstance(answer_metadata, dict) else [] + if not isinstance(suggested_questions, list): + suggested_questions = [] data.append( RagMessageItemVO( id=(answer["message_id"] if answer else row["message_id"]), @@ -223,7 +230,8 @@ class RagChatServiceImpl(IRagChatService): query=row["content"], answer=answer["content"] if answer else "", feedback=({"rating": answer["feedback"]} if answer and answer.get("feedback") else None), - retrieverResources=(answer.get("sources") if answer else None), + retrieverResources=answer_sources or None, + suggestedQuestions=[str(item) for item in suggested_questions], createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, ) ) @@ -392,6 +400,18 @@ class RagChatServiceImpl(IRagChatService): area = row.get("area") or "" return area in ("", "省级", user_area or "") or bool(row.get("dataset_public")) + def _parse_json_field(self, value): + if value is None: + return {} + if isinstance(value, (dict, list)): + return value + if isinstance(value, str): + try: + return json.loads(value) + except Exception: + return {} + return {} + async def _ensure_conversation(self, user_id: int, conversation_id: str | None, app_id: int | None) -> str: if conversation_id and conversation_id != "-1": async with GetAsyncSession() as session: @@ -450,7 +470,7 @@ class RagChatServiceImpl(IRagChatService): await session.execute( text( """ - SELECT id, name, collection_name, retrieval_model + SELECT id, name, collection_name, retrieval_model, embedding_model FROM rag_dataset WHERE id = :dataset_id AND deleted_at IS NULL LIMIT 1 @@ -475,7 +495,12 @@ class RagChatServiceImpl(IRagChatService): return [], dataset.get("name") or "" try: collection = get_chroma().get_or_create_collection(dataset["collection_name"]) - result = collection.query(query_texts=[query], n_results=max(top_k, 1)) + query_embedding = await self._embed_texts([query], dataset.get("embedding_model") or "") + result = collection.query( + query_embeddings=query_embedding, + n_results=max(top_k, 1), + include=["documents", "metadatas", "distances"], + ) docs = (result.get("documents") or [[]])[0] metas = (result.get("metadatas") or [[]])[0] distances = (result.get("distances") or [[]])[0] @@ -483,7 +508,8 @@ class RagChatServiceImpl(IRagChatService): for idx, doc in enumerate(docs): meta = metas[idx] if idx < len(metas) else {} dist = distances[idx] if idx < len(distances) else 0.0 - score = 1 - float(dist or 0.0) + distance = max(0.0, float(dist or 0.0)) + score = 1.0 / (1.0 + distance) if score_threshold is not None and score < score_threshold: continue chunks.append( @@ -501,6 +527,46 @@ class RagChatServiceImpl(IRagChatService): except Exception: return [], dataset.get("name") or "" + async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]: + embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}/embeddings" + embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"] + embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4" + batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10)) + if not embed_url or not embed_key: + raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务") + + embeddings: list[list[float]] = [] + async with httpx.AsyncClient(timeout=120.0) as client: + for start in range(0, len(texts), batch_size): + batch_texts = texts[start:start + batch_size] + try: + response = await client.post( + embed_url, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {embed_key}", + }, + json={"model": embed_model, "input": batch_texts}, + ) + response.raise_for_status() + except httpx.HTTPStatusError as exc: + error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}" + raise LeauditException( + StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, + f"向量化服务调用失败: {error_message[:300]}", + ) from exc + + payload = response.json() + rows = payload.get("data") or [] + batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")] + if len(batch_embeddings) != len(batch_texts): + raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常") + embeddings.extend(batch_embeddings) + + if len(embeddings) != len(texts): + raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常") + return embeddings + def _build_sources(self, context_chunks: list[dict], dataset_name: str) -> list[dict]: return [ { diff --git a/fastapi_modules/fastapi_leaudit/services/impl/ragDatasetServiceImpl.py b/fastapi_modules/fastapi_leaudit/services/impl/ragDatasetServiceImpl.py index 893290d..360988e 100644 --- a/fastapi_modules/fastapi_leaudit/services/impl/ragDatasetServiceImpl.py +++ b/fastapi_modules/fastapi_leaudit/services/impl/ragDatasetServiceImpl.py @@ -1186,7 +1186,7 @@ class RagDatasetServiceImpl(IRagDatasetService): content = documents[index] if index < len(documents) else "" metadata = metadatas[index] if index < len(metadatas) and isinstance(metadatas[index], dict) else {} distance = float(distances[index]) if index < len(distances) and distances[index] is not None else 1.0 - score = max(0.0, min(1.0, 1.0 - distance)) + score = max(0.0, min(1.0, 1.0 / (1.0 + max(0.0, distance)))) if score_threshold_enabled and score_threshold is not None and score < score_threshold: continue