from __future__ import annotations import asyncio import json import re import time import uuid from typing import AsyncGenerator import httpx from sqlalchemy import text from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException from fastapi_modules.fastapi_leaudit.domian.Dto.ragChatDto import ( RagConversationRenameDTO, RagMessageFeedbackDTO, RagStopMessageDTO, ) from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import ( RagAppParametersVO, RagChatAppListVO, RagChatAppVO, RagConversationItemVO, RagConversationPageVO, RagConversationRenameVO, RagMessageItemVO, RagMessagePageVO, RagOperationResultVO, ) from fastapi_modules.fastapi_leaudit.rag_engine.config import ( RAG_CONFIG, build_openai_chat_completions_url, build_openai_embeddings_url, ) from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma from fastapi_modules.fastapi_leaudit.rag_engine.generator import generate_stream from fastapi_modules.fastapi_leaudit.rag_engine.question_chains import generate_followups from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService DEFAULT_CONVERSATION_NAME = "新对话" DEFAULT_TITLE_SOURCE = "default" AUTO_TITLE_SOURCE = "auto" MANUAL_TITLE_SOURCE = "manual" class RagChatServiceImpl(IRagChatService): _message_tasks: dict[str, asyncio.Task] = {} _task_events: dict[str, list[dict]] = {} _task_done: dict[str, bool] = {} _task_locks: dict[str, asyncio.Lock] = {} _title_tasks: dict[str, asyncio.Task] = {} async def GetApps(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagChatAppListVO: apps = await self._load_apps(UserArea, UserRole, only_default=False) return RagChatAppListVO(data=apps, total=len(apps)) async def GetDefaultApp(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagChatAppVO | None: apps = await self._load_apps(UserArea, UserRole, only_default=True) if apps: return apps[0] all_apps = await self._load_apps(UserArea, UserRole, only_default=False) return all_apps[0] if all_apps else None async def SendMessage( self, CurrentUserId: int, UserName: str, UserArea: str | None, UserRole: str | None, Query: str, ConversationId: str | None, AppId: int | None, ) -> AsyncGenerator[bytes, None]: if not Query.strip(): raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "问题不能为空") app = await self._resolve_app(AppId, UserArea, UserRole) if not app: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "未配置可用聊天应用") conversationId = await self._ensure_conversation(CurrentUserId, ConversationId, app["id"]) messageId = str(uuid.uuid4()) taskId = str(uuid.uuid4()) is_new_conversation = not ConversationId or ConversationId == "-1" async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata) VALUES (:message_id, :conversation_id, 'user', :content, '[]'::jsonb, '{}'::jsonb) """ ), { "message_id": str(uuid.uuid4()), "conversation_id": conversationId, "content": Query, }, ) await session.execute( text( """ INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata) VALUES (:message_id, :conversation_id, 'assistant', '', '[]'::jsonb, CAST(:metadata AS jsonb)) """ ), { "message_id": messageId, "conversation_id": conversationId, "metadata": json.dumps({"status": "running", "task_id": taskId}, ensure_ascii=False), }, ) await session.execute( text( "UPDATE rag_conversation SET updated_at = NOW() WHERE conversation_id = :conversation_id" ), {"conversation_id": conversationId}, ) await self._start_message_task( task_id=taskId, conversation_id=conversationId, message_id=messageId, query=Query, app=app, ) event_index = 0 initial_events: list[dict] = [] if is_new_conversation: initial_events.append( { "event": "conversation_created", "conversation_id": conversationId, "message_id": messageId, "task_id": taskId, } ) while True: if event_index < len(initial_events): payload = initial_events[event_index] event_index += 1 yield self._format_sse(payload) continue events = self._task_events.get(taskId, []) if event_index - len(initial_events) < len(events): payload = events[event_index - len(initial_events)] event_index += 1 yield self._format_sse(payload) continue if self._task_done.get(taskId): break await asyncio.sleep(0.05) async def GetConversations(self, CurrentUserId: int, AppId: int | None, Page: int, PageSize: int) -> RagConversationPageVO: async with GetAsyncSession() as session: rows = ( await session.execute( text( """ SELECT conversation_id, name, introduction, created_at, updated_at , COALESCE(title_source, 'default') AS title_source , COALESCE(EXTRACT(EPOCH FROM last_message_at), 0) AS last_message_at FROM rag_conversation WHERE user_id = :user_id AND deleted_at IS NULL AND (CAST(:app_id AS BIGINT) IS NULL OR app_id = CAST(:app_id AS BIGINT)) ORDER BY COALESCE(last_message_at, updated_at) DESC, updated_at DESC OFFSET :offset LIMIT :limit """ ), { "user_id": CurrentUserId, "app_id": AppId, "offset": max(Page - 1, 0) * PageSize, "limit": PageSize + 1, }, ) ).mappings().all() has_more = len(rows) > PageSize items = rows[:PageSize] return RagConversationPageVO( data=[ RagConversationItemVO( id=row["conversation_id"], name=row["name"], introduction=row.get("introduction") or "", titleSource=str(row.get("title_source") or DEFAULT_TITLE_SOURCE), createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0, lastMessageAt=int(float(row.get("last_message_at") or 0)), ) for row in items ], hasMore=has_more, limit=PageSize, ) async def GetConversationMessages(self, CurrentUserId: int, ConversationId: str, Page: int, PageSize: int) -> RagMessagePageVO: await self._ensure_conversation_owner(CurrentUserId, ConversationId) async with GetAsyncSession() as session: rows = ( await session.execute( text( """ SELECT message_id, role, content, sources, metadata, feedback, created_at FROM rag_message WHERE conversation_id = :conversation_id ORDER BY created_at ASC, CASE role WHEN 'user' THEN 0 WHEN 'assistant' THEN 1 ELSE 2 END ASC, message_id ASC OFFSET :offset LIMIT :limit """ ), { "conversation_id": ConversationId, "offset": max(Page - 1, 0) * PageSize, "limit": PageSize + 1, }, ) ).mappings().all() has_more = len(rows) > PageSize items = rows[:PageSize] data: list[RagMessageItemVO] = [] idx = 0 while idx < len(items): row = items[idx] if row["role"] == "user": answer = items[idx + 1] if idx + 1 < len(items) and items[idx + 1]["role"] == "assistant" else None answer_metadata = dict((answer.get("metadata") if answer else None) or {}) answer_status = str(answer_metadata.get("status") or ("completed" if answer else "running")) answer_content = (answer.get("content") if answer else None) or "" if answer: task_id = str(answer_metadata.get("task_id") or "").strip() reconstructed_content = self._rebuild_message_content_from_events(task_id) if task_id else "" if reconstructed_content and len(reconstructed_content) >= len(answer_content): if reconstructed_content != answer_content: await self._update_message_progress( conversation_id=ConversationId, message_id=answer["message_id"], content=reconstructed_content, metadata=answer_metadata, ) answer_content = reconstructed_content normalized_status = await self._resolve_persisted_message_status( conversation_id=ConversationId, message_id=answer["message_id"], content=answer_content, metadata=answer_metadata, ) if normalized_status != answer_status: answer_status = normalized_status answer_metadata["status"] = normalized_status data.append( RagMessageItemVO( id=(answer["message_id"] if answer else row["message_id"]), conversationId=ConversationId, query=row["content"], answer=answer_content if answer else "", feedback=({"rating": answer["feedback"]} if answer and answer.get("feedback") else None), retrieverResources=(answer.get("sources") if answer else None), suggestedQuestions=[str(item) for item in (answer_metadata.get("suggested_questions") or []) if str(item).strip()], status=answer_status, createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, ) ) idx += 2 if answer else 1 else: idx += 1 return RagMessagePageVO(data=data, hasMore=has_more, limit=PageSize) async def _resolve_persisted_message_status( self, *, conversation_id: str, message_id: str, content: str, metadata: dict, ) -> str: status = str(metadata.get("status") or "completed") if status != "running": return status task_id = str(metadata.get("task_id") or "").strip() task = self._message_tasks.get(task_id) if task_id else None task_done = self._task_done.get(task_id, False) if task_id else False if task and not task.done() and not task_done: return "running" normalized_status = "completed" if content.strip() else "error" normalized_metadata = { **metadata, "status": normalized_status, } if normalized_status == "error" and not normalized_metadata.get("error"): normalized_metadata["error"] = "生成任务已结束,但未产出有效回答" await self._update_message_progress( conversation_id=conversation_id, message_id=message_id, content=content, metadata=normalized_metadata, ) return normalized_status def _rebuild_message_content_from_events(self, task_id: str) -> str: if not task_id: return "" chunks: list[str] = [] for event in self._task_events.get(task_id, []): if event.get("event") != "message": continue answer = event.get("answer") if isinstance(answer, str) and answer: chunks.append(answer) return "".join(chunks) async def RenameConversation(self, CurrentUserId: int, ConversationId: str, Body: RagConversationRenameDTO) -> RagConversationRenameVO: await self._ensure_conversation_owner(CurrentUserId, ConversationId) final_name = Body.name.strip() if not final_name: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "会话名称不能为空") async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ UPDATE rag_conversation SET name = :name, title_source = 'manual', title_generation_status = CASE WHEN COALESCE(title_generation_status, 'idle') = 'running' THEN 'succeeded' ELSE COALESCE(title_generation_status, 'idle') END, title_generation_error = NULL, updated_at = NOW() WHERE conversation_id = :conversation_id """ ), {"name": final_name, "conversation_id": ConversationId}, ) return RagConversationRenameVO(result="success", name=final_name) async def DeleteConversation(self, CurrentUserId: int, ConversationId: str) -> RagOperationResultVO: await self._ensure_conversation_owner(CurrentUserId, ConversationId) async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( "UPDATE rag_conversation SET deleted_at = NOW(), updated_at = NOW() WHERE conversation_id = :conversation_id" ), {"conversation_id": ConversationId}, ) return RagOperationResultVO(result="success") async def UpdateFeedback(self, CurrentUserId: int, MessageId: str, Body: RagMessageFeedbackDTO) -> RagOperationResultVO: async with GetAsyncSession() as session: owner = ( await session.execute( text( """ SELECT c.user_id FROM rag_message m JOIN rag_conversation c ON c.conversation_id = m.conversation_id WHERE m.message_id = :message_id AND c.deleted_at IS NULL LIMIT 1 """ ), {"message_id": MessageId}, ) ).scalar_one_or_none() if owner is None: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "消息不存在") if int(owner) != CurrentUserId: raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权修改该消息反馈") await session.execute( text("UPDATE rag_message SET feedback = :feedback WHERE message_id = :message_id"), {"feedback": Body.rating, "message_id": MessageId}, ) return RagOperationResultVO(result="success") async def StopMessage(self, CurrentUserId: int, MessageId: str, Body: RagStopMessageDTO | None = None) -> RagOperationResultVO: async with GetAsyncSession() as session: row = ( await session.execute( text( """ SELECT m.metadata, c.user_id FROM rag_message m JOIN rag_conversation c ON c.conversation_id = m.conversation_id WHERE m.message_id = :message_id AND c.deleted_at IS NULL LIMIT 1 """ ), {"message_id": MessageId}, ) ).mappings().first() if not row: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "消息不存在") if int(row["user_id"]) != CurrentUserId: raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权停止该消息") metadata = row.get("metadata") or {} task_id = str(Body.taskId or metadata.get("task_id") or "").strip() task = self._message_tasks.get(task_id) if task_id else None if task and not task.done(): task.cancel() return RagOperationResultVO(result="success") async def GetAppParameters( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, AppId: int | None, ) -> RagAppParametersVO: app = await self._resolve_app(AppId, UserArea, UserRole) if not app: return RagAppParametersVO() try: suggested = json.loads(app.get("suggested_questions") or "[]") if not isinstance(suggested, list): suggested = [] except Exception: suggested = [] return RagAppParametersVO( openingStatement=app.get("opening_statement") or "", suggestedQuestions=[str(item) for item in suggested[:6]], userInputForm=[], fileUpload={"image": {"enabled": False}}, ) async def _load_apps(self, user_area: str | None, user_role: str | None, only_default: bool) -> list[RagChatAppVO]: async with GetAsyncSession() as session: sql = ( """ SELECT a.id, a.name, a.description, a.is_default FROM rag_chat_app a LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL WHERE a.deleted_at IS NULL AND a.status = 1 AND (:only_default = FALSE OR a.is_default = TRUE) AND ( :is_provincial = TRUE OR a.area IN (:user_area, '省级', '') OR COALESCE(d.is_public, FALSE) = TRUE ) ORDER BY a.sort_order ASC, a.created_at DESC """ ) rows = ( await session.execute( text(sql), { "only_default": only_default, "is_provincial": user_role == "provincial_admin", "user_area": user_area or "", }, ) ).mappings().all() return [ RagChatAppVO( appId=str(row["id"]), appName=row["name"], description=row.get("description") or "", isDefault=bool(row.get("is_default")), ) for row in rows ] async def _resolve_app(self, app_id: int | None, user_area: str | None, user_role: str | None) -> dict | None: async with GetAsyncSession() as session: params = { "app_id": app_id, "user_area": user_area or "", "is_provincial": user_role == "provincial_admin", } base_sql = ( """ SELECT a.id, a.name, a.description, a.area, a.dataset_id, a.system_prompt, a.llm_model, a.temperature, a.max_tokens, a.opening_statement, a.suggested_questions, a.is_default, COALESCE(d.is_public, FALSE) AS dataset_public, COALESCE(d.name, '') AS dataset_name FROM rag_chat_app a LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL WHERE a.deleted_at IS NULL AND a.status = 1 """ ) if app_id is not None: row = ( await session.execute( text(base_sql + " AND a.id = :app_id LIMIT 1"), params, ) ).mappings().first() if row and self._app_visible(row, user_area, user_role): return dict(row) row = ( await session.execute( text(base_sql + " AND a.is_default = TRUE ORDER BY a.sort_order ASC, a.created_at DESC LIMIT 1"), params, ) ).mappings().first() if row and self._app_visible(row, user_area, user_role): return dict(row) row = ( await session.execute( text(base_sql + " ORDER BY a.sort_order ASC, a.created_at DESC LIMIT 1"), params, ) ).mappings().first() return dict(row) if row and self._app_visible(row, user_area, user_role) else None def _app_visible(self, row: dict, user_area: str | None, user_role: str | None) -> bool: if user_role == "provincial_admin": return True area = row.get("area") or "" return area in ("", "省级", user_area or "") or bool(row.get("dataset_public")) async def _ensure_conversation(self, user_id: int, conversation_id: str | None, app_id: int | None) -> str: if conversation_id and conversation_id != "-1": async with GetAsyncSession() as session: row = ( await session.execute( text( """ SELECT conversation_id, user_id FROM rag_conversation WHERE conversation_id = :conversation_id AND deleted_at IS NULL LIMIT 1 """ ), {"conversation_id": conversation_id}, ) ).mappings().first() if row: if int(row["user_id"]) != user_id: raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权使用该会话") return str(row["conversation_id"]) conversation_id = str(uuid.uuid4()) async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ INSERT INTO rag_conversation (conversation_id, user_id, app_id, name, introduction) VALUES (:conversation_id, :user_id, :app_id, :name, '') """ ), { "conversation_id": conversation_id, "user_id": user_id, "app_id": app_id, "name": DEFAULT_CONVERSATION_NAME, }, ) return conversation_id async def _ensure_conversation_owner(self, user_id: int, conversation_id: str) -> None: async with GetAsyncSession() as session: owner = ( await session.execute( text( "SELECT user_id FROM rag_conversation WHERE conversation_id = :conversation_id AND deleted_at IS NULL LIMIT 1" ), {"conversation_id": conversation_id}, ) ).scalar_one_or_none() if owner is None: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "会话不存在") if int(owner) != user_id: raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权访问该会话") async def _retrieve_context(self, dataset_id: int | None, query: str) -> tuple[list[dict], str]: if not dataset_id: return [], "" async with GetAsyncSession() as session: dataset = ( await session.execute( text( """ SELECT id, name, collection_name, retrieval_model, embedding_model FROM rag_dataset WHERE id = :dataset_id AND deleted_at IS NULL LIMIT 1 """ ), {"dataset_id": dataset_id}, ) ).mappings().first() if not dataset: return [], "" retrieval_model = dataset.get("retrieval_model") or {} top_k = int(retrieval_model.get("top_k") or 5) score_threshold = None if retrieval_model.get("score_threshold_enabled"): try: score_threshold = float(retrieval_model.get("score_threshold")) except (TypeError, ValueError): score_threshold = None try: query_embedding = await self._embed_texts([query], dataset.get("embedding_model") or "") collection = get_chroma().get_or_create_collection(dataset["collection_name"]) result = collection.query( query_embeddings=query_embedding, n_results=max(top_k, 1), include=["documents", "metadatas", "distances"], ) ids = (result.get("ids") or [[]])[0] if result.get("ids") else [] docs = (result.get("documents") or [[]])[0] metas = (result.get("metadatas") or [[]])[0] distances = (result.get("distances") or [[]])[0] chunks: list[dict] = [] for idx, doc in enumerate(docs): meta = metas[idx] if idx < len(metas) else {} dist = float(distances[idx]) if idx < len(distances) and distances[idx] is not None else 1.0 score = 1.0 / (1.0 + max(dist, 0.0)) if score_threshold is not None and score < score_threshold: continue chunks.append( { "id": str(ids[idx] if idx < len(ids) else meta.get("id") or idx), "text": doc, "source": meta.get("source") or meta.get("document_name") or dataset.get("name") or "", "score": score, "chunk_index": int(meta.get("chunk_index") or idx), "document_name": meta.get("document_name") or meta.get("source") or "", "document_id": meta.get("document_id"), "page": meta.get("page"), } ) chunks = await self._hydrate_document_hits(dataset_id, chunks) if chunks: return chunks[:top_k], dataset.get("name") or "" except Exception: pass try: chunks = await self._keyword_retrieve_context( dataset_id=dataset_id, collection_name=str(dataset["collection_name"]), dataset_name=str(dataset.get("name") or ""), query=query, top_k=top_k, score_threshold=score_threshold, ) return chunks[:top_k], dataset.get("name") or "" except Exception: return [], dataset.get("name") or "" async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]: embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or build_openai_embeddings_url(RAG_CONFIG["LLM_BASE_URL"]) embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"] embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4" batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10)) if not embed_url or not embed_key: raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务") embeddings: list[list[float]] = [] async with httpx.AsyncClient(timeout=120.0) as client: for start in range(0, len(texts), batch_size): batch_texts = texts[start:start + batch_size] try: response = await client.post( embed_url, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {embed_key}", }, json={"model": embed_model, "input": batch_texts}, ) response.raise_for_status() except httpx.HTTPStatusError as exc: error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}" raise LeauditException( StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, f"向量化服务调用失败: {error_message[:300]}", ) from exc payload = response.json() rows = payload.get("data") or [] batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")] if len(batch_embeddings) != len(batch_texts): raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常") embeddings.extend(batch_embeddings) if len(embeddings) != len(texts): raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常") return embeddings async def _start_message_task( self, *, task_id: str, conversation_id: str, message_id: str, query: str, app: dict, ) -> None: self._task_events[task_id] = [] self._task_done[task_id] = False self._task_locks.setdefault(task_id, asyncio.Lock()) task = asyncio.create_task( self._run_message_task( task_id=task_id, conversation_id=conversation_id, message_id=message_id, query=query, app=app, ) ) self._message_tasks[task_id] = task async def _run_message_task( self, *, task_id: str, conversation_id: str, message_id: str, query: str, app: dict, ) -> None: context_chunks: list[dict] = [] dataset_name = "" collected_answer = "" message_end_payload: dict | None = None final_status = "completed" error_payload: dict | None = None last_persisted_length = 0 last_persisted_at = time.monotonic() try: context_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), query) async for chunk in generate_stream( query=query, context_chunks=context_chunks, conversation_id=conversation_id, message_id=message_id, system_prompt=app.get("system_prompt") or "", model=app.get("llm_model") or "", temperature=app.get("temperature"), max_tokens=app.get("max_tokens"), dataset_name=dataset_name, task_id=task_id, ): data = self._parse_sse_event(chunk) if not data: continue event = data.get("event") if event == "message": collected_answer += data.get("answer", "") now = time.monotonic() if len(collected_answer) - last_persisted_length >= 80 or now - last_persisted_at >= 0.5: await self._update_message_progress( conversation_id=conversation_id, message_id=message_id, content=collected_answer, metadata={"status": "running", "task_id": task_id}, ) last_persisted_length = len(collected_answer) last_persisted_at = now await self._append_task_event(task_id, data) continue if event == "message_end": message_end_payload = data continue if event == "error": final_status = "error" error_payload = data await self._append_task_event(task_id, data) continue await self._append_task_event(task_id, data) if final_status == "completed": followups: list[str] = [] try: followups = await generate_followups(query, collected_answer) except Exception: followups = [] if message_end_payload: message_end_payload.setdefault("metadata", {})["suggested_questions"] = followups await self._append_task_event(task_id, message_end_payload) await self._finalize_message_record( conversation_id=conversation_id, message_id=message_id, content=collected_answer, sources=self._build_sources(context_chunks, dataset_name), metadata={"suggested_questions": followups, "status": "completed", "task_id": task_id}, ) await self._maybe_schedule_auto_title( conversation_id=conversation_id, message_id=message_id, query=query, answer=collected_answer, ) else: await self._finalize_message_record( conversation_id=conversation_id, message_id=message_id, content=collected_answer, sources=self._build_sources(context_chunks, dataset_name), metadata={ "suggested_questions": [], "status": "error", "task_id": task_id, "error": (error_payload or {}).get("message", ""), }, ) except asyncio.CancelledError: final_status = "stopped" await self._append_task_event( task_id, { "event": "error", "task_id": task_id, "message_id": message_id, "code": "message_stopped", "message": "用户已停止回答", }, ) await self._finalize_message_record( conversation_id=conversation_id, message_id=message_id, content=collected_answer, sources=self._build_sources(context_chunks, dataset_name), metadata={"suggested_questions": [], "status": "stopped", "task_id": task_id}, ) raise except Exception as exc: final_status = "error" await self._append_task_event( task_id, { "event": "error", "task_id": task_id, "message_id": message_id, "code": "server_error", "message": str(exc), }, ) await self._finalize_message_record( conversation_id=conversation_id, message_id=message_id, content=collected_answer, sources=self._build_sources(context_chunks, dataset_name), metadata={"suggested_questions": [], "status": "error", "task_id": task_id, "error": str(exc)}, ) finally: self._task_done[task_id] = True self._message_tasks.pop(task_id, None) self._task_locks.pop(task_id, None) async def _append_task_event(self, task_id: str, payload: dict) -> None: lock = self._task_locks.setdefault(task_id, asyncio.Lock()) async with lock: self._task_events.setdefault(task_id, []).append(payload) async def _finalize_message_record( self, *, conversation_id: str, message_id: str, content: str, sources: list[dict], metadata: dict, ) -> None: async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ UPDATE rag_message SET content = :content, sources = CAST(:sources AS jsonb), metadata = CAST(:metadata AS jsonb) WHERE message_id = :message_id AND conversation_id = :conversation_id AND role = 'assistant' """ ), { "conversation_id": conversation_id, "message_id": message_id, "content": content, "sources": json.dumps(sources, ensure_ascii=False), "metadata": json.dumps(metadata, ensure_ascii=False), }, ) await session.execute( text( """ UPDATE rag_conversation SET updated_at = NOW(), last_message_at = NOW() WHERE conversation_id = :conversation_id """ ), {"conversation_id": conversation_id}, ) async def _update_message_progress( self, *, conversation_id: str, message_id: str, content: str, metadata: dict, ) -> None: async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ UPDATE rag_message SET content = :content, metadata = CAST(:metadata AS jsonb) WHERE message_id = :message_id AND conversation_id = :conversation_id AND role = 'assistant' """ ), { "conversation_id": conversation_id, "message_id": message_id, "content": content, "metadata": json.dumps(metadata, ensure_ascii=False), }, ) async def _maybe_schedule_auto_title( self, *, conversation_id: str, message_id: str, query: str, answer: str, ) -> None: normalized_query = (query or "").strip() normalized_answer = (answer or "").strip() if not normalized_query or not normalized_answer: return async with GetAsyncSession() as session: row = ( await session.execute( text( """ SELECT conversation_id, name, COALESCE(title_source, 'default') AS title_source, COALESCE(title_generation_status, 'idle') AS title_generation_status, first_answer_message_id FROM rag_conversation WHERE conversation_id = :conversation_id AND deleted_at IS NULL LIMIT 1 """ ), {"conversation_id": conversation_id}, ) ).mappings().first() if not row: return title_source = str(row.get("title_source") or DEFAULT_TITLE_SOURCE) if title_source == MANUAL_TITLE_SOURCE: return current_name = str(row.get("name") or "").strip() if current_name and current_name != DEFAULT_CONVERSATION_NAME and title_source != DEFAULT_TITLE_SOURCE: return current_status = str(row.get("title_generation_status") or "idle") if current_status in {"pending", "running", "succeeded"}: return if row.get("first_answer_message_id") and str(row.get("first_answer_message_id")) != message_id: return async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ UPDATE rag_conversation SET title_generation_status = 'pending', first_answer_message_id = COALESCE(first_answer_message_id, :message_id), title_generation_error = NULL WHERE conversation_id = :conversation_id AND deleted_at IS NULL AND COALESCE(title_source, 'default') = 'default' """ ), { "conversation_id": conversation_id, "message_id": message_id, }, ) existing_task = self._title_tasks.get(conversation_id) if existing_task and not existing_task.done(): return task = asyncio.create_task( self._run_auto_title_task( conversation_id=conversation_id, answer_message_id=message_id, query=normalized_query, answer=normalized_answer, ) ) self._title_tasks[conversation_id] = task async def _run_auto_title_task( self, *, conversation_id: str, answer_message_id: str, query: str, answer: str, ) -> None: try: async with GetAsyncSession() as session: row = ( await session.execute( text( """ SELECT name, COALESCE(title_source, 'default') AS title_source, COALESCE(title_generation_status, 'idle') AS title_generation_status, first_answer_message_id FROM rag_conversation WHERE conversation_id = :conversation_id AND deleted_at IS NULL LIMIT 1 """ ), {"conversation_id": conversation_id}, ) ).mappings().first() if not row: return if str(row.get("title_source") or DEFAULT_TITLE_SOURCE) == MANUAL_TITLE_SOURCE: return if row.get("first_answer_message_id") and str(row.get("first_answer_message_id")) != answer_message_id: return async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ UPDATE rag_conversation SET title_generation_status = 'running', title_generation_error = NULL WHERE conversation_id = :conversation_id AND deleted_at IS NULL AND COALESCE(title_source, 'default') = 'default' """ ), {"conversation_id": conversation_id}, ) generated_title = await self._generate_conversation_title(query=query, answer=answer) cleaned_title = self._sanitize_generated_title(generated_title) if not cleaned_title: cleaned_title = self._build_fallback_title(query=query, answer=answer) if not cleaned_title: raise ValueError("未生成有效标题") async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ UPDATE rag_conversation SET name = :name, title_source = 'auto', title_generation_status = 'succeeded', title_generated_at = NOW(), title_generation_error = NULL WHERE conversation_id = :conversation_id AND deleted_at IS NULL AND COALESCE(title_source, 'default') = 'default' AND ( name IS NULL OR BTRIM(name) = '' OR name = :default_name ) """ ), { "conversation_id": conversation_id, "name": cleaned_title, "default_name": DEFAULT_CONVERSATION_NAME, }, ) except Exception as exc: async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ UPDATE rag_conversation SET title_generation_status = 'failed', title_generation_error = :error WHERE conversation_id = :conversation_id AND deleted_at IS NULL AND COALESCE(title_source, 'default') = 'default' """ ), { "conversation_id": conversation_id, "error": str(exc)[:1000], }, ) finally: self._title_tasks.pop(conversation_id, None) async def _generate_conversation_title(self, *, query: str, answer: str) -> str: prompt = ( "请基于用户首轮提问和助手首轮回答,生成一个简洁、准确的中文会话标题。" "要求:" "1. 只输出标题本身;" "2. 不要标点结尾;" "3. 不要出现“关于”“用户询问”“问题解答”等空话;" "4. 优先 12-24 个中文字符,最长不超过 40 个字符。\\n" f"用户问题:{query[:500]}\\n" f"助手回答:{answer[:1500]}" ) async with httpx.AsyncClient(timeout=30.0) as client: resp = await client.post( build_openai_chat_completions_url(RAG_CONFIG["LLM_BASE_URL"]), json={ "model": RAG_CONFIG["LLM_MODEL"], "messages": [{"role": "user", "content": prompt}], "temperature": 0.2, "max_tokens": 80, "chat_template_kwargs": {"enable_thinking": False}, }, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {RAG_CONFIG['LLM_API_KEY']}", }, ) resp.raise_for_status() content = resp.json()["choices"][0]["message"]["content"] return str(content or "").strip() def _sanitize_generated_title(self, value: str) -> str: title = str(value or "").strip() if not title: return "" title = title.replace("\r", " ").replace("\n", " ").strip() title = re.sub(r"^```[\w-]*", "", title).strip() title = title.replace("```", "").strip() title = re.sub(r'^[\"“”\'‘’\[\]()()【】]+', "", title) title = re.sub(r'[\"“”\'‘’\[\]()()【】]+$', "", title) title = re.sub(r"^(标题|会话标题)[::]\\s*", "", title) title = re.sub(r"\\s+", " ", title).strip() title = title.rstrip("。!?!?.;;,,::") if title in {"新对话", "会话标题", "标题"}: return "" if len(title) > 40: title = title[:40].rstrip(",,。;;:: ") return title def _build_fallback_title(self, *, query: str, answer: str) -> str: base = query.strip() or answer.strip() if not base: return "" base = re.sub(r"\\s+", " ", base) base = re.sub(r"^[请问帮我一下关于针对结合根据请解释说明一下::,,\\s]+", "", base) base = base.rstrip("。!?!?.;;,,::") if len(base) > 24: base = base[:24].rstrip(",,。;;:: ") return base async def _keyword_retrieve_context( self, *, dataset_id: int, collection_name: str, dataset_name: str, query: str, top_k: int, score_threshold: float | None, ) -> list[dict]: collection = get_chroma().get_or_create_collection(collection_name) raw = collection.get(include=["documents", "metadatas"]) ids = raw.get("ids") or [] docs = raw.get("documents") or [] metas = raw.get("metadatas") or [] terms = self._build_keyword_terms(query) if not terms: return [] scored_chunks: list[dict] = [] for idx, chunk_id in enumerate(ids): doc = docs[idx] if idx < len(docs) else "" meta = metas[idx] if idx < len(metas) and isinstance(metas[idx], dict) else {} score = self._score_keyword_chunk( query=query, terms=terms, content=doc or "", document_name=str(meta.get("document_name") or meta.get("source") or ""), ) if score <= 0: continue if score_threshold is not None and score < score_threshold: continue scored_chunks.append( { "id": str(chunk_id), "text": doc or "", "source": meta.get("source") or meta.get("document_name") or dataset_name, "score": score, "chunk_index": int(meta.get("chunk_index") or idx), "document_name": meta.get("document_name") or meta.get("source") or "", "document_id": meta.get("document_id"), "page": meta.get("page"), } ) scored_chunks.sort(key=lambda item: (-float(item.get("score") or 0.0), int(item.get("chunk_index") or 0))) hydrated = await self._hydrate_document_hits(dataset_id, scored_chunks[: max(top_k * 3, top_k)]) return hydrated[:top_k] def _build_keyword_terms(self, query: str) -> list[str]: normalized = self._normalize_keyword_query(query) spans = [item.strip() for item in re.findall(r"[\u4e00-\u9fffA-Za-z0-9]+", normalized) if item.strip()] if not spans: return [] stop_terms = { "什么", "请问", "一下", "有关", "关于", "如何", "哪些", "怎么", "是否", "规定", "办法", "条例", "法律", } terms: list[str] = [] for span in spans: if span in stop_terms: continue terms.append(span) if re.fullmatch(r"[\u4e00-\u9fff]+", span): for size in (2, 3, 4): if len(span) > size: for start in range(0, len(span) - size + 1): token = span[start:start + size] if token not in stop_terms: terms.append(token) unique_terms: list[str] = [] seen: set[str] = set() for term in sorted(terms, key=len, reverse=True): if term and term not in seen: unique_terms.append(term) seen.add(term) return unique_terms[:20] def _normalize_keyword_query(self, query: str) -> str: normalized = (query or "").strip().lower() patterns = [ "是什么", "什么是", "有哪些", "有什么", "是什么?", "是什么?", "请问", "介绍一下", "解释一下", "帮我分析", "帮我看看", ] for pattern in patterns: normalized = normalized.replace(pattern, " ") return re.sub(r"\s+", " ", normalized).strip() def _score_keyword_chunk(self, *, query: str, terms: list[str], content: str, document_name: str) -> float: haystack = f"{document_name}\n{content}".lower() if not haystack: return 0.0 exact_query = self._normalize_keyword_query(query) if exact_query and exact_query in haystack: return 0.98 matched_weight = 0.0 total_weight = 0.0 name_bonus = 0.0 for term in terms: weight = float(max(len(term), 1) ** 2) total_weight += weight if term.lower() in haystack: matched_weight += weight if term.lower() in document_name.lower(): name_bonus += min(0.15, 0.03 * len(term)) if total_weight <= 0: return 0.0 score = (matched_weight / total_weight) + name_bonus return round(min(score, 0.99), 6) def _format_sse(self, payload: dict) -> bytes: return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n".encode("utf-8") def _build_sources(self, context_chunks: list[dict], dataset_name: str) -> list[dict]: return [ { "position": index + 1, "dataset_id": str(chunk.get("dataset_id") or ""), "dataset_name": dataset_name, "document_id": str(chunk.get("document_id") or ""), "document_name": chunk.get("document_name") or chunk.get("source", ""), "data_source_type": "upload_file", "segment_id": chunk.get("id", ""), "retriever_from": "rag", "score": round(chunk.get("score", 0.0), 4), "hit_count": chunk.get("hit_count", 0), "word_count": len(chunk.get("text", "")), "segment_position": index + 1, "index_node_hash": "", "content": chunk.get("text", "")[:500], "page": None, } for index, chunk in enumerate(context_chunks) ] async def _hydrate_document_hits(self, dataset_id: int, chunks: list[dict]) -> list[dict]: source_names = sorted( { str(chunk.get("document_name") or chunk.get("source") or "").strip() for chunk in chunks if str(chunk.get("document_name") or chunk.get("source") or "").strip() } ) if not source_names: return chunks async with GetAsyncSession() as session: rows = ( await session.execute( text( """ SELECT id, original_name, enabled, hit_count FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL AND original_name = ANY(:source_names) """ ), { "dataset_id": dataset_id, "source_names": source_names, }, ) ).mappings().all() document_map = {str(row["original_name"]): row for row in rows} visible_chunks: list[dict] = [] hit_document_ids: list[int] = [] for chunk in chunks: source_name = str(chunk.get("document_name") or chunk.get("source") or "").strip() document = document_map.get(source_name) if document and not bool(document.get("enabled")): continue if document: chunk["document_id"] = document["id"] chunk["dataset_id"] = dataset_id chunk["document_name"] = document["original_name"] chunk["hit_count"] = document.get("hit_count") or 0 hit_document_ids.append(int(document["id"])) visible_chunks.append(chunk) if hit_document_ids: async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ UPDATE rag_document SET hit_count = hit_count + 1, updated_at = NOW() WHERE id = ANY(:document_ids) """ ), {"document_ids": sorted(set(hit_document_ids))}, ) return visible_chunks def _parse_sse_event(self, chunk: str) -> dict | None: data_lines: list[str] = [] for line in chunk.splitlines(): if line.startswith("data: "): data_lines.append(line[6:]) elif line.startswith("data:"): data_lines.append(line[5:].lstrip()) if not data_lines: return None payload = "\n".join(part for part in data_lines if part.strip()).strip() if not payload or payload == "[DONE]": return None payload = payload.removesuffix("\\n\\n").removesuffix("\\n").strip() try: data = json.loads(payload) except json.JSONDecodeError: return None return data if isinstance(data, dict) else None