1774 lines
70 KiB
Python
1774 lines
70 KiB
Python
from __future__ import annotations
|
||
|
||
import asyncio
|
||
import json
|
||
import re
|
||
import time
|
||
import uuid
|
||
from typing import AsyncGenerator
|
||
|
||
import httpx
|
||
from sqlalchemy import text
|
||
|
||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
|
||
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
|
||
|
||
from fastapi_modules.fastapi_leaudit.domian.Dto.ragChatDto import (
|
||
RagConversationRenameDTO,
|
||
RagMessageFeedbackDTO,
|
||
RagStopMessageDTO,
|
||
)
|
||
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import (
|
||
RagAppParametersVO,
|
||
RagChatAppListVO,
|
||
RagChatAppVO,
|
||
RagConversationItemVO,
|
||
RagConversationPageVO,
|
||
RagConversationRenameVO,
|
||
RagMessageItemVO,
|
||
RagMessagePageVO,
|
||
RagOperationResultVO,
|
||
)
|
||
from fastapi_modules.fastapi_leaudit.rag_engine.config import (
|
||
RAG_CONFIG,
|
||
build_openai_chat_completions_url,
|
||
build_openai_embeddings_url,
|
||
)
|
||
from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma
|
||
from fastapi_modules.fastapi_leaudit.rag_engine.generator import generate_stream
|
||
from fastapi_modules.fastapi_leaudit.rag_engine.question_chains import generate_followups
|
||
from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService
|
||
from fastapi_modules.fastapi_leaudit.services.impl.tenantResolver import TenantResolver
|
||
|
||
|
||
DEFAULT_CONVERSATION_NAME = "新对话"
|
||
DEFAULT_TITLE_SOURCE = "default"
|
||
AUTO_TITLE_SOURCE = "auto"
|
||
MANUAL_TITLE_SOURCE = "manual"
|
||
|
||
|
||
class RagChatServiceImpl(IRagChatService):
|
||
_message_tasks: dict[str, asyncio.Task] = {}
|
||
_task_events: dict[str, list[dict]] = {}
|
||
_task_done: dict[str, bool] = {}
|
||
_task_locks: dict[str, asyncio.Lock] = {}
|
||
_title_tasks: dict[str, asyncio.Task] = {}
|
||
|
||
def __init__(self) -> None:
|
||
self.TenantResolver = TenantResolver()
|
||
|
||
_APP_TENANT_NAME_SQL = (
|
||
"CASE "
|
||
"WHEN NULLIF(BTRIM(a.tenant_code), '') = 'PUBLIC' THEN '公共' "
|
||
"WHEN NULLIF(BTRIM(a.tenant_code), '') = 'PROVINCIAL' THEN '省级' "
|
||
"ELSE COALESCE(NULLIF(BTRIM(a.area), ''), '未分配地区') "
|
||
"END"
|
||
)
|
||
|
||
async def GetApps(
|
||
self,
|
||
CurrentUserId: int,
|
||
UserArea: str | None,
|
||
UserRole: str | None,
|
||
TenantCode: str | None = None,
|
||
TenantName: str | None = None,
|
||
) -> RagChatAppListVO:
|
||
apps = await self._load_apps(UserArea, UserRole, TenantCode, TenantName, only_default=False)
|
||
return RagChatAppListVO(data=apps, total=len(apps))
|
||
|
||
async def GetDefaultApp(
|
||
self,
|
||
CurrentUserId: int,
|
||
UserArea: str | None,
|
||
UserRole: str | None,
|
||
TenantCode: str | None = None,
|
||
TenantName: str | None = None,
|
||
) -> RagChatAppVO | None:
|
||
apps = await self._load_apps(UserArea, UserRole, TenantCode, TenantName, only_default=True)
|
||
if apps:
|
||
return apps[0]
|
||
all_apps = await self._load_apps(UserArea, UserRole, TenantCode, TenantName, only_default=False)
|
||
return all_apps[0] if all_apps else None
|
||
|
||
async def SendMessage(
|
||
self,
|
||
CurrentUserId: int,
|
||
UserName: str,
|
||
UserArea: str | None,
|
||
UserRole: str | None,
|
||
Query: str,
|
||
ConversationId: str | None,
|
||
AppId: int | None,
|
||
TenantCode: str | None = None,
|
||
TenantName: str | None = None,
|
||
) -> AsyncGenerator[bytes, None]:
|
||
if not Query.strip():
|
||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "问题不能为空")
|
||
|
||
app = await self._resolve_app(AppId, UserArea, UserRole, TenantCode, TenantName)
|
||
if not app:
|
||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "未配置可用聊天应用")
|
||
|
||
conversationId = await self._ensure_conversation(
|
||
user_id=CurrentUserId,
|
||
conversation_id=ConversationId,
|
||
app_id=app["id"],
|
||
user_area=UserArea,
|
||
user_role=UserRole,
|
||
tenant_code=TenantCode,
|
||
tenant_name=TenantName,
|
||
)
|
||
messageId = str(uuid.uuid4())
|
||
taskId = str(uuid.uuid4())
|
||
is_new_conversation = not ConversationId or ConversationId == "-1"
|
||
|
||
async with GetAsyncSession() as session:
|
||
async with session.begin():
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata)
|
||
VALUES (:message_id, :conversation_id, 'user', :content, '[]'::jsonb, '{}'::jsonb)
|
||
"""
|
||
),
|
||
{
|
||
"message_id": str(uuid.uuid4()),
|
||
"conversation_id": conversationId,
|
||
"content": Query,
|
||
},
|
||
)
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata)
|
||
VALUES (:message_id, :conversation_id, 'assistant', '', '[]'::jsonb, CAST(:metadata AS jsonb))
|
||
"""
|
||
),
|
||
{
|
||
"message_id": messageId,
|
||
"conversation_id": conversationId,
|
||
"metadata": json.dumps({"status": "running", "task_id": taskId}, ensure_ascii=False),
|
||
},
|
||
)
|
||
await session.execute(
|
||
text(
|
||
"UPDATE rag_conversation SET updated_at = NOW() WHERE conversation_id = :conversation_id"
|
||
),
|
||
{"conversation_id": conversationId},
|
||
)
|
||
|
||
await self._start_message_task(
|
||
task_id=taskId,
|
||
conversation_id=conversationId,
|
||
message_id=messageId,
|
||
query=Query,
|
||
app=app,
|
||
)
|
||
|
||
event_index = 0
|
||
initial_events: list[dict] = []
|
||
if is_new_conversation:
|
||
initial_events.append(
|
||
{
|
||
"event": "conversation_created",
|
||
"conversation_id": conversationId,
|
||
"message_id": messageId,
|
||
"task_id": taskId,
|
||
}
|
||
)
|
||
|
||
while True:
|
||
if event_index < len(initial_events):
|
||
payload = initial_events[event_index]
|
||
event_index += 1
|
||
yield self._format_sse(payload)
|
||
continue
|
||
|
||
events = self._task_events.get(taskId, [])
|
||
if event_index - len(initial_events) < len(events):
|
||
payload = events[event_index - len(initial_events)]
|
||
event_index += 1
|
||
yield self._format_sse(payload)
|
||
continue
|
||
|
||
if self._task_done.get(taskId):
|
||
break
|
||
|
||
await asyncio.sleep(0.05)
|
||
|
||
async def GetConversations(
|
||
self,
|
||
CurrentUserId: int,
|
||
UserArea: str | None,
|
||
UserRole: str | None,
|
||
AppId: int | None,
|
||
Page: int,
|
||
PageSize: int,
|
||
TenantCode: str | None = None,
|
||
TenantName: str | None = None,
|
||
) -> RagConversationPageVO:
|
||
tenant_context = await self._resolve_tenant_context(UserArea, TenantCode, TenantName)
|
||
async with GetAsyncSession() as session:
|
||
rows = (
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
SELECT conversation_id, name, introduction, created_at, updated_at
|
||
, COALESCE(title_source, 'default') AS title_source
|
||
, COALESCE(EXTRACT(EPOCH FROM last_message_at), 0) AS last_message_at
|
||
FROM rag_conversation
|
||
WHERE user_id = :user_id
|
||
AND deleted_at IS NULL
|
||
AND (CAST(:app_id AS BIGINT) IS NULL OR app_id = CAST(:app_id AS BIGINT))
|
||
ORDER BY COALESCE(last_message_at, updated_at) DESC, updated_at DESC
|
||
OFFSET :offset LIMIT :limit
|
||
"""
|
||
),
|
||
{
|
||
"user_id": CurrentUserId,
|
||
"app_id": AppId,
|
||
"offset": max(Page - 1, 0) * PageSize,
|
||
"limit": PageSize + 1,
|
||
},
|
||
)
|
||
).mappings().all()
|
||
filtered_rows: list[dict] = []
|
||
for row in rows:
|
||
record = dict(row)
|
||
if await self._conversation_accessible(
|
||
conversation_id=str(record["conversation_id"]),
|
||
expected_user_id=CurrentUserId,
|
||
tenant_context=tenant_context,
|
||
user_role=UserRole,
|
||
app_id=AppId,
|
||
session=session,
|
||
):
|
||
filtered_rows.append(record)
|
||
has_more = len(filtered_rows) > PageSize
|
||
items = filtered_rows[:PageSize]
|
||
return RagConversationPageVO(
|
||
data=[
|
||
RagConversationItemVO(
|
||
id=row["conversation_id"],
|
||
name=row["name"],
|
||
introduction=row.get("introduction") or "",
|
||
titleSource=str(row.get("title_source") or DEFAULT_TITLE_SOURCE),
|
||
createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0,
|
||
updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0,
|
||
lastMessageAt=int(float(row.get("last_message_at") or 0)),
|
||
)
|
||
for row in items
|
||
],
|
||
hasMore=has_more,
|
||
limit=PageSize,
|
||
)
|
||
|
||
async def GetConversationMessages(
|
||
self,
|
||
CurrentUserId: int,
|
||
UserArea: str | None,
|
||
UserRole: str | None,
|
||
ConversationId: str,
|
||
Page: int,
|
||
PageSize: int,
|
||
TenantCode: str | None = None,
|
||
TenantName: str | None = None,
|
||
) -> RagMessagePageVO:
|
||
await self._ensure_conversation_owner(
|
||
user_id=CurrentUserId,
|
||
conversation_id=ConversationId,
|
||
user_area=UserArea,
|
||
user_role=UserRole,
|
||
tenant_code=TenantCode,
|
||
tenant_name=TenantName,
|
||
)
|
||
async with GetAsyncSession() as session:
|
||
rows = (
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
SELECT message_id, role, content, sources, metadata, feedback, created_at
|
||
FROM rag_message
|
||
WHERE conversation_id = :conversation_id
|
||
ORDER BY created_at ASC,
|
||
CASE role
|
||
WHEN 'user' THEN 0
|
||
WHEN 'assistant' THEN 1
|
||
ELSE 2
|
||
END ASC,
|
||
message_id ASC
|
||
OFFSET :offset LIMIT :limit
|
||
"""
|
||
),
|
||
{
|
||
"conversation_id": ConversationId,
|
||
"offset": max(Page - 1, 0) * PageSize,
|
||
"limit": PageSize + 1,
|
||
},
|
||
)
|
||
).mappings().all()
|
||
has_more = len(rows) > PageSize
|
||
items = rows[:PageSize]
|
||
data: list[RagMessageItemVO] = []
|
||
idx = 0
|
||
while idx < len(items):
|
||
row = items[idx]
|
||
if row["role"] == "user":
|
||
answer = items[idx + 1] if idx + 1 < len(items) and items[idx + 1]["role"] == "assistant" else None
|
||
answer_metadata = dict((answer.get("metadata") if answer else None) or {})
|
||
answer_status = str(answer_metadata.get("status") or ("completed" if answer else "running"))
|
||
answer_content = (answer.get("content") if answer else None) or ""
|
||
|
||
if answer:
|
||
task_id = str(answer_metadata.get("task_id") or "").strip()
|
||
reconstructed_content = self._rebuild_message_content_from_events(task_id) if task_id else ""
|
||
if reconstructed_content and len(reconstructed_content) >= len(answer_content):
|
||
if reconstructed_content != answer_content:
|
||
await self._update_message_progress(
|
||
conversation_id=ConversationId,
|
||
message_id=answer["message_id"],
|
||
content=reconstructed_content,
|
||
metadata=answer_metadata,
|
||
)
|
||
answer_content = reconstructed_content
|
||
|
||
normalized_status = await self._resolve_persisted_message_status(
|
||
conversation_id=ConversationId,
|
||
message_id=answer["message_id"],
|
||
content=answer_content,
|
||
metadata=answer_metadata,
|
||
)
|
||
if normalized_status != answer_status:
|
||
answer_status = normalized_status
|
||
answer_metadata["status"] = normalized_status
|
||
|
||
data.append(
|
||
RagMessageItemVO(
|
||
id=(answer["message_id"] if answer else row["message_id"]),
|
||
conversationId=ConversationId,
|
||
query=row["content"],
|
||
answer=answer_content if answer else "",
|
||
feedback=({"rating": answer["feedback"]} if answer and answer.get("feedback") else None),
|
||
retrieverResources=(answer.get("sources") if answer else None),
|
||
suggestedQuestions=[str(item) for item in (answer_metadata.get("suggested_questions") or []) if str(item).strip()],
|
||
status=answer_status,
|
||
createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0,
|
||
)
|
||
)
|
||
idx += 2 if answer else 1
|
||
else:
|
||
idx += 1
|
||
return RagMessagePageVO(data=data, hasMore=has_more, limit=PageSize)
|
||
|
||
async def _resolve_persisted_message_status(
|
||
self,
|
||
*,
|
||
conversation_id: str,
|
||
message_id: str,
|
||
content: str,
|
||
metadata: dict,
|
||
) -> str:
|
||
status = str(metadata.get("status") or "completed")
|
||
if status != "running":
|
||
return status
|
||
|
||
task_id = str(metadata.get("task_id") or "").strip()
|
||
task = self._message_tasks.get(task_id) if task_id else None
|
||
task_done = self._task_done.get(task_id, False) if task_id else False
|
||
|
||
if task and not task.done() and not task_done:
|
||
return "running"
|
||
|
||
normalized_status = "completed" if content.strip() else "error"
|
||
normalized_metadata = {
|
||
**metadata,
|
||
"status": normalized_status,
|
||
}
|
||
if normalized_status == "error" and not normalized_metadata.get("error"):
|
||
normalized_metadata["error"] = "生成任务已结束,但未产出有效回答"
|
||
|
||
await self._update_message_progress(
|
||
conversation_id=conversation_id,
|
||
message_id=message_id,
|
||
content=content,
|
||
metadata=normalized_metadata,
|
||
)
|
||
return normalized_status
|
||
|
||
def _rebuild_message_content_from_events(self, task_id: str) -> str:
|
||
if not task_id:
|
||
return ""
|
||
|
||
chunks: list[str] = []
|
||
for event in self._task_events.get(task_id, []):
|
||
if event.get("event") != "message":
|
||
continue
|
||
answer = event.get("answer")
|
||
if isinstance(answer, str) and answer:
|
||
chunks.append(answer)
|
||
return "".join(chunks)
|
||
|
||
async def RenameConversation(
|
||
self,
|
||
CurrentUserId: int,
|
||
UserArea: str | None,
|
||
UserRole: str | None,
|
||
ConversationId: str,
|
||
Body: RagConversationRenameDTO,
|
||
TenantCode: str | None = None,
|
||
TenantName: str | None = None,
|
||
) -> RagConversationRenameVO:
|
||
await self._ensure_conversation_owner(
|
||
user_id=CurrentUserId,
|
||
conversation_id=ConversationId,
|
||
user_area=UserArea,
|
||
user_role=UserRole,
|
||
tenant_code=TenantCode,
|
||
tenant_name=TenantName,
|
||
)
|
||
final_name = Body.name.strip()
|
||
if not final_name:
|
||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "会话名称不能为空")
|
||
async with GetAsyncSession() as session:
|
||
async with session.begin():
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
UPDATE rag_conversation
|
||
SET name = :name,
|
||
title_source = 'manual',
|
||
title_generation_status = CASE
|
||
WHEN COALESCE(title_generation_status, 'idle') = 'running' THEN 'succeeded'
|
||
ELSE COALESCE(title_generation_status, 'idle')
|
||
END,
|
||
title_generation_error = NULL,
|
||
updated_at = NOW()
|
||
WHERE conversation_id = :conversation_id
|
||
"""
|
||
),
|
||
{"name": final_name, "conversation_id": ConversationId},
|
||
)
|
||
return RagConversationRenameVO(result="success", name=final_name)
|
||
|
||
async def DeleteConversation(
|
||
self,
|
||
CurrentUserId: int,
|
||
UserArea: str | None,
|
||
UserRole: str | None,
|
||
ConversationId: str,
|
||
TenantCode: str | None = None,
|
||
TenantName: str | None = None,
|
||
) -> RagOperationResultVO:
|
||
await self._ensure_conversation_owner(
|
||
user_id=CurrentUserId,
|
||
conversation_id=ConversationId,
|
||
user_area=UserArea,
|
||
user_role=UserRole,
|
||
tenant_code=TenantCode,
|
||
tenant_name=TenantName,
|
||
)
|
||
async with GetAsyncSession() as session:
|
||
async with session.begin():
|
||
await session.execute(
|
||
text(
|
||
"UPDATE rag_conversation SET deleted_at = NOW(), updated_at = NOW() WHERE conversation_id = :conversation_id"
|
||
),
|
||
{"conversation_id": ConversationId},
|
||
)
|
||
return RagOperationResultVO(result="success")
|
||
|
||
async def UpdateFeedback(
|
||
self,
|
||
CurrentUserId: int,
|
||
UserArea: str | None,
|
||
UserRole: str | None,
|
||
MessageId: str,
|
||
Body: RagMessageFeedbackDTO,
|
||
TenantCode: str | None = None,
|
||
TenantName: str | None = None,
|
||
) -> RagOperationResultVO:
|
||
tenant_context = await self._resolve_tenant_context(UserArea, TenantCode, TenantName)
|
||
async with GetAsyncSession() as session:
|
||
row = (
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
SELECT c.user_id, c.conversation_id
|
||
FROM rag_message m
|
||
JOIN rag_conversation c ON c.conversation_id = m.conversation_id
|
||
WHERE m.message_id = :message_id AND c.deleted_at IS NULL
|
||
LIMIT 1
|
||
"""
|
||
),
|
||
{"message_id": MessageId},
|
||
)
|
||
).mappings().first()
|
||
if not row:
|
||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "消息不存在")
|
||
if not await self._conversation_accessible(
|
||
conversation_id=str(row["conversation_id"]),
|
||
expected_user_id=CurrentUserId,
|
||
tenant_context=tenant_context,
|
||
user_role=UserRole,
|
||
session=session,
|
||
):
|
||
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权修改该消息反馈")
|
||
await session.execute(
|
||
text("UPDATE rag_message SET feedback = :feedback WHERE message_id = :message_id"),
|
||
{"feedback": Body.rating, "message_id": MessageId},
|
||
)
|
||
return RagOperationResultVO(result="success")
|
||
|
||
async def StopMessage(
|
||
self,
|
||
CurrentUserId: int,
|
||
UserArea: str | None,
|
||
UserRole: str | None,
|
||
MessageId: str,
|
||
Body: RagStopMessageDTO | None = None,
|
||
TenantCode: str | None = None,
|
||
TenantName: str | None = None,
|
||
) -> RagOperationResultVO:
|
||
tenant_context = await self._resolve_tenant_context(UserArea, TenantCode, TenantName)
|
||
async with GetAsyncSession() as session:
|
||
row = (
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
SELECT m.metadata, c.user_id, c.conversation_id
|
||
FROM rag_message m
|
||
JOIN rag_conversation c ON c.conversation_id = m.conversation_id
|
||
WHERE m.message_id = :message_id
|
||
AND c.deleted_at IS NULL
|
||
LIMIT 1
|
||
"""
|
||
),
|
||
{"message_id": MessageId},
|
||
)
|
||
).mappings().first()
|
||
if not row:
|
||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "消息不存在")
|
||
if not await self._conversation_accessible(
|
||
conversation_id=str(row["conversation_id"]),
|
||
expected_user_id=CurrentUserId,
|
||
tenant_context=tenant_context,
|
||
user_role=UserRole,
|
||
):
|
||
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权停止该消息")
|
||
|
||
metadata = row.get("metadata") or {}
|
||
task_id = str(Body.taskId or metadata.get("task_id") or "").strip()
|
||
task = self._message_tasks.get(task_id) if task_id else None
|
||
if task and not task.done():
|
||
task.cancel()
|
||
return RagOperationResultVO(result="success")
|
||
|
||
async def GetAppParameters(
|
||
self,
|
||
CurrentUserId: int,
|
||
UserArea: str | None,
|
||
UserRole: str | None,
|
||
AppId: int | None,
|
||
TenantCode: str | None = None,
|
||
TenantName: str | None = None,
|
||
) -> RagAppParametersVO:
|
||
app = await self._resolve_app(AppId, UserArea, UserRole, TenantCode, TenantName)
|
||
if not app:
|
||
return RagAppParametersVO()
|
||
try:
|
||
suggested = json.loads(app.get("suggested_questions") or "[]")
|
||
if not isinstance(suggested, list):
|
||
suggested = []
|
||
except Exception:
|
||
suggested = []
|
||
return RagAppParametersVO(
|
||
openingStatement=app.get("opening_statement") or "",
|
||
suggestedQuestions=[str(item) for item in suggested[:6]],
|
||
userInputForm=[],
|
||
fileUpload={"image": {"enabled": False}},
|
||
)
|
||
|
||
async def _load_apps(
|
||
self,
|
||
user_area: str | None,
|
||
user_role: str | None,
|
||
tenant_code: str | None,
|
||
tenant_name: str | None,
|
||
only_default: bool,
|
||
) -> list[RagChatAppVO]:
|
||
async with GetAsyncSession() as session:
|
||
await self._ensure_rag_chat_schema(session)
|
||
sql = (
|
||
f"""
|
||
SELECT a.id, a.name, a.description, a.is_default, a.area,
|
||
COALESCE(NULLIF(BTRIM(a.tenant_code), ''), NULL) AS tenant_code,
|
||
{self._APP_TENANT_NAME_SQL} AS tenant_name,
|
||
COALESCE(d.is_public, FALSE) AS dataset_public
|
||
FROM rag_chat_app a
|
||
LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL
|
||
WHERE a.deleted_at IS NULL
|
||
AND a.status = 1
|
||
AND (:only_default = FALSE OR a.is_default = TRUE)
|
||
ORDER BY a.sort_order ASC, a.created_at DESC
|
||
"""
|
||
)
|
||
rows = (
|
||
await session.execute(
|
||
text(sql),
|
||
{
|
||
"only_default": only_default,
|
||
},
|
||
)
|
||
).mappings().all()
|
||
tenant_context = await self._resolve_tenant_context(user_area, tenant_code, tenant_name)
|
||
data: list[RagChatAppVO] = []
|
||
for row in rows:
|
||
record = dict(row)
|
||
if not await self._app_visible(record, tenant_context=tenant_context, user_role=user_role):
|
||
continue
|
||
data.append(
|
||
RagChatAppVO(
|
||
appId=str(record["id"]),
|
||
appName=record["name"],
|
||
description=record.get("description") or "",
|
||
tenantCode=str(record.get("tenant_code") or ""),
|
||
tenantName=str(record.get("tenant_name") or record.get("area") or ""),
|
||
isDefault=bool(record.get("is_default")),
|
||
)
|
||
)
|
||
return data
|
||
|
||
async def _resolve_app(
|
||
self,
|
||
app_id: int | None,
|
||
user_area: str | None,
|
||
user_role: str | None,
|
||
tenant_code: str | None,
|
||
tenant_name: str | None,
|
||
) -> dict | None:
|
||
tenant_context = await self._resolve_tenant_context(user_area, tenant_code, tenant_name)
|
||
async with GetAsyncSession() as session:
|
||
await self._ensure_rag_chat_schema(session)
|
||
params = {
|
||
"app_id": app_id,
|
||
}
|
||
base_sql = (
|
||
f"""
|
||
SELECT a.id, a.name, a.description, a.area,
|
||
COALESCE(NULLIF(BTRIM(a.tenant_code), ''), NULL) AS tenant_code,
|
||
{self._APP_TENANT_NAME_SQL} AS tenant_name,
|
||
a.dataset_id, a.system_prompt,
|
||
a.llm_model, a.temperature, a.max_tokens, a.opening_statement,
|
||
a.suggested_questions, a.is_default, COALESCE(d.is_public, FALSE) AS dataset_public,
|
||
COALESCE(d.name, '') AS dataset_name
|
||
FROM rag_chat_app a
|
||
LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL
|
||
WHERE a.deleted_at IS NULL AND a.status = 1
|
||
"""
|
||
)
|
||
if app_id is not None:
|
||
row = (
|
||
await session.execute(
|
||
text(base_sql + " AND a.id = :app_id LIMIT 1"),
|
||
params,
|
||
)
|
||
).mappings().first()
|
||
if row and await self._app_visible(dict(row), tenant_context=tenant_context, user_role=user_role):
|
||
return dict(row)
|
||
row = (
|
||
await session.execute(
|
||
text(base_sql + " AND a.is_default = TRUE ORDER BY a.sort_order ASC, a.created_at DESC LIMIT 1"),
|
||
params,
|
||
)
|
||
).mappings().first()
|
||
if row and await self._app_visible(dict(row), tenant_context=tenant_context, user_role=user_role):
|
||
return dict(row)
|
||
row = (
|
||
await session.execute(
|
||
text(base_sql + " ORDER BY a.sort_order ASC, a.created_at DESC LIMIT 1"),
|
||
params,
|
||
)
|
||
).mappings().first()
|
||
return dict(row) if row and await self._app_visible(dict(row), tenant_context=tenant_context, user_role=user_role) else None
|
||
|
||
async def _app_visible(self, row: dict, tenant_context: dict, user_role: str | None) -> bool:
|
||
if self._role_is_global(user_role):
|
||
return True
|
||
if bool(row.get("dataset_public")):
|
||
return True
|
||
if str(row.get("tenant_code") or "").strip().upper() == "PUBLIC":
|
||
return True
|
||
return self._row_matches_tenant_scope(
|
||
row_tenant_code=row.get("tenant_code"),
|
||
row_area=row.get("area"),
|
||
tenant_context=tenant_context,
|
||
)
|
||
|
||
async def _resolve_tenant_context(
|
||
self,
|
||
user_area: str | None,
|
||
tenant_code: str | None,
|
||
tenant_name: str | None,
|
||
) -> dict[str, str | None]:
|
||
resolved = await self.TenantResolver.ResolveUserContext(
|
||
Area=user_area,
|
||
TenantCode=tenant_code,
|
||
TenantName=tenant_name,
|
||
Source="rag_chat_user",
|
||
)
|
||
return {
|
||
"tenant_code": resolved.tenant_code,
|
||
"tenant_name": resolved.tenant_name,
|
||
"tenant_type": resolved.tenant_type,
|
||
"area": user_area,
|
||
}
|
||
|
||
async def _resolve_record_tenant(self, raw_value: str | None):
|
||
return await self.TenantResolver.Resolve(
|
||
RawValue=raw_value,
|
||
Source="rag_chat_record",
|
||
)
|
||
|
||
async def _ensure_rag_chat_schema(self, session) -> None:
|
||
await session.execute(text("ALTER TABLE rag_chat_app ADD COLUMN IF NOT EXISTS tenant_code VARCHAR(64) NULL"))
|
||
await session.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_chat_app_tenant_code ON rag_chat_app(tenant_code) WHERE deleted_at IS NULL"))
|
||
|
||
@staticmethod
|
||
def _tenant_context_is_global(tenant_context: dict[str, str | None]) -> bool:
|
||
tenant_code = str(tenant_context.get("tenant_code") or "").strip().upper()
|
||
return tenant_code in {"PUBLIC", "PROVINCIAL"}
|
||
|
||
@staticmethod
|
||
def _role_is_global(user_role: str | None) -> bool:
|
||
normalized = str(user_role or "").strip()
|
||
return normalized in {"super_admin", "provincial_admin"}
|
||
|
||
def _row_matches_tenant_scope(
|
||
self,
|
||
*,
|
||
row_tenant_code: str | None,
|
||
row_area: str | None,
|
||
tenant_context: dict[str, str | None],
|
||
) -> bool:
|
||
user_tenant_code = str(tenant_context.get("tenant_code") or "").strip()
|
||
if user_tenant_code:
|
||
return str(row_tenant_code or "").strip() == user_tenant_code
|
||
return str(row_area or "").strip() == str(tenant_context.get("area") or "").strip()
|
||
|
||
async def _ensure_conversation(
|
||
self,
|
||
user_id: int,
|
||
conversation_id: str | None,
|
||
app_id: int | None,
|
||
user_area: str | None,
|
||
user_role: str | None,
|
||
tenant_code: str | None,
|
||
tenant_name: str | None,
|
||
) -> str:
|
||
tenant_context = await self._resolve_tenant_context(user_area, tenant_code, tenant_name)
|
||
if conversation_id and conversation_id != "-1":
|
||
async with GetAsyncSession() as session:
|
||
row = (
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
SELECT conversation_id, user_id
|
||
FROM rag_conversation
|
||
WHERE conversation_id = :conversation_id
|
||
AND deleted_at IS NULL
|
||
LIMIT 1
|
||
"""
|
||
),
|
||
{"conversation_id": conversation_id},
|
||
)
|
||
).mappings().first()
|
||
if row:
|
||
if not await self._conversation_accessible(
|
||
conversation_id=str(row["conversation_id"]),
|
||
expected_user_id=user_id,
|
||
tenant_context=tenant_context,
|
||
user_role=user_role,
|
||
app_id=app_id,
|
||
session=session,
|
||
):
|
||
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权使用该会话")
|
||
return str(row["conversation_id"])
|
||
conversation_id = str(uuid.uuid4())
|
||
async with GetAsyncSession() as session:
|
||
async with session.begin():
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
INSERT INTO rag_conversation (conversation_id, user_id, app_id, name, introduction)
|
||
VALUES (:conversation_id, :user_id, :app_id, :name, '')
|
||
"""
|
||
),
|
||
{
|
||
"conversation_id": conversation_id,
|
||
"user_id": user_id,
|
||
"app_id": app_id,
|
||
"name": DEFAULT_CONVERSATION_NAME,
|
||
},
|
||
)
|
||
return conversation_id
|
||
|
||
async def _ensure_conversation_owner(
|
||
self,
|
||
*,
|
||
user_id: int,
|
||
conversation_id: str,
|
||
user_area: str | None,
|
||
user_role: str | None,
|
||
tenant_code: str | None,
|
||
tenant_name: str | None,
|
||
) -> None:
|
||
tenant_context = await self._resolve_tenant_context(user_area, tenant_code, tenant_name)
|
||
if not await self._conversation_accessible(
|
||
conversation_id=conversation_id,
|
||
expected_user_id=user_id,
|
||
tenant_context=tenant_context,
|
||
user_role=user_role,
|
||
):
|
||
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权访问该会话")
|
||
|
||
async def _conversation_accessible(
|
||
self,
|
||
*,
|
||
conversation_id: str,
|
||
expected_user_id: int,
|
||
tenant_context: dict[str, str | None],
|
||
user_role: str | None,
|
||
app_id: int | None = None,
|
||
session=None,
|
||
) -> bool:
|
||
if session is not None:
|
||
return await self._conversation_accessible_with_session(
|
||
session=session,
|
||
conversation_id=conversation_id,
|
||
expected_user_id=expected_user_id,
|
||
tenant_context=tenant_context,
|
||
user_role=user_role,
|
||
app_id=app_id,
|
||
)
|
||
async with GetAsyncSession() as owned_session:
|
||
return await self._conversation_accessible_with_session(
|
||
session=owned_session,
|
||
conversation_id=conversation_id,
|
||
expected_user_id=expected_user_id,
|
||
tenant_context=tenant_context,
|
||
user_role=user_role,
|
||
app_id=app_id,
|
||
)
|
||
|
||
async def _conversation_accessible_with_session(
|
||
self,
|
||
*,
|
||
session,
|
||
conversation_id: str,
|
||
expected_user_id: int,
|
||
tenant_context: dict[str, str | None],
|
||
user_role: str | None,
|
||
app_id: int | None = None,
|
||
) -> bool:
|
||
row = (
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
SELECT
|
||
c.conversation_id,
|
||
c.user_id,
|
||
c.app_id,
|
||
a.area,
|
||
COALESCE(NULLIF(BTRIM(a.tenant_code), ''), NULL) AS tenant_code,
|
||
COALESCE(d.is_public, FALSE) AS dataset_public
|
||
FROM rag_conversation c
|
||
LEFT JOIN rag_chat_app a ON a.id = c.app_id AND a.deleted_at IS NULL
|
||
LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL
|
||
WHERE c.conversation_id = :conversation_id
|
||
AND c.deleted_at IS NULL
|
||
LIMIT 1
|
||
"""
|
||
),
|
||
{"conversation_id": conversation_id},
|
||
)
|
||
).mappings().first()
|
||
if not row:
|
||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "会话不存在")
|
||
if int(row["user_id"]) != expected_user_id:
|
||
return False
|
||
if app_id is not None and row.get("app_id") is not None and int(row["app_id"]) != int(app_id):
|
||
return False
|
||
app_row = {
|
||
"tenant_code": row.get("tenant_code"),
|
||
"area": row.get("area"),
|
||
"dataset_public": row.get("dataset_public"),
|
||
}
|
||
return await self._app_visible(app_row, tenant_context=tenant_context, user_role=user_role)
|
||
|
||
async def _retrieve_context(self, dataset_id: int | None, query: str) -> tuple[list[dict], str]:
|
||
if not dataset_id:
|
||
return [], ""
|
||
async with GetAsyncSession() as session:
|
||
dataset = (
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
SELECT id, name, collection_name, retrieval_model, embedding_model
|
||
FROM rag_dataset
|
||
WHERE id = :dataset_id AND deleted_at IS NULL
|
||
LIMIT 1
|
||
"""
|
||
),
|
||
{"dataset_id": dataset_id},
|
||
)
|
||
).mappings().first()
|
||
if not dataset:
|
||
return [], ""
|
||
retrieval_model = dataset.get("retrieval_model") or {}
|
||
top_k = int(retrieval_model.get("top_k") or 5)
|
||
score_threshold = None
|
||
if retrieval_model.get("score_threshold_enabled"):
|
||
try:
|
||
score_threshold = float(retrieval_model.get("score_threshold"))
|
||
except (TypeError, ValueError):
|
||
score_threshold = None
|
||
try:
|
||
query_embedding = await self._embed_texts([query], dataset.get("embedding_model") or "")
|
||
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
||
result = collection.query(
|
||
query_embeddings=query_embedding,
|
||
n_results=max(top_k, 1),
|
||
include=["documents", "metadatas", "distances"],
|
||
)
|
||
ids = (result.get("ids") or [[]])[0] if result.get("ids") else []
|
||
docs = (result.get("documents") or [[]])[0]
|
||
metas = (result.get("metadatas") or [[]])[0]
|
||
distances = (result.get("distances") or [[]])[0]
|
||
chunks: list[dict] = []
|
||
for idx, doc in enumerate(docs):
|
||
meta = metas[idx] if idx < len(metas) else {}
|
||
dist = float(distances[idx]) if idx < len(distances) and distances[idx] is not None else 1.0
|
||
score = 1.0 / (1.0 + max(dist, 0.0))
|
||
if score_threshold is not None and score < score_threshold:
|
||
continue
|
||
chunks.append(
|
||
{
|
||
"id": str(ids[idx] if idx < len(ids) else meta.get("id") or idx),
|
||
"text": doc,
|
||
"source": meta.get("source") or meta.get("document_name") or dataset.get("name") or "",
|
||
"score": score,
|
||
"chunk_index": int(meta.get("chunk_index") or idx),
|
||
"document_name": meta.get("document_name") or meta.get("source") or "",
|
||
"document_id": meta.get("document_id"),
|
||
"page": meta.get("page"),
|
||
}
|
||
)
|
||
chunks = await self._hydrate_document_hits(dataset_id, chunks)
|
||
if chunks:
|
||
return chunks[:top_k], dataset.get("name") or ""
|
||
except Exception:
|
||
pass
|
||
|
||
try:
|
||
chunks = await self._keyword_retrieve_context(
|
||
dataset_id=dataset_id,
|
||
collection_name=str(dataset["collection_name"]),
|
||
dataset_name=str(dataset.get("name") or ""),
|
||
query=query,
|
||
top_k=top_k,
|
||
score_threshold=score_threshold,
|
||
)
|
||
return chunks[:top_k], dataset.get("name") or ""
|
||
except Exception:
|
||
return [], dataset.get("name") or ""
|
||
|
||
async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]:
|
||
embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or build_openai_embeddings_url(RAG_CONFIG["LLM_BASE_URL"])
|
||
embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"]
|
||
embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4"
|
||
batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10))
|
||
if not embed_url or not embed_key:
|
||
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务")
|
||
|
||
embeddings: list[list[float]] = []
|
||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||
for start in range(0, len(texts), batch_size):
|
||
batch_texts = texts[start:start + batch_size]
|
||
try:
|
||
response = await client.post(
|
||
embed_url,
|
||
headers={
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {embed_key}",
|
||
},
|
||
json={"model": embed_model, "input": batch_texts},
|
||
)
|
||
response.raise_for_status()
|
||
except httpx.HTTPStatusError as exc:
|
||
error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}"
|
||
raise LeauditException(
|
||
StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR,
|
||
f"向量化服务调用失败: {error_message[:300]}",
|
||
) from exc
|
||
|
||
payload = response.json()
|
||
rows = payload.get("data") or []
|
||
batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")]
|
||
if len(batch_embeddings) != len(batch_texts):
|
||
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
|
||
embeddings.extend(batch_embeddings)
|
||
|
||
if len(embeddings) != len(texts):
|
||
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
|
||
return embeddings
|
||
|
||
async def _start_message_task(
|
||
self,
|
||
*,
|
||
task_id: str,
|
||
conversation_id: str,
|
||
message_id: str,
|
||
query: str,
|
||
app: dict,
|
||
) -> None:
|
||
self._task_events[task_id] = []
|
||
self._task_done[task_id] = False
|
||
self._task_locks.setdefault(task_id, asyncio.Lock())
|
||
task = asyncio.create_task(
|
||
self._run_message_task(
|
||
task_id=task_id,
|
||
conversation_id=conversation_id,
|
||
message_id=message_id,
|
||
query=query,
|
||
app=app,
|
||
)
|
||
)
|
||
self._message_tasks[task_id] = task
|
||
|
||
async def _run_message_task(
|
||
self,
|
||
*,
|
||
task_id: str,
|
||
conversation_id: str,
|
||
message_id: str,
|
||
query: str,
|
||
app: dict,
|
||
) -> None:
|
||
context_chunks: list[dict] = []
|
||
dataset_name = ""
|
||
collected_answer = ""
|
||
message_end_payload: dict | None = None
|
||
final_status = "completed"
|
||
error_payload: dict | None = None
|
||
last_persisted_length = 0
|
||
last_persisted_at = time.monotonic()
|
||
|
||
try:
|
||
context_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), query)
|
||
async for chunk in generate_stream(
|
||
query=query,
|
||
context_chunks=context_chunks,
|
||
conversation_id=conversation_id,
|
||
message_id=message_id,
|
||
system_prompt=app.get("system_prompt") or "",
|
||
model=app.get("llm_model") or "",
|
||
temperature=app.get("temperature"),
|
||
max_tokens=app.get("max_tokens"),
|
||
dataset_name=dataset_name,
|
||
task_id=task_id,
|
||
):
|
||
data = self._parse_sse_event(chunk)
|
||
if not data:
|
||
continue
|
||
|
||
event = data.get("event")
|
||
if event == "message":
|
||
collected_answer += data.get("answer", "")
|
||
now = time.monotonic()
|
||
if len(collected_answer) - last_persisted_length >= 80 or now - last_persisted_at >= 0.5:
|
||
await self._update_message_progress(
|
||
conversation_id=conversation_id,
|
||
message_id=message_id,
|
||
content=collected_answer,
|
||
metadata={"status": "running", "task_id": task_id},
|
||
)
|
||
last_persisted_length = len(collected_answer)
|
||
last_persisted_at = now
|
||
await self._append_task_event(task_id, data)
|
||
continue
|
||
if event == "message_end":
|
||
message_end_payload = data
|
||
continue
|
||
if event == "error":
|
||
final_status = "error"
|
||
error_payload = data
|
||
await self._append_task_event(task_id, data)
|
||
continue
|
||
|
||
await self._append_task_event(task_id, data)
|
||
|
||
if final_status == "completed":
|
||
followups: list[str] = []
|
||
try:
|
||
followups = await generate_followups(query, collected_answer)
|
||
except Exception:
|
||
followups = []
|
||
|
||
if message_end_payload:
|
||
message_end_payload.setdefault("metadata", {})["suggested_questions"] = followups
|
||
await self._append_task_event(task_id, message_end_payload)
|
||
await self._finalize_message_record(
|
||
conversation_id=conversation_id,
|
||
message_id=message_id,
|
||
content=collected_answer,
|
||
sources=self._build_sources(context_chunks, dataset_name),
|
||
metadata={"suggested_questions": followups, "status": "completed", "task_id": task_id},
|
||
)
|
||
await self._maybe_schedule_auto_title(
|
||
conversation_id=conversation_id,
|
||
message_id=message_id,
|
||
query=query,
|
||
answer=collected_answer,
|
||
)
|
||
else:
|
||
await self._finalize_message_record(
|
||
conversation_id=conversation_id,
|
||
message_id=message_id,
|
||
content=collected_answer,
|
||
sources=self._build_sources(context_chunks, dataset_name),
|
||
metadata={
|
||
"suggested_questions": [],
|
||
"status": "error",
|
||
"task_id": task_id,
|
||
"error": (error_payload or {}).get("message", ""),
|
||
},
|
||
)
|
||
except asyncio.CancelledError:
|
||
final_status = "stopped"
|
||
await self._append_task_event(
|
||
task_id,
|
||
{
|
||
"event": "error",
|
||
"task_id": task_id,
|
||
"message_id": message_id,
|
||
"code": "message_stopped",
|
||
"message": "用户已停止回答",
|
||
},
|
||
)
|
||
await self._finalize_message_record(
|
||
conversation_id=conversation_id,
|
||
message_id=message_id,
|
||
content=collected_answer,
|
||
sources=self._build_sources(context_chunks, dataset_name),
|
||
metadata={"suggested_questions": [], "status": "stopped", "task_id": task_id},
|
||
)
|
||
raise
|
||
except Exception as exc:
|
||
final_status = "error"
|
||
await self._append_task_event(
|
||
task_id,
|
||
{
|
||
"event": "error",
|
||
"task_id": task_id,
|
||
"message_id": message_id,
|
||
"code": "server_error",
|
||
"message": str(exc),
|
||
},
|
||
)
|
||
await self._finalize_message_record(
|
||
conversation_id=conversation_id,
|
||
message_id=message_id,
|
||
content=collected_answer,
|
||
sources=self._build_sources(context_chunks, dataset_name),
|
||
metadata={"suggested_questions": [], "status": "error", "task_id": task_id, "error": str(exc)},
|
||
)
|
||
finally:
|
||
self._task_done[task_id] = True
|
||
self._message_tasks.pop(task_id, None)
|
||
self._task_locks.pop(task_id, None)
|
||
|
||
async def _append_task_event(self, task_id: str, payload: dict) -> None:
|
||
lock = self._task_locks.setdefault(task_id, asyncio.Lock())
|
||
async with lock:
|
||
self._task_events.setdefault(task_id, []).append(payload)
|
||
|
||
async def _finalize_message_record(
|
||
self,
|
||
*,
|
||
conversation_id: str,
|
||
message_id: str,
|
||
content: str,
|
||
sources: list[dict],
|
||
metadata: dict,
|
||
) -> None:
|
||
async with GetAsyncSession() as session:
|
||
async with session.begin():
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
UPDATE rag_message
|
||
SET content = :content,
|
||
sources = CAST(:sources AS jsonb),
|
||
metadata = CAST(:metadata AS jsonb)
|
||
WHERE message_id = :message_id
|
||
AND conversation_id = :conversation_id
|
||
AND role = 'assistant'
|
||
"""
|
||
),
|
||
{
|
||
"conversation_id": conversation_id,
|
||
"message_id": message_id,
|
||
"content": content,
|
||
"sources": json.dumps(sources, ensure_ascii=False),
|
||
"metadata": json.dumps(metadata, ensure_ascii=False),
|
||
},
|
||
)
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
UPDATE rag_conversation
|
||
SET updated_at = NOW(),
|
||
last_message_at = NOW()
|
||
WHERE conversation_id = :conversation_id
|
||
"""
|
||
),
|
||
{"conversation_id": conversation_id},
|
||
)
|
||
|
||
async def _update_message_progress(
|
||
self,
|
||
*,
|
||
conversation_id: str,
|
||
message_id: str,
|
||
content: str,
|
||
metadata: dict,
|
||
) -> None:
|
||
async with GetAsyncSession() as session:
|
||
async with session.begin():
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
UPDATE rag_message
|
||
SET content = :content,
|
||
metadata = CAST(:metadata AS jsonb)
|
||
WHERE message_id = :message_id
|
||
AND conversation_id = :conversation_id
|
||
AND role = 'assistant'
|
||
"""
|
||
),
|
||
{
|
||
"conversation_id": conversation_id,
|
||
"message_id": message_id,
|
||
"content": content,
|
||
"metadata": json.dumps(metadata, ensure_ascii=False),
|
||
},
|
||
)
|
||
|
||
async def _maybe_schedule_auto_title(
|
||
self,
|
||
*,
|
||
conversation_id: str,
|
||
message_id: str,
|
||
query: str,
|
||
answer: str,
|
||
) -> None:
|
||
normalized_query = (query or "").strip()
|
||
normalized_answer = (answer or "").strip()
|
||
if not normalized_query or not normalized_answer:
|
||
return
|
||
|
||
async with GetAsyncSession() as session:
|
||
row = (
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
SELECT conversation_id,
|
||
name,
|
||
COALESCE(title_source, 'default') AS title_source,
|
||
COALESCE(title_generation_status, 'idle') AS title_generation_status,
|
||
first_answer_message_id
|
||
FROM rag_conversation
|
||
WHERE conversation_id = :conversation_id
|
||
AND deleted_at IS NULL
|
||
LIMIT 1
|
||
"""
|
||
),
|
||
{"conversation_id": conversation_id},
|
||
)
|
||
).mappings().first()
|
||
|
||
if not row:
|
||
return
|
||
|
||
title_source = str(row.get("title_source") or DEFAULT_TITLE_SOURCE)
|
||
if title_source == MANUAL_TITLE_SOURCE:
|
||
return
|
||
|
||
current_name = str(row.get("name") or "").strip()
|
||
if current_name and current_name != DEFAULT_CONVERSATION_NAME and title_source != DEFAULT_TITLE_SOURCE:
|
||
return
|
||
|
||
current_status = str(row.get("title_generation_status") or "idle")
|
||
if current_status in {"pending", "running", "succeeded"}:
|
||
return
|
||
|
||
if row.get("first_answer_message_id") and str(row.get("first_answer_message_id")) != message_id:
|
||
return
|
||
|
||
async with GetAsyncSession() as session:
|
||
async with session.begin():
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
UPDATE rag_conversation
|
||
SET title_generation_status = 'pending',
|
||
first_answer_message_id = COALESCE(first_answer_message_id, :message_id),
|
||
title_generation_error = NULL
|
||
WHERE conversation_id = :conversation_id
|
||
AND deleted_at IS NULL
|
||
AND COALESCE(title_source, 'default') = 'default'
|
||
"""
|
||
),
|
||
{
|
||
"conversation_id": conversation_id,
|
||
"message_id": message_id,
|
||
},
|
||
)
|
||
|
||
existing_task = self._title_tasks.get(conversation_id)
|
||
if existing_task and not existing_task.done():
|
||
return
|
||
|
||
task = asyncio.create_task(
|
||
self._run_auto_title_task(
|
||
conversation_id=conversation_id,
|
||
answer_message_id=message_id,
|
||
query=normalized_query,
|
||
answer=normalized_answer,
|
||
)
|
||
)
|
||
self._title_tasks[conversation_id] = task
|
||
|
||
async def _run_auto_title_task(
|
||
self,
|
||
*,
|
||
conversation_id: str,
|
||
answer_message_id: str,
|
||
query: str,
|
||
answer: str,
|
||
) -> None:
|
||
try:
|
||
async with GetAsyncSession() as session:
|
||
row = (
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
SELECT name,
|
||
COALESCE(title_source, 'default') AS title_source,
|
||
COALESCE(title_generation_status, 'idle') AS title_generation_status,
|
||
first_answer_message_id
|
||
FROM rag_conversation
|
||
WHERE conversation_id = :conversation_id
|
||
AND deleted_at IS NULL
|
||
LIMIT 1
|
||
"""
|
||
),
|
||
{"conversation_id": conversation_id},
|
||
)
|
||
).mappings().first()
|
||
|
||
if not row:
|
||
return
|
||
|
||
if str(row.get("title_source") or DEFAULT_TITLE_SOURCE) == MANUAL_TITLE_SOURCE:
|
||
return
|
||
|
||
if row.get("first_answer_message_id") and str(row.get("first_answer_message_id")) != answer_message_id:
|
||
return
|
||
|
||
async with GetAsyncSession() as session:
|
||
async with session.begin():
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
UPDATE rag_conversation
|
||
SET title_generation_status = 'running',
|
||
title_generation_error = NULL
|
||
WHERE conversation_id = :conversation_id
|
||
AND deleted_at IS NULL
|
||
AND COALESCE(title_source, 'default') = 'default'
|
||
"""
|
||
),
|
||
{"conversation_id": conversation_id},
|
||
)
|
||
|
||
generated_title = await self._generate_conversation_title(query=query, answer=answer)
|
||
cleaned_title = self._sanitize_generated_title(generated_title)
|
||
if not cleaned_title:
|
||
cleaned_title = self._build_fallback_title(query=query, answer=answer)
|
||
|
||
if not cleaned_title:
|
||
raise ValueError("未生成有效标题")
|
||
|
||
async with GetAsyncSession() as session:
|
||
async with session.begin():
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
UPDATE rag_conversation
|
||
SET name = :name,
|
||
title_source = 'auto',
|
||
title_generation_status = 'succeeded',
|
||
title_generated_at = NOW(),
|
||
title_generation_error = NULL
|
||
WHERE conversation_id = :conversation_id
|
||
AND deleted_at IS NULL
|
||
AND COALESCE(title_source, 'default') = 'default'
|
||
AND (
|
||
name IS NULL
|
||
OR BTRIM(name) = ''
|
||
OR name = :default_name
|
||
)
|
||
"""
|
||
),
|
||
{
|
||
"conversation_id": conversation_id,
|
||
"name": cleaned_title,
|
||
"default_name": DEFAULT_CONVERSATION_NAME,
|
||
},
|
||
)
|
||
except Exception as exc:
|
||
async with GetAsyncSession() as session:
|
||
async with session.begin():
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
UPDATE rag_conversation
|
||
SET title_generation_status = 'failed',
|
||
title_generation_error = :error
|
||
WHERE conversation_id = :conversation_id
|
||
AND deleted_at IS NULL
|
||
AND COALESCE(title_source, 'default') = 'default'
|
||
"""
|
||
),
|
||
{
|
||
"conversation_id": conversation_id,
|
||
"error": str(exc)[:1000],
|
||
},
|
||
)
|
||
finally:
|
||
self._title_tasks.pop(conversation_id, None)
|
||
|
||
async def _generate_conversation_title(self, *, query: str, answer: str) -> str:
|
||
prompt = (
|
||
"请基于用户首轮提问和助手首轮回答,生成一个简洁、准确的中文会话标题。"
|
||
"要求:"
|
||
"1. 只输出标题本身;"
|
||
"2. 不要标点结尾;"
|
||
"3. 不要出现“关于”“用户询问”“问题解答”等空话;"
|
||
"4. 优先 12-24 个中文字符,最长不超过 40 个字符。\\n"
|
||
f"用户问题:{query[:500]}\\n"
|
||
f"助手回答:{answer[:1500]}"
|
||
)
|
||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||
resp = await client.post(
|
||
build_openai_chat_completions_url(RAG_CONFIG["LLM_BASE_URL"]),
|
||
json={
|
||
"model": RAG_CONFIG["LLM_MODEL"],
|
||
"messages": [{"role": "user", "content": prompt}],
|
||
"temperature": 0.2,
|
||
"max_tokens": 80,
|
||
"chat_template_kwargs": {"enable_thinking": False},
|
||
},
|
||
headers={
|
||
"Content-Type": "application/json",
|
||
"Authorization": f"Bearer {RAG_CONFIG['LLM_API_KEY']}",
|
||
},
|
||
)
|
||
resp.raise_for_status()
|
||
content = resp.json()["choices"][0]["message"]["content"]
|
||
return str(content or "").strip()
|
||
|
||
def _sanitize_generated_title(self, value: str) -> str:
|
||
title = str(value or "").strip()
|
||
if not title:
|
||
return ""
|
||
|
||
title = title.replace("\r", " ").replace("\n", " ").strip()
|
||
title = re.sub(r"^```[\w-]*", "", title).strip()
|
||
title = title.replace("```", "").strip()
|
||
title = re.sub(r'^[\"“”\'‘’\[\]()()【】]+', "", title)
|
||
title = re.sub(r'[\"“”\'‘’\[\]()()【】]+$', "", title)
|
||
title = re.sub(r"^(标题|会话标题)[::]\\s*", "", title)
|
||
title = re.sub(r"\\s+", " ", title).strip()
|
||
title = title.rstrip("。!?!?.;;,,::")
|
||
|
||
if title in {"新对话", "会话标题", "标题"}:
|
||
return ""
|
||
|
||
if len(title) > 40:
|
||
title = title[:40].rstrip(",,。;;:: ")
|
||
|
||
return title
|
||
|
||
def _build_fallback_title(self, *, query: str, answer: str) -> str:
|
||
base = query.strip() or answer.strip()
|
||
if not base:
|
||
return ""
|
||
|
||
base = re.sub(r"\\s+", " ", base)
|
||
base = re.sub(r"^[请问帮我一下关于针对结合根据请解释说明一下::,,\\s]+", "", base)
|
||
base = base.rstrip("。!?!?.;;,,::")
|
||
if len(base) > 24:
|
||
base = base[:24].rstrip(",,。;;:: ")
|
||
return base
|
||
|
||
async def _keyword_retrieve_context(
|
||
self,
|
||
*,
|
||
dataset_id: int,
|
||
collection_name: str,
|
||
dataset_name: str,
|
||
query: str,
|
||
top_k: int,
|
||
score_threshold: float | None,
|
||
) -> list[dict]:
|
||
collection = get_chroma().get_or_create_collection(collection_name)
|
||
raw = collection.get(include=["documents", "metadatas"])
|
||
ids = raw.get("ids") or []
|
||
docs = raw.get("documents") or []
|
||
metas = raw.get("metadatas") or []
|
||
|
||
terms = self._build_keyword_terms(query)
|
||
if not terms:
|
||
return []
|
||
|
||
scored_chunks: list[dict] = []
|
||
for idx, chunk_id in enumerate(ids):
|
||
doc = docs[idx] if idx < len(docs) else ""
|
||
meta = metas[idx] if idx < len(metas) and isinstance(metas[idx], dict) else {}
|
||
score = self._score_keyword_chunk(
|
||
query=query,
|
||
terms=terms,
|
||
content=doc or "",
|
||
document_name=str(meta.get("document_name") or meta.get("source") or ""),
|
||
)
|
||
if score <= 0:
|
||
continue
|
||
if score_threshold is not None and score < score_threshold:
|
||
continue
|
||
scored_chunks.append(
|
||
{
|
||
"id": str(chunk_id),
|
||
"text": doc or "",
|
||
"source": meta.get("source") or meta.get("document_name") or dataset_name,
|
||
"score": score,
|
||
"chunk_index": int(meta.get("chunk_index") or idx),
|
||
"document_name": meta.get("document_name") or meta.get("source") or "",
|
||
"document_id": meta.get("document_id"),
|
||
"page": meta.get("page"),
|
||
}
|
||
)
|
||
|
||
scored_chunks.sort(key=lambda item: (-float(item.get("score") or 0.0), int(item.get("chunk_index") or 0)))
|
||
hydrated = await self._hydrate_document_hits(dataset_id, scored_chunks[: max(top_k * 3, top_k)])
|
||
return hydrated[:top_k]
|
||
|
||
def _build_keyword_terms(self, query: str) -> list[str]:
|
||
normalized = self._normalize_keyword_query(query)
|
||
spans = [item.strip() for item in re.findall(r"[\u4e00-\u9fffA-Za-z0-9]+", normalized) if item.strip()]
|
||
if not spans:
|
||
return []
|
||
|
||
stop_terms = {
|
||
"什么",
|
||
"请问",
|
||
"一下",
|
||
"有关",
|
||
"关于",
|
||
"如何",
|
||
"哪些",
|
||
"怎么",
|
||
"是否",
|
||
"规定",
|
||
"办法",
|
||
"条例",
|
||
"法律",
|
||
}
|
||
terms: list[str] = []
|
||
for span in spans:
|
||
if span in stop_terms:
|
||
continue
|
||
terms.append(span)
|
||
if re.fullmatch(r"[\u4e00-\u9fff]+", span):
|
||
for size in (2, 3, 4):
|
||
if len(span) > size:
|
||
for start in range(0, len(span) - size + 1):
|
||
token = span[start:start + size]
|
||
if token not in stop_terms:
|
||
terms.append(token)
|
||
|
||
unique_terms: list[str] = []
|
||
seen: set[str] = set()
|
||
for term in sorted(terms, key=len, reverse=True):
|
||
if term and term not in seen:
|
||
unique_terms.append(term)
|
||
seen.add(term)
|
||
return unique_terms[:20]
|
||
|
||
def _normalize_keyword_query(self, query: str) -> str:
|
||
normalized = (query or "").strip().lower()
|
||
patterns = [
|
||
"是什么",
|
||
"什么是",
|
||
"有哪些",
|
||
"有什么",
|
||
"是什么?",
|
||
"是什么?",
|
||
"请问",
|
||
"介绍一下",
|
||
"解释一下",
|
||
"帮我分析",
|
||
"帮我看看",
|
||
]
|
||
for pattern in patterns:
|
||
normalized = normalized.replace(pattern, " ")
|
||
return re.sub(r"\s+", " ", normalized).strip()
|
||
|
||
def _score_keyword_chunk(self, *, query: str, terms: list[str], content: str, document_name: str) -> float:
|
||
haystack = f"{document_name}\n{content}".lower()
|
||
if not haystack:
|
||
return 0.0
|
||
|
||
exact_query = self._normalize_keyword_query(query)
|
||
if exact_query and exact_query in haystack:
|
||
return 0.98
|
||
|
||
matched_weight = 0.0
|
||
total_weight = 0.0
|
||
name_bonus = 0.0
|
||
for term in terms:
|
||
weight = float(max(len(term), 1) ** 2)
|
||
total_weight += weight
|
||
if term.lower() in haystack:
|
||
matched_weight += weight
|
||
if term.lower() in document_name.lower():
|
||
name_bonus += min(0.15, 0.03 * len(term))
|
||
|
||
if total_weight <= 0:
|
||
return 0.0
|
||
score = (matched_weight / total_weight) + name_bonus
|
||
return round(min(score, 0.99), 6)
|
||
|
||
def _format_sse(self, payload: dict) -> bytes:
|
||
return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n".encode("utf-8")
|
||
|
||
def _build_sources(self, context_chunks: list[dict], dataset_name: str) -> list[dict]:
|
||
return [
|
||
{
|
||
"position": index + 1,
|
||
"dataset_id": str(chunk.get("dataset_id") or ""),
|
||
"dataset_name": dataset_name,
|
||
"document_id": str(chunk.get("document_id") or ""),
|
||
"document_name": chunk.get("document_name") or chunk.get("source", ""),
|
||
"data_source_type": "upload_file",
|
||
"segment_id": chunk.get("id", ""),
|
||
"retriever_from": "rag",
|
||
"score": round(chunk.get("score", 0.0), 4),
|
||
"hit_count": chunk.get("hit_count", 0),
|
||
"word_count": len(chunk.get("text", "")),
|
||
"segment_position": index + 1,
|
||
"index_node_hash": "",
|
||
"content": chunk.get("text", "")[:500],
|
||
"page": None,
|
||
}
|
||
for index, chunk in enumerate(context_chunks)
|
||
]
|
||
|
||
async def _hydrate_document_hits(self, dataset_id: int, chunks: list[dict]) -> list[dict]:
|
||
source_names = sorted(
|
||
{
|
||
str(chunk.get("document_name") or chunk.get("source") or "").strip()
|
||
for chunk in chunks
|
||
if str(chunk.get("document_name") or chunk.get("source") or "").strip()
|
||
}
|
||
)
|
||
if not source_names:
|
||
return chunks
|
||
|
||
async with GetAsyncSession() as session:
|
||
rows = (
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
SELECT id, original_name, enabled, hit_count
|
||
FROM rag_document
|
||
WHERE dataset_id = :dataset_id
|
||
AND deleted_at IS NULL
|
||
AND original_name = ANY(:source_names)
|
||
"""
|
||
),
|
||
{
|
||
"dataset_id": dataset_id,
|
||
"source_names": source_names,
|
||
},
|
||
)
|
||
).mappings().all()
|
||
|
||
document_map = {str(row["original_name"]): row for row in rows}
|
||
visible_chunks: list[dict] = []
|
||
hit_document_ids: list[int] = []
|
||
for chunk in chunks:
|
||
source_name = str(chunk.get("document_name") or chunk.get("source") or "").strip()
|
||
document = document_map.get(source_name)
|
||
if document and not bool(document.get("enabled")):
|
||
continue
|
||
if document:
|
||
chunk["document_id"] = document["id"]
|
||
chunk["dataset_id"] = dataset_id
|
||
chunk["document_name"] = document["original_name"]
|
||
chunk["hit_count"] = document.get("hit_count") or 0
|
||
hit_document_ids.append(int(document["id"]))
|
||
visible_chunks.append(chunk)
|
||
|
||
if hit_document_ids:
|
||
async with GetAsyncSession() as session:
|
||
async with session.begin():
|
||
await session.execute(
|
||
text(
|
||
"""
|
||
UPDATE rag_document
|
||
SET hit_count = hit_count + 1,
|
||
updated_at = NOW()
|
||
WHERE id = ANY(:document_ids)
|
||
"""
|
||
),
|
||
{"document_ids": sorted(set(hit_document_ids))},
|
||
)
|
||
|
||
return visible_chunks
|
||
|
||
def _parse_sse_event(self, chunk: str) -> dict | None:
|
||
data_lines: list[str] = []
|
||
for line in chunk.splitlines():
|
||
if line.startswith("data: "):
|
||
data_lines.append(line[6:])
|
||
elif line.startswith("data:"):
|
||
data_lines.append(line[5:].lstrip())
|
||
|
||
if not data_lines:
|
||
return None
|
||
|
||
payload = "\n".join(part for part in data_lines if part.strip()).strip()
|
||
if not payload or payload == "[DONE]":
|
||
return None
|
||
payload = payload.removesuffix("\\n\\n").removesuffix("\\n").strip()
|
||
|
||
try:
|
||
data = json.loads(payload)
|
||
except json.JSONDecodeError:
|
||
return None
|
||
|
||
return data if isinstance(data, dict) else None
|