from __future__ import annotations import json import uuid from typing import AsyncGenerator from sqlalchemy import text from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException from fastapi_modules.fastapi_leaudit.domian.Dto.ragChatDto import ( RagConversationRenameDTO, RagMessageFeedbackDTO, ) from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import ( RagAppParametersVO, RagChatAppListVO, RagChatAppVO, RagConversationItemVO, RagConversationPageVO, RagConversationRenameVO, RagMessageItemVO, RagMessagePageVO, RagOperationResultVO, ) from fastapi_modules.fastapi_leaudit.rag_engine.generator import generate_stream from fastapi_modules.fastapi_leaudit.rag_engine.question_chains import generate_followups from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService class RagChatServiceImpl(IRagChatService): async def GetApps(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagChatAppListVO: apps = await self._load_apps(UserArea, UserRole, only_default=False) return RagChatAppListVO(data=apps, total=len(apps)) async def GetDefaultApp(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagChatAppVO | None: apps = await self._load_apps(UserArea, UserRole, only_default=True) if apps: return apps[0] all_apps = await self._load_apps(UserArea, UserRole, only_default=False) return all_apps[0] if all_apps else None async def SendMessage( self, CurrentUserId: int, UserName: str, UserArea: str | None, UserRole: str | None, Query: str, ConversationId: str | None, AppId: int | None, ) -> AsyncGenerator[bytes, None]: if not Query.strip(): raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "问题不能为空") app = await self._resolve_app(AppId, UserArea, UserRole) if not app: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "未配置可用聊天应用") conversationId = await self._ensure_conversation(CurrentUserId, ConversationId, app["id"]) messageId = str(uuid.uuid4()) async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata) VALUES (:message_id, :conversation_id, 'user', :content, '[]'::jsonb, '{}'::jsonb) """ ), { "message_id": str(uuid.uuid4()), "conversation_id": conversationId, "content": Query, }, ) context_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), Query) collected_answer = "" held_message_end: dict | None = None async for chunk in generate_stream( query=Query, context_chunks=context_chunks, conversation_id=conversationId, message_id=messageId, system_prompt=app.get("system_prompt") or "", model=app.get("llm_model") or "", temperature=app.get("temperature"), max_tokens=app.get("max_tokens"), dataset_name=dataset_name, ): chunk_bytes = chunk.encode("utf-8") data = self._parse_sse_event(chunk) if not data: yield chunk_bytes continue if data.get("event") == "message": collected_answer += data.get("answer", "") yield chunk_bytes continue if data.get("event") == "message_end": held_message_end = data continue yield chunk_bytes followups: list[str] = [] try: followups = await generate_followups(Query, collected_answer) except Exception: followups = [] if held_message_end: try: held_message_end.setdefault("metadata", {})["suggested_questions"] = followups yield f"data: {json.dumps(held_message_end, ensure_ascii=False)}\n\n".encode("utf-8") except Exception: yield f"data: {json.dumps(held_message_end, ensure_ascii=False)}\n\n".encode("utf-8") async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata) VALUES (:message_id, :conversation_id, 'assistant', :content, CAST(:sources AS jsonb), CAST(:metadata AS jsonb)) """ ), { "message_id": messageId, "conversation_id": conversationId, "content": collected_answer, "sources": json.dumps(self._build_sources(context_chunks, dataset_name), ensure_ascii=False), "metadata": json.dumps({"suggested_questions": followups}, ensure_ascii=False), }, ) await session.execute( text( "UPDATE rag_conversation SET updated_at = NOW() WHERE conversation_id = :conversation_id" ), {"conversation_id": conversationId}, ) async def GetConversations(self, CurrentUserId: int, AppId: int | None, Page: int, PageSize: int) -> RagConversationPageVO: async with GetAsyncSession() as session: rows = ( await session.execute( text( """ SELECT conversation_id, name, introduction, created_at, updated_at FROM rag_conversation WHERE user_id = :user_id AND deleted_at IS NULL AND (CAST(:app_id AS BIGINT) IS NULL OR app_id = CAST(:app_id AS BIGINT)) ORDER BY updated_at DESC OFFSET :offset LIMIT :limit """ ), { "user_id": CurrentUserId, "app_id": AppId, "offset": max(Page - 1, 0) * PageSize, "limit": PageSize + 1, }, ) ).mappings().all() has_more = len(rows) > PageSize items = rows[:PageSize] return RagConversationPageVO( data=[ RagConversationItemVO( id=row["conversation_id"], name=row["name"], introduction=row.get("introduction") or "", createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0, ) for row in items ], hasMore=has_more, limit=PageSize, ) async def GetConversationMessages(self, CurrentUserId: int, ConversationId: str, Page: int, PageSize: int) -> RagMessagePageVO: await self._ensure_conversation_owner(CurrentUserId, ConversationId) async with GetAsyncSession() as session: rows = ( await session.execute( text( """ SELECT message_id, role, content, sources, feedback, created_at FROM rag_message WHERE conversation_id = :conversation_id ORDER BY created_at ASC OFFSET :offset LIMIT :limit """ ), { "conversation_id": ConversationId, "offset": max(Page - 1, 0) * PageSize, "limit": PageSize + 1, }, ) ).mappings().all() has_more = len(rows) > PageSize items = rows[:PageSize] data: list[RagMessageItemVO] = [] idx = 0 while idx < len(items): row = items[idx] if row["role"] == "user": answer = items[idx + 1] if idx + 1 < len(items) and items[idx + 1]["role"] == "assistant" else None data.append( RagMessageItemVO( id=(answer["message_id"] if answer else row["message_id"]), conversationId=ConversationId, query=row["content"], answer=answer["content"] if answer else "", feedback=({"rating": answer["feedback"]} if answer and answer.get("feedback") else None), retrieverResources=(answer.get("sources") if answer else None), createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, ) ) idx += 2 if answer else 1 else: idx += 1 return RagMessagePageVO(data=data, hasMore=has_more, limit=PageSize) async def RenameConversation(self, CurrentUserId: int, ConversationId: str, Body: RagConversationRenameDTO) -> RagConversationRenameVO: await self._ensure_conversation_owner(CurrentUserId, ConversationId) async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( "UPDATE rag_conversation SET name = :name, updated_at = NOW() WHERE conversation_id = :conversation_id" ), {"name": Body.name, "conversation_id": ConversationId}, ) return RagConversationRenameVO(result="success", name=Body.name) async def DeleteConversation(self, CurrentUserId: int, ConversationId: str) -> RagOperationResultVO: await self._ensure_conversation_owner(CurrentUserId, ConversationId) async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( "UPDATE rag_conversation SET deleted_at = NOW(), updated_at = NOW() WHERE conversation_id = :conversation_id" ), {"conversation_id": ConversationId}, ) return RagOperationResultVO(result="success") async def UpdateFeedback(self, CurrentUserId: int, MessageId: str, Body: RagMessageFeedbackDTO) -> RagOperationResultVO: async with GetAsyncSession() as session: owner = ( await session.execute( text( """ SELECT c.user_id FROM rag_message m JOIN rag_conversation c ON c.conversation_id = m.conversation_id WHERE m.message_id = :message_id AND c.deleted_at IS NULL LIMIT 1 """ ), {"message_id": MessageId}, ) ).scalar_one_or_none() if owner is None: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "消息不存在") if int(owner) != CurrentUserId: raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权修改该消息反馈") await session.execute( text("UPDATE rag_message SET feedback = :feedback WHERE message_id = :message_id"), {"feedback": Body.rating, "message_id": MessageId}, ) return RagOperationResultVO(result="success") async def GetAppParameters( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, AppId: int | None, ) -> RagAppParametersVO: app = await self._resolve_app(AppId, UserArea, UserRole) if not app: return RagAppParametersVO() try: suggested = json.loads(app.get("suggested_questions") or "[]") if not isinstance(suggested, list): suggested = [] except Exception: suggested = [] return RagAppParametersVO( openingStatement=app.get("opening_statement") or "", suggestedQuestions=[str(item) for item in suggested[:6]], userInputForm=[], fileUpload={"image": {"enabled": False}}, ) async def _load_apps(self, user_area: str | None, user_role: str | None, only_default: bool) -> list[RagChatAppVO]: async with GetAsyncSession() as session: sql = ( """ SELECT a.id, a.name, a.description, a.is_default FROM rag_chat_app a LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL WHERE a.deleted_at IS NULL AND a.status = 1 AND (:only_default = FALSE OR a.is_default = TRUE) AND ( :is_provincial = TRUE OR a.area IN (:user_area, '省级', '') OR COALESCE(d.is_public, FALSE) = TRUE ) ORDER BY a.sort_order ASC, a.created_at DESC """ ) rows = ( await session.execute( text(sql), { "only_default": only_default, "is_provincial": user_role == "provincial_admin", "user_area": user_area or "", }, ) ).mappings().all() return [ RagChatAppVO( appId=str(row["id"]), appName=row["name"], description=row.get("description") or "", isDefault=bool(row.get("is_default")), ) for row in rows ] async def _resolve_app(self, app_id: int | None, user_area: str | None, user_role: str | None) -> dict | None: async with GetAsyncSession() as session: params = { "app_id": app_id, "user_area": user_area or "", "is_provincial": user_role == "provincial_admin", } base_sql = ( """ SELECT a.id, a.name, a.description, a.area, a.dataset_id, a.system_prompt, a.llm_model, a.temperature, a.max_tokens, a.opening_statement, a.suggested_questions, a.is_default, COALESCE(d.is_public, FALSE) AS dataset_public, COALESCE(d.name, '') AS dataset_name FROM rag_chat_app a LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL WHERE a.deleted_at IS NULL AND a.status = 1 """ ) if app_id is not None: row = ( await session.execute( text(base_sql + " AND a.id = :app_id LIMIT 1"), params, ) ).mappings().first() if row and self._app_visible(row, user_area, user_role): return dict(row) row = ( await session.execute( text(base_sql + " AND a.is_default = TRUE ORDER BY a.sort_order ASC, a.created_at DESC LIMIT 1"), params, ) ).mappings().first() if row and self._app_visible(row, user_area, user_role): return dict(row) row = ( await session.execute( text(base_sql + " ORDER BY a.sort_order ASC, a.created_at DESC LIMIT 1"), params, ) ).mappings().first() return dict(row) if row and self._app_visible(row, user_area, user_role) else None def _app_visible(self, row: dict, user_area: str | None, user_role: str | None) -> bool: if user_role == "provincial_admin": return True area = row.get("area") or "" return area in ("", "省级", user_area or "") or bool(row.get("dataset_public")) async def _ensure_conversation(self, user_id: int, conversation_id: str | None, app_id: int | None) -> str: if conversation_id and conversation_id != "-1": async with GetAsyncSession() as session: row = ( await session.execute( text( """ SELECT conversation_id, user_id FROM rag_conversation WHERE conversation_id = :conversation_id AND deleted_at IS NULL LIMIT 1 """ ), {"conversation_id": conversation_id}, ) ).mappings().first() if row: if int(row["user_id"]) != user_id: raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权使用该会话") return str(row["conversation_id"]) conversation_id = str(uuid.uuid4()) async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ INSERT INTO rag_conversation (conversation_id, user_id, app_id, name, introduction) VALUES (:conversation_id, :user_id, :app_id, '新对话', '') """ ), {"conversation_id": conversation_id, "user_id": user_id, "app_id": app_id}, ) return conversation_id async def _ensure_conversation_owner(self, user_id: int, conversation_id: str) -> None: async with GetAsyncSession() as session: owner = ( await session.execute( text( "SELECT user_id FROM rag_conversation WHERE conversation_id = :conversation_id AND deleted_at IS NULL LIMIT 1" ), {"conversation_id": conversation_id}, ) ).scalar_one_or_none() if owner is None: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "会话不存在") if int(owner) != user_id: raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权访问该会话") async def _retrieve_context(self, dataset_id: int | None, query: str) -> tuple[list[dict], str]: if not dataset_id: return [], "" async with GetAsyncSession() as session: dataset = ( await session.execute( text( """ SELECT id, name, collection_name, retrieval_model FROM rag_dataset WHERE id = :dataset_id AND deleted_at IS NULL LIMIT 1 """ ), {"dataset_id": dataset_id}, ) ).mappings().first() if not dataset: return [], "" retrieval_model = dataset.get("retrieval_model") or {} top_k = int(retrieval_model.get("top_k") or 5) score_threshold = None if retrieval_model.get("score_threshold_enabled"): try: score_threshold = float(retrieval_model.get("score_threshold")) except (TypeError, ValueError): score_threshold = None try: from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma except Exception: return [], dataset.get("name") or "" try: collection = get_chroma().get_or_create_collection(dataset["collection_name"]) result = collection.query(query_texts=[query], n_results=max(top_k, 1)) docs = (result.get("documents") or [[]])[0] metas = (result.get("metadatas") or [[]])[0] distances = (result.get("distances") or [[]])[0] chunks: list[dict] = [] for idx, doc in enumerate(docs): meta = metas[idx] if idx < len(metas) else {} dist = distances[idx] if idx < len(distances) else 0.0 score = 1 - float(dist or 0.0) if score_threshold is not None and score < score_threshold: continue chunks.append( { "id": str(meta.get("id") or idx), "text": doc, "source": meta.get("source") or meta.get("document_name") or dataset.get("name") or "", "score": score, "chunk_index": idx, "document_name": meta.get("document_name") or meta.get("source") or "", } ) chunks = await self._hydrate_document_hits(dataset_id, chunks) return chunks[:top_k], dataset.get("name") or "" except Exception: return [], dataset.get("name") or "" def _build_sources(self, context_chunks: list[dict], dataset_name: str) -> list[dict]: return [ { "position": index + 1, "dataset_id": str(chunk.get("dataset_id") or ""), "dataset_name": dataset_name, "document_id": str(chunk.get("document_id") or ""), "document_name": chunk.get("document_name") or chunk.get("source", ""), "data_source_type": "upload_file", "segment_id": chunk.get("id", ""), "retriever_from": "rag", "score": round(chunk.get("score", 0.0), 4), "hit_count": chunk.get("hit_count", 0), "word_count": len(chunk.get("text", "")), "segment_position": index + 1, "index_node_hash": "", "content": chunk.get("text", "")[:500], "page": None, } for index, chunk in enumerate(context_chunks) ] async def _hydrate_document_hits(self, dataset_id: int, chunks: list[dict]) -> list[dict]: source_names = sorted( { str(chunk.get("document_name") or chunk.get("source") or "").strip() for chunk in chunks if str(chunk.get("document_name") or chunk.get("source") or "").strip() } ) if not source_names: return chunks async with GetAsyncSession() as session: rows = ( await session.execute( text( """ SELECT id, original_name, enabled, hit_count FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL AND original_name = ANY(:source_names) """ ), { "dataset_id": dataset_id, "source_names": source_names, }, ) ).mappings().all() document_map = {str(row["original_name"]): row for row in rows} visible_chunks: list[dict] = [] hit_document_ids: list[int] = [] for chunk in chunks: source_name = str(chunk.get("document_name") or chunk.get("source") or "").strip() document = document_map.get(source_name) if document and not bool(document.get("enabled")): continue if document: chunk["document_id"] = document["id"] chunk["dataset_id"] = dataset_id chunk["document_name"] = document["original_name"] chunk["hit_count"] = document.get("hit_count") or 0 hit_document_ids.append(int(document["id"])) visible_chunks.append(chunk) if hit_document_ids: async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ UPDATE rag_document SET hit_count = hit_count + 1, updated_at = NOW() WHERE id = ANY(:document_ids) """ ), {"document_ids": sorted(set(hit_document_ids))}, ) return visible_chunks def _parse_sse_event(self, chunk: str) -> dict | None: data_lines: list[str] = [] for line in chunk.splitlines(): if line.startswith("data: "): data_lines.append(line[6:]) elif line.startswith("data:"): data_lines.append(line[5:].lstrip()) if not data_lines: return None payload = "\n".join(part for part in data_lines if part.strip()).strip() if not payload or payload == "[DONE]": return None payload = payload.removesuffix("\\n\\n").removesuffix("\\n").strip() try: data = json.loads(payload) except json.JSONDecodeError: return None return data if isinstance(data, dict) else None