Files
2026-05-25 15:37:37 +08:00

1778 lines
70 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
"""RAG 聊天服务实现。"""
from __future__ import annotations
import asyncio
import json
import re
import time
import uuid
from typing import AsyncGenerator
import httpx
from sqlalchemy import text
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
from fastapi_modules.fastapi_leaudit.domian.Dto.ragChatDto import (
RagConversationRenameDTO,
RagMessageFeedbackDTO,
RagStopMessageDTO,
)
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import (
RagAppParametersVO,
RagChatAppListVO,
RagChatAppVO,
RagConversationItemVO,
RagConversationPageVO,
RagConversationRenameVO,
RagMessageItemVO,
RagMessagePageVO,
RagOperationResultVO,
)
from fastapi_modules.fastapi_leaudit.rag_engine.config import (
RAG_CONFIG,
build_openai_chat_completions_url,
)
from fastapi_modules.fastapi_leaudit.rag_engine.generator import generate_stream
from fastapi_modules.fastapi_leaudit.rag_engine.question_chains import generate_followups
from fastapi_modules.fastapi_leaudit.rag_engine.retriever import RagRetriever
from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService
from fastapi_modules.fastapi_leaudit.services.impl.tenantResolver import TenantResolver
DEFAULT_CONVERSATION_NAME = "新对话"
DEFAULT_TITLE_SOURCE = "default"
AUTO_TITLE_SOURCE = "auto"
MANUAL_TITLE_SOURCE = "manual"
class RagChatServiceImpl(IRagChatService):
"""RAG 聊天服务实现。"""
_message_tasks: dict[str, asyncio.Task] = {}
_task_events: dict[str, list[dict]] = {}
_task_done: dict[str, bool] = {}
_task_locks: dict[str, asyncio.Lock] = {}
_title_tasks: dict[str, asyncio.Task] = {}
_chat_schema_checked = False
_chat_schema_lock = asyncio.Lock()
def __init__(self, retriever: RagRetriever | None = None) -> None:
self.TenantResolver = TenantResolver()
self.retriever = retriever or RagRetriever()
_APP_TENANT_NAME_SQL = (
"CASE "
"WHEN NULLIF(BTRIM(a.tenant_code), '') = 'PUBLIC' THEN '公共' "
"WHEN NULLIF(BTRIM(a.tenant_code), '') = 'PROVINCIAL' THEN '省级' "
"ELSE COALESCE(NULLIF(BTRIM(a.area), ''), '未分配地区') "
"END"
)
async def GetApps(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> RagChatAppListVO:
apps = await self._load_apps(UserArea, UserRole, TenantCode, TenantName, only_default=False)
return RagChatAppListVO(data=apps, total=len(apps))
async def GetDefaultApp(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> RagChatAppVO | None:
apps = await self._load_apps(UserArea, UserRole, TenantCode, TenantName, only_default=True)
if apps:
return apps[0]
all_apps = await self._load_apps(UserArea, UserRole, TenantCode, TenantName, only_default=False)
return all_apps[0] if all_apps else None
async def SendMessage(
self,
CurrentUserId: int,
UserName: str,
UserArea: str | None,
UserRole: str | None,
Query: str,
ConversationId: str | None,
AppId: int | None,
AttachmentId: str | None = None,
AttachmentIds: list[str] | None = None,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> AsyncGenerator[bytes, None]:
if not Query.strip():
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "问题不能为空")
app = await self._resolve_app(AppId, UserArea, UserRole, TenantCode, TenantName)
if not app:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "未配置可用聊天应用")
conversationId = await self._ensure_conversation(
user_id=CurrentUserId,
conversation_id=ConversationId,
app_id=app["id"],
user_area=UserArea,
user_role=UserRole,
tenant_code=TenantCode,
tenant_name=TenantName,
)
messageId = str(uuid.uuid4())
taskId = str(uuid.uuid4())
is_new_conversation = not ConversationId or ConversationId == "-1"
active_attachment_ids = self._normalize_attachment_ids(
attachment_id=AttachmentId,
attachment_ids=AttachmentIds,
)
if not active_attachment_ids:
active_attachment_ids = await self._resolve_attachment_ids_for_conversation(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=conversationId,
)
attachment_records = await self._load_message_attachment_records(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=conversationId,
AttachmentIds=active_attachment_ids,
)
user_message_files = self._build_user_message_files(attachment_records)
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata)
VALUES (:message_id, :conversation_id, 'user', :content, '[]'::jsonb, CAST(:metadata AS jsonb))
"""
),
{
"message_id": str(uuid.uuid4()),
"conversation_id": conversationId,
"content": Query,
"metadata": json.dumps({"message_files": user_message_files}, ensure_ascii=False),
},
)
await session.execute(
text(
"""
INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata)
VALUES (:message_id, :conversation_id, 'assistant', '', '[]'::jsonb, CAST(:metadata AS jsonb))
"""
),
{
"message_id": messageId,
"conversation_id": conversationId,
"metadata": json.dumps({"status": "running", "task_id": taskId}, ensure_ascii=False),
},
)
await session.execute(
text(
"UPDATE rag_conversation SET updated_at = NOW() WHERE conversation_id = :conversation_id"
),
{"conversation_id": conversationId},
)
await self._start_message_task(
task_id=taskId,
conversation_id=conversationId,
message_id=messageId,
query=Query,
app=app,
current_user_id=CurrentUserId,
tenant_code=TenantCode,
user_area=UserArea,
attachment_id=AttachmentId,
attachment_ids=active_attachment_ids,
)
event_index = 0
initial_events: list[dict] = []
if is_new_conversation:
initial_events.append(
{
"event": "conversation_created",
"conversation_id": conversationId,
"message_id": messageId,
"task_id": taskId,
}
)
while True:
if event_index < len(initial_events):
payload = initial_events[event_index]
event_index += 1
yield self._format_sse(payload)
continue
events = self._task_events.get(taskId, [])
if event_index - len(initial_events) < len(events):
payload = events[event_index - len(initial_events)]
event_index += 1
yield self._format_sse(payload)
continue
if self._task_done.get(taskId):
break
await asyncio.sleep(0.05)
async def GetConversations(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
AppId: int | None,
Page: int,
PageSize: int,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> RagConversationPageVO:
tenant_context = await self._resolve_tenant_context(UserArea, TenantCode, TenantName)
async with GetAsyncSession() as session:
rows = (
await session.execute(
text(
"""
SELECT conversation_id, name, introduction, created_at, updated_at
, COALESCE(title_source, 'default') AS title_source
, COALESCE(EXTRACT(EPOCH FROM last_message_at), 0) AS last_message_at
FROM rag_conversation
WHERE user_id = :user_id
AND deleted_at IS NULL
AND (CAST(:app_id AS BIGINT) IS NULL OR app_id = CAST(:app_id AS BIGINT))
ORDER BY COALESCE(last_message_at, updated_at) DESC, updated_at DESC
OFFSET :offset LIMIT :limit
"""
),
{
"user_id": CurrentUserId,
"app_id": AppId,
"offset": max(Page - 1, 0) * PageSize,
"limit": PageSize + 1,
},
)
).mappings().all()
filtered_rows: list[dict] = []
for row in rows:
record = dict(row)
if await self._conversation_accessible(
conversation_id=str(record["conversation_id"]),
expected_user_id=CurrentUserId,
tenant_context=tenant_context,
user_role=UserRole,
app_id=AppId,
session=session,
):
filtered_rows.append(record)
has_more = len(filtered_rows) > PageSize
items = filtered_rows[:PageSize]
return RagConversationPageVO(
data=[
RagConversationItemVO(
id=row["conversation_id"],
name=row["name"],
introduction=row.get("introduction") or "",
titleSource=str(row.get("title_source") or DEFAULT_TITLE_SOURCE),
createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0,
updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0,
lastMessageAt=int(float(row.get("last_message_at") or 0)),
)
for row in items
],
hasMore=has_more,
limit=PageSize,
)
async def GetConversationMessages(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
ConversationId: str,
Page: int,
PageSize: int,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> RagMessagePageVO:
await self._ensure_conversation_owner(
user_id=CurrentUserId,
conversation_id=ConversationId,
user_area=UserArea,
user_role=UserRole,
tenant_code=TenantCode,
tenant_name=TenantName,
)
async with GetAsyncSession() as session:
rows = (
await session.execute(
text(
"""
SELECT message_id, role, content, sources, metadata, feedback, created_at
FROM rag_message
WHERE conversation_id = :conversation_id
ORDER BY created_at ASC,
CASE role
WHEN 'user' THEN 0
WHEN 'assistant' THEN 1
ELSE 2
END ASC,
message_id ASC
OFFSET :offset LIMIT :limit
"""
),
{
"conversation_id": ConversationId,
"offset": max(Page - 1, 0) * PageSize,
"limit": PageSize + 1,
},
)
).mappings().all()
has_more = len(rows) > PageSize
items = rows[:PageSize]
data: list[RagMessageItemVO] = []
idx = 0
while idx < len(items):
row = items[idx]
if row["role"] == "user":
answer = items[idx + 1] if idx + 1 < len(items) and items[idx + 1]["role"] == "assistant" else None
user_metadata = dict(row.get("metadata") or {})
answer_metadata = dict((answer.get("metadata") if answer else None) or {})
answer_status = str(answer_metadata.get("status") or ("completed" if answer else "running"))
answer_content = (answer.get("content") if answer else None) or ""
if answer:
task_id = str(answer_metadata.get("task_id") or "").strip()
reconstructed_content = self._rebuild_message_content_from_events(task_id) if task_id else ""
if reconstructed_content and len(reconstructed_content) >= len(answer_content):
if reconstructed_content != answer_content:
await self._update_message_progress(
conversation_id=ConversationId,
message_id=answer["message_id"],
content=reconstructed_content,
metadata=answer_metadata,
)
answer_content = reconstructed_content
normalized_status = await self._resolve_persisted_message_status(
conversation_id=ConversationId,
message_id=answer["message_id"],
content=answer_content,
metadata=answer_metadata,
)
if normalized_status != answer_status:
answer_status = normalized_status
answer_metadata["status"] = normalized_status
data.append(
RagMessageItemVO(
id=(answer["message_id"] if answer else row["message_id"]),
conversationId=ConversationId,
query=row["content"],
answer=answer_content if answer else "",
messageFiles=[item for item in (user_metadata.get("message_files") or []) if isinstance(item, dict)],
feedback=({"rating": answer["feedback"]} if answer and answer.get("feedback") else None),
retrieverResources=(answer.get("sources") if answer else None),
suggestedQuestions=[str(item) for item in (answer_metadata.get("suggested_questions") or []) if str(item).strip()],
status=answer_status,
createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0,
)
)
idx += 2 if answer else 1
else:
idx += 1
return RagMessagePageVO(data=data, hasMore=has_more, limit=PageSize)
async def _resolve_persisted_message_status(
self,
*,
conversation_id: str,
message_id: str,
content: str,
metadata: dict,
) -> str:
status = str(metadata.get("status") or "completed")
if status != "running":
return status
task_id = str(metadata.get("task_id") or "").strip()
task = self._message_tasks.get(task_id) if task_id else None
task_done = self._task_done.get(task_id, False) if task_id else False
if task and not task.done() and not task_done:
return "running"
normalized_status = "completed" if content.strip() else "error"
normalized_metadata = {
**metadata,
"status": normalized_status,
}
if normalized_status == "error" and not normalized_metadata.get("error"):
normalized_metadata["error"] = "生成任务已结束,但未产出有效回答"
await self._update_message_progress(
conversation_id=conversation_id,
message_id=message_id,
content=content,
metadata=normalized_metadata,
)
return normalized_status
def _rebuild_message_content_from_events(self, task_id: str) -> str:
if not task_id:
return ""
chunks: list[str] = []
for event in self._task_events.get(task_id, []):
if event.get("event") != "message":
continue
answer = event.get("answer")
if isinstance(answer, str) and answer:
chunks.append(answer)
return "".join(chunks)
async def RenameConversation(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
ConversationId: str,
Body: RagConversationRenameDTO,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> RagConversationRenameVO:
await self._ensure_conversation_owner(
user_id=CurrentUserId,
conversation_id=ConversationId,
user_area=UserArea,
user_role=UserRole,
tenant_code=TenantCode,
tenant_name=TenantName,
)
final_name = Body.name.strip()
if not final_name:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "会话名称不能为空")
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
UPDATE rag_conversation
SET name = :name,
title_source = 'manual',
title_generation_status = CASE
WHEN COALESCE(title_generation_status, 'idle') = 'running' THEN 'succeeded'
ELSE COALESCE(title_generation_status, 'idle')
END,
title_generation_error = NULL,
updated_at = NOW()
WHERE conversation_id = :conversation_id
"""
),
{"name": final_name, "conversation_id": ConversationId},
)
return RagConversationRenameVO(result="success", name=final_name)
async def DeleteConversation(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
ConversationId: str,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> RagOperationResultVO:
await self._ensure_conversation_owner(
user_id=CurrentUserId,
conversation_id=ConversationId,
user_area=UserArea,
user_role=UserRole,
tenant_code=TenantCode,
tenant_name=TenantName,
)
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"UPDATE rag_conversation SET deleted_at = NOW(), updated_at = NOW() WHERE conversation_id = :conversation_id"
),
{"conversation_id": ConversationId},
)
return RagOperationResultVO(result="success")
async def UpdateFeedback(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
MessageId: str,
Body: RagMessageFeedbackDTO,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> RagOperationResultVO:
tenant_context = await self._resolve_tenant_context(UserArea, TenantCode, TenantName)
async with GetAsyncSession() as session:
row = (
await session.execute(
text(
"""
SELECT c.user_id, c.conversation_id
FROM rag_message m
JOIN rag_conversation c ON c.conversation_id = m.conversation_id
WHERE m.message_id = :message_id AND c.deleted_at IS NULL
LIMIT 1
"""
),
{"message_id": MessageId},
)
).mappings().first()
if not row:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "消息不存在")
if not await self._conversation_accessible(
conversation_id=str(row["conversation_id"]),
expected_user_id=CurrentUserId,
tenant_context=tenant_context,
user_role=UserRole,
session=session,
):
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权修改该消息反馈")
await session.execute(
text("UPDATE rag_message SET feedback = :feedback WHERE message_id = :message_id"),
{"feedback": Body.rating, "message_id": MessageId},
)
return RagOperationResultVO(result="success")
async def StopMessage(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
MessageId: str,
Body: RagStopMessageDTO | None = None,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> RagOperationResultVO:
tenant_context = await self._resolve_tenant_context(UserArea, TenantCode, TenantName)
async with GetAsyncSession() as session:
row = (
await session.execute(
text(
"""
SELECT m.metadata, c.user_id, c.conversation_id
FROM rag_message m
JOIN rag_conversation c ON c.conversation_id = m.conversation_id
WHERE m.message_id = :message_id
AND c.deleted_at IS NULL
LIMIT 1
"""
),
{"message_id": MessageId},
)
).mappings().first()
if not row:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "消息不存在")
if not await self._conversation_accessible(
conversation_id=str(row["conversation_id"]),
expected_user_id=CurrentUserId,
tenant_context=tenant_context,
user_role=UserRole,
):
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权停止该消息")
metadata = row.get("metadata") or {}
task_id = str(Body.taskId or metadata.get("task_id") or "").strip()
task = self._message_tasks.get(task_id) if task_id else None
if task and not task.done():
task.cancel()
return RagOperationResultVO(result="success")
async def GetAppParameters(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
AppId: int | None,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> RagAppParametersVO:
app = await self._resolve_app(AppId, UserArea, UserRole, TenantCode, TenantName)
if not app:
return RagAppParametersVO()
try:
suggested = json.loads(app.get("suggested_questions") or "[]")
if not isinstance(suggested, list):
suggested = []
except Exception:
suggested = []
return RagAppParametersVO(
openingStatement=app.get("opening_statement") or "",
suggestedQuestions=[str(item) for item in suggested[:6]],
userInputForm=[],
fileUpload={"image": {"enabled": False}},
)
async def _load_apps(
self,
user_area: str | None,
user_role: str | None,
tenant_code: str | None,
tenant_name: str | None,
only_default: bool,
) -> list[RagChatAppVO]:
async with GetAsyncSession() as session:
await self._ensure_rag_chat_schema(session)
sql = (
f"""
SELECT a.id, a.name, a.description, a.is_default, a.area,
COALESCE(NULLIF(BTRIM(a.tenant_code), ''), NULL) AS tenant_code,
{self._APP_TENANT_NAME_SQL} AS tenant_name,
COALESCE(d.is_public, FALSE) AS dataset_public
FROM rag_chat_app a
LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL
WHERE a.deleted_at IS NULL
AND a.status = 1
AND (:only_default = FALSE OR a.is_default = TRUE)
ORDER BY a.sort_order ASC, a.created_at DESC
"""
)
rows = (
await session.execute(
text(sql),
{
"only_default": only_default,
},
)
).mappings().all()
tenant_context = await self._resolve_tenant_context(user_area, tenant_code, tenant_name)
data: list[RagChatAppVO] = []
for row in rows:
record = dict(row)
if not await self._app_visible(record, tenant_context=tenant_context, user_role=user_role):
continue
data.append(
RagChatAppVO(
appId=str(record["id"]),
appName=record["name"],
description=record.get("description") or "",
tenantCode=str(record.get("tenant_code") or ""),
tenantName=str(record.get("tenant_name") or record.get("area") or ""),
isDefault=bool(record.get("is_default")),
)
)
return data
async def _resolve_app(
self,
app_id: int | None,
user_area: str | None,
user_role: str | None,
tenant_code: str | None,
tenant_name: str | None,
) -> dict | None:
tenant_context = await self._resolve_tenant_context(user_area, tenant_code, tenant_name)
async with GetAsyncSession() as session:
await self._ensure_rag_chat_schema(session)
params = {
"app_id": app_id,
}
base_sql = (
f"""
SELECT a.id, a.name, a.description, a.area,
COALESCE(NULLIF(BTRIM(a.tenant_code), ''), NULL) AS tenant_code,
{self._APP_TENANT_NAME_SQL} AS tenant_name,
a.dataset_id, a.system_prompt,
a.llm_model, a.temperature, a.max_tokens, a.opening_statement,
a.suggested_questions, a.is_default, COALESCE(d.is_public, FALSE) AS dataset_public,
COALESCE(d.name, '') AS dataset_name
FROM rag_chat_app a
LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL
WHERE a.deleted_at IS NULL AND a.status = 1
"""
)
if app_id is not None:
row = (
await session.execute(
text(base_sql + " AND a.id = :app_id LIMIT 1"),
params,
)
).mappings().first()
if row and await self._app_visible(dict(row), tenant_context=tenant_context, user_role=user_role):
return dict(row)
row = (
await session.execute(
text(base_sql + " AND a.is_default = TRUE ORDER BY a.sort_order ASC, a.created_at DESC LIMIT 1"),
params,
)
).mappings().first()
if row and await self._app_visible(dict(row), tenant_context=tenant_context, user_role=user_role):
return dict(row)
row = (
await session.execute(
text(base_sql + " ORDER BY a.sort_order ASC, a.created_at DESC LIMIT 1"),
params,
)
).mappings().first()
return dict(row) if row and await self._app_visible(dict(row), tenant_context=tenant_context, user_role=user_role) else None
async def _app_visible(self, row: dict, tenant_context: dict, user_role: str | None) -> bool:
if self._role_is_global(user_role):
return True
if bool(row.get("dataset_public")):
return True
if str(row.get("tenant_code") or "").strip().upper() == "PUBLIC":
return True
return self._row_matches_tenant_scope(
row_tenant_code=row.get("tenant_code"),
row_area=row.get("area"),
tenant_context=tenant_context,
)
async def _resolve_tenant_context(
self,
user_area: str | None,
tenant_code: str | None,
tenant_name: str | None,
) -> dict[str, str | None]:
resolved = await self.TenantResolver.ResolveUserContext(
Area=user_area,
TenantCode=tenant_code,
TenantName=tenant_name,
Source="rag_chat_user",
)
return {
"tenant_code": resolved.tenant_code,
"tenant_name": resolved.tenant_name,
"tenant_type": resolved.tenant_type,
"area": user_area,
}
async def _resolve_record_tenant(self, raw_value: str | None):
return await self.TenantResolver.Resolve(
RawValue=raw_value,
Source="rag_chat_record",
)
async def _ensure_rag_chat_schema(self, session) -> None:
if self.__class__._chat_schema_checked:
return
async with self.__class__._chat_schema_lock:
if self.__class__._chat_schema_checked:
return
exists = (
await session.execute(
text(
"""
SELECT 1
FROM information_schema.columns
WHERE table_schema = current_schema()
AND table_name = 'rag_chat_app'
AND column_name = 'tenant_code'
"""
)
)
).scalar_one_or_none()
if exists:
self.__class__._chat_schema_checked = True
return
await session.execute(text("SET LOCAL lock_timeout = '1000ms'"))
await session.execute(text("ALTER TABLE rag_chat_app ADD COLUMN tenant_code VARCHAR(64) NULL"))
await session.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_chat_app_tenant_code ON rag_chat_app(tenant_code) WHERE deleted_at IS NULL"))
self.__class__._chat_schema_checked = True
@staticmethod
def _tenant_context_is_global(tenant_context: dict[str, str | None]) -> bool:
tenant_code = str(tenant_context.get("tenant_code") or "").strip().upper()
return tenant_code in {"PUBLIC", "PROVINCIAL"}
@staticmethod
def _role_is_global(user_role: str | None) -> bool:
normalized = str(user_role or "").strip()
return normalized in {"super_admin", "provincial_admin"}
def _row_matches_tenant_scope(
self,
*,
row_tenant_code: str | None,
row_area: str | None,
tenant_context: dict[str, str | None],
) -> bool:
user_tenant_code = str(tenant_context.get("tenant_code") or "").strip()
if user_tenant_code:
return str(row_tenant_code or "").strip() == user_tenant_code
return str(row_area or "").strip() == str(tenant_context.get("area") or "").strip()
async def _ensure_conversation(
self,
user_id: int,
conversation_id: str | None,
app_id: int | None,
user_area: str | None,
user_role: str | None,
tenant_code: str | None,
tenant_name: str | None,
) -> str:
tenant_context = await self._resolve_tenant_context(user_area, tenant_code, tenant_name)
if conversation_id and conversation_id != "-1":
async with GetAsyncSession() as session:
row = (
await session.execute(
text(
"""
SELECT conversation_id, user_id
FROM rag_conversation
WHERE conversation_id = :conversation_id
AND deleted_at IS NULL
LIMIT 1
"""
),
{"conversation_id": conversation_id},
)
).mappings().first()
if row:
if not await self._conversation_accessible(
conversation_id=str(row["conversation_id"]),
expected_user_id=user_id,
tenant_context=tenant_context,
user_role=user_role,
app_id=app_id,
session=session,
):
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权使用该会话")
return str(row["conversation_id"])
conversation_id = str(uuid.uuid4())
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
INSERT INTO rag_conversation (conversation_id, user_id, app_id, name, introduction)
VALUES (:conversation_id, :user_id, :app_id, :name, '')
"""
),
{
"conversation_id": conversation_id,
"user_id": user_id,
"app_id": app_id,
"name": DEFAULT_CONVERSATION_NAME,
},
)
return conversation_id
async def _ensure_conversation_owner(
self,
*,
user_id: int,
conversation_id: str,
user_area: str | None,
user_role: str | None,
tenant_code: str | None,
tenant_name: str | None,
) -> None:
tenant_context = await self._resolve_tenant_context(user_area, tenant_code, tenant_name)
if not await self._conversation_accessible(
conversation_id=conversation_id,
expected_user_id=user_id,
tenant_context=tenant_context,
user_role=user_role,
):
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权访问该会话")
async def _conversation_accessible(
self,
*,
conversation_id: str,
expected_user_id: int,
tenant_context: dict[str, str | None],
user_role: str | None,
app_id: int | None = None,
session=None,
) -> bool:
if session is not None:
return await self._conversation_accessible_with_session(
session=session,
conversation_id=conversation_id,
expected_user_id=expected_user_id,
tenant_context=tenant_context,
user_role=user_role,
app_id=app_id,
)
async with GetAsyncSession() as owned_session:
return await self._conversation_accessible_with_session(
session=owned_session,
conversation_id=conversation_id,
expected_user_id=expected_user_id,
tenant_context=tenant_context,
user_role=user_role,
app_id=app_id,
)
async def _conversation_accessible_with_session(
self,
*,
session,
conversation_id: str,
expected_user_id: int,
tenant_context: dict[str, str | None],
user_role: str | None,
app_id: int | None = None,
) -> bool:
row = (
await session.execute(
text(
"""
SELECT
c.conversation_id,
c.user_id,
c.app_id,
a.area,
COALESCE(NULLIF(BTRIM(a.tenant_code), ''), NULL) AS tenant_code,
COALESCE(d.is_public, FALSE) AS dataset_public
FROM rag_conversation c
LEFT JOIN rag_chat_app a ON a.id = c.app_id AND a.deleted_at IS NULL
LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL
WHERE c.conversation_id = :conversation_id
AND c.deleted_at IS NULL
LIMIT 1
"""
),
{"conversation_id": conversation_id},
)
).mappings().first()
if not row:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "会话不存在")
if int(row["user_id"]) != expected_user_id:
return False
if app_id is not None and row.get("app_id") is not None and int(row["app_id"]) != int(app_id):
return False
app_row = {
"tenant_code": row.get("tenant_code"),
"area": row.get("area"),
"dataset_public": row.get("dataset_public"),
}
return await self._app_visible(app_row, tenant_context=tenant_context, user_role=user_role)
async def _retrieve_context(self, dataset_id: int | None, query: str) -> tuple[list[dict], str]:
result = await self.retriever.retrieve(query=query, dataset_id=dataset_id)
chunks = [
{
**chunk,
"source_scope": chunk.get("source_scope") or "formal_kb",
"data_source_type": chunk.get("data_source_type") or "formal_kb",
}
for chunk in result.chunks
]
return chunks, result.dataset_name
async def _resolve_attachment_id_for_conversation(
self,
*,
CurrentUserId: int,
TenantCode: str | None,
UserArea: str | None,
ConversationId: str,
) -> str | None:
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
service = RagChatAttachmentServiceImpl()
return await service.ResolveActiveAttachmentIdForConversation(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
)
async def _resolve_attachment_ids_for_conversation(
self,
*,
CurrentUserId: int,
TenantCode: str | None,
UserArea: str | None,
ConversationId: str,
) -> list[str]:
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
service = RagChatAttachmentServiceImpl()
if hasattr(service, "ResolveActiveAttachmentIdsForConversation"):
return await service.ResolveActiveAttachmentIdsForConversation(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
)
attachment_id = await service.ResolveActiveAttachmentIdForConversation(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
)
return [attachment_id] if attachment_id else []
async def _retrieve_attachment_context(
self,
*,
CurrentUserId: int,
TenantCode: str | None,
UserArea: str | None,
ConversationId: str,
AttachmentId: str,
Query: str,
) -> tuple[list[dict], str]:
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
service = RagChatAttachmentServiceImpl()
return await service.RetrieveAttachmentContext(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
AttachmentId=AttachmentId,
Query=Query,
TopK=5,
)
async def _load_message_attachment_records(
self,
*,
CurrentUserId: int,
TenantCode: str | None,
UserArea: str | None,
ConversationId: str,
AttachmentIds: list[str],
) -> list[dict]:
if not AttachmentIds:
return []
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
service = RagChatAttachmentServiceImpl()
records: list[dict] = []
for attachment_id in AttachmentIds:
record = await service.ValidateAttachmentForChat(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
AttachmentId=attachment_id,
)
records.append(record)
return records
def _build_user_message_files(self, attachment_records: list[dict]) -> list[dict]:
files: list[dict] = []
for record in attachment_records:
attachment_id = str(record.get("attachment_id") or "").strip()
if not attachment_id:
continue
file_name = str(record.get("original_name") or record.get("filename") or "上传文件")
content_type = str(record.get("content_type") or "")
file_type = "image" if content_type.startswith("image/") or re.search(r"\.(png|jpe?g|webp|bmp|tiff?)$", file_name, re.I) else "file"
files.append(
{
"id": attachment_id,
"upload_file_id": attachment_id,
"name": file_name,
"fileName": file_name,
"type": file_type,
"transfer_method": "local_file",
"contentType": content_type or None,
"fileSize": int(record.get("file_size") or 0),
"belongs_to": "user",
"usage": "temporary_attachment",
}
)
return files
def _build_formal_kb_query(self, *, query: str, attachment_chunks: list[dict]) -> str:
if not attachment_chunks:
return query
facts: list[str] = []
for chunk in attachment_chunks[:3]:
text_value = str(chunk.get("text") or "").strip()
if text_value:
facts.append(text_value[:500])
if not facts:
return query
return (
f"{query}\n\n"
"用户上传文档中检索到的相关事实:\n"
+ "\n".join(f"- {item}" for item in facts)
+ "\n\n请检索这些事实对应的法律责任、处罚依据、裁量规则或案例。"
)
def _merge_context_chunks(self, *, attachment_chunks: list[dict], formal_chunks: list[dict]) -> list[dict]:
merged: list[dict] = []
for chunk in attachment_chunks:
merged.append(
{
**chunk,
"source_scope": "chat_attachment",
"data_source_type": "chat_attachment",
}
)
for chunk in formal_chunks:
merged.append(
{
**chunk,
"source_scope": chunk.get("source_scope") or "formal_kb",
"data_source_type": chunk.get("data_source_type") or "formal_kb",
}
)
return merged
def _normalize_attachment_ids(self, *, attachment_id: str | None, attachment_ids: list[str] | None) -> list[str]:
normalized: list[str] = []
for raw in [*(attachment_ids or []), attachment_id]:
value = str(raw or "").strip()
if value and value not in normalized:
normalized.append(value)
return normalized
async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]:
return await self.retriever._embed_texts(texts, model_name)
async def _start_message_task(
self,
*,
task_id: str,
conversation_id: str,
message_id: str,
query: str,
app: dict,
current_user_id: int | None = None,
tenant_code: str | None = None,
user_area: str | None = None,
attachment_id: str | None = None,
attachment_ids: list[str] | None = None,
) -> None:
self._task_events[task_id] = []
self._task_done[task_id] = False
self._task_locks.setdefault(task_id, asyncio.Lock())
task = asyncio.create_task(
self._run_message_task(
task_id=task_id,
conversation_id=conversation_id,
message_id=message_id,
query=query,
app=app,
current_user_id=current_user_id,
tenant_code=tenant_code,
user_area=user_area,
attachment_id=attachment_id,
attachment_ids=attachment_ids,
)
)
self._message_tasks[task_id] = task
async def _run_message_task(
self,
*,
task_id: str,
conversation_id: str,
message_id: str,
query: str,
app: dict,
current_user_id: int | None = None,
tenant_code: str | None = None,
user_area: str | None = None,
attachment_id: str | None = None,
attachment_ids: list[str] | None = None,
) -> None:
context_chunks: list[dict] = []
dataset_name = ""
collected_answer = ""
message_end_payload: dict | None = None
final_status = "completed"
error_payload: dict | None = None
last_persisted_length = 0
last_persisted_at = time.monotonic()
try:
attachment_chunks: list[dict] = []
attachment_names: list[str] = []
active_attachment_ids = self._normalize_attachment_ids(
attachment_id=attachment_id,
attachment_ids=attachment_ids,
)
if not active_attachment_ids and current_user_id is not None:
active_attachment_ids = await self._resolve_attachment_ids_for_conversation(
CurrentUserId=current_user_id,
TenantCode=tenant_code,
UserArea=user_area,
ConversationId=conversation_id,
)
if active_attachment_ids:
if current_user_id is None:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件缺少用户上下文")
for active_attachment_id in active_attachment_ids:
next_chunks, next_attachment_name = await self._retrieve_attachment_context(
CurrentUserId=current_user_id,
TenantCode=tenant_code,
UserArea=user_area,
ConversationId=conversation_id,
AttachmentId=active_attachment_id,
Query=query,
)
attachment_chunks.extend(next_chunks)
if next_attachment_name:
attachment_names.append(next_attachment_name)
legal_query = self._build_formal_kb_query(query=query, attachment_chunks=attachment_chunks)
formal_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), legal_query)
context_chunks = self._merge_context_chunks(
attachment_chunks=attachment_chunks,
formal_chunks=formal_chunks,
)
generation_dataset_name = dataset_name
attachment_name = "".join(dict.fromkeys(attachment_names))
if attachment_name and dataset_name:
generation_dataset_name = f"{attachment_name} + {dataset_name}"
elif attachment_name:
generation_dataset_name = attachment_name
async for chunk in generate_stream(
query=query,
context_chunks=context_chunks,
conversation_id=conversation_id,
message_id=message_id,
system_prompt=app.get("system_prompt") or "",
model=app.get("llm_model") or "",
temperature=app.get("temperature"),
max_tokens=app.get("max_tokens"),
dataset_name=generation_dataset_name,
task_id=task_id,
):
data = self._parse_sse_event(chunk)
if not data:
continue
event = data.get("event")
if event == "message":
collected_answer += data.get("answer", "")
now = time.monotonic()
if len(collected_answer) - last_persisted_length >= 80 or now - last_persisted_at >= 0.5:
await self._update_message_progress(
conversation_id=conversation_id,
message_id=message_id,
content=collected_answer,
metadata={"status": "running", "task_id": task_id},
)
last_persisted_length = len(collected_answer)
last_persisted_at = now
await self._append_task_event(task_id, data)
continue
if event == "message_end":
message_end_payload = data
continue
if event == "error":
final_status = "error"
error_payload = data
await self._append_task_event(task_id, data)
continue
await self._append_task_event(task_id, data)
if final_status == "completed":
followups: list[str] = []
try:
followups = await generate_followups(query, collected_answer)
except Exception:
followups = []
sources = self._build_sources(context_chunks, dataset_name)
if message_end_payload:
message_end_metadata = message_end_payload.setdefault("metadata", {})
message_end_metadata["suggested_questions"] = followups
message_end_metadata["retriever_resources"] = sources
await self._append_task_event(task_id, message_end_payload)
await self._finalize_message_record(
conversation_id=conversation_id,
message_id=message_id,
content=collected_answer,
sources=sources,
metadata={"suggested_questions": followups, "status": "completed", "task_id": task_id},
)
await self._maybe_schedule_auto_title(
conversation_id=conversation_id,
message_id=message_id,
query=query,
answer=collected_answer,
)
else:
await self._finalize_message_record(
conversation_id=conversation_id,
message_id=message_id,
content=collected_answer,
sources=self._build_sources(context_chunks, dataset_name),
metadata={
"suggested_questions": [],
"status": "error",
"task_id": task_id,
"error": (error_payload or {}).get("message", ""),
},
)
except asyncio.CancelledError:
final_status = "stopped"
await self._append_task_event(
task_id,
{
"event": "error",
"task_id": task_id,
"message_id": message_id,
"code": "message_stopped",
"message": "用户已停止回答",
},
)
await self._finalize_message_record(
conversation_id=conversation_id,
message_id=message_id,
content=collected_answer,
sources=self._build_sources(context_chunks, dataset_name),
metadata={"suggested_questions": [], "status": "stopped", "task_id": task_id},
)
raise
except Exception as exc:
final_status = "error"
await self._append_task_event(
task_id,
{
"event": "error",
"task_id": task_id,
"message_id": message_id,
"code": "server_error",
"message": str(exc),
},
)
await self._finalize_message_record(
conversation_id=conversation_id,
message_id=message_id,
content=collected_answer,
sources=self._build_sources(context_chunks, dataset_name),
metadata={"suggested_questions": [], "status": "error", "task_id": task_id, "error": str(exc)},
)
finally:
self._task_done[task_id] = True
self._message_tasks.pop(task_id, None)
self._task_locks.pop(task_id, None)
async def _append_task_event(self, task_id: str, payload: dict) -> None:
lock = self._task_locks.setdefault(task_id, asyncio.Lock())
async with lock:
self._task_events.setdefault(task_id, []).append(payload)
async def _finalize_message_record(
self,
*,
conversation_id: str,
message_id: str,
content: str,
sources: list[dict],
metadata: dict,
) -> None:
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
UPDATE rag_message
SET content = :content,
sources = CAST(:sources AS jsonb),
metadata = CAST(:metadata AS jsonb)
WHERE message_id = :message_id
AND conversation_id = :conversation_id
AND role = 'assistant'
"""
),
{
"conversation_id": conversation_id,
"message_id": message_id,
"content": content,
"sources": json.dumps(sources, ensure_ascii=False),
"metadata": json.dumps(metadata, ensure_ascii=False),
},
)
await session.execute(
text(
"""
UPDATE rag_conversation
SET updated_at = NOW(),
last_message_at = NOW()
WHERE conversation_id = :conversation_id
"""
),
{"conversation_id": conversation_id},
)
async def _update_message_progress(
self,
*,
conversation_id: str,
message_id: str,
content: str,
metadata: dict,
) -> None:
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
UPDATE rag_message
SET content = :content,
metadata = CAST(:metadata AS jsonb)
WHERE message_id = :message_id
AND conversation_id = :conversation_id
AND role = 'assistant'
"""
),
{
"conversation_id": conversation_id,
"message_id": message_id,
"content": content,
"metadata": json.dumps(metadata, ensure_ascii=False),
},
)
async def _maybe_schedule_auto_title(
self,
*,
conversation_id: str,
message_id: str,
query: str,
answer: str,
) -> None:
normalized_query = (query or "").strip()
normalized_answer = (answer or "").strip()
if not normalized_query or not normalized_answer:
return
async with GetAsyncSession() as session:
row = (
await session.execute(
text(
"""
SELECT conversation_id,
name,
COALESCE(title_source, 'default') AS title_source,
COALESCE(title_generation_status, 'idle') AS title_generation_status,
first_answer_message_id
FROM rag_conversation
WHERE conversation_id = :conversation_id
AND deleted_at IS NULL
LIMIT 1
"""
),
{"conversation_id": conversation_id},
)
).mappings().first()
if not row:
return
title_source = str(row.get("title_source") or DEFAULT_TITLE_SOURCE)
if title_source == MANUAL_TITLE_SOURCE:
return
current_name = str(row.get("name") or "").strip()
if current_name and current_name != DEFAULT_CONVERSATION_NAME and title_source != DEFAULT_TITLE_SOURCE:
return
current_status = str(row.get("title_generation_status") or "idle")
if current_status in {"pending", "running", "succeeded"}:
return
if row.get("first_answer_message_id") and str(row.get("first_answer_message_id")) != message_id:
return
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
UPDATE rag_conversation
SET title_generation_status = 'pending',
first_answer_message_id = COALESCE(first_answer_message_id, :message_id),
title_generation_error = NULL
WHERE conversation_id = :conversation_id
AND deleted_at IS NULL
AND COALESCE(title_source, 'default') = 'default'
"""
),
{
"conversation_id": conversation_id,
"message_id": message_id,
},
)
existing_task = self._title_tasks.get(conversation_id)
if existing_task and not existing_task.done():
return
task = asyncio.create_task(
self._run_auto_title_task(
conversation_id=conversation_id,
answer_message_id=message_id,
query=normalized_query,
answer=normalized_answer,
)
)
self._title_tasks[conversation_id] = task
async def _run_auto_title_task(
self,
*,
conversation_id: str,
answer_message_id: str,
query: str,
answer: str,
) -> None:
try:
async with GetAsyncSession() as session:
row = (
await session.execute(
text(
"""
SELECT name,
COALESCE(title_source, 'default') AS title_source,
COALESCE(title_generation_status, 'idle') AS title_generation_status,
first_answer_message_id
FROM rag_conversation
WHERE conversation_id = :conversation_id
AND deleted_at IS NULL
LIMIT 1
"""
),
{"conversation_id": conversation_id},
)
).mappings().first()
if not row:
return
if str(row.get("title_source") or DEFAULT_TITLE_SOURCE) == MANUAL_TITLE_SOURCE:
return
if row.get("first_answer_message_id") and str(row.get("first_answer_message_id")) != answer_message_id:
return
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
UPDATE rag_conversation
SET title_generation_status = 'running',
title_generation_error = NULL
WHERE conversation_id = :conversation_id
AND deleted_at IS NULL
AND COALESCE(title_source, 'default') = 'default'
"""
),
{"conversation_id": conversation_id},
)
generated_title = await self._generate_conversation_title(query=query, answer=answer)
cleaned_title = self._sanitize_generated_title(generated_title)
if not cleaned_title:
cleaned_title = self._build_fallback_title(query=query, answer=answer)
if not cleaned_title:
raise ValueError("未生成有效标题")
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
UPDATE rag_conversation
SET name = :name,
title_source = 'auto',
title_generation_status = 'succeeded',
title_generated_at = NOW(),
title_generation_error = NULL
WHERE conversation_id = :conversation_id
AND deleted_at IS NULL
AND COALESCE(title_source, 'default') = 'default'
AND (
name IS NULL
OR BTRIM(name) = ''
OR name = :default_name
)
"""
),
{
"conversation_id": conversation_id,
"name": cleaned_title,
"default_name": DEFAULT_CONVERSATION_NAME,
},
)
except Exception as exc:
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
UPDATE rag_conversation
SET title_generation_status = 'failed',
title_generation_error = :error
WHERE conversation_id = :conversation_id
AND deleted_at IS NULL
AND COALESCE(title_source, 'default') = 'default'
"""
),
{
"conversation_id": conversation_id,
"error": str(exc)[:1000],
},
)
finally:
self._title_tasks.pop(conversation_id, None)
async def _generate_conversation_title(self, *, query: str, answer: str) -> str:
prompt = (
"请基于用户首轮提问和助手首轮回答,生成一个简洁、准确的中文会话标题。"
"要求:"
"1. 只输出标题本身;"
"2. 不要标点结尾;"
"3. 不要出现“关于”“用户询问”“问题解答”等空话;"
"4. 优先 12-24 个中文字符,最长不超过 40 个字符。\\n"
f"用户问题:{query[:500]}\\n"
f"助手回答:{answer[:1500]}"
)
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
build_openai_chat_completions_url(RAG_CONFIG["LLM_BASE_URL"]),
json={
"model": RAG_CONFIG["LLM_MODEL"],
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.2,
"max_tokens": 80,
"chat_template_kwargs": {"enable_thinking": False},
},
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {RAG_CONFIG['LLM_API_KEY']}",
},
)
resp.raise_for_status()
content = resp.json()["choices"][0]["message"]["content"]
return str(content or "").strip()
def _sanitize_generated_title(self, value: str) -> str:
title = str(value or "").strip()
if not title:
return ""
title = title.replace("\r", " ").replace("\n", " ").strip()
title = re.sub(r"^```[\w-]*", "", title).strip()
title = title.replace("```", "").strip()
title = re.sub(r'^[\"“”\'‘’\[\]()()【】]+', "", title)
title = re.sub(r'[\"“”\'‘’\[\]()()【】]+$', "", title)
title = re.sub(r"^(标题|会话标题)[:]\\s*", "", title)
title = re.sub(r"\\s+", " ", title).strip()
title = title.rstrip("。!?!?.;,:")
if title in {"新对话", "会话标题", "标题"}:
return ""
if len(title) > 40:
title = title[:40].rstrip(",。;;: ")
return title
def _build_fallback_title(self, *, query: str, answer: str) -> str:
base = query.strip() or answer.strip()
if not base:
return ""
base = re.sub(r"\\s+", " ", base)
base = re.sub(r"^[请问帮我一下关于针对结合根据请解释说明一下::,,\\s]+", "", base)
base = base.rstrip("。!?!?.;,:")
if len(base) > 24:
base = base[:24].rstrip(",。;;: ")
return base
async def _keyword_retrieve_context(
self,
*,
dataset_id: int,
collection_name: str,
dataset_name: str,
query: str,
top_k: int,
score_threshold: float | None,
) -> list[dict]:
chunks = await self.retriever._keyword_retrieve_context(
dataset_id=dataset_id,
collection_name=collection_name,
dataset_name=dataset_name,
query=query,
top_k=top_k,
score_threshold=score_threshold,
source_names=None,
)
return chunks[:top_k]
def _build_keyword_terms(self, query: str) -> list[str]:
return self.retriever._build_keyword_terms(query)
def _normalize_keyword_query(self, query: str) -> str:
return self.retriever._normalize_keyword_query(query)
def _score_keyword_chunk(self, *, query: str, terms: list[str], content: str, document_name: str) -> float:
return self.retriever._score_keyword_chunk(
query=query,
terms=terms,
content=content,
document_name=document_name,
)
def _format_sse(self, payload: dict) -> bytes:
return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n".encode("utf-8")
def _build_sources(self, context_chunks: list[dict], dataset_name: str) -> list[dict]:
build_sources = getattr(self.retriever, "build_sources", None)
if callable(build_sources):
return build_sources(context_chunks, dataset_name)
return RagRetriever(hydrate_documents=False).build_sources(context_chunks, dataset_name)
async def _hydrate_document_hits(self, dataset_id: int, chunks: list[dict]) -> list[dict]:
return await self.retriever._hydrate_document_hits(dataset_id, chunks)
def _parse_sse_event(self, chunk: str) -> dict | None:
data_lines: list[str] = []
for line in chunk.splitlines():
if line.startswith("data: "):
data_lines.append(line[6:])
elif line.startswith("data:"):
data_lines.append(line[5:].lstrip())
if not data_lines:
return None
payload = "\n".join(part for part in data_lines if part.strip()).strip()
if not payload or payload == "[DONE]":
return None
payload = payload.removesuffix("\\n\\n").removesuffix("\\n").strip()
try:
data = json.loads(payload)
except json.JSONDecodeError:
return None
return data if isinstance(data, dict) else None