feat(rag): add temporary chat attachments

This commit is contained in:
wren
2026-05-25 15:37:37 +08:00
parent 0f385c9839
commit 75c077da77
16 changed files with 2257 additions and 16 deletions
@@ -33,16 +33,34 @@ async def generate_stream(
max_context_chars = 8000
if context_chunks:
parts: list[str] = []
attachment_parts: list[str] = []
formal_parts: list[str] = []
total_len = 0
for chunk in context_chunks:
part = f"[来源: {chunk.get('source', '未知')}]\\n{chunk.get('text', '')}"
scope = str(chunk.get("source_scope") or chunk.get("data_source_type") or "formal_kb")
if scope == "chat_attachment":
label = "上传文件事实"
else:
label = "正式知识库依据"
part = f"[{label} | 来源: {chunk.get('source', '未知')}]\\n{chunk.get('text', '')}"
if total_len + len(part) > max_context_chars:
break
parts.append(part)
if scope == "chat_attachment":
attachment_parts.append(part)
else:
formal_parts.append(part)
total_len += len(part)
context_text = "\\n\\n---\\n\\n".join(parts)
user_content = f"知识库内容:\\n{context_text}\\n\\n用户问题: {query}"
context_sections: list[str] = []
if attachment_parts:
context_sections.append("【上传文件事实】\n" + "\\n\\n---\\n\\n".join(attachment_parts))
if formal_parts:
context_sections.append("【正式知识库依据】\n" + "\\n\\n---\\n\\n".join(formal_parts))
context_text = "\\n\\n======\\n\\n".join(context_sections)
user_content = (
"请严格区分上下文来源:上传文件事实只能用于判断用户材料中的事实;正式知识库依据只能用于法律、处罚、裁量或案例依据。"
"如果依据不足,请明确说明,不要编造。\n\n"
f"{context_text}\\n\\n用户问题: {query}"
)
else:
user_content = query
@@ -115,7 +133,7 @@ async def generate_stream(
"dataset_name": dataset_name,
"document_id": "",
"document_name": chunk.get("source", ""),
"data_source_type": "upload_file",
"data_source_type": chunk.get("data_source_type") or chunk.get("source_scope") or "upload_file",
"segment_id": chunk.get("id", ""),
"retriever_from": "rag",
"score": round(chunk.get("score", 0.0), 4),
@@ -125,6 +143,8 @@ async def generate_stream(
"index_node_hash": "",
"content": chunk.get("text", "")[:500],
"page": None,
"source_scope": chunk.get("source_scope") or "",
"attachment_id": chunk.get("attachment_id") or "",
}
for i, chunk in enumerate(context_chunks)
]
@@ -183,10 +183,15 @@ class RagRetriever:
"source": meta.get("source") or meta.get("document_name") or dataset_name,
"score": score,
"chunk_index": int(meta.get("chunk_index") or idx),
"document_name": document_name,
"document_id": meta.get("document_id"),
"page": meta.get("page"),
}
"document_name": document_name,
"document_id": meta.get("document_id"),
"page": meta.get("page"),
"source_scope": meta.get("source_scope"),
"attachment_id": meta.get("attachment_id"),
"conversation_id": meta.get("conversation_id"),
"tenant_code": meta.get("tenant_code"),
"user_id": meta.get("user_id"),
}
)
return chunks
@@ -285,6 +290,11 @@ class RagRetriever:
"document_name": document_name,
"document_id": meta.get("document_id"),
"page": meta.get("page"),
"source_scope": meta.get("source_scope"),
"attachment_id": meta.get("attachment_id"),
"conversation_id": meta.get("conversation_id"),
"tenant_code": meta.get("tenant_code"),
"user_id": meta.get("user_id"),
}
)
@@ -392,10 +402,10 @@ class RagRetriever:
{
"position": index + 1,
"dataset_id": str(chunk.get("dataset_id") or ""),
"dataset_name": dataset_name,
"dataset_name": chunk.get("dataset_name") or dataset_name,
"document_id": str(chunk.get("document_id") or ""),
"document_name": chunk.get("document_name") or chunk.get("source", ""),
"data_source_type": "upload_file",
"data_source_type": chunk.get("data_source_type") or chunk.get("source_scope") or "upload_file",
"segment_id": chunk.get("id", ""),
"retriever_from": "rag",
"score": round(float(chunk.get("score") or 0.0), 4),
@@ -405,6 +415,8 @@ class RagRetriever:
"index_node_hash": "",
"content": chunk.get("text", "")[:500],
"page": None,
"source_scope": chunk.get("source_scope") or "",
"attachment_id": chunk.get("attachment_id") or "",
}
for index, chunk in enumerate(context_chunks)
]