feat: restore rag dataset management and linkage

This commit is contained in:
wren
2026-05-11 17:21:33 +08:00
parent da2bb8310d
commit dcc0f3c30d
6 changed files with 2208 additions and 46 deletions
@@ -80,7 +80,7 @@ class RagChatServiceImpl(IRagChatService):
context_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), Query)
collected_answer = ""
held_message_end: bytes | None = None
held_message_end: dict | None = None
async for chunk in generate_stream(
query=Query,
@@ -94,17 +94,21 @@ class RagChatServiceImpl(IRagChatService):
dataset_name=dataset_name,
):
chunk_bytes = chunk.encode("utf-8")
for line in chunk.strip().split("\n"):
if not line.startswith("data: "):
continue
data = json.loads(line[6:])
if data.get("event") == "message":
collected_answer += data.get("answer", "")
elif data.get("event") == "message_end":
held_message_end = chunk_bytes
continue
if held_message_end is None:
data = self._parse_sse_event(chunk)
if not data:
yield chunk_bytes
continue
if data.get("event") == "message":
collected_answer += data.get("answer", "")
yield chunk_bytes
continue
if data.get("event") == "message_end":
held_message_end = data
continue
yield chunk_bytes
followups: list[str] = []
try:
@@ -114,15 +118,10 @@ class RagChatServiceImpl(IRagChatService):
if held_message_end:
try:
for line in held_message_end.decode("utf-8").strip().split("\n"):
if not line.startswith("data: "):
continue
end_data = json.loads(line[6:])
if end_data.get("event") == "message_end":
end_data.setdefault("metadata", {})["suggested_questions"] = followups
yield f"data: {json.dumps(end_data, ensure_ascii=False)}\\n\\n".encode("utf-8")
held_message_end.setdefault("metadata", {})["suggested_questions"] = followups
yield f"data: {json.dumps(held_message_end, ensure_ascii=False)}\n\n".encode("utf-8")
except Exception:
yield held_message_end
yield f"data: {json.dumps(held_message_end, ensure_ascii=False)}\n\n".encode("utf-8")
async with GetAsyncSession() as session:
async with session.begin():
@@ -158,7 +157,7 @@ class RagChatServiceImpl(IRagChatService):
FROM rag_conversation
WHERE user_id = :user_id
AND deleted_at IS NULL
AND (:app_id IS NULL OR app_id = :app_id)
AND (CAST(:app_id AS BIGINT) IS NULL OR app_id = CAST(:app_id AS BIGINT))
ORDER BY updated_at DESC
OFFSET :offset LIMIT :limit
"""
@@ -277,11 +276,10 @@ class RagChatServiceImpl(IRagChatService):
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "消息不存在")
if int(owner) != CurrentUserId:
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权修改该消息反馈")
async with session.begin():
await session.execute(
text("UPDATE rag_message SET feedback = :feedback WHERE message_id = :message_id"),
{"feedback": Body.rating, "message_id": MessageId},
)
await session.execute(
text("UPDATE rag_message SET feedback = :feedback WHERE message_id = :message_id"),
{"feedback": Body.rating, "message_id": MessageId},
)
return RagOperationResultVO(result="success")
async def GetAppParameters(
@@ -587,3 +585,26 @@ class RagChatServiceImpl(IRagChatService):
)
return visible_chunks
def _parse_sse_event(self, chunk: str) -> dict | None:
data_lines: list[str] = []
for line in chunk.splitlines():
if line.startswith("data: "):
data_lines.append(line[6:])
elif line.startswith("data:"):
data_lines.append(line[5:].lstrip())
if not data_lines:
return None
payload = "\n".join(part for part in data_lines if part.strip()).strip()
if not payload or payload == "[DONE]":
return None
payload = payload.removesuffix("\\n\\n").removesuffix("\\n").strip()
try:
data = json.loads(payload)
except json.JSONDecodeError:
return None
return data if isinstance(data, dict) else None