677 lines
29 KiB
Python
677 lines
29 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import uuid
|
|
from typing import AsyncGenerator
|
|
|
|
import httpx
|
|
from sqlalchemy import text
|
|
|
|
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
|
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
|
|
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
|
|
|
|
from fastapi_modules.fastapi_leaudit.domian.Dto.ragChatDto import (
|
|
RagConversationRenameDTO,
|
|
RagMessageFeedbackDTO,
|
|
)
|
|
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import (
|
|
RagAppParametersVO,
|
|
RagChatAppListVO,
|
|
RagChatAppVO,
|
|
RagConversationItemVO,
|
|
RagConversationPageVO,
|
|
RagConversationRenameVO,
|
|
RagMessageItemVO,
|
|
RagMessagePageVO,
|
|
RagOperationResultVO,
|
|
)
|
|
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG
|
|
from fastapi_modules.fastapi_leaudit.rag_engine.generator import generate_stream
|
|
from fastapi_modules.fastapi_leaudit.rag_engine.question_chains import generate_followups
|
|
from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService
|
|
|
|
|
|
class RagChatServiceImpl(IRagChatService):
|
|
async def GetApps(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagChatAppListVO:
|
|
apps = await self._load_apps(UserArea, UserRole, only_default=False)
|
|
return RagChatAppListVO(data=apps, total=len(apps))
|
|
|
|
async def GetDefaultApp(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagChatAppVO | None:
|
|
apps = await self._load_apps(UserArea, UserRole, only_default=True)
|
|
if apps:
|
|
return apps[0]
|
|
all_apps = await self._load_apps(UserArea, UserRole, only_default=False)
|
|
return all_apps[0] if all_apps else None
|
|
|
|
async def SendMessage(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserName: str,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
Query: str,
|
|
ConversationId: str | None,
|
|
AppId: int | None,
|
|
) -> AsyncGenerator[bytes, None]:
|
|
if not Query.strip():
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "问题不能为空")
|
|
|
|
app = await self._resolve_app(AppId, UserArea, UserRole)
|
|
if not app:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "未配置可用聊天应用")
|
|
|
|
conversationId = await self._ensure_conversation(CurrentUserId, ConversationId, app["id"])
|
|
messageId = str(uuid.uuid4())
|
|
|
|
async with GetAsyncSession() as session:
|
|
async with session.begin():
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata)
|
|
VALUES (:message_id, :conversation_id, 'user', :content, '[]'::jsonb, '{}'::jsonb)
|
|
"""
|
|
),
|
|
{
|
|
"message_id": str(uuid.uuid4()),
|
|
"conversation_id": conversationId,
|
|
"content": Query,
|
|
},
|
|
)
|
|
|
|
context_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), Query)
|
|
collected_answer = ""
|
|
held_message_end: dict | None = None
|
|
|
|
async for chunk in generate_stream(
|
|
query=Query,
|
|
context_chunks=context_chunks,
|
|
conversation_id=conversationId,
|
|
message_id=messageId,
|
|
system_prompt=app.get("system_prompt") or "",
|
|
model=app.get("llm_model") or "",
|
|
temperature=app.get("temperature"),
|
|
max_tokens=app.get("max_tokens"),
|
|
dataset_name=dataset_name,
|
|
):
|
|
chunk_bytes = chunk.encode("utf-8")
|
|
data = self._parse_sse_event(chunk)
|
|
if not data:
|
|
yield chunk_bytes
|
|
continue
|
|
|
|
if data.get("event") == "message":
|
|
collected_answer += data.get("answer", "")
|
|
yield chunk_bytes
|
|
continue
|
|
|
|
if data.get("event") == "message_end":
|
|
held_message_end = data
|
|
continue
|
|
|
|
yield chunk_bytes
|
|
|
|
followups: list[str] = []
|
|
try:
|
|
followups = await generate_followups(Query, collected_answer)
|
|
except Exception:
|
|
followups = []
|
|
|
|
if held_message_end:
|
|
try:
|
|
held_message_end.setdefault("metadata", {})["suggested_questions"] = followups
|
|
yield f"data: {json.dumps(held_message_end, ensure_ascii=False)}\n\n".encode("utf-8")
|
|
except Exception:
|
|
yield f"data: {json.dumps(held_message_end, ensure_ascii=False)}\n\n".encode("utf-8")
|
|
|
|
async with GetAsyncSession() as session:
|
|
async with session.begin():
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata)
|
|
VALUES (:message_id, :conversation_id, 'assistant', :content, CAST(:sources AS jsonb), CAST(:metadata AS jsonb))
|
|
"""
|
|
),
|
|
{
|
|
"message_id": messageId,
|
|
"conversation_id": conversationId,
|
|
"content": collected_answer,
|
|
"sources": json.dumps(self._build_sources(context_chunks, dataset_name), ensure_ascii=False),
|
|
"metadata": json.dumps({"suggested_questions": followups}, ensure_ascii=False),
|
|
},
|
|
)
|
|
await session.execute(
|
|
text(
|
|
"UPDATE rag_conversation SET updated_at = NOW() WHERE conversation_id = :conversation_id"
|
|
),
|
|
{"conversation_id": conversationId},
|
|
)
|
|
|
|
async def GetConversations(self, CurrentUserId: int, AppId: int | None, Page: int, PageSize: int) -> RagConversationPageVO:
|
|
async with GetAsyncSession() as session:
|
|
rows = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
SELECT conversation_id, name, introduction, created_at, updated_at
|
|
FROM rag_conversation
|
|
WHERE user_id = :user_id
|
|
AND deleted_at IS NULL
|
|
AND (CAST(:app_id AS BIGINT) IS NULL OR app_id = CAST(:app_id AS BIGINT))
|
|
ORDER BY updated_at DESC
|
|
OFFSET :offset LIMIT :limit
|
|
"""
|
|
),
|
|
{
|
|
"user_id": CurrentUserId,
|
|
"app_id": AppId,
|
|
"offset": max(Page - 1, 0) * PageSize,
|
|
"limit": PageSize + 1,
|
|
},
|
|
)
|
|
).mappings().all()
|
|
has_more = len(rows) > PageSize
|
|
items = rows[:PageSize]
|
|
return RagConversationPageVO(
|
|
data=[
|
|
RagConversationItemVO(
|
|
id=row["conversation_id"],
|
|
name=row["name"],
|
|
introduction=row.get("introduction") or "",
|
|
createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0,
|
|
updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0,
|
|
)
|
|
for row in items
|
|
],
|
|
hasMore=has_more,
|
|
limit=PageSize,
|
|
)
|
|
|
|
async def GetConversationMessages(self, CurrentUserId: int, ConversationId: str, Page: int, PageSize: int) -> RagMessagePageVO:
|
|
await self._ensure_conversation_owner(CurrentUserId, ConversationId)
|
|
async with GetAsyncSession() as session:
|
|
rows = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
SELECT message_id, role, content, sources, metadata, feedback, created_at
|
|
FROM rag_message
|
|
WHERE conversation_id = :conversation_id
|
|
ORDER BY created_at ASC
|
|
OFFSET :offset LIMIT :limit
|
|
"""
|
|
),
|
|
{
|
|
"conversation_id": ConversationId,
|
|
"offset": max(Page - 1, 0) * PageSize,
|
|
"limit": PageSize + 1,
|
|
},
|
|
)
|
|
).mappings().all()
|
|
has_more = len(rows) > PageSize
|
|
items = rows[:PageSize]
|
|
data: list[RagMessageItemVO] = []
|
|
idx = 0
|
|
while idx < len(items):
|
|
row = items[idx]
|
|
if row["role"] == "user":
|
|
answer = items[idx + 1] if idx + 1 < len(items) and items[idx + 1]["role"] == "assistant" else None
|
|
answer_sources = self._parse_json_field(answer.get("sources")) if answer else []
|
|
answer_metadata = self._parse_json_field(answer.get("metadata")) if answer else {}
|
|
suggested_questions = answer_metadata.get("suggested_questions") if isinstance(answer_metadata, dict) else []
|
|
if not isinstance(suggested_questions, list):
|
|
suggested_questions = []
|
|
data.append(
|
|
RagMessageItemVO(
|
|
id=(answer["message_id"] if answer else row["message_id"]),
|
|
conversationId=ConversationId,
|
|
query=row["content"],
|
|
answer=answer["content"] if answer else "",
|
|
feedback=({"rating": answer["feedback"]} if answer and answer.get("feedback") else None),
|
|
retrieverResources=answer_sources or None,
|
|
suggestedQuestions=[str(item) for item in suggested_questions],
|
|
createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0,
|
|
)
|
|
)
|
|
idx += 2 if answer else 1
|
|
else:
|
|
idx += 1
|
|
return RagMessagePageVO(data=data, hasMore=has_more, limit=PageSize)
|
|
|
|
async def RenameConversation(self, CurrentUserId: int, ConversationId: str, Body: RagConversationRenameDTO) -> RagConversationRenameVO:
|
|
await self._ensure_conversation_owner(CurrentUserId, ConversationId)
|
|
async with GetAsyncSession() as session:
|
|
async with session.begin():
|
|
await session.execute(
|
|
text(
|
|
"UPDATE rag_conversation SET name = :name, updated_at = NOW() WHERE conversation_id = :conversation_id"
|
|
),
|
|
{"name": Body.name, "conversation_id": ConversationId},
|
|
)
|
|
return RagConversationRenameVO(result="success", name=Body.name)
|
|
|
|
async def DeleteConversation(self, CurrentUserId: int, ConversationId: str) -> RagOperationResultVO:
|
|
await self._ensure_conversation_owner(CurrentUserId, ConversationId)
|
|
async with GetAsyncSession() as session:
|
|
async with session.begin():
|
|
await session.execute(
|
|
text(
|
|
"UPDATE rag_conversation SET deleted_at = NOW(), updated_at = NOW() WHERE conversation_id = :conversation_id"
|
|
),
|
|
{"conversation_id": ConversationId},
|
|
)
|
|
return RagOperationResultVO(result="success")
|
|
|
|
async def UpdateFeedback(self, CurrentUserId: int, MessageId: str, Body: RagMessageFeedbackDTO) -> RagOperationResultVO:
|
|
async with GetAsyncSession() as session:
|
|
owner = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
SELECT c.user_id
|
|
FROM rag_message m
|
|
JOIN rag_conversation c ON c.conversation_id = m.conversation_id
|
|
WHERE m.message_id = :message_id AND c.deleted_at IS NULL
|
|
LIMIT 1
|
|
"""
|
|
),
|
|
{"message_id": MessageId},
|
|
)
|
|
).scalar_one_or_none()
|
|
if owner is None:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "消息不存在")
|
|
if int(owner) != CurrentUserId:
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权修改该消息反馈")
|
|
await session.execute(
|
|
text("UPDATE rag_message SET feedback = :feedback WHERE message_id = :message_id"),
|
|
{"feedback": Body.rating, "message_id": MessageId},
|
|
)
|
|
return RagOperationResultVO(result="success")
|
|
|
|
async def GetAppParameters(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
AppId: int | None,
|
|
) -> RagAppParametersVO:
|
|
app = await self._resolve_app(AppId, UserArea, UserRole)
|
|
if not app:
|
|
return RagAppParametersVO()
|
|
try:
|
|
suggested = json.loads(app.get("suggested_questions") or "[]")
|
|
if not isinstance(suggested, list):
|
|
suggested = []
|
|
except Exception:
|
|
suggested = []
|
|
return RagAppParametersVO(
|
|
openingStatement=app.get("opening_statement") or "",
|
|
suggestedQuestions=[str(item) for item in suggested[:6]],
|
|
userInputForm=[],
|
|
fileUpload={"image": {"enabled": False}},
|
|
)
|
|
|
|
async def _load_apps(self, user_area: str | None, user_role: str | None, only_default: bool) -> list[RagChatAppVO]:
|
|
async with GetAsyncSession() as session:
|
|
sql = (
|
|
"""
|
|
SELECT a.id, a.name, a.description, a.is_default
|
|
FROM rag_chat_app a
|
|
LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL
|
|
WHERE a.deleted_at IS NULL
|
|
AND a.status = 1
|
|
AND (:only_default = FALSE OR a.is_default = TRUE)
|
|
AND (
|
|
:is_provincial = TRUE
|
|
OR a.area IN (:user_area, '省级', '')
|
|
OR COALESCE(d.is_public, FALSE) = TRUE
|
|
)
|
|
ORDER BY a.sort_order ASC, a.created_at DESC
|
|
"""
|
|
)
|
|
rows = (
|
|
await session.execute(
|
|
text(sql),
|
|
{
|
|
"only_default": only_default,
|
|
"is_provincial": user_role == "provincial_admin",
|
|
"user_area": user_area or "",
|
|
},
|
|
)
|
|
).mappings().all()
|
|
return [
|
|
RagChatAppVO(
|
|
appId=str(row["id"]),
|
|
appName=row["name"],
|
|
description=row.get("description") or "",
|
|
isDefault=bool(row.get("is_default")),
|
|
)
|
|
for row in rows
|
|
]
|
|
|
|
async def _resolve_app(self, app_id: int | None, user_area: str | None, user_role: str | None) -> dict | None:
|
|
async with GetAsyncSession() as session:
|
|
params = {
|
|
"app_id": app_id,
|
|
"user_area": user_area or "",
|
|
"is_provincial": user_role == "provincial_admin",
|
|
}
|
|
base_sql = (
|
|
"""
|
|
SELECT a.id, a.name, a.description, a.area, a.dataset_id, a.system_prompt,
|
|
a.llm_model, a.temperature, a.max_tokens, a.opening_statement,
|
|
a.suggested_questions, a.is_default, COALESCE(d.is_public, FALSE) AS dataset_public,
|
|
COALESCE(d.name, '') AS dataset_name
|
|
FROM rag_chat_app a
|
|
LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL
|
|
WHERE a.deleted_at IS NULL AND a.status = 1
|
|
"""
|
|
)
|
|
if app_id is not None:
|
|
row = (
|
|
await session.execute(
|
|
text(base_sql + " AND a.id = :app_id LIMIT 1"),
|
|
params,
|
|
)
|
|
).mappings().first()
|
|
if row and self._app_visible(row, user_area, user_role):
|
|
return dict(row)
|
|
row = (
|
|
await session.execute(
|
|
text(base_sql + " AND a.is_default = TRUE ORDER BY a.sort_order ASC, a.created_at DESC LIMIT 1"),
|
|
params,
|
|
)
|
|
).mappings().first()
|
|
if row and self._app_visible(row, user_area, user_role):
|
|
return dict(row)
|
|
row = (
|
|
await session.execute(
|
|
text(base_sql + " ORDER BY a.sort_order ASC, a.created_at DESC LIMIT 1"),
|
|
params,
|
|
)
|
|
).mappings().first()
|
|
return dict(row) if row and self._app_visible(row, user_area, user_role) else None
|
|
|
|
def _app_visible(self, row: dict, user_area: str | None, user_role: str | None) -> bool:
|
|
if user_role == "provincial_admin":
|
|
return True
|
|
area = row.get("area") or ""
|
|
return area in ("", "省级", user_area or "") or bool(row.get("dataset_public"))
|
|
|
|
def _parse_json_field(self, value):
|
|
if value is None:
|
|
return {}
|
|
if isinstance(value, (dict, list)):
|
|
return value
|
|
if isinstance(value, str):
|
|
try:
|
|
return json.loads(value)
|
|
except Exception:
|
|
return {}
|
|
return {}
|
|
|
|
async def _ensure_conversation(self, user_id: int, conversation_id: str | None, app_id: int | None) -> str:
|
|
if conversation_id and conversation_id != "-1":
|
|
async with GetAsyncSession() as session:
|
|
row = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
SELECT conversation_id, user_id
|
|
FROM rag_conversation
|
|
WHERE conversation_id = :conversation_id
|
|
AND deleted_at IS NULL
|
|
LIMIT 1
|
|
"""
|
|
),
|
|
{"conversation_id": conversation_id},
|
|
)
|
|
).mappings().first()
|
|
if row:
|
|
if int(row["user_id"]) != user_id:
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权使用该会话")
|
|
return str(row["conversation_id"])
|
|
conversation_id = str(uuid.uuid4())
|
|
async with GetAsyncSession() as session:
|
|
async with session.begin():
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO rag_conversation (conversation_id, user_id, app_id, name, introduction)
|
|
VALUES (:conversation_id, :user_id, :app_id, '新对话', '')
|
|
"""
|
|
),
|
|
{"conversation_id": conversation_id, "user_id": user_id, "app_id": app_id},
|
|
)
|
|
return conversation_id
|
|
|
|
async def _ensure_conversation_owner(self, user_id: int, conversation_id: str) -> None:
|
|
async with GetAsyncSession() as session:
|
|
owner = (
|
|
await session.execute(
|
|
text(
|
|
"SELECT user_id FROM rag_conversation WHERE conversation_id = :conversation_id AND deleted_at IS NULL LIMIT 1"
|
|
),
|
|
{"conversation_id": conversation_id},
|
|
)
|
|
).scalar_one_or_none()
|
|
if owner is None:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "会话不存在")
|
|
if int(owner) != user_id:
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权访问该会话")
|
|
|
|
async def _retrieve_context(self, dataset_id: int | None, query: str) -> tuple[list[dict], str]:
|
|
if not dataset_id:
|
|
return [], ""
|
|
async with GetAsyncSession() as session:
|
|
dataset = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
SELECT id, name, collection_name, retrieval_model, embedding_model
|
|
FROM rag_dataset
|
|
WHERE id = :dataset_id AND deleted_at IS NULL
|
|
LIMIT 1
|
|
"""
|
|
),
|
|
{"dataset_id": dataset_id},
|
|
)
|
|
).mappings().first()
|
|
if not dataset:
|
|
return [], ""
|
|
retrieval_model = dataset.get("retrieval_model") or {}
|
|
top_k = int(retrieval_model.get("top_k") or 5)
|
|
score_threshold = None
|
|
if retrieval_model.get("score_threshold_enabled"):
|
|
try:
|
|
score_threshold = float(retrieval_model.get("score_threshold"))
|
|
except (TypeError, ValueError):
|
|
score_threshold = None
|
|
try:
|
|
from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma
|
|
except Exception:
|
|
return [], dataset.get("name") or ""
|
|
try:
|
|
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
|
query_embedding = await self._embed_texts([query], dataset.get("embedding_model") or "")
|
|
result = collection.query(
|
|
query_embeddings=query_embedding,
|
|
n_results=max(top_k, 1),
|
|
include=["documents", "metadatas", "distances"],
|
|
)
|
|
docs = (result.get("documents") or [[]])[0]
|
|
metas = (result.get("metadatas") or [[]])[0]
|
|
distances = (result.get("distances") or [[]])[0]
|
|
chunks: list[dict] = []
|
|
for idx, doc in enumerate(docs):
|
|
meta = metas[idx] if idx < len(metas) else {}
|
|
dist = distances[idx] if idx < len(distances) else 0.0
|
|
distance = max(0.0, float(dist or 0.0))
|
|
score = 1.0 / (1.0 + distance)
|
|
if score_threshold is not None and score < score_threshold:
|
|
continue
|
|
chunks.append(
|
|
{
|
|
"id": str(meta.get("id") or idx),
|
|
"text": doc,
|
|
"source": meta.get("source") or meta.get("document_name") or dataset.get("name") or "",
|
|
"score": score,
|
|
"chunk_index": idx,
|
|
"document_name": meta.get("document_name") or meta.get("source") or "",
|
|
}
|
|
)
|
|
chunks = await self._hydrate_document_hits(dataset_id, chunks)
|
|
return chunks[:top_k], dataset.get("name") or ""
|
|
except Exception:
|
|
return [], dataset.get("name") or ""
|
|
|
|
async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]:
|
|
embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}/embeddings"
|
|
embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"]
|
|
embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4"
|
|
batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10))
|
|
if not embed_url or not embed_key:
|
|
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务")
|
|
|
|
embeddings: list[list[float]] = []
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
for start in range(0, len(texts), batch_size):
|
|
batch_texts = texts[start:start + batch_size]
|
|
try:
|
|
response = await client.post(
|
|
embed_url,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {embed_key}",
|
|
},
|
|
json={"model": embed_model, "input": batch_texts},
|
|
)
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}"
|
|
raise LeauditException(
|
|
StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
f"向量化服务调用失败: {error_message[:300]}",
|
|
) from exc
|
|
|
|
payload = response.json()
|
|
rows = payload.get("data") or []
|
|
batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")]
|
|
if len(batch_embeddings) != len(batch_texts):
|
|
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
|
|
embeddings.extend(batch_embeddings)
|
|
|
|
if len(embeddings) != len(texts):
|
|
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
|
|
return embeddings
|
|
|
|
def _build_sources(self, context_chunks: list[dict], dataset_name: str) -> list[dict]:
|
|
return [
|
|
{
|
|
"position": index + 1,
|
|
"dataset_id": str(chunk.get("dataset_id") or ""),
|
|
"dataset_name": dataset_name,
|
|
"document_id": str(chunk.get("document_id") or ""),
|
|
"document_name": chunk.get("document_name") or chunk.get("source", ""),
|
|
"data_source_type": "upload_file",
|
|
"segment_id": chunk.get("id", ""),
|
|
"retriever_from": "rag",
|
|
"score": round(chunk.get("score", 0.0), 4),
|
|
"hit_count": chunk.get("hit_count", 0),
|
|
"word_count": len(chunk.get("text", "")),
|
|
"segment_position": index + 1,
|
|
"index_node_hash": "",
|
|
"content": chunk.get("text", "")[:500],
|
|
"page": None,
|
|
}
|
|
for index, chunk in enumerate(context_chunks)
|
|
]
|
|
|
|
async def _hydrate_document_hits(self, dataset_id: int, chunks: list[dict]) -> list[dict]:
|
|
source_names = sorted(
|
|
{
|
|
str(chunk.get("document_name") or chunk.get("source") or "").strip()
|
|
for chunk in chunks
|
|
if str(chunk.get("document_name") or chunk.get("source") or "").strip()
|
|
}
|
|
)
|
|
if not source_names:
|
|
return chunks
|
|
|
|
async with GetAsyncSession() as session:
|
|
rows = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
SELECT id, original_name, enabled, hit_count
|
|
FROM rag_document
|
|
WHERE dataset_id = :dataset_id
|
|
AND deleted_at IS NULL
|
|
AND original_name = ANY(:source_names)
|
|
"""
|
|
),
|
|
{
|
|
"dataset_id": dataset_id,
|
|
"source_names": source_names,
|
|
},
|
|
)
|
|
).mappings().all()
|
|
|
|
document_map = {str(row["original_name"]): row for row in rows}
|
|
visible_chunks: list[dict] = []
|
|
hit_document_ids: list[int] = []
|
|
for chunk in chunks:
|
|
source_name = str(chunk.get("document_name") or chunk.get("source") or "").strip()
|
|
document = document_map.get(source_name)
|
|
if document and not bool(document.get("enabled")):
|
|
continue
|
|
if document:
|
|
chunk["document_id"] = document["id"]
|
|
chunk["dataset_id"] = dataset_id
|
|
chunk["document_name"] = document["original_name"]
|
|
chunk["hit_count"] = document.get("hit_count") or 0
|
|
hit_document_ids.append(int(document["id"]))
|
|
visible_chunks.append(chunk)
|
|
|
|
if hit_document_ids:
|
|
async with GetAsyncSession() as session:
|
|
async with session.begin():
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
UPDATE rag_document
|
|
SET hit_count = hit_count + 1,
|
|
updated_at = NOW()
|
|
WHERE id = ANY(:document_ids)
|
|
"""
|
|
),
|
|
{"document_ids": sorted(set(hit_document_ids))},
|
|
)
|
|
|
|
return visible_chunks
|
|
|
|
def _parse_sse_event(self, chunk: str) -> dict | None:
|
|
data_lines: list[str] = []
|
|
for line in chunk.splitlines():
|
|
if line.startswith("data: "):
|
|
data_lines.append(line[6:])
|
|
elif line.startswith("data:"):
|
|
data_lines.append(line[5:].lstrip())
|
|
|
|
if not data_lines:
|
|
return None
|
|
|
|
payload = "\n".join(part for part in data_lines if part.strip()).strip()
|
|
if not payload or payload == "[DONE]":
|
|
return None
|
|
payload = payload.removesuffix("\\n\\n").removesuffix("\\n").strip()
|
|
|
|
try:
|
|
data = json.loads(payload)
|
|
except json.JSONDecodeError:
|
|
return None
|
|
|
|
return data if isinstance(data, dict) else None
|