feat(rag): add temporary chat attachments

This commit is contained in:
wren
2026-05-25 15:37:37 +08:00
parent 0f385c9839
commit 75c077da77
16 changed files with 2257 additions and 16 deletions
@@ -106,6 +106,8 @@ class RagChatServiceImpl(IRagChatService):
Query: str,
ConversationId: str | None,
AppId: int | None,
AttachmentId: str | None = None,
AttachmentIds: list[str] | None = None,
TenantCode: str | None = None,
TenantName: str | None = None,
) -> AsyncGenerator[bytes, None]:
@@ -128,6 +130,25 @@ class RagChatServiceImpl(IRagChatService):
messageId = str(uuid.uuid4())
taskId = str(uuid.uuid4())
is_new_conversation = not ConversationId or ConversationId == "-1"
active_attachment_ids = self._normalize_attachment_ids(
attachment_id=AttachmentId,
attachment_ids=AttachmentIds,
)
if not active_attachment_ids:
active_attachment_ids = await self._resolve_attachment_ids_for_conversation(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=conversationId,
)
attachment_records = await self._load_message_attachment_records(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=conversationId,
AttachmentIds=active_attachment_ids,
)
user_message_files = self._build_user_message_files(attachment_records)
async with GetAsyncSession() as session:
async with session.begin():
@@ -135,13 +156,14 @@ class RagChatServiceImpl(IRagChatService):
text(
"""
INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata)
VALUES (:message_id, :conversation_id, 'user', :content, '[]'::jsonb, '{}'::jsonb)
VALUES (:message_id, :conversation_id, 'user', :content, '[]'::jsonb, CAST(:metadata AS jsonb))
"""
),
{
"message_id": str(uuid.uuid4()),
"conversation_id": conversationId,
"content": Query,
"metadata": json.dumps({"message_files": user_message_files}, ensure_ascii=False),
},
)
await session.execute(
@@ -170,6 +192,11 @@ class RagChatServiceImpl(IRagChatService):
message_id=messageId,
query=Query,
app=app,
current_user_id=CurrentUserId,
tenant_code=TenantCode,
user_area=UserArea,
attachment_id=AttachmentId,
attachment_ids=active_attachment_ids,
)
event_index = 0
@@ -322,6 +349,7 @@ class RagChatServiceImpl(IRagChatService):
row = items[idx]
if row["role"] == "user":
answer = items[idx + 1] if idx + 1 < len(items) and items[idx + 1]["role"] == "assistant" else None
user_metadata = dict(row.get("metadata") or {})
answer_metadata = dict((answer.get("metadata") if answer else None) or {})
answer_status = str(answer_metadata.get("status") or ("completed" if answer else "running"))
answer_content = (answer.get("content") if answer else None) or ""
@@ -355,6 +383,7 @@ class RagChatServiceImpl(IRagChatService):
conversationId=ConversationId,
query=row["content"],
answer=answer_content if answer else "",
messageFiles=[item for item in (user_metadata.get("message_files") or []) if isinstance(item, dict)],
feedback=({"rating": answer["feedback"]} if answer and answer.get("feedback") else None),
retrieverResources=(answer.get("sources") if answer else None),
suggestedQuestions=[str(item) for item in (answer_metadata.get("suggested_questions") or []) if str(item).strip()],
@@ -940,7 +969,181 @@ class RagChatServiceImpl(IRagChatService):
async def _retrieve_context(self, dataset_id: int | None, query: str) -> tuple[list[dict], str]:
result = await self.retriever.retrieve(query=query, dataset_id=dataset_id)
return result.chunks, result.dataset_name
chunks = [
{
**chunk,
"source_scope": chunk.get("source_scope") or "formal_kb",
"data_source_type": chunk.get("data_source_type") or "formal_kb",
}
for chunk in result.chunks
]
return chunks, result.dataset_name
async def _resolve_attachment_id_for_conversation(
self,
*,
CurrentUserId: int,
TenantCode: str | None,
UserArea: str | None,
ConversationId: str,
) -> str | None:
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
service = RagChatAttachmentServiceImpl()
return await service.ResolveActiveAttachmentIdForConversation(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
)
async def _resolve_attachment_ids_for_conversation(
self,
*,
CurrentUserId: int,
TenantCode: str | None,
UserArea: str | None,
ConversationId: str,
) -> list[str]:
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
service = RagChatAttachmentServiceImpl()
if hasattr(service, "ResolveActiveAttachmentIdsForConversation"):
return await service.ResolveActiveAttachmentIdsForConversation(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
)
attachment_id = await service.ResolveActiveAttachmentIdForConversation(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
)
return [attachment_id] if attachment_id else []
async def _retrieve_attachment_context(
self,
*,
CurrentUserId: int,
TenantCode: str | None,
UserArea: str | None,
ConversationId: str,
AttachmentId: str,
Query: str,
) -> tuple[list[dict], str]:
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
service = RagChatAttachmentServiceImpl()
return await service.RetrieveAttachmentContext(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
AttachmentId=AttachmentId,
Query=Query,
TopK=5,
)
async def _load_message_attachment_records(
self,
*,
CurrentUserId: int,
TenantCode: str | None,
UserArea: str | None,
ConversationId: str,
AttachmentIds: list[str],
) -> list[dict]:
if not AttachmentIds:
return []
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
service = RagChatAttachmentServiceImpl()
records: list[dict] = []
for attachment_id in AttachmentIds:
record = await service.ValidateAttachmentForChat(
CurrentUserId=CurrentUserId,
TenantCode=TenantCode,
UserArea=UserArea,
ConversationId=ConversationId,
AttachmentId=attachment_id,
)
records.append(record)
return records
def _build_user_message_files(self, attachment_records: list[dict]) -> list[dict]:
files: list[dict] = []
for record in attachment_records:
attachment_id = str(record.get("attachment_id") or "").strip()
if not attachment_id:
continue
file_name = str(record.get("original_name") or record.get("filename") or "上传文件")
content_type = str(record.get("content_type") or "")
file_type = "image" if content_type.startswith("image/") or re.search(r"\.(png|jpe?g|webp|bmp|tiff?)$", file_name, re.I) else "file"
files.append(
{
"id": attachment_id,
"upload_file_id": attachment_id,
"name": file_name,
"fileName": file_name,
"type": file_type,
"transfer_method": "local_file",
"contentType": content_type or None,
"fileSize": int(record.get("file_size") or 0),
"belongs_to": "user",
"usage": "temporary_attachment",
}
)
return files
def _build_formal_kb_query(self, *, query: str, attachment_chunks: list[dict]) -> str:
if not attachment_chunks:
return query
facts: list[str] = []
for chunk in attachment_chunks[:3]:
text_value = str(chunk.get("text") or "").strip()
if text_value:
facts.append(text_value[:500])
if not facts:
return query
return (
f"{query}\n\n"
"用户上传文档中检索到的相关事实:\n"
+ "\n".join(f"- {item}" for item in facts)
+ "\n\n请检索这些事实对应的法律责任、处罚依据、裁量规则或案例。"
)
def _merge_context_chunks(self, *, attachment_chunks: list[dict], formal_chunks: list[dict]) -> list[dict]:
merged: list[dict] = []
for chunk in attachment_chunks:
merged.append(
{
**chunk,
"source_scope": "chat_attachment",
"data_source_type": "chat_attachment",
}
)
for chunk in formal_chunks:
merged.append(
{
**chunk,
"source_scope": chunk.get("source_scope") or "formal_kb",
"data_source_type": chunk.get("data_source_type") or "formal_kb",
}
)
return merged
def _normalize_attachment_ids(self, *, attachment_id: str | None, attachment_ids: list[str] | None) -> list[str]:
normalized: list[str] = []
for raw in [*(attachment_ids or []), attachment_id]:
value = str(raw or "").strip()
if value and value not in normalized:
normalized.append(value)
return normalized
async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]:
return await self.retriever._embed_texts(texts, model_name)
@@ -953,6 +1156,11 @@ class RagChatServiceImpl(IRagChatService):
message_id: str,
query: str,
app: dict,
current_user_id: int | None = None,
tenant_code: str | None = None,
user_area: str | None = None,
attachment_id: str | None = None,
attachment_ids: list[str] | None = None,
) -> None:
self._task_events[task_id] = []
self._task_done[task_id] = False
@@ -964,6 +1172,11 @@ class RagChatServiceImpl(IRagChatService):
message_id=message_id,
query=query,
app=app,
current_user_id=current_user_id,
tenant_code=tenant_code,
user_area=user_area,
attachment_id=attachment_id,
attachment_ids=attachment_ids,
)
)
self._message_tasks[task_id] = task
@@ -976,6 +1189,11 @@ class RagChatServiceImpl(IRagChatService):
message_id: str,
query: str,
app: dict,
current_user_id: int | None = None,
tenant_code: str | None = None,
user_area: str | None = None,
attachment_id: str | None = None,
attachment_ids: list[str] | None = None,
) -> None:
context_chunks: list[dict] = []
dataset_name = ""
@@ -987,7 +1205,46 @@ class RagChatServiceImpl(IRagChatService):
last_persisted_at = time.monotonic()
try:
context_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), query)
attachment_chunks: list[dict] = []
attachment_names: list[str] = []
active_attachment_ids = self._normalize_attachment_ids(
attachment_id=attachment_id,
attachment_ids=attachment_ids,
)
if not active_attachment_ids and current_user_id is not None:
active_attachment_ids = await self._resolve_attachment_ids_for_conversation(
CurrentUserId=current_user_id,
TenantCode=tenant_code,
UserArea=user_area,
ConversationId=conversation_id,
)
if active_attachment_ids:
if current_user_id is None:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件缺少用户上下文")
for active_attachment_id in active_attachment_ids:
next_chunks, next_attachment_name = await self._retrieve_attachment_context(
CurrentUserId=current_user_id,
TenantCode=tenant_code,
UserArea=user_area,
ConversationId=conversation_id,
AttachmentId=active_attachment_id,
Query=query,
)
attachment_chunks.extend(next_chunks)
if next_attachment_name:
attachment_names.append(next_attachment_name)
legal_query = self._build_formal_kb_query(query=query, attachment_chunks=attachment_chunks)
formal_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), legal_query)
context_chunks = self._merge_context_chunks(
attachment_chunks=attachment_chunks,
formal_chunks=formal_chunks,
)
generation_dataset_name = dataset_name
attachment_name = "".join(dict.fromkeys(attachment_names))
if attachment_name and dataset_name:
generation_dataset_name = f"{attachment_name} + {dataset_name}"
elif attachment_name:
generation_dataset_name = attachment_name
async for chunk in generate_stream(
query=query,
context_chunks=context_chunks,
@@ -997,7 +1254,7 @@ class RagChatServiceImpl(IRagChatService):
model=app.get("llm_model") or "",
temperature=app.get("temperature"),
max_tokens=app.get("max_tokens"),
dataset_name=dataset_name,
dataset_name=generation_dataset_name,
task_id=task_id,
):
data = self._parse_sse_event(chunk)