Files
leaudit-platform-backend/fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py
T

1802 lines
71 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
from __future__ import annotations
import asyncio
import json
import re
import time
import uuid
from typing import AsyncGenerator
import httpx
from sqlalchemy import text
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
from fastapi_modules.fastapi_leaudit.domian.Dto.ragChatDto import (
RagConversationRenameDTO,
RagMessageFeedbackDTO,
RagStopMessageDTO,
)
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import (
RagAppParametersVO,
RagChatAppListVO,
RagChatAppVO,
RagConversationItemVO,
RagConversationPageVO,
RagConversationRenameVO,
RagMessageItemVO,
RagMessagePageVO,
RagOperationResultVO,
)
from fastapi_modules.fastapi_leaudit.rag_engine.config import (
RAG_CONFIG,
build_openai_chat_completions_url,
build_openai_embeddings_url,
)
from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma
from fastapi_modules.fastapi_leaudit.rag_engine.generator import generate_stream
from fastapi_modules.fastapi_leaudit.rag_engine.question_chains import generate_followups
from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService
from fastapi_modules.fastapi_leaudit.services.impl.tenantResolver import TenantResolver
DEFAULT_CONVERSATION_NAME = "新对话"
DEFAULT_TITLE_SOURCE = "default"
AUTO_TITLE_SOURCE = "auto"
MANUAL_TITLE_SOURCE = "manual"
class RagChatServiceImpl(IRagChatService):
_message_tasks: dict[str, asyncio.Task] = {}
_task_events: dict[str, list[dict]] = {}
_task_done: dict[str, bool] = {}
_task_locks: dict[str, asyncio.Lock] = {}
_title_tasks: dict[str, asyncio.Task] = {}
_chat_schema_checked = False
_chat_schema_lock = asyncio.Lock()
def __init__(self) -> None:
self.TenantResolver = TenantResolver()
_APP_TENANT_NAME_SQL = (
"CASE "
"WHEN NULLIF(BTRIM(a.tenant_code), '') = 'PUBLIC' THEN '公共' "
"WHEN NULLIF(BTRIM(a.tenant_code), '') = 'PROVINCIAL' THEN '省级' "
"ELSE COALESCE(NULLIF(BTRIM(a.area), ''), '未分配地区') "
"END"
)
async def GetApps(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> RagChatAppListVO:
apps = await self._load_apps(UserArea, UserRole, TenantCode, TenantName, only_default=False)
return RagChatAppListVO(data=apps, total=len(apps))
async def GetDefaultApp(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> RagChatAppVO | None:
apps = await self._load_apps(UserArea, UserRole, TenantCode, TenantName, only_default=True)
if apps:
return apps[0]
all_apps = await self._load_apps(UserArea, UserRole, TenantCode, TenantName, only_default=False)
return all_apps[0] if all_apps else None
async def SendMessage(
self,
CurrentUserId: int,
UserName: str,
UserArea: str | None,
UserRole: str | None,
Query: str,
ConversationId: str | None,
AppId: int | None,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> AsyncGenerator[bytes, None]:
if not Query.strip():
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "问题不能为空")
app = await self._resolve_app(AppId, UserArea, UserRole, TenantCode, TenantName)
if not app:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "未配置可用聊天应用")
conversationId = await self._ensure_conversation(
user_id=CurrentUserId,
conversation_id=ConversationId,
app_id=app["id"],
user_area=UserArea,
user_role=UserRole,
tenant_code=TenantCode,
tenant_name=TenantName,
)
messageId = str(uuid.uuid4())
taskId = str(uuid.uuid4())
is_new_conversation = not ConversationId or ConversationId == "-1"
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata)
VALUES (:message_id, :conversation_id, 'user', :content, '[]'::jsonb, '{}'::jsonb)
"""
),
{
"message_id": str(uuid.uuid4()),
"conversation_id": conversationId,
"content": Query,
},
)
await session.execute(
text(
"""
INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata)
VALUES (:message_id, :conversation_id, 'assistant', '', '[]'::jsonb, CAST(:metadata AS jsonb))
"""
),
{
"message_id": messageId,
"conversation_id": conversationId,
"metadata": json.dumps({"status": "running", "task_id": taskId}, ensure_ascii=False),
},
)
await session.execute(
text(
"UPDATE rag_conversation SET updated_at = NOW() WHERE conversation_id = :conversation_id"
),
{"conversation_id": conversationId},
)
await self._start_message_task(
task_id=taskId,
conversation_id=conversationId,
message_id=messageId,
query=Query,
app=app,
)
event_index = 0
initial_events: list[dict] = []
if is_new_conversation:
initial_events.append(
{
"event": "conversation_created",
"conversation_id": conversationId,
"message_id": messageId,
"task_id": taskId,
}
)
while True:
if event_index < len(initial_events):
payload = initial_events[event_index]
event_index += 1
yield self._format_sse(payload)
continue
events = self._task_events.get(taskId, [])
if event_index - len(initial_events) < len(events):
payload = events[event_index - len(initial_events)]
event_index += 1
yield self._format_sse(payload)
continue
if self._task_done.get(taskId):
break
await asyncio.sleep(0.05)
async def GetConversations(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
AppId: int | None,
Page: int,
PageSize: int,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> RagConversationPageVO:
tenant_context = await self._resolve_tenant_context(UserArea, TenantCode, TenantName)
async with GetAsyncSession() as session:
rows = (
await session.execute(
text(
"""
SELECT conversation_id, name, introduction, created_at, updated_at
, COALESCE(title_source, 'default') AS title_source
, COALESCE(EXTRACT(EPOCH FROM last_message_at), 0) AS last_message_at
FROM rag_conversation
WHERE user_id = :user_id
AND deleted_at IS NULL
AND (CAST(:app_id AS BIGINT) IS NULL OR app_id = CAST(:app_id AS BIGINT))
ORDER BY COALESCE(last_message_at, updated_at) DESC, updated_at DESC
OFFSET :offset LIMIT :limit
"""
),
{
"user_id": CurrentUserId,
"app_id": AppId,
"offset": max(Page - 1, 0) * PageSize,
"limit": PageSize + 1,
},
)
).mappings().all()
filtered_rows: list[dict] = []
for row in rows:
record = dict(row)
if await self._conversation_accessible(
conversation_id=str(record["conversation_id"]),
expected_user_id=CurrentUserId,
tenant_context=tenant_context,
user_role=UserRole,
app_id=AppId,
session=session,
):
filtered_rows.append(record)
has_more = len(filtered_rows) > PageSize
items = filtered_rows[:PageSize]
return RagConversationPageVO(
data=[
RagConversationItemVO(
id=row["conversation_id"],
name=row["name"],
introduction=row.get("introduction") or "",
titleSource=str(row.get("title_source") or DEFAULT_TITLE_SOURCE),
createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0,
updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0,
lastMessageAt=int(float(row.get("last_message_at") or 0)),
)
for row in items
],
hasMore=has_more,
limit=PageSize,
)
async def GetConversationMessages(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
ConversationId: str,
Page: int,
PageSize: int,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> RagMessagePageVO:
await self._ensure_conversation_owner(
user_id=CurrentUserId,
conversation_id=ConversationId,
user_area=UserArea,
user_role=UserRole,
tenant_code=TenantCode,
tenant_name=TenantName,
)
async with GetAsyncSession() as session:
rows = (
await session.execute(
text(
"""
SELECT message_id, role, content, sources, metadata, feedback, created_at
FROM rag_message
WHERE conversation_id = :conversation_id
ORDER BY created_at ASC,
CASE role
WHEN 'user' THEN 0
WHEN 'assistant' THEN 1
ELSE 2
END ASC,
message_id ASC
OFFSET :offset LIMIT :limit
"""
),
{
"conversation_id": ConversationId,
"offset": max(Page - 1, 0) * PageSize,
"limit": PageSize + 1,
},
)
).mappings().all()
has_more = len(rows) > PageSize
items = rows[:PageSize]
data: list[RagMessageItemVO] = []
idx = 0
while idx < len(items):
row = items[idx]
if row["role"] == "user":
answer = items[idx + 1] if idx + 1 < len(items) and items[idx + 1]["role"] == "assistant" else None
answer_metadata = dict((answer.get("metadata") if answer else None) or {})
answer_status = str(answer_metadata.get("status") or ("completed" if answer else "running"))
answer_content = (answer.get("content") if answer else None) or ""
if answer:
task_id = str(answer_metadata.get("task_id") or "").strip()
reconstructed_content = self._rebuild_message_content_from_events(task_id) if task_id else ""
if reconstructed_content and len(reconstructed_content) >= len(answer_content):
if reconstructed_content != answer_content:
await self._update_message_progress(
conversation_id=ConversationId,
message_id=answer["message_id"],
content=reconstructed_content,
metadata=answer_metadata,
)
answer_content = reconstructed_content
normalized_status = await self._resolve_persisted_message_status(
conversation_id=ConversationId,
message_id=answer["message_id"],
content=answer_content,
metadata=answer_metadata,
)
if normalized_status != answer_status:
answer_status = normalized_status
answer_metadata["status"] = normalized_status
data.append(
RagMessageItemVO(
id=(answer["message_id"] if answer else row["message_id"]),
conversationId=ConversationId,
query=row["content"],
answer=answer_content if answer else "",
feedback=({"rating": answer["feedback"]} if answer and answer.get("feedback") else None),
retrieverResources=(answer.get("sources") if answer else None),
suggestedQuestions=[str(item) for item in (answer_metadata.get("suggested_questions") or []) if str(item).strip()],
status=answer_status,
createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0,
)
)
idx += 2 if answer else 1
else:
idx += 1
return RagMessagePageVO(data=data, hasMore=has_more, limit=PageSize)
async def _resolve_persisted_message_status(
self,
*,
conversation_id: str,
message_id: str,
content: str,
metadata: dict,
) -> str:
status = str(metadata.get("status") or "completed")
if status != "running":
return status
task_id = str(metadata.get("task_id") or "").strip()
task = self._message_tasks.get(task_id) if task_id else None
task_done = self._task_done.get(task_id, False) if task_id else False
if task and not task.done() and not task_done:
return "running"
normalized_status = "completed" if content.strip() else "error"
normalized_metadata = {
**metadata,
"status": normalized_status,
}
if normalized_status == "error" and not normalized_metadata.get("error"):
normalized_metadata["error"] = "生成任务已结束,但未产出有效回答"
await self._update_message_progress(
conversation_id=conversation_id,
message_id=message_id,
content=content,
metadata=normalized_metadata,
)
return normalized_status
def _rebuild_message_content_from_events(self, task_id: str) -> str:
if not task_id:
return ""
chunks: list[str] = []
for event in self._task_events.get(task_id, []):
if event.get("event") != "message":
continue
answer = event.get("answer")
if isinstance(answer, str) and answer:
chunks.append(answer)
return "".join(chunks)
async def RenameConversation(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
ConversationId: str,
Body: RagConversationRenameDTO,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> RagConversationRenameVO:
await self._ensure_conversation_owner(
user_id=CurrentUserId,
conversation_id=ConversationId,
user_area=UserArea,
user_role=UserRole,
tenant_code=TenantCode,
tenant_name=TenantName,
)
final_name = Body.name.strip()
if not final_name:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "会话名称不能为空")
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
UPDATE rag_conversation
SET name = :name,
title_source = 'manual',
title_generation_status = CASE
WHEN COALESCE(title_generation_status, 'idle') = 'running' THEN 'succeeded'
ELSE COALESCE(title_generation_status, 'idle')
END,
title_generation_error = NULL,
updated_at = NOW()
WHERE conversation_id = :conversation_id
"""
),
{"name": final_name, "conversation_id": ConversationId},
)
return RagConversationRenameVO(result="success", name=final_name)
async def DeleteConversation(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
ConversationId: str,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> RagOperationResultVO:
await self._ensure_conversation_owner(
user_id=CurrentUserId,
conversation_id=ConversationId,
user_area=UserArea,
user_role=UserRole,
tenant_code=TenantCode,
tenant_name=TenantName,
)
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"UPDATE rag_conversation SET deleted_at = NOW(), updated_at = NOW() WHERE conversation_id = :conversation_id"
),
{"conversation_id": ConversationId},
)
return RagOperationResultVO(result="success")
async def UpdateFeedback(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
MessageId: str,
Body: RagMessageFeedbackDTO,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> RagOperationResultVO:
tenant_context = await self._resolve_tenant_context(UserArea, TenantCode, TenantName)
async with GetAsyncSession() as session:
row = (
await session.execute(
text(
"""
SELECT c.user_id, c.conversation_id
FROM rag_message m
JOIN rag_conversation c ON c.conversation_id = m.conversation_id
WHERE m.message_id = :message_id AND c.deleted_at IS NULL
LIMIT 1
"""
),
{"message_id": MessageId},
)
).mappings().first()
if not row:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "消息不存在")
if not await self._conversation_accessible(
conversation_id=str(row["conversation_id"]),
expected_user_id=CurrentUserId,
tenant_context=tenant_context,
user_role=UserRole,
session=session,
):
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权修改该消息反馈")
await session.execute(
text("UPDATE rag_message SET feedback = :feedback WHERE message_id = :message_id"),
{"feedback": Body.rating, "message_id": MessageId},
)
return RagOperationResultVO(result="success")
async def StopMessage(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
MessageId: str,
Body: RagStopMessageDTO | None = None,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> RagOperationResultVO:
tenant_context = await self._resolve_tenant_context(UserArea, TenantCode, TenantName)
async with GetAsyncSession() as session:
row = (
await session.execute(
text(
"""
SELECT m.metadata, c.user_id, c.conversation_id
FROM rag_message m
JOIN rag_conversation c ON c.conversation_id = m.conversation_id
WHERE m.message_id = :message_id
AND c.deleted_at IS NULL
LIMIT 1
"""
),
{"message_id": MessageId},
)
).mappings().first()
if not row:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "消息不存在")
if not await self._conversation_accessible(
conversation_id=str(row["conversation_id"]),
expected_user_id=CurrentUserId,
tenant_context=tenant_context,
user_role=UserRole,
):
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权停止该消息")
metadata = row.get("metadata") or {}
task_id = str(Body.taskId or metadata.get("task_id") or "").strip()
task = self._message_tasks.get(task_id) if task_id else None
if task and not task.done():
task.cancel()
return RagOperationResultVO(result="success")
async def GetAppParameters(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
AppId: int | None,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> RagAppParametersVO:
app = await self._resolve_app(AppId, UserArea, UserRole, TenantCode, TenantName)
if not app:
return RagAppParametersVO()
try:
suggested = json.loads(app.get("suggested_questions") or "[]")
if not isinstance(suggested, list):
suggested = []
except Exception:
suggested = []
return RagAppParametersVO(
openingStatement=app.get("opening_statement") or "",
suggestedQuestions=[str(item) for item in suggested[:6]],
userInputForm=[],
fileUpload={"image": {"enabled": False}},
)
async def _load_apps(
self,
user_area: str | None,
user_role: str | None,
tenant_code: str | None,
tenant_name: str | None,
only_default: bool,
) -> list[RagChatAppVO]:
async with GetAsyncSession() as session:
await self._ensure_rag_chat_schema(session)
sql = (
f"""
SELECT a.id, a.name, a.description, a.is_default, a.area,
COALESCE(NULLIF(BTRIM(a.tenant_code), ''), NULL) AS tenant_code,
{self._APP_TENANT_NAME_SQL} AS tenant_name,
COALESCE(d.is_public, FALSE) AS dataset_public
FROM rag_chat_app a
LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL
WHERE a.deleted_at IS NULL
AND a.status = 1
AND (:only_default = FALSE OR a.is_default = TRUE)
ORDER BY a.sort_order ASC, a.created_at DESC
"""
)
rows = (
await session.execute(
text(sql),
{
"only_default": only_default,
},
)
).mappings().all()
tenant_context = await self._resolve_tenant_context(user_area, tenant_code, tenant_name)
data: list[RagChatAppVO] = []
for row in rows:
record = dict(row)
if not await self._app_visible(record, tenant_context=tenant_context, user_role=user_role):
continue
data.append(
RagChatAppVO(
appId=str(record["id"]),
appName=record["name"],
description=record.get("description") or "",
tenantCode=str(record.get("tenant_code") or ""),
tenantName=str(record.get("tenant_name") or record.get("area") or ""),
isDefault=bool(record.get("is_default")),
)
)
return data
async def _resolve_app(
self,
app_id: int | None,
user_area: str | None,
user_role: str | None,
tenant_code: str | None,
tenant_name: str | None,
) -> dict | None:
tenant_context = await self._resolve_tenant_context(user_area, tenant_code, tenant_name)
async with GetAsyncSession() as session:
await self._ensure_rag_chat_schema(session)
params = {
"app_id": app_id,
}
base_sql = (
f"""
SELECT a.id, a.name, a.description, a.area,
COALESCE(NULLIF(BTRIM(a.tenant_code), ''), NULL) AS tenant_code,
{self._APP_TENANT_NAME_SQL} AS tenant_name,
a.dataset_id, a.system_prompt,
a.llm_model, a.temperature, a.max_tokens, a.opening_statement,
a.suggested_questions, a.is_default, COALESCE(d.is_public, FALSE) AS dataset_public,
COALESCE(d.name, '') AS dataset_name
FROM rag_chat_app a
LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL
WHERE a.deleted_at IS NULL AND a.status = 1
"""
)
if app_id is not None:
row = (
await session.execute(
text(base_sql + " AND a.id = :app_id LIMIT 1"),
params,
)
).mappings().first()
if row and await self._app_visible(dict(row), tenant_context=tenant_context, user_role=user_role):
return dict(row)
row = (
await session.execute(
text(base_sql + " AND a.is_default = TRUE ORDER BY a.sort_order ASC, a.created_at DESC LIMIT 1"),
params,
)
).mappings().first()
if row and await self._app_visible(dict(row), tenant_context=tenant_context, user_role=user_role):
return dict(row)
row = (
await session.execute(
text(base_sql + " ORDER BY a.sort_order ASC, a.created_at DESC LIMIT 1"),
params,
)
).mappings().first()
return dict(row) if row and await self._app_visible(dict(row), tenant_context=tenant_context, user_role=user_role) else None
async def _app_visible(self, row: dict, tenant_context: dict, user_role: str | None) -> bool:
if self._role_is_global(user_role):
return True
if bool(row.get("dataset_public")):
return True
if str(row.get("tenant_code") or "").strip().upper() == "PUBLIC":
return True
return self._row_matches_tenant_scope(
row_tenant_code=row.get("tenant_code"),
row_area=row.get("area"),
tenant_context=tenant_context,
)
async def _resolve_tenant_context(
self,
user_area: str | None,
tenant_code: str | None,
tenant_name: str | None,
) -> dict[str, str | None]:
resolved = await self.TenantResolver.ResolveUserContext(
Area=user_area,
TenantCode=tenant_code,
TenantName=tenant_name,
Source="rag_chat_user",
)
return {
"tenant_code": resolved.tenant_code,
"tenant_name": resolved.tenant_name,
"tenant_type": resolved.tenant_type,
"area": user_area,
}
async def _resolve_record_tenant(self, raw_value: str | None):
return await self.TenantResolver.Resolve(
RawValue=raw_value,
Source="rag_chat_record",
)
async def _ensure_rag_chat_schema(self, session) -> None:
if self.__class__._chat_schema_checked:
return
async with self.__class__._chat_schema_lock:
if self.__class__._chat_schema_checked:
return
exists = (
await session.execute(
text(
"""
SELECT 1
FROM information_schema.columns
WHERE table_schema = current_schema()
AND table_name = 'rag_chat_app'
AND column_name = 'tenant_code'
"""
)
)
).scalar_one_or_none()
if exists:
self.__class__._chat_schema_checked = True
return
await session.execute(text("SET LOCAL lock_timeout = '1000ms'"))
await session.execute(text("ALTER TABLE rag_chat_app ADD COLUMN tenant_code VARCHAR(64) NULL"))
await session.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_chat_app_tenant_code ON rag_chat_app(tenant_code) WHERE deleted_at IS NULL"))
self.__class__._chat_schema_checked = True
@staticmethod
def _tenant_context_is_global(tenant_context: dict[str, str | None]) -> bool:
tenant_code = str(tenant_context.get("tenant_code") or "").strip().upper()
return tenant_code in {"PUBLIC", "PROVINCIAL"}
@staticmethod
def _role_is_global(user_role: str | None) -> bool:
normalized = str(user_role or "").strip()
return normalized in {"super_admin", "provincial_admin"}
def _row_matches_tenant_scope(
self,
*,
row_tenant_code: str | None,
row_area: str | None,
tenant_context: dict[str, str | None],
) -> bool:
user_tenant_code = str(tenant_context.get("tenant_code") or "").strip()
if user_tenant_code:
return str(row_tenant_code or "").strip() == user_tenant_code
return str(row_area or "").strip() == str(tenant_context.get("area") or "").strip()
async def _ensure_conversation(
self,
user_id: int,
conversation_id: str | None,
app_id: int | None,
user_area: str | None,
user_role: str | None,
tenant_code: str | None,
tenant_name: str | None,
) -> str:
tenant_context = await self._resolve_tenant_context(user_area, tenant_code, tenant_name)
if conversation_id and conversation_id != "-1":
async with GetAsyncSession() as session:
row = (
await session.execute(
text(
"""
SELECT conversation_id, user_id
FROM rag_conversation
WHERE conversation_id = :conversation_id
AND deleted_at IS NULL
LIMIT 1
"""
),
{"conversation_id": conversation_id},
)
).mappings().first()
if row:
if not await self._conversation_accessible(
conversation_id=str(row["conversation_id"]),
expected_user_id=user_id,
tenant_context=tenant_context,
user_role=user_role,
app_id=app_id,
session=session,
):
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权使用该会话")
return str(row["conversation_id"])
conversation_id = str(uuid.uuid4())
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
INSERT INTO rag_conversation (conversation_id, user_id, app_id, name, introduction)
VALUES (:conversation_id, :user_id, :app_id, :name, '')
"""
),
{
"conversation_id": conversation_id,
"user_id": user_id,
"app_id": app_id,
"name": DEFAULT_CONVERSATION_NAME,
},
)
return conversation_id
async def _ensure_conversation_owner(
self,
*,
user_id: int,
conversation_id: str,
user_area: str | None,
user_role: str | None,
tenant_code: str | None,
tenant_name: str | None,
) -> None:
tenant_context = await self._resolve_tenant_context(user_area, tenant_code, tenant_name)
if not await self._conversation_accessible(
conversation_id=conversation_id,
expected_user_id=user_id,
tenant_context=tenant_context,
user_role=user_role,
):
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权访问该会话")
async def _conversation_accessible(
self,
*,
conversation_id: str,
expected_user_id: int,
tenant_context: dict[str, str | None],
user_role: str | None,
app_id: int | None = None,
session=None,
) -> bool:
if session is not None:
return await self._conversation_accessible_with_session(
session=session,
conversation_id=conversation_id,
expected_user_id=expected_user_id,
tenant_context=tenant_context,
user_role=user_role,
app_id=app_id,
)
async with GetAsyncSession() as owned_session:
return await self._conversation_accessible_with_session(
session=owned_session,
conversation_id=conversation_id,
expected_user_id=expected_user_id,
tenant_context=tenant_context,
user_role=user_role,
app_id=app_id,
)
async def _conversation_accessible_with_session(
self,
*,
session,
conversation_id: str,
expected_user_id: int,
tenant_context: dict[str, str | None],
user_role: str | None,
app_id: int | None = None,
) -> bool:
row = (
await session.execute(
text(
"""
SELECT
c.conversation_id,
c.user_id,
c.app_id,
a.area,
COALESCE(NULLIF(BTRIM(a.tenant_code), ''), NULL) AS tenant_code,
COALESCE(d.is_public, FALSE) AS dataset_public
FROM rag_conversation c
LEFT JOIN rag_chat_app a ON a.id = c.app_id AND a.deleted_at IS NULL
LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL
WHERE c.conversation_id = :conversation_id
AND c.deleted_at IS NULL
LIMIT 1
"""
),
{"conversation_id": conversation_id},
)
).mappings().first()
if not row:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "会话不存在")
if int(row["user_id"]) != expected_user_id:
return False
if app_id is not None and row.get("app_id") is not None and int(row["app_id"]) != int(app_id):
return False
app_row = {
"tenant_code": row.get("tenant_code"),
"area": row.get("area"),
"dataset_public": row.get("dataset_public"),
}
return await self._app_visible(app_row, tenant_context=tenant_context, user_role=user_role)
async def _retrieve_context(self, dataset_id: int | None, query: str) -> tuple[list[dict], str]:
if not dataset_id:
return [], ""
async with GetAsyncSession() as session:
dataset = (
await session.execute(
text(
"""
SELECT id, name, collection_name, retrieval_model, embedding_model
FROM rag_dataset
WHERE id = :dataset_id AND deleted_at IS NULL
LIMIT 1
"""
),
{"dataset_id": dataset_id},
)
).mappings().first()
if not dataset:
return [], ""
retrieval_model = dataset.get("retrieval_model") or {}
top_k = int(retrieval_model.get("top_k") or 5)
score_threshold = None
if retrieval_model.get("score_threshold_enabled"):
try:
score_threshold = float(retrieval_model.get("score_threshold"))
except (TypeError, ValueError):
score_threshold = None
try:
query_embedding = await self._embed_texts([query], dataset.get("embedding_model") or "")
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
result = collection.query(
query_embeddings=query_embedding,
n_results=max(top_k, 1),
include=["documents", "metadatas", "distances"],
)
ids = (result.get("ids") or [[]])[0] if result.get("ids") else []
docs = (result.get("documents") or [[]])[0]
metas = (result.get("metadatas") or [[]])[0]
distances = (result.get("distances") or [[]])[0]
chunks: list[dict] = []
for idx, doc in enumerate(docs):
meta = metas[idx] if idx < len(metas) else {}
dist = float(distances[idx]) if idx < len(distances) and distances[idx] is not None else 1.0
score = 1.0 / (1.0 + max(dist, 0.0))
if score_threshold is not None and score < score_threshold:
continue
chunks.append(
{
"id": str(ids[idx] if idx < len(ids) else meta.get("id") or idx),
"text": doc,
"source": meta.get("source") or meta.get("document_name") or dataset.get("name") or "",
"score": score,
"chunk_index": int(meta.get("chunk_index") or idx),
"document_name": meta.get("document_name") or meta.get("source") or "",
"document_id": meta.get("document_id"),
"page": meta.get("page"),
}
)
chunks = await self._hydrate_document_hits(dataset_id, chunks)
if chunks:
return chunks[:top_k], dataset.get("name") or ""
except Exception:
pass
try:
chunks = await self._keyword_retrieve_context(
dataset_id=dataset_id,
collection_name=str(dataset["collection_name"]),
dataset_name=str(dataset.get("name") or ""),
query=query,
top_k=top_k,
score_threshold=score_threshold,
)
return chunks[:top_k], dataset.get("name") or ""
except Exception:
return [], dataset.get("name") or ""
async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]:
embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or build_openai_embeddings_url(RAG_CONFIG["LLM_BASE_URL"])
embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"]
embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4"
batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10))
if not embed_url or not embed_key:
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务")
embeddings: list[list[float]] = []
async with httpx.AsyncClient(timeout=120.0) as client:
for start in range(0, len(texts), batch_size):
batch_texts = texts[start:start + batch_size]
try:
response = await client.post(
embed_url,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {embed_key}",
},
json={"model": embed_model, "input": batch_texts},
)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}"
raise LeauditException(
StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR,
f"向量化服务调用失败: {error_message[:300]}",
) from exc
payload = response.json()
rows = payload.get("data") or []
batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")]
if len(batch_embeddings) != len(batch_texts):
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
embeddings.extend(batch_embeddings)
if len(embeddings) != len(texts):
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
return embeddings
async def _start_message_task(
self,
*,
task_id: str,
conversation_id: str,
message_id: str,
query: str,
app: dict,
) -> None:
self._task_events[task_id] = []
self._task_done[task_id] = False
self._task_locks.setdefault(task_id, asyncio.Lock())
task = asyncio.create_task(
self._run_message_task(
task_id=task_id,
conversation_id=conversation_id,
message_id=message_id,
query=query,
app=app,
)
)
self._message_tasks[task_id] = task
async def _run_message_task(
self,
*,
task_id: str,
conversation_id: str,
message_id: str,
query: str,
app: dict,
) -> None:
context_chunks: list[dict] = []
dataset_name = ""
collected_answer = ""
message_end_payload: dict | None = None
final_status = "completed"
error_payload: dict | None = None
last_persisted_length = 0
last_persisted_at = time.monotonic()
try:
context_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), query)
async for chunk in generate_stream(
query=query,
context_chunks=context_chunks,
conversation_id=conversation_id,
message_id=message_id,
system_prompt=app.get("system_prompt") or "",
model=app.get("llm_model") or "",
temperature=app.get("temperature"),
max_tokens=app.get("max_tokens"),
dataset_name=dataset_name,
task_id=task_id,
):
data = self._parse_sse_event(chunk)
if not data:
continue
event = data.get("event")
if event == "message":
collected_answer += data.get("answer", "")
now = time.monotonic()
if len(collected_answer) - last_persisted_length >= 80 or now - last_persisted_at >= 0.5:
await self._update_message_progress(
conversation_id=conversation_id,
message_id=message_id,
content=collected_answer,
metadata={"status": "running", "task_id": task_id},
)
last_persisted_length = len(collected_answer)
last_persisted_at = now
await self._append_task_event(task_id, data)
continue
if event == "message_end":
message_end_payload = data
continue
if event == "error":
final_status = "error"
error_payload = data
await self._append_task_event(task_id, data)
continue
await self._append_task_event(task_id, data)
if final_status == "completed":
followups: list[str] = []
try:
followups = await generate_followups(query, collected_answer)
except Exception:
followups = []
if message_end_payload:
message_end_payload.setdefault("metadata", {})["suggested_questions"] = followups
await self._append_task_event(task_id, message_end_payload)
await self._finalize_message_record(
conversation_id=conversation_id,
message_id=message_id,
content=collected_answer,
sources=self._build_sources(context_chunks, dataset_name),
metadata={"suggested_questions": followups, "status": "completed", "task_id": task_id},
)
await self._maybe_schedule_auto_title(
conversation_id=conversation_id,
message_id=message_id,
query=query,
answer=collected_answer,
)
else:
await self._finalize_message_record(
conversation_id=conversation_id,
message_id=message_id,
content=collected_answer,
sources=self._build_sources(context_chunks, dataset_name),
metadata={
"suggested_questions": [],
"status": "error",
"task_id": task_id,
"error": (error_payload or {}).get("message", ""),
},
)
except asyncio.CancelledError:
final_status = "stopped"
await self._append_task_event(
task_id,
{
"event": "error",
"task_id": task_id,
"message_id": message_id,
"code": "message_stopped",
"message": "用户已停止回答",
},
)
await self._finalize_message_record(
conversation_id=conversation_id,
message_id=message_id,
content=collected_answer,
sources=self._build_sources(context_chunks, dataset_name),
metadata={"suggested_questions": [], "status": "stopped", "task_id": task_id},
)
raise
except Exception as exc:
final_status = "error"
await self._append_task_event(
task_id,
{
"event": "error",
"task_id": task_id,
"message_id": message_id,
"code": "server_error",
"message": str(exc),
},
)
await self._finalize_message_record(
conversation_id=conversation_id,
message_id=message_id,
content=collected_answer,
sources=self._build_sources(context_chunks, dataset_name),
metadata={"suggested_questions": [], "status": "error", "task_id": task_id, "error": str(exc)},
)
finally:
self._task_done[task_id] = True
self._message_tasks.pop(task_id, None)
self._task_locks.pop(task_id, None)
async def _append_task_event(self, task_id: str, payload: dict) -> None:
lock = self._task_locks.setdefault(task_id, asyncio.Lock())
async with lock:
self._task_events.setdefault(task_id, []).append(payload)
async def _finalize_message_record(
self,
*,
conversation_id: str,
message_id: str,
content: str,
sources: list[dict],
metadata: dict,
) -> None:
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
UPDATE rag_message
SET content = :content,
sources = CAST(:sources AS jsonb),
metadata = CAST(:metadata AS jsonb)
WHERE message_id = :message_id
AND conversation_id = :conversation_id
AND role = 'assistant'
"""
),
{
"conversation_id": conversation_id,
"message_id": message_id,
"content": content,
"sources": json.dumps(sources, ensure_ascii=False),
"metadata": json.dumps(metadata, ensure_ascii=False),
},
)
await session.execute(
text(
"""
UPDATE rag_conversation
SET updated_at = NOW(),
last_message_at = NOW()
WHERE conversation_id = :conversation_id
"""
),
{"conversation_id": conversation_id},
)
async def _update_message_progress(
self,
*,
conversation_id: str,
message_id: str,
content: str,
metadata: dict,
) -> None:
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
UPDATE rag_message
SET content = :content,
metadata = CAST(:metadata AS jsonb)
WHERE message_id = :message_id
AND conversation_id = :conversation_id
AND role = 'assistant'
"""
),
{
"conversation_id": conversation_id,
"message_id": message_id,
"content": content,
"metadata": json.dumps(metadata, ensure_ascii=False),
},
)
async def _maybe_schedule_auto_title(
self,
*,
conversation_id: str,
message_id: str,
query: str,
answer: str,
) -> None:
normalized_query = (query or "").strip()
normalized_answer = (answer or "").strip()
if not normalized_query or not normalized_answer:
return
async with GetAsyncSession() as session:
row = (
await session.execute(
text(
"""
SELECT conversation_id,
name,
COALESCE(title_source, 'default') AS title_source,
COALESCE(title_generation_status, 'idle') AS title_generation_status,
first_answer_message_id
FROM rag_conversation
WHERE conversation_id = :conversation_id
AND deleted_at IS NULL
LIMIT 1
"""
),
{"conversation_id": conversation_id},
)
).mappings().first()
if not row:
return
title_source = str(row.get("title_source") or DEFAULT_TITLE_SOURCE)
if title_source == MANUAL_TITLE_SOURCE:
return
current_name = str(row.get("name") or "").strip()
if current_name and current_name != DEFAULT_CONVERSATION_NAME and title_source != DEFAULT_TITLE_SOURCE:
return
current_status = str(row.get("title_generation_status") or "idle")
if current_status in {"pending", "running", "succeeded"}:
return
if row.get("first_answer_message_id") and str(row.get("first_answer_message_id")) != message_id:
return
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
UPDATE rag_conversation
SET title_generation_status = 'pending',
first_answer_message_id = COALESCE(first_answer_message_id, :message_id),
title_generation_error = NULL
WHERE conversation_id = :conversation_id
AND deleted_at IS NULL
AND COALESCE(title_source, 'default') = 'default'
"""
),
{
"conversation_id": conversation_id,
"message_id": message_id,
},
)
existing_task = self._title_tasks.get(conversation_id)
if existing_task and not existing_task.done():
return
task = asyncio.create_task(
self._run_auto_title_task(
conversation_id=conversation_id,
answer_message_id=message_id,
query=normalized_query,
answer=normalized_answer,
)
)
self._title_tasks[conversation_id] = task
async def _run_auto_title_task(
self,
*,
conversation_id: str,
answer_message_id: str,
query: str,
answer: str,
) -> None:
try:
async with GetAsyncSession() as session:
row = (
await session.execute(
text(
"""
SELECT name,
COALESCE(title_source, 'default') AS title_source,
COALESCE(title_generation_status, 'idle') AS title_generation_status,
first_answer_message_id
FROM rag_conversation
WHERE conversation_id = :conversation_id
AND deleted_at IS NULL
LIMIT 1
"""
),
{"conversation_id": conversation_id},
)
).mappings().first()
if not row:
return
if str(row.get("title_source") or DEFAULT_TITLE_SOURCE) == MANUAL_TITLE_SOURCE:
return
if row.get("first_answer_message_id") and str(row.get("first_answer_message_id")) != answer_message_id:
return
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
UPDATE rag_conversation
SET title_generation_status = 'running',
title_generation_error = NULL
WHERE conversation_id = :conversation_id
AND deleted_at IS NULL
AND COALESCE(title_source, 'default') = 'default'
"""
),
{"conversation_id": conversation_id},
)
generated_title = await self._generate_conversation_title(query=query, answer=answer)
cleaned_title = self._sanitize_generated_title(generated_title)
if not cleaned_title:
cleaned_title = self._build_fallback_title(query=query, answer=answer)
if not cleaned_title:
raise ValueError("未生成有效标题")
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
UPDATE rag_conversation
SET name = :name,
title_source = 'auto',
title_generation_status = 'succeeded',
title_generated_at = NOW(),
title_generation_error = NULL
WHERE conversation_id = :conversation_id
AND deleted_at IS NULL
AND COALESCE(title_source, 'default') = 'default'
AND (
name IS NULL
OR BTRIM(name) = ''
OR name = :default_name
)
"""
),
{
"conversation_id": conversation_id,
"name": cleaned_title,
"default_name": DEFAULT_CONVERSATION_NAME,
},
)
except Exception as exc:
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
UPDATE rag_conversation
SET title_generation_status = 'failed',
title_generation_error = :error
WHERE conversation_id = :conversation_id
AND deleted_at IS NULL
AND COALESCE(title_source, 'default') = 'default'
"""
),
{
"conversation_id": conversation_id,
"error": str(exc)[:1000],
},
)
finally:
self._title_tasks.pop(conversation_id, None)
async def _generate_conversation_title(self, *, query: str, answer: str) -> str:
prompt = (
"请基于用户首轮提问和助手首轮回答,生成一个简洁、准确的中文会话标题。"
"要求:"
"1. 只输出标题本身;"
"2. 不要标点结尾;"
"3. 不要出现“关于”“用户询问”“问题解答”等空话;"
"4. 优先 12-24 个中文字符,最长不超过 40 个字符。\\n"
f"用户问题:{query[:500]}\\n"
f"助手回答:{answer[:1500]}"
)
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
build_openai_chat_completions_url(RAG_CONFIG["LLM_BASE_URL"]),
json={
"model": RAG_CONFIG["LLM_MODEL"],
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.2,
"max_tokens": 80,
"chat_template_kwargs": {"enable_thinking": False},
},
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {RAG_CONFIG['LLM_API_KEY']}",
},
)
resp.raise_for_status()
content = resp.json()["choices"][0]["message"]["content"]
return str(content or "").strip()
def _sanitize_generated_title(self, value: str) -> str:
title = str(value or "").strip()
if not title:
return ""
title = title.replace("\r", " ").replace("\n", " ").strip()
title = re.sub(r"^```[\w-]*", "", title).strip()
title = title.replace("```", "").strip()
title = re.sub(r'^[\"“”\'‘’\[\]()()【】]+', "", title)
title = re.sub(r'[\"“”\'‘’\[\]()()【】]+$', "", title)
title = re.sub(r"^(标题|会话标题)[:]\\s*", "", title)
title = re.sub(r"\\s+", " ", title).strip()
title = title.rstrip("。!?!?.;,:")
if title in {"新对话", "会话标题", "标题"}:
return ""
if len(title) > 40:
title = title[:40].rstrip(",。;;: ")
return title
def _build_fallback_title(self, *, query: str, answer: str) -> str:
base = query.strip() or answer.strip()
if not base:
return ""
base = re.sub(r"\\s+", " ", base)
base = re.sub(r"^[请问帮我一下关于针对结合根据请解释说明一下::,,\\s]+", "", base)
base = base.rstrip("。!?!?.;,:")
if len(base) > 24:
base = base[:24].rstrip(",。;;: ")
return base
async def _keyword_retrieve_context(
self,
*,
dataset_id: int,
collection_name: str,
dataset_name: str,
query: str,
top_k: int,
score_threshold: float | None,
) -> list[dict]:
collection = get_chroma().get_or_create_collection(collection_name)
raw = collection.get(include=["documents", "metadatas"])
ids = raw.get("ids") or []
docs = raw.get("documents") or []
metas = raw.get("metadatas") or []
terms = self._build_keyword_terms(query)
if not terms:
return []
scored_chunks: list[dict] = []
for idx, chunk_id in enumerate(ids):
doc = docs[idx] if idx < len(docs) else ""
meta = metas[idx] if idx < len(metas) and isinstance(metas[idx], dict) else {}
score = self._score_keyword_chunk(
query=query,
terms=terms,
content=doc or "",
document_name=str(meta.get("document_name") or meta.get("source") or ""),
)
if score <= 0:
continue
if score_threshold is not None and score < score_threshold:
continue
scored_chunks.append(
{
"id": str(chunk_id),
"text": doc or "",
"source": meta.get("source") or meta.get("document_name") or dataset_name,
"score": score,
"chunk_index": int(meta.get("chunk_index") or idx),
"document_name": meta.get("document_name") or meta.get("source") or "",
"document_id": meta.get("document_id"),
"page": meta.get("page"),
}
)
scored_chunks.sort(key=lambda item: (-float(item.get("score") or 0.0), int(item.get("chunk_index") or 0)))
hydrated = await self._hydrate_document_hits(dataset_id, scored_chunks[: max(top_k * 3, top_k)])
return hydrated[:top_k]
def _build_keyword_terms(self, query: str) -> list[str]:
normalized = self._normalize_keyword_query(query)
spans = [item.strip() for item in re.findall(r"[\u4e00-\u9fffA-Za-z0-9]+", normalized) if item.strip()]
if not spans:
return []
stop_terms = {
"什么",
"请问",
"一下",
"有关",
"关于",
"如何",
"哪些",
"怎么",
"是否",
"规定",
"办法",
"条例",
"法律",
}
terms: list[str] = []
for span in spans:
if span in stop_terms:
continue
terms.append(span)
if re.fullmatch(r"[\u4e00-\u9fff]+", span):
for size in (2, 3, 4):
if len(span) > size:
for start in range(0, len(span) - size + 1):
token = span[start:start + size]
if token not in stop_terms:
terms.append(token)
unique_terms: list[str] = []
seen: set[str] = set()
for term in sorted(terms, key=len, reverse=True):
if term and term not in seen:
unique_terms.append(term)
seen.add(term)
return unique_terms[:20]
def _normalize_keyword_query(self, query: str) -> str:
normalized = (query or "").strip().lower()
patterns = [
"是什么",
"什么是",
"有哪些",
"有什么",
"是什么?",
"是什么?",
"请问",
"介绍一下",
"解释一下",
"帮我分析",
"帮我看看",
]
for pattern in patterns:
normalized = normalized.replace(pattern, " ")
return re.sub(r"\s+", " ", normalized).strip()
def _score_keyword_chunk(self, *, query: str, terms: list[str], content: str, document_name: str) -> float:
haystack = f"{document_name}\n{content}".lower()
if not haystack:
return 0.0
exact_query = self._normalize_keyword_query(query)
if exact_query and exact_query in haystack:
return 0.98
matched_weight = 0.0
total_weight = 0.0
name_bonus = 0.0
for term in terms:
weight = float(max(len(term), 1) ** 2)
total_weight += weight
if term.lower() in haystack:
matched_weight += weight
if term.lower() in document_name.lower():
name_bonus += min(0.15, 0.03 * len(term))
if total_weight <= 0:
return 0.0
score = (matched_weight / total_weight) + name_bonus
return round(min(score, 0.99), 6)
def _format_sse(self, payload: dict) -> bytes:
return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n".encode("utf-8")
def _build_sources(self, context_chunks: list[dict], dataset_name: str) -> list[dict]:
return [
{
"position": index + 1,
"dataset_id": str(chunk.get("dataset_id") or ""),
"dataset_name": dataset_name,
"document_id": str(chunk.get("document_id") or ""),
"document_name": chunk.get("document_name") or chunk.get("source", ""),
"data_source_type": "upload_file",
"segment_id": chunk.get("id", ""),
"retriever_from": "rag",
"score": round(chunk.get("score", 0.0), 4),
"hit_count": chunk.get("hit_count", 0),
"word_count": len(chunk.get("text", "")),
"segment_position": index + 1,
"index_node_hash": "",
"content": chunk.get("text", "")[:500],
"page": None,
}
for index, chunk in enumerate(context_chunks)
]
async def _hydrate_document_hits(self, dataset_id: int, chunks: list[dict]) -> list[dict]:
source_names = sorted(
{
str(chunk.get("document_name") or chunk.get("source") or "").strip()
for chunk in chunks
if str(chunk.get("document_name") or chunk.get("source") or "").strip()
}
)
if not source_names:
return chunks
async with GetAsyncSession() as session:
rows = (
await session.execute(
text(
"""
SELECT id, original_name, enabled, hit_count
FROM rag_document
WHERE dataset_id = :dataset_id
AND deleted_at IS NULL
AND original_name = ANY(:source_names)
"""
),
{
"dataset_id": dataset_id,
"source_names": source_names,
},
)
).mappings().all()
document_map = {str(row["original_name"]): row for row in rows}
visible_chunks: list[dict] = []
hit_document_ids: list[int] = []
for chunk in chunks:
source_name = str(chunk.get("document_name") or chunk.get("source") or "").strip()
document = document_map.get(source_name)
if document and not bool(document.get("enabled")):
continue
if document:
chunk["document_id"] = document["id"]
chunk["dataset_id"] = dataset_id
chunk["document_name"] = document["original_name"]
chunk["hit_count"] = document.get("hit_count") or 0
hit_document_ids.append(int(document["id"]))
visible_chunks.append(chunk)
if hit_document_ids:
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
UPDATE rag_document
SET hit_count = hit_count + 1,
updated_at = NOW()
WHERE id = ANY(:document_ids)
"""
),
{"document_ids": sorted(set(hit_document_ids))},
)
return visible_chunks
def _parse_sse_event(self, chunk: str) -> dict | None:
data_lines: list[str] = []
for line in chunk.splitlines():
if line.startswith("data: "):
data_lines.append(line[6:])
elif line.startswith("data:"):
data_lines.append(line[5:].lstrip())
if not data_lines:
return None
payload = "\n".join(part for part in data_lines if part.strip()).strip()
if not payload or payload == "[DONE]":
return None
payload = payload.removesuffix("\\n\\n").removesuffix("\\n").strip()
try:
data = json.loads(payload)
except json.JSONDecodeError:
return None
return data if isinstance(data, dict) else None