"""RAG 聊天服务实现。""" from __future__ import annotations import asyncio import json import re import time import uuid from typing import AsyncGenerator import httpx from sqlalchemy import text from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException from fastapi_modules.fastapi_leaudit.domian.Dto.ragChatDto import ( RagConversationRenameDTO, RagMessageFeedbackDTO, RagStopMessageDTO, ) from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import ( RagAppParametersVO, RagChatAppListVO, RagChatAppVO, RagConversationItemVO, RagConversationPageVO, RagConversationRenameVO, RagMessageItemVO, RagMessagePageVO, RagOperationResultVO, ) from fastapi_modules.fastapi_leaudit.rag_engine.config import ( RAG_CONFIG, build_openai_chat_completions_url, ) from fastapi_modules.fastapi_leaudit.rag_engine.generator import generate_stream from fastapi_modules.fastapi_leaudit.rag_engine.question_chains import generate_followups from fastapi_modules.fastapi_leaudit.rag_engine.retriever import RagRetriever from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService from fastapi_modules.fastapi_leaudit.services.impl.tenantResolver import TenantResolver DEFAULT_CONVERSATION_NAME = "新对话" DEFAULT_TITLE_SOURCE = "default" AUTO_TITLE_SOURCE = "auto" MANUAL_TITLE_SOURCE = "manual" class RagChatServiceImpl(IRagChatService): """RAG 聊天服务实现。""" _message_tasks: dict[str, asyncio.Task] = {} _task_events: dict[str, list[dict]] = {} _task_done: dict[str, bool] = {} _task_locks: dict[str, asyncio.Lock] = {} _title_tasks: dict[str, asyncio.Task] = {} _chat_schema_checked = False _chat_schema_lock = asyncio.Lock() def __init__(self, retriever: RagRetriever | None = None) -> None: self.TenantResolver = TenantResolver() self.retriever = retriever or RagRetriever() _APP_TENANT_NAME_SQL = ( "CASE " "WHEN NULLIF(BTRIM(a.tenant_code), '') = 'PUBLIC' THEN '公共' " "WHEN NULLIF(BTRIM(a.tenant_code), '') = 'PROVINCIAL' THEN '省级' " "ELSE COALESCE(NULLIF(BTRIM(a.area), ''), '未分配地区') " "END" ) async def GetApps( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None = None, TenantName: str | None = None, ) -> RagChatAppListVO: apps = await self._load_apps(UserArea, UserRole, TenantCode, TenantName, only_default=False) return RagChatAppListVO(data=apps, total=len(apps)) async def GetDefaultApp( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None = None, TenantName: str | None = None, ) -> RagChatAppVO | None: apps = await self._load_apps(UserArea, UserRole, TenantCode, TenantName, only_default=True) if apps: return apps[0] all_apps = await self._load_apps(UserArea, UserRole, TenantCode, TenantName, only_default=False) return all_apps[0] if all_apps else None async def SendMessage( self, CurrentUserId: int, UserName: str, UserArea: str | None, UserRole: str | None, Query: str, ConversationId: str | None, AppId: int | None, TenantCode: str | None = None, TenantName: str | None = None, ) -> AsyncGenerator[bytes, None]: if not Query.strip(): raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "问题不能为空") app = await self._resolve_app(AppId, UserArea, UserRole, TenantCode, TenantName) if not app: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "未配置可用聊天应用") conversationId = await self._ensure_conversation( user_id=CurrentUserId, conversation_id=ConversationId, app_id=app["id"], user_area=UserArea, user_role=UserRole, tenant_code=TenantCode, tenant_name=TenantName, ) messageId = str(uuid.uuid4()) taskId = str(uuid.uuid4()) is_new_conversation = not ConversationId or ConversationId == "-1" async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata) VALUES (:message_id, :conversation_id, 'user', :content, '[]'::jsonb, '{}'::jsonb) """ ), { "message_id": str(uuid.uuid4()), "conversation_id": conversationId, "content": Query, }, ) await session.execute( text( """ INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata) VALUES (:message_id, :conversation_id, 'assistant', '', '[]'::jsonb, CAST(:metadata AS jsonb)) """ ), { "message_id": messageId, "conversation_id": conversationId, "metadata": json.dumps({"status": "running", "task_id": taskId}, ensure_ascii=False), }, ) await session.execute( text( "UPDATE rag_conversation SET updated_at = NOW() WHERE conversation_id = :conversation_id" ), {"conversation_id": conversationId}, ) await self._start_message_task( task_id=taskId, conversation_id=conversationId, message_id=messageId, query=Query, app=app, ) event_index = 0 initial_events: list[dict] = [] if is_new_conversation: initial_events.append( { "event": "conversation_created", "conversation_id": conversationId, "message_id": messageId, "task_id": taskId, } ) while True: if event_index < len(initial_events): payload = initial_events[event_index] event_index += 1 yield self._format_sse(payload) continue events = self._task_events.get(taskId, []) if event_index - len(initial_events) < len(events): payload = events[event_index - len(initial_events)] event_index += 1 yield self._format_sse(payload) continue if self._task_done.get(taskId): break await asyncio.sleep(0.05) async def GetConversations( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, AppId: int | None, Page: int, PageSize: int, TenantCode: str | None = None, TenantName: str | None = None, ) -> RagConversationPageVO: tenant_context = await self._resolve_tenant_context(UserArea, TenantCode, TenantName) async with GetAsyncSession() as session: rows = ( await session.execute( text( """ SELECT conversation_id, name, introduction, created_at, updated_at , COALESCE(title_source, 'default') AS title_source , COALESCE(EXTRACT(EPOCH FROM last_message_at), 0) AS last_message_at FROM rag_conversation WHERE user_id = :user_id AND deleted_at IS NULL AND (CAST(:app_id AS BIGINT) IS NULL OR app_id = CAST(:app_id AS BIGINT)) ORDER BY COALESCE(last_message_at, updated_at) DESC, updated_at DESC OFFSET :offset LIMIT :limit """ ), { "user_id": CurrentUserId, "app_id": AppId, "offset": max(Page - 1, 0) * PageSize, "limit": PageSize + 1, }, ) ).mappings().all() filtered_rows: list[dict] = [] for row in rows: record = dict(row) if await self._conversation_accessible( conversation_id=str(record["conversation_id"]), expected_user_id=CurrentUserId, tenant_context=tenant_context, user_role=UserRole, app_id=AppId, session=session, ): filtered_rows.append(record) has_more = len(filtered_rows) > PageSize items = filtered_rows[:PageSize] return RagConversationPageVO( data=[ RagConversationItemVO( id=row["conversation_id"], name=row["name"], introduction=row.get("introduction") or "", titleSource=str(row.get("title_source") or DEFAULT_TITLE_SOURCE), createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0, lastMessageAt=int(float(row.get("last_message_at") or 0)), ) for row in items ], hasMore=has_more, limit=PageSize, ) async def GetConversationMessages( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, ConversationId: str, Page: int, PageSize: int, TenantCode: str | None = None, TenantName: str | None = None, ) -> RagMessagePageVO: await self._ensure_conversation_owner( user_id=CurrentUserId, conversation_id=ConversationId, user_area=UserArea, user_role=UserRole, tenant_code=TenantCode, tenant_name=TenantName, ) async with GetAsyncSession() as session: rows = ( await session.execute( text( """ SELECT message_id, role, content, sources, metadata, feedback, created_at FROM rag_message WHERE conversation_id = :conversation_id ORDER BY created_at ASC, CASE role WHEN 'user' THEN 0 WHEN 'assistant' THEN 1 ELSE 2 END ASC, message_id ASC OFFSET :offset LIMIT :limit """ ), { "conversation_id": ConversationId, "offset": max(Page - 1, 0) * PageSize, "limit": PageSize + 1, }, ) ).mappings().all() has_more = len(rows) > PageSize items = rows[:PageSize] data: list[RagMessageItemVO] = [] idx = 0 while idx < len(items): row = items[idx] if row["role"] == "user": answer = items[idx + 1] if idx + 1 < len(items) and items[idx + 1]["role"] == "assistant" else None answer_metadata = dict((answer.get("metadata") if answer else None) or {}) answer_status = str(answer_metadata.get("status") or ("completed" if answer else "running")) answer_content = (answer.get("content") if answer else None) or "" if answer: task_id = str(answer_metadata.get("task_id") or "").strip() reconstructed_content = self._rebuild_message_content_from_events(task_id) if task_id else "" if reconstructed_content and len(reconstructed_content) >= len(answer_content): if reconstructed_content != answer_content: await self._update_message_progress( conversation_id=ConversationId, message_id=answer["message_id"], content=reconstructed_content, metadata=answer_metadata, ) answer_content = reconstructed_content normalized_status = await self._resolve_persisted_message_status( conversation_id=ConversationId, message_id=answer["message_id"], content=answer_content, metadata=answer_metadata, ) if normalized_status != answer_status: answer_status = normalized_status answer_metadata["status"] = normalized_status data.append( RagMessageItemVO( id=(answer["message_id"] if answer else row["message_id"]), conversationId=ConversationId, query=row["content"], answer=answer_content if answer else "", feedback=({"rating": answer["feedback"]} if answer and answer.get("feedback") else None), retrieverResources=(answer.get("sources") if answer else None), suggestedQuestions=[str(item) for item in (answer_metadata.get("suggested_questions") or []) if str(item).strip()], status=answer_status, createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, ) ) idx += 2 if answer else 1 else: idx += 1 return RagMessagePageVO(data=data, hasMore=has_more, limit=PageSize) async def _resolve_persisted_message_status( self, *, conversation_id: str, message_id: str, content: str, metadata: dict, ) -> str: status = str(metadata.get("status") or "completed") if status != "running": return status task_id = str(metadata.get("task_id") or "").strip() task = self._message_tasks.get(task_id) if task_id else None task_done = self._task_done.get(task_id, False) if task_id else False if task and not task.done() and not task_done: return "running" normalized_status = "completed" if content.strip() else "error" normalized_metadata = { **metadata, "status": normalized_status, } if normalized_status == "error" and not normalized_metadata.get("error"): normalized_metadata["error"] = "生成任务已结束,但未产出有效回答" await self._update_message_progress( conversation_id=conversation_id, message_id=message_id, content=content, metadata=normalized_metadata, ) return normalized_status def _rebuild_message_content_from_events(self, task_id: str) -> str: if not task_id: return "" chunks: list[str] = [] for event in self._task_events.get(task_id, []): if event.get("event") != "message": continue answer = event.get("answer") if isinstance(answer, str) and answer: chunks.append(answer) return "".join(chunks) async def RenameConversation( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, ConversationId: str, Body: RagConversationRenameDTO, TenantCode: str | None = None, TenantName: str | None = None, ) -> RagConversationRenameVO: await self._ensure_conversation_owner( user_id=CurrentUserId, conversation_id=ConversationId, user_area=UserArea, user_role=UserRole, tenant_code=TenantCode, tenant_name=TenantName, ) final_name = Body.name.strip() if not final_name: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "会话名称不能为空") async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ UPDATE rag_conversation SET name = :name, title_source = 'manual', title_generation_status = CASE WHEN COALESCE(title_generation_status, 'idle') = 'running' THEN 'succeeded' ELSE COALESCE(title_generation_status, 'idle') END, title_generation_error = NULL, updated_at = NOW() WHERE conversation_id = :conversation_id """ ), {"name": final_name, "conversation_id": ConversationId}, ) return RagConversationRenameVO(result="success", name=final_name) async def DeleteConversation( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, ConversationId: str, TenantCode: str | None = None, TenantName: str | None = None, ) -> RagOperationResultVO: await self._ensure_conversation_owner( user_id=CurrentUserId, conversation_id=ConversationId, user_area=UserArea, user_role=UserRole, tenant_code=TenantCode, tenant_name=TenantName, ) async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( "UPDATE rag_conversation SET deleted_at = NOW(), updated_at = NOW() WHERE conversation_id = :conversation_id" ), {"conversation_id": ConversationId}, ) return RagOperationResultVO(result="success") async def UpdateFeedback( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, MessageId: str, Body: RagMessageFeedbackDTO, TenantCode: str | None = None, TenantName: str | None = None, ) -> RagOperationResultVO: tenant_context = await self._resolve_tenant_context(UserArea, TenantCode, TenantName) async with GetAsyncSession() as session: row = ( await session.execute( text( """ SELECT c.user_id, c.conversation_id FROM rag_message m JOIN rag_conversation c ON c.conversation_id = m.conversation_id WHERE m.message_id = :message_id AND c.deleted_at IS NULL LIMIT 1 """ ), {"message_id": MessageId}, ) ).mappings().first() if not row: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "消息不存在") if not await self._conversation_accessible( conversation_id=str(row["conversation_id"]), expected_user_id=CurrentUserId, tenant_context=tenant_context, user_role=UserRole, session=session, ): raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权修改该消息反馈") await session.execute( text("UPDATE rag_message SET feedback = :feedback WHERE message_id = :message_id"), {"feedback": Body.rating, "message_id": MessageId}, ) return RagOperationResultVO(result="success") async def StopMessage( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, MessageId: str, Body: RagStopMessageDTO | None = None, TenantCode: str | None = None, TenantName: str | None = None, ) -> RagOperationResultVO: tenant_context = await self._resolve_tenant_context(UserArea, TenantCode, TenantName) async with GetAsyncSession() as session: row = ( await session.execute( text( """ SELECT m.metadata, c.user_id, c.conversation_id FROM rag_message m JOIN rag_conversation c ON c.conversation_id = m.conversation_id WHERE m.message_id = :message_id AND c.deleted_at IS NULL LIMIT 1 """ ), {"message_id": MessageId}, ) ).mappings().first() if not row: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "消息不存在") if not await self._conversation_accessible( conversation_id=str(row["conversation_id"]), expected_user_id=CurrentUserId, tenant_context=tenant_context, user_role=UserRole, ): raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权停止该消息") metadata = row.get("metadata") or {} task_id = str(Body.taskId or metadata.get("task_id") or "").strip() task = self._message_tasks.get(task_id) if task_id else None if task and not task.done(): task.cancel() return RagOperationResultVO(result="success") async def GetAppParameters( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, AppId: int | None, TenantCode: str | None = None, TenantName: str | None = None, ) -> RagAppParametersVO: app = await self._resolve_app(AppId, UserArea, UserRole, TenantCode, TenantName) if not app: return RagAppParametersVO() try: suggested = json.loads(app.get("suggested_questions") or "[]") if not isinstance(suggested, list): suggested = [] except Exception: suggested = [] return RagAppParametersVO( openingStatement=app.get("opening_statement") or "", suggestedQuestions=[str(item) for item in suggested[:6]], userInputForm=[], fileUpload={"image": {"enabled": False}}, ) async def _load_apps( self, user_area: str | None, user_role: str | None, tenant_code: str | None, tenant_name: str | None, only_default: bool, ) -> list[RagChatAppVO]: async with GetAsyncSession() as session: await self._ensure_rag_chat_schema(session) sql = ( f""" SELECT a.id, a.name, a.description, a.is_default, a.area, COALESCE(NULLIF(BTRIM(a.tenant_code), ''), NULL) AS tenant_code, {self._APP_TENANT_NAME_SQL} AS tenant_name, COALESCE(d.is_public, FALSE) AS dataset_public FROM rag_chat_app a LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL WHERE a.deleted_at IS NULL AND a.status = 1 AND (:only_default = FALSE OR a.is_default = TRUE) ORDER BY a.sort_order ASC, a.created_at DESC """ ) rows = ( await session.execute( text(sql), { "only_default": only_default, }, ) ).mappings().all() tenant_context = await self._resolve_tenant_context(user_area, tenant_code, tenant_name) data: list[RagChatAppVO] = [] for row in rows: record = dict(row) if not await self._app_visible(record, tenant_context=tenant_context, user_role=user_role): continue data.append( RagChatAppVO( appId=str(record["id"]), appName=record["name"], description=record.get("description") or "", tenantCode=str(record.get("tenant_code") or ""), tenantName=str(record.get("tenant_name") or record.get("area") or ""), isDefault=bool(record.get("is_default")), ) ) return data async def _resolve_app( self, app_id: int | None, user_area: str | None, user_role: str | None, tenant_code: str | None, tenant_name: str | None, ) -> dict | None: tenant_context = await self._resolve_tenant_context(user_area, tenant_code, tenant_name) async with GetAsyncSession() as session: await self._ensure_rag_chat_schema(session) params = { "app_id": app_id, } base_sql = ( f""" SELECT a.id, a.name, a.description, a.area, COALESCE(NULLIF(BTRIM(a.tenant_code), ''), NULL) AS tenant_code, {self._APP_TENANT_NAME_SQL} AS tenant_name, a.dataset_id, a.system_prompt, a.llm_model, a.temperature, a.max_tokens, a.opening_statement, a.suggested_questions, a.is_default, COALESCE(d.is_public, FALSE) AS dataset_public, COALESCE(d.name, '') AS dataset_name FROM rag_chat_app a LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL WHERE a.deleted_at IS NULL AND a.status = 1 """ ) if app_id is not None: row = ( await session.execute( text(base_sql + " AND a.id = :app_id LIMIT 1"), params, ) ).mappings().first() if row and await self._app_visible(dict(row), tenant_context=tenant_context, user_role=user_role): return dict(row) row = ( await session.execute( text(base_sql + " AND a.is_default = TRUE ORDER BY a.sort_order ASC, a.created_at DESC LIMIT 1"), params, ) ).mappings().first() if row and await self._app_visible(dict(row), tenant_context=tenant_context, user_role=user_role): return dict(row) row = ( await session.execute( text(base_sql + " ORDER BY a.sort_order ASC, a.created_at DESC LIMIT 1"), params, ) ).mappings().first() return dict(row) if row and await self._app_visible(dict(row), tenant_context=tenant_context, user_role=user_role) else None async def _app_visible(self, row: dict, tenant_context: dict, user_role: str | None) -> bool: if self._role_is_global(user_role): return True if bool(row.get("dataset_public")): return True if str(row.get("tenant_code") or "").strip().upper() == "PUBLIC": return True return self._row_matches_tenant_scope( row_tenant_code=row.get("tenant_code"), row_area=row.get("area"), tenant_context=tenant_context, ) async def _resolve_tenant_context( self, user_area: str | None, tenant_code: str | None, tenant_name: str | None, ) -> dict[str, str | None]: resolved = await self.TenantResolver.ResolveUserContext( Area=user_area, TenantCode=tenant_code, TenantName=tenant_name, Source="rag_chat_user", ) return { "tenant_code": resolved.tenant_code, "tenant_name": resolved.tenant_name, "tenant_type": resolved.tenant_type, "area": user_area, } async def _resolve_record_tenant(self, raw_value: str | None): return await self.TenantResolver.Resolve( RawValue=raw_value, Source="rag_chat_record", ) async def _ensure_rag_chat_schema(self, session) -> None: if self.__class__._chat_schema_checked: return async with self.__class__._chat_schema_lock: if self.__class__._chat_schema_checked: return exists = ( await session.execute( text( """ SELECT 1 FROM information_schema.columns WHERE table_schema = current_schema() AND table_name = 'rag_chat_app' AND column_name = 'tenant_code' """ ) ) ).scalar_one_or_none() if exists: self.__class__._chat_schema_checked = True return await session.execute(text("SET LOCAL lock_timeout = '1000ms'")) await session.execute(text("ALTER TABLE rag_chat_app ADD COLUMN tenant_code VARCHAR(64) NULL")) await session.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_chat_app_tenant_code ON rag_chat_app(tenant_code) WHERE deleted_at IS NULL")) self.__class__._chat_schema_checked = True @staticmethod def _tenant_context_is_global(tenant_context: dict[str, str | None]) -> bool: tenant_code = str(tenant_context.get("tenant_code") or "").strip().upper() return tenant_code in {"PUBLIC", "PROVINCIAL"} @staticmethod def _role_is_global(user_role: str | None) -> bool: normalized = str(user_role or "").strip() return normalized in {"super_admin", "provincial_admin"} def _row_matches_tenant_scope( self, *, row_tenant_code: str | None, row_area: str | None, tenant_context: dict[str, str | None], ) -> bool: user_tenant_code = str(tenant_context.get("tenant_code") or "").strip() if user_tenant_code: return str(row_tenant_code or "").strip() == user_tenant_code return str(row_area or "").strip() == str(tenant_context.get("area") or "").strip() async def _ensure_conversation( self, user_id: int, conversation_id: str | None, app_id: int | None, user_area: str | None, user_role: str | None, tenant_code: str | None, tenant_name: str | None, ) -> str: tenant_context = await self._resolve_tenant_context(user_area, tenant_code, tenant_name) if conversation_id and conversation_id != "-1": async with GetAsyncSession() as session: row = ( await session.execute( text( """ SELECT conversation_id, user_id FROM rag_conversation WHERE conversation_id = :conversation_id AND deleted_at IS NULL LIMIT 1 """ ), {"conversation_id": conversation_id}, ) ).mappings().first() if row: if not await self._conversation_accessible( conversation_id=str(row["conversation_id"]), expected_user_id=user_id, tenant_context=tenant_context, user_role=user_role, app_id=app_id, session=session, ): raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权使用该会话") return str(row["conversation_id"]) conversation_id = str(uuid.uuid4()) async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ INSERT INTO rag_conversation (conversation_id, user_id, app_id, name, introduction) VALUES (:conversation_id, :user_id, :app_id, :name, '') """ ), { "conversation_id": conversation_id, "user_id": user_id, "app_id": app_id, "name": DEFAULT_CONVERSATION_NAME, }, ) return conversation_id async def _ensure_conversation_owner( self, *, user_id: int, conversation_id: str, user_area: str | None, user_role: str | None, tenant_code: str | None, tenant_name: str | None, ) -> None: tenant_context = await self._resolve_tenant_context(user_area, tenant_code, tenant_name) if not await self._conversation_accessible( conversation_id=conversation_id, expected_user_id=user_id, tenant_context=tenant_context, user_role=user_role, ): raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权访问该会话") async def _conversation_accessible( self, *, conversation_id: str, expected_user_id: int, tenant_context: dict[str, str | None], user_role: str | None, app_id: int | None = None, session=None, ) -> bool: if session is not None: return await self._conversation_accessible_with_session( session=session, conversation_id=conversation_id, expected_user_id=expected_user_id, tenant_context=tenant_context, user_role=user_role, app_id=app_id, ) async with GetAsyncSession() as owned_session: return await self._conversation_accessible_with_session( session=owned_session, conversation_id=conversation_id, expected_user_id=expected_user_id, tenant_context=tenant_context, user_role=user_role, app_id=app_id, ) async def _conversation_accessible_with_session( self, *, session, conversation_id: str, expected_user_id: int, tenant_context: dict[str, str | None], user_role: str | None, app_id: int | None = None, ) -> bool: row = ( await session.execute( text( """ SELECT c.conversation_id, c.user_id, c.app_id, a.area, COALESCE(NULLIF(BTRIM(a.tenant_code), ''), NULL) AS tenant_code, COALESCE(d.is_public, FALSE) AS dataset_public FROM rag_conversation c LEFT JOIN rag_chat_app a ON a.id = c.app_id AND a.deleted_at IS NULL LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL WHERE c.conversation_id = :conversation_id AND c.deleted_at IS NULL LIMIT 1 """ ), {"conversation_id": conversation_id}, ) ).mappings().first() if not row: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "会话不存在") if int(row["user_id"]) != expected_user_id: return False if app_id is not None and row.get("app_id") is not None and int(row["app_id"]) != int(app_id): return False app_row = { "tenant_code": row.get("tenant_code"), "area": row.get("area"), "dataset_public": row.get("dataset_public"), } return await self._app_visible(app_row, tenant_context=tenant_context, user_role=user_role) async def _retrieve_context(self, dataset_id: int | None, query: str) -> tuple[list[dict], str]: result = await self.retriever.retrieve(query=query, dataset_id=dataset_id) return result.chunks, result.dataset_name async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]: return await self.retriever._embed_texts(texts, model_name) async def _start_message_task( self, *, task_id: str, conversation_id: str, message_id: str, query: str, app: dict, ) -> None: self._task_events[task_id] = [] self._task_done[task_id] = False self._task_locks.setdefault(task_id, asyncio.Lock()) task = asyncio.create_task( self._run_message_task( task_id=task_id, conversation_id=conversation_id, message_id=message_id, query=query, app=app, ) ) self._message_tasks[task_id] = task async def _run_message_task( self, *, task_id: str, conversation_id: str, message_id: str, query: str, app: dict, ) -> None: context_chunks: list[dict] = [] dataset_name = "" collected_answer = "" message_end_payload: dict | None = None final_status = "completed" error_payload: dict | None = None last_persisted_length = 0 last_persisted_at = time.monotonic() try: context_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), query) async for chunk in generate_stream( query=query, context_chunks=context_chunks, conversation_id=conversation_id, message_id=message_id, system_prompt=app.get("system_prompt") or "", model=app.get("llm_model") or "", temperature=app.get("temperature"), max_tokens=app.get("max_tokens"), dataset_name=dataset_name, task_id=task_id, ): data = self._parse_sse_event(chunk) if not data: continue event = data.get("event") if event == "message": collected_answer += data.get("answer", "") now = time.monotonic() if len(collected_answer) - last_persisted_length >= 80 or now - last_persisted_at >= 0.5: await self._update_message_progress( conversation_id=conversation_id, message_id=message_id, content=collected_answer, metadata={"status": "running", "task_id": task_id}, ) last_persisted_length = len(collected_answer) last_persisted_at = now await self._append_task_event(task_id, data) continue if event == "message_end": message_end_payload = data continue if event == "error": final_status = "error" error_payload = data await self._append_task_event(task_id, data) continue await self._append_task_event(task_id, data) if final_status == "completed": followups: list[str] = [] try: followups = await generate_followups(query, collected_answer) except Exception: followups = [] sources = self._build_sources(context_chunks, dataset_name) if message_end_payload: message_end_metadata = message_end_payload.setdefault("metadata", {}) message_end_metadata["suggested_questions"] = followups message_end_metadata["retriever_resources"] = sources await self._append_task_event(task_id, message_end_payload) await self._finalize_message_record( conversation_id=conversation_id, message_id=message_id, content=collected_answer, sources=sources, metadata={"suggested_questions": followups, "status": "completed", "task_id": task_id}, ) await self._maybe_schedule_auto_title( conversation_id=conversation_id, message_id=message_id, query=query, answer=collected_answer, ) else: await self._finalize_message_record( conversation_id=conversation_id, message_id=message_id, content=collected_answer, sources=self._build_sources(context_chunks, dataset_name), metadata={ "suggested_questions": [], "status": "error", "task_id": task_id, "error": (error_payload or {}).get("message", ""), }, ) except asyncio.CancelledError: final_status = "stopped" await self._append_task_event( task_id, { "event": "error", "task_id": task_id, "message_id": message_id, "code": "message_stopped", "message": "用户已停止回答", }, ) await self._finalize_message_record( conversation_id=conversation_id, message_id=message_id, content=collected_answer, sources=self._build_sources(context_chunks, dataset_name), metadata={"suggested_questions": [], "status": "stopped", "task_id": task_id}, ) raise except Exception as exc: final_status = "error" await self._append_task_event( task_id, { "event": "error", "task_id": task_id, "message_id": message_id, "code": "server_error", "message": str(exc), }, ) await self._finalize_message_record( conversation_id=conversation_id, message_id=message_id, content=collected_answer, sources=self._build_sources(context_chunks, dataset_name), metadata={"suggested_questions": [], "status": "error", "task_id": task_id, "error": str(exc)}, ) finally: self._task_done[task_id] = True self._message_tasks.pop(task_id, None) self._task_locks.pop(task_id, None) async def _append_task_event(self, task_id: str, payload: dict) -> None: lock = self._task_locks.setdefault(task_id, asyncio.Lock()) async with lock: self._task_events.setdefault(task_id, []).append(payload) async def _finalize_message_record( self, *, conversation_id: str, message_id: str, content: str, sources: list[dict], metadata: dict, ) -> None: async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ UPDATE rag_message SET content = :content, sources = CAST(:sources AS jsonb), metadata = CAST(:metadata AS jsonb) WHERE message_id = :message_id AND conversation_id = :conversation_id AND role = 'assistant' """ ), { "conversation_id": conversation_id, "message_id": message_id, "content": content, "sources": json.dumps(sources, ensure_ascii=False), "metadata": json.dumps(metadata, ensure_ascii=False), }, ) await session.execute( text( """ UPDATE rag_conversation SET updated_at = NOW(), last_message_at = NOW() WHERE conversation_id = :conversation_id """ ), {"conversation_id": conversation_id}, ) async def _update_message_progress( self, *, conversation_id: str, message_id: str, content: str, metadata: dict, ) -> None: async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ UPDATE rag_message SET content = :content, metadata = CAST(:metadata AS jsonb) WHERE message_id = :message_id AND conversation_id = :conversation_id AND role = 'assistant' """ ), { "conversation_id": conversation_id, "message_id": message_id, "content": content, "metadata": json.dumps(metadata, ensure_ascii=False), }, ) async def _maybe_schedule_auto_title( self, *, conversation_id: str, message_id: str, query: str, answer: str, ) -> None: normalized_query = (query or "").strip() normalized_answer = (answer or "").strip() if not normalized_query or not normalized_answer: return async with GetAsyncSession() as session: row = ( await session.execute( text( """ SELECT conversation_id, name, COALESCE(title_source, 'default') AS title_source, COALESCE(title_generation_status, 'idle') AS title_generation_status, first_answer_message_id FROM rag_conversation WHERE conversation_id = :conversation_id AND deleted_at IS NULL LIMIT 1 """ ), {"conversation_id": conversation_id}, ) ).mappings().first() if not row: return title_source = str(row.get("title_source") or DEFAULT_TITLE_SOURCE) if title_source == MANUAL_TITLE_SOURCE: return current_name = str(row.get("name") or "").strip() if current_name and current_name != DEFAULT_CONVERSATION_NAME and title_source != DEFAULT_TITLE_SOURCE: return current_status = str(row.get("title_generation_status") or "idle") if current_status in {"pending", "running", "succeeded"}: return if row.get("first_answer_message_id") and str(row.get("first_answer_message_id")) != message_id: return async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ UPDATE rag_conversation SET title_generation_status = 'pending', first_answer_message_id = COALESCE(first_answer_message_id, :message_id), title_generation_error = NULL WHERE conversation_id = :conversation_id AND deleted_at IS NULL AND COALESCE(title_source, 'default') = 'default' """ ), { "conversation_id": conversation_id, "message_id": message_id, }, ) existing_task = self._title_tasks.get(conversation_id) if existing_task and not existing_task.done(): return task = asyncio.create_task( self._run_auto_title_task( conversation_id=conversation_id, answer_message_id=message_id, query=normalized_query, answer=normalized_answer, ) ) self._title_tasks[conversation_id] = task async def _run_auto_title_task( self, *, conversation_id: str, answer_message_id: str, query: str, answer: str, ) -> None: try: async with GetAsyncSession() as session: row = ( await session.execute( text( """ SELECT name, COALESCE(title_source, 'default') AS title_source, COALESCE(title_generation_status, 'idle') AS title_generation_status, first_answer_message_id FROM rag_conversation WHERE conversation_id = :conversation_id AND deleted_at IS NULL LIMIT 1 """ ), {"conversation_id": conversation_id}, ) ).mappings().first() if not row: return if str(row.get("title_source") or DEFAULT_TITLE_SOURCE) == MANUAL_TITLE_SOURCE: return if row.get("first_answer_message_id") and str(row.get("first_answer_message_id")) != answer_message_id: return async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ UPDATE rag_conversation SET title_generation_status = 'running', title_generation_error = NULL WHERE conversation_id = :conversation_id AND deleted_at IS NULL AND COALESCE(title_source, 'default') = 'default' """ ), {"conversation_id": conversation_id}, ) generated_title = await self._generate_conversation_title(query=query, answer=answer) cleaned_title = self._sanitize_generated_title(generated_title) if not cleaned_title: cleaned_title = self._build_fallback_title(query=query, answer=answer) if not cleaned_title: raise ValueError("未生成有效标题") async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ UPDATE rag_conversation SET name = :name, title_source = 'auto', title_generation_status = 'succeeded', title_generated_at = NOW(), title_generation_error = NULL WHERE conversation_id = :conversation_id AND deleted_at IS NULL AND COALESCE(title_source, 'default') = 'default' AND ( name IS NULL OR BTRIM(name) = '' OR name = :default_name ) """ ), { "conversation_id": conversation_id, "name": cleaned_title, "default_name": DEFAULT_CONVERSATION_NAME, }, ) except Exception as exc: async with GetAsyncSession() as session: async with session.begin(): await session.execute( text( """ UPDATE rag_conversation SET title_generation_status = 'failed', title_generation_error = :error WHERE conversation_id = :conversation_id AND deleted_at IS NULL AND COALESCE(title_source, 'default') = 'default' """ ), { "conversation_id": conversation_id, "error": str(exc)[:1000], }, ) finally: self._title_tasks.pop(conversation_id, None) async def _generate_conversation_title(self, *, query: str, answer: str) -> str: prompt = ( "请基于用户首轮提问和助手首轮回答,生成一个简洁、准确的中文会话标题。" "要求:" "1. 只输出标题本身;" "2. 不要标点结尾;" "3. 不要出现“关于”“用户询问”“问题解答”等空话;" "4. 优先 12-24 个中文字符,最长不超过 40 个字符。\\n" f"用户问题:{query[:500]}\\n" f"助手回答:{answer[:1500]}" ) async with httpx.AsyncClient(timeout=30.0) as client: resp = await client.post( build_openai_chat_completions_url(RAG_CONFIG["LLM_BASE_URL"]), json={ "model": RAG_CONFIG["LLM_MODEL"], "messages": [{"role": "user", "content": prompt}], "temperature": 0.2, "max_tokens": 80, "chat_template_kwargs": {"enable_thinking": False}, }, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {RAG_CONFIG['LLM_API_KEY']}", }, ) resp.raise_for_status() content = resp.json()["choices"][0]["message"]["content"] return str(content or "").strip() def _sanitize_generated_title(self, value: str) -> str: title = str(value or "").strip() if not title: return "" title = title.replace("\r", " ").replace("\n", " ").strip() title = re.sub(r"^```[\w-]*", "", title).strip() title = title.replace("```", "").strip() title = re.sub(r'^[\"“”\'‘’\[\]()()【】]+', "", title) title = re.sub(r'[\"“”\'‘’\[\]()()【】]+$', "", title) title = re.sub(r"^(标题|会话标题)[::]\\s*", "", title) title = re.sub(r"\\s+", " ", title).strip() title = title.rstrip("。!?!?.;;,,::") if title in {"新对话", "会话标题", "标题"}: return "" if len(title) > 40: title = title[:40].rstrip(",,。;;:: ") return title def _build_fallback_title(self, *, query: str, answer: str) -> str: base = query.strip() or answer.strip() if not base: return "" base = re.sub(r"\\s+", " ", base) base = re.sub(r"^[请问帮我一下关于针对结合根据请解释说明一下::,,\\s]+", "", base) base = base.rstrip("。!?!?.;;,,::") if len(base) > 24: base = base[:24].rstrip(",,。;;:: ") return base async def _keyword_retrieve_context( self, *, dataset_id: int, collection_name: str, dataset_name: str, query: str, top_k: int, score_threshold: float | None, ) -> list[dict]: chunks = await self.retriever._keyword_retrieve_context( dataset_id=dataset_id, collection_name=collection_name, dataset_name=dataset_name, query=query, top_k=top_k, score_threshold=score_threshold, source_names=None, ) return chunks[:top_k] def _build_keyword_terms(self, query: str) -> list[str]: return self.retriever._build_keyword_terms(query) def _normalize_keyword_query(self, query: str) -> str: return self.retriever._normalize_keyword_query(query) def _score_keyword_chunk(self, *, query: str, terms: list[str], content: str, document_name: str) -> float: return self.retriever._score_keyword_chunk( query=query, terms=terms, content=content, document_name=document_name, ) def _format_sse(self, payload: dict) -> bytes: return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n".encode("utf-8") def _build_sources(self, context_chunks: list[dict], dataset_name: str) -> list[dict]: build_sources = getattr(self.retriever, "build_sources", None) if callable(build_sources): return build_sources(context_chunks, dataset_name) return RagRetriever(hydrate_documents=False).build_sources(context_chunks, dataset_name) async def _hydrate_document_hits(self, dataset_id: int, chunks: list[dict]) -> list[dict]: return await self.retriever._hydrate_document_hits(dataset_id, chunks) def _parse_sse_event(self, chunk: str) -> dict | None: data_lines: list[str] = [] for line in chunk.splitlines(): if line.startswith("data: "): data_lines.append(line[6:]) elif line.startswith("data:"): data_lines.append(line[5:].lstrip()) if not data_lines: return None payload = "\n".join(part for part in data_lines if part.strip()).strip() if not payload or payload == "[DONE]": return None payload = payload.removesuffix("\\n\\n").removesuffix("\\n").strip() try: data = json.loads(payload) except json.JSONDecodeError: return None return data if isinstance(data, dict) else None