Files
leaudit-platform-backend/fastapi_modules/fastapi_leaudit/rag_engine/generator.py
T
2026-05-25 15:37:37 +08:00

168 lines
6.5 KiB
Python

from __future__ import annotations
import json
import time
import uuid
from typing import AsyncGenerator
import httpx
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG, build_openai_chat_completions_url
DEFAULT_SYSTEM_PROMPT = """你是烟草行业智慧法务小助手,专注于烟草专卖法规、合同管理、行政处罚等相关法律法规。\n\n回答要求:\n- 先用一句话直接回答,再展开详细说明\n- 多个要点用编号列表\n- 关键法条和数字用 **加粗**\n- 分类信息用表格\n- 层级结构用缩进子列表\n- 不要加标题,直接输出正文"""
async def generate_stream(
query: str,
context_chunks: list[dict],
conversation_id: str,
message_id: str,
task_id: str | None = None,
system_prompt: str = "",
model: str = "",
temperature: float | None = None,
max_tokens: int | None = None,
dataset_name: str = "",
) -> AsyncGenerator[str, None]:
task_id = task_id or str(uuid.uuid4())
created_at = int(time.time())
_model = model or RAG_CONFIG["LLM_MODEL"]
_temp = temperature if temperature is not None else RAG_CONFIG["LLM_TEMPERATURE"]
_max_tok = max_tokens or RAG_CONFIG["LLM_MAX_TOKENS"]
_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
max_context_chars = 8000
if context_chunks:
attachment_parts: list[str] = []
formal_parts: list[str] = []
total_len = 0
for chunk in context_chunks:
scope = str(chunk.get("source_scope") or chunk.get("data_source_type") or "formal_kb")
if scope == "chat_attachment":
label = "上传文件事实"
else:
label = "正式知识库依据"
part = f"[{label} | 来源: {chunk.get('source', '未知')}]\\n{chunk.get('text', '')}"
if total_len + len(part) > max_context_chars:
break
if scope == "chat_attachment":
attachment_parts.append(part)
else:
formal_parts.append(part)
total_len += len(part)
context_sections: list[str] = []
if attachment_parts:
context_sections.append("【上传文件事实】\n" + "\\n\\n---\\n\\n".join(attachment_parts))
if formal_parts:
context_sections.append("【正式知识库依据】\n" + "\\n\\n---\\n\\n".join(formal_parts))
context_text = "\\n\\n======\\n\\n".join(context_sections)
user_content = (
"请严格区分上下文来源:上传文件事实只能用于判断用户材料中的事实;正式知识库依据只能用于法律、处罚、裁量或案例依据。"
"如果依据不足,请明确说明,不要编造。\n\n"
f"{context_text}\\n\\n用户问题: {query}"
)
else:
user_content = query
messages = [
{"role": "system", "content": _prompt},
{"role": "user", "content": user_content},
]
total_tokens = 0
try:
async with httpx.AsyncClient(timeout=RAG_CONFIG["LLM_TIMEOUT"]) as client:
async with client.stream(
"POST",
build_openai_chat_completions_url(RAG_CONFIG["LLM_BASE_URL"]),
json={
"model": _model,
"messages": messages,
"temperature": _temp,
"max_tokens": _max_tok,
"stream": True,
"chat_template_kwargs": {"enable_thinking": False},
},
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {RAG_CONFIG['LLM_API_KEY']}",
},
) as resp:
resp.raise_for_status()
async for line in resp.aiter_lines():
if not line.startswith("data: "):
continue
payload = line[6:].strip()
if payload == "[DONE]":
break
chunk = json.loads(payload)
choices = chunk.get("choices") or []
first_choice = choices[0] if choices and isinstance(choices[0], dict) else {}
delta = first_choice.get("delta", {})
text = delta.get("content", "")
if text:
yield _sse_line(
{
"event": "message",
"task_id": task_id,
"message_id": message_id,
"conversation_id": conversation_id,
"answer": text,
"created_at": created_at,
}
)
usage = chunk.get("usage")
if usage:
total_tokens = usage.get("total_tokens", total_tokens)
except Exception as exc:
yield _sse_line(
{
"event": "error",
"task_id": task_id,
"message_id": message_id,
"code": "llm_error",
"message": str(exc),
}
)
return
retriever_resources = [
{
"position": i + 1,
"dataset_id": "",
"dataset_name": dataset_name,
"document_id": "",
"document_name": chunk.get("source", ""),
"data_source_type": chunk.get("data_source_type") or chunk.get("source_scope") or "upload_file",
"segment_id": chunk.get("id", ""),
"retriever_from": "rag",
"score": round(chunk.get("score", 0.0), 4),
"hit_count": 0,
"word_count": len(chunk.get("text", "")),
"segment_position": i + 1,
"index_node_hash": "",
"content": chunk.get("text", "")[:500],
"page": None,
"source_scope": chunk.get("source_scope") or "",
"attachment_id": chunk.get("attachment_id") or "",
}
for i, chunk in enumerate(context_chunks)
]
yield _sse_line(
{
"event": "message_end",
"task_id": task_id,
"message_id": message_id,
"conversation_id": conversation_id,
"metadata": {
"usage": {"total_tokens": total_tokens},
"retriever_resources": retriever_resources,
},
}
)
def _sse_line(data: dict) -> str:
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"