feat(rag): add temporary chat attachments
This commit is contained in:
@@ -106,6 +106,8 @@ class RagChatServiceImpl(IRagChatService):
|
||||
Query: str,
|
||||
ConversationId: str | None,
|
||||
AppId: int | None,
|
||||
AttachmentId: str | None = None,
|
||||
AttachmentIds: list[str] | None = None,
|
||||
TenantCode: str | None = None,
|
||||
TenantName: str | None = None,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
@@ -128,6 +130,25 @@ class RagChatServiceImpl(IRagChatService):
|
||||
messageId = str(uuid.uuid4())
|
||||
taskId = str(uuid.uuid4())
|
||||
is_new_conversation = not ConversationId or ConversationId == "-1"
|
||||
active_attachment_ids = self._normalize_attachment_ids(
|
||||
attachment_id=AttachmentId,
|
||||
attachment_ids=AttachmentIds,
|
||||
)
|
||||
if not active_attachment_ids:
|
||||
active_attachment_ids = await self._resolve_attachment_ids_for_conversation(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=conversationId,
|
||||
)
|
||||
attachment_records = await self._load_message_attachment_records(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=conversationId,
|
||||
AttachmentIds=active_attachment_ids,
|
||||
)
|
||||
user_message_files = self._build_user_message_files(attachment_records)
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
async with session.begin():
|
||||
@@ -135,13 +156,14 @@ class RagChatServiceImpl(IRagChatService):
|
||||
text(
|
||||
"""
|
||||
INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata)
|
||||
VALUES (:message_id, :conversation_id, 'user', :content, '[]'::jsonb, '{}'::jsonb)
|
||||
VALUES (:message_id, :conversation_id, 'user', :content, '[]'::jsonb, CAST(:metadata AS jsonb))
|
||||
"""
|
||||
),
|
||||
{
|
||||
"message_id": str(uuid.uuid4()),
|
||||
"conversation_id": conversationId,
|
||||
"content": Query,
|
||||
"metadata": json.dumps({"message_files": user_message_files}, ensure_ascii=False),
|
||||
},
|
||||
)
|
||||
await session.execute(
|
||||
@@ -170,6 +192,11 @@ class RagChatServiceImpl(IRagChatService):
|
||||
message_id=messageId,
|
||||
query=Query,
|
||||
app=app,
|
||||
current_user_id=CurrentUserId,
|
||||
tenant_code=TenantCode,
|
||||
user_area=UserArea,
|
||||
attachment_id=AttachmentId,
|
||||
attachment_ids=active_attachment_ids,
|
||||
)
|
||||
|
||||
event_index = 0
|
||||
@@ -322,6 +349,7 @@ class RagChatServiceImpl(IRagChatService):
|
||||
row = items[idx]
|
||||
if row["role"] == "user":
|
||||
answer = items[idx + 1] if idx + 1 < len(items) and items[idx + 1]["role"] == "assistant" else None
|
||||
user_metadata = dict(row.get("metadata") or {})
|
||||
answer_metadata = dict((answer.get("metadata") if answer else None) or {})
|
||||
answer_status = str(answer_metadata.get("status") or ("completed" if answer else "running"))
|
||||
answer_content = (answer.get("content") if answer else None) or ""
|
||||
@@ -355,6 +383,7 @@ class RagChatServiceImpl(IRagChatService):
|
||||
conversationId=ConversationId,
|
||||
query=row["content"],
|
||||
answer=answer_content if answer else "",
|
||||
messageFiles=[item for item in (user_metadata.get("message_files") or []) if isinstance(item, dict)],
|
||||
feedback=({"rating": answer["feedback"]} if answer and answer.get("feedback") else None),
|
||||
retrieverResources=(answer.get("sources") if answer else None),
|
||||
suggestedQuestions=[str(item) for item in (answer_metadata.get("suggested_questions") or []) if str(item).strip()],
|
||||
@@ -940,7 +969,181 @@ class RagChatServiceImpl(IRagChatService):
|
||||
|
||||
async def _retrieve_context(self, dataset_id: int | None, query: str) -> tuple[list[dict], str]:
|
||||
result = await self.retriever.retrieve(query=query, dataset_id=dataset_id)
|
||||
return result.chunks, result.dataset_name
|
||||
chunks = [
|
||||
{
|
||||
**chunk,
|
||||
"source_scope": chunk.get("source_scope") or "formal_kb",
|
||||
"data_source_type": chunk.get("data_source_type") or "formal_kb",
|
||||
}
|
||||
for chunk in result.chunks
|
||||
]
|
||||
return chunks, result.dataset_name
|
||||
|
||||
async def _resolve_attachment_id_for_conversation(
|
||||
self,
|
||||
*,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
UserArea: str | None,
|
||||
ConversationId: str,
|
||||
) -> str | None:
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
|
||||
|
||||
service = RagChatAttachmentServiceImpl()
|
||||
return await service.ResolveActiveAttachmentIdForConversation(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
)
|
||||
|
||||
async def _resolve_attachment_ids_for_conversation(
|
||||
self,
|
||||
*,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
UserArea: str | None,
|
||||
ConversationId: str,
|
||||
) -> list[str]:
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
|
||||
|
||||
service = RagChatAttachmentServiceImpl()
|
||||
if hasattr(service, "ResolveActiveAttachmentIdsForConversation"):
|
||||
return await service.ResolveActiveAttachmentIdsForConversation(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
)
|
||||
attachment_id = await service.ResolveActiveAttachmentIdForConversation(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
)
|
||||
return [attachment_id] if attachment_id else []
|
||||
|
||||
async def _retrieve_attachment_context(
|
||||
self,
|
||||
*,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
UserArea: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentId: str,
|
||||
Query: str,
|
||||
) -> tuple[list[dict], str]:
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
|
||||
|
||||
service = RagChatAttachmentServiceImpl()
|
||||
return await service.RetrieveAttachmentContext(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
AttachmentId=AttachmentId,
|
||||
Query=Query,
|
||||
TopK=5,
|
||||
)
|
||||
|
||||
async def _load_message_attachment_records(
|
||||
self,
|
||||
*,
|
||||
CurrentUserId: int,
|
||||
TenantCode: str | None,
|
||||
UserArea: str | None,
|
||||
ConversationId: str,
|
||||
AttachmentIds: list[str],
|
||||
) -> list[dict]:
|
||||
if not AttachmentIds:
|
||||
return []
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.ragChatAttachmentServiceImpl import RagChatAttachmentServiceImpl
|
||||
|
||||
service = RagChatAttachmentServiceImpl()
|
||||
records: list[dict] = []
|
||||
for attachment_id in AttachmentIds:
|
||||
record = await service.ValidateAttachmentForChat(
|
||||
CurrentUserId=CurrentUserId,
|
||||
TenantCode=TenantCode,
|
||||
UserArea=UserArea,
|
||||
ConversationId=ConversationId,
|
||||
AttachmentId=attachment_id,
|
||||
)
|
||||
records.append(record)
|
||||
return records
|
||||
|
||||
def _build_user_message_files(self, attachment_records: list[dict]) -> list[dict]:
|
||||
files: list[dict] = []
|
||||
for record in attachment_records:
|
||||
attachment_id = str(record.get("attachment_id") or "").strip()
|
||||
if not attachment_id:
|
||||
continue
|
||||
file_name = str(record.get("original_name") or record.get("filename") or "上传文件")
|
||||
content_type = str(record.get("content_type") or "")
|
||||
file_type = "image" if content_type.startswith("image/") or re.search(r"\.(png|jpe?g|webp|bmp|tiff?)$", file_name, re.I) else "file"
|
||||
files.append(
|
||||
{
|
||||
"id": attachment_id,
|
||||
"upload_file_id": attachment_id,
|
||||
"name": file_name,
|
||||
"fileName": file_name,
|
||||
"type": file_type,
|
||||
"transfer_method": "local_file",
|
||||
"contentType": content_type or None,
|
||||
"fileSize": int(record.get("file_size") or 0),
|
||||
"belongs_to": "user",
|
||||
"usage": "temporary_attachment",
|
||||
}
|
||||
)
|
||||
return files
|
||||
|
||||
def _build_formal_kb_query(self, *, query: str, attachment_chunks: list[dict]) -> str:
|
||||
if not attachment_chunks:
|
||||
return query
|
||||
|
||||
facts: list[str] = []
|
||||
for chunk in attachment_chunks[:3]:
|
||||
text_value = str(chunk.get("text") or "").strip()
|
||||
if text_value:
|
||||
facts.append(text_value[:500])
|
||||
if not facts:
|
||||
return query
|
||||
|
||||
return (
|
||||
f"{query}\n\n"
|
||||
"用户上传文档中检索到的相关事实:\n"
|
||||
+ "\n".join(f"- {item}" for item in facts)
|
||||
+ "\n\n请检索这些事实对应的法律责任、处罚依据、裁量规则或案例。"
|
||||
)
|
||||
|
||||
def _merge_context_chunks(self, *, attachment_chunks: list[dict], formal_chunks: list[dict]) -> list[dict]:
|
||||
merged: list[dict] = []
|
||||
for chunk in attachment_chunks:
|
||||
merged.append(
|
||||
{
|
||||
**chunk,
|
||||
"source_scope": "chat_attachment",
|
||||
"data_source_type": "chat_attachment",
|
||||
}
|
||||
)
|
||||
for chunk in formal_chunks:
|
||||
merged.append(
|
||||
{
|
||||
**chunk,
|
||||
"source_scope": chunk.get("source_scope") or "formal_kb",
|
||||
"data_source_type": chunk.get("data_source_type") or "formal_kb",
|
||||
}
|
||||
)
|
||||
return merged
|
||||
|
||||
def _normalize_attachment_ids(self, *, attachment_id: str | None, attachment_ids: list[str] | None) -> list[str]:
|
||||
normalized: list[str] = []
|
||||
for raw in [*(attachment_ids or []), attachment_id]:
|
||||
value = str(raw or "").strip()
|
||||
if value and value not in normalized:
|
||||
normalized.append(value)
|
||||
return normalized
|
||||
|
||||
async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]:
|
||||
return await self.retriever._embed_texts(texts, model_name)
|
||||
@@ -953,6 +1156,11 @@ class RagChatServiceImpl(IRagChatService):
|
||||
message_id: str,
|
||||
query: str,
|
||||
app: dict,
|
||||
current_user_id: int | None = None,
|
||||
tenant_code: str | None = None,
|
||||
user_area: str | None = None,
|
||||
attachment_id: str | None = None,
|
||||
attachment_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
self._task_events[task_id] = []
|
||||
self._task_done[task_id] = False
|
||||
@@ -964,6 +1172,11 @@ class RagChatServiceImpl(IRagChatService):
|
||||
message_id=message_id,
|
||||
query=query,
|
||||
app=app,
|
||||
current_user_id=current_user_id,
|
||||
tenant_code=tenant_code,
|
||||
user_area=user_area,
|
||||
attachment_id=attachment_id,
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
)
|
||||
self._message_tasks[task_id] = task
|
||||
@@ -976,6 +1189,11 @@ class RagChatServiceImpl(IRagChatService):
|
||||
message_id: str,
|
||||
query: str,
|
||||
app: dict,
|
||||
current_user_id: int | None = None,
|
||||
tenant_code: str | None = None,
|
||||
user_area: str | None = None,
|
||||
attachment_id: str | None = None,
|
||||
attachment_ids: list[str] | None = None,
|
||||
) -> None:
|
||||
context_chunks: list[dict] = []
|
||||
dataset_name = ""
|
||||
@@ -987,7 +1205,46 @@ class RagChatServiceImpl(IRagChatService):
|
||||
last_persisted_at = time.monotonic()
|
||||
|
||||
try:
|
||||
context_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), query)
|
||||
attachment_chunks: list[dict] = []
|
||||
attachment_names: list[str] = []
|
||||
active_attachment_ids = self._normalize_attachment_ids(
|
||||
attachment_id=attachment_id,
|
||||
attachment_ids=attachment_ids,
|
||||
)
|
||||
if not active_attachment_ids and current_user_id is not None:
|
||||
active_attachment_ids = await self._resolve_attachment_ids_for_conversation(
|
||||
CurrentUserId=current_user_id,
|
||||
TenantCode=tenant_code,
|
||||
UserArea=user_area,
|
||||
ConversationId=conversation_id,
|
||||
)
|
||||
if active_attachment_ids:
|
||||
if current_user_id is None:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件缺少用户上下文")
|
||||
for active_attachment_id in active_attachment_ids:
|
||||
next_chunks, next_attachment_name = await self._retrieve_attachment_context(
|
||||
CurrentUserId=current_user_id,
|
||||
TenantCode=tenant_code,
|
||||
UserArea=user_area,
|
||||
ConversationId=conversation_id,
|
||||
AttachmentId=active_attachment_id,
|
||||
Query=query,
|
||||
)
|
||||
attachment_chunks.extend(next_chunks)
|
||||
if next_attachment_name:
|
||||
attachment_names.append(next_attachment_name)
|
||||
legal_query = self._build_formal_kb_query(query=query, attachment_chunks=attachment_chunks)
|
||||
formal_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), legal_query)
|
||||
context_chunks = self._merge_context_chunks(
|
||||
attachment_chunks=attachment_chunks,
|
||||
formal_chunks=formal_chunks,
|
||||
)
|
||||
generation_dataset_name = dataset_name
|
||||
attachment_name = "、".join(dict.fromkeys(attachment_names))
|
||||
if attachment_name and dataset_name:
|
||||
generation_dataset_name = f"{attachment_name} + {dataset_name}"
|
||||
elif attachment_name:
|
||||
generation_dataset_name = attachment_name
|
||||
async for chunk in generate_stream(
|
||||
query=query,
|
||||
context_chunks=context_chunks,
|
||||
@@ -997,7 +1254,7 @@ class RagChatServiceImpl(IRagChatService):
|
||||
model=app.get("llm_model") or "",
|
||||
temperature=app.get("temperature"),
|
||||
max_tokens=app.get("max_tokens"),
|
||||
dataset_name=dataset_name,
|
||||
dataset_name=generation_dataset_name,
|
||||
task_id=task_id,
|
||||
):
|
||||
data = self._parse_sse_event(chunk)
|
||||
|
||||
Reference in New Issue
Block a user