148 lines
5.3 KiB
Python
148 lines
5.3 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
import time
|
|
import uuid
|
|
from typing import AsyncGenerator
|
|
|
|
import httpx
|
|
|
|
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG, build_openai_chat_completions_url
|
|
|
|
DEFAULT_SYSTEM_PROMPT = """你是烟草行业智慧法务小助手,专注于烟草专卖法规、合同管理、行政处罚等相关法律法规。\n\n回答要求:\n- 先用一句话直接回答,再展开详细说明\n- 多个要点用编号列表\n- 关键法条和数字用 **加粗**\n- 分类信息用表格\n- 层级结构用缩进子列表\n- 不要加标题,直接输出正文"""
|
|
|
|
|
|
async def generate_stream(
|
|
query: str,
|
|
context_chunks: list[dict],
|
|
conversation_id: str,
|
|
message_id: str,
|
|
task_id: str | None = None,
|
|
system_prompt: str = "",
|
|
model: str = "",
|
|
temperature: float | None = None,
|
|
max_tokens: int | None = None,
|
|
dataset_name: str = "",
|
|
) -> AsyncGenerator[str, None]:
|
|
task_id = task_id or str(uuid.uuid4())
|
|
created_at = int(time.time())
|
|
_model = model or RAG_CONFIG["LLM_MODEL"]
|
|
_temp = temperature if temperature is not None else RAG_CONFIG["LLM_TEMPERATURE"]
|
|
_max_tok = max_tokens or RAG_CONFIG["LLM_MAX_TOKENS"]
|
|
_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
|
|
|
|
max_context_chars = 8000
|
|
if context_chunks:
|
|
parts: list[str] = []
|
|
total_len = 0
|
|
for chunk in context_chunks:
|
|
part = f"[来源: {chunk.get('source', '未知')}]\\n{chunk.get('text', '')}"
|
|
if total_len + len(part) > max_context_chars:
|
|
break
|
|
parts.append(part)
|
|
total_len += len(part)
|
|
context_text = "\\n\\n---\\n\\n".join(parts)
|
|
user_content = f"知识库内容:\\n{context_text}\\n\\n用户问题: {query}"
|
|
else:
|
|
user_content = query
|
|
|
|
messages = [
|
|
{"role": "system", "content": _prompt},
|
|
{"role": "user", "content": user_content},
|
|
]
|
|
|
|
total_tokens = 0
|
|
try:
|
|
async with httpx.AsyncClient(timeout=RAG_CONFIG["LLM_TIMEOUT"]) as client:
|
|
async with client.stream(
|
|
"POST",
|
|
build_openai_chat_completions_url(RAG_CONFIG["LLM_BASE_URL"]),
|
|
json={
|
|
"model": _model,
|
|
"messages": messages,
|
|
"temperature": _temp,
|
|
"max_tokens": _max_tok,
|
|
"stream": True,
|
|
"chat_template_kwargs": {"enable_thinking": False},
|
|
},
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {RAG_CONFIG['LLM_API_KEY']}",
|
|
},
|
|
) as resp:
|
|
resp.raise_for_status()
|
|
async for line in resp.aiter_lines():
|
|
if not line.startswith("data: "):
|
|
continue
|
|
payload = line[6:].strip()
|
|
if payload == "[DONE]":
|
|
break
|
|
chunk = json.loads(payload)
|
|
choices = chunk.get("choices") or []
|
|
first_choice = choices[0] if choices and isinstance(choices[0], dict) else {}
|
|
delta = first_choice.get("delta", {})
|
|
text = delta.get("content", "")
|
|
if text:
|
|
yield _sse_line(
|
|
{
|
|
"event": "message",
|
|
"task_id": task_id,
|
|
"message_id": message_id,
|
|
"conversation_id": conversation_id,
|
|
"answer": text,
|
|
"created_at": created_at,
|
|
}
|
|
)
|
|
usage = chunk.get("usage")
|
|
if usage:
|
|
total_tokens = usage.get("total_tokens", total_tokens)
|
|
except Exception as exc:
|
|
yield _sse_line(
|
|
{
|
|
"event": "error",
|
|
"task_id": task_id,
|
|
"message_id": message_id,
|
|
"code": "llm_error",
|
|
"message": str(exc),
|
|
}
|
|
)
|
|
return
|
|
|
|
retriever_resources = [
|
|
{
|
|
"position": i + 1,
|
|
"dataset_id": "",
|
|
"dataset_name": dataset_name,
|
|
"document_id": "",
|
|
"document_name": chunk.get("source", ""),
|
|
"data_source_type": "upload_file",
|
|
"segment_id": chunk.get("id", ""),
|
|
"retriever_from": "rag",
|
|
"score": round(chunk.get("score", 0.0), 4),
|
|
"hit_count": 0,
|
|
"word_count": len(chunk.get("text", "")),
|
|
"segment_position": i + 1,
|
|
"index_node_hash": "",
|
|
"content": chunk.get("text", "")[:500],
|
|
"page": None,
|
|
}
|
|
for i, chunk in enumerate(context_chunks)
|
|
]
|
|
|
|
yield _sse_line(
|
|
{
|
|
"event": "message_end",
|
|
"task_id": task_id,
|
|
"message_id": message_id,
|
|
"conversation_id": conversation_id,
|
|
"metadata": {
|
|
"usage": {"total_tokens": total_tokens},
|
|
"retriever_resources": retriever_resources,
|
|
},
|
|
}
|
|
)
|
|
|
|
|
|
def _sse_line(data: dict) -> str:
|
|
return f"data: {json.dumps(data, ensure_ascii=False)}\n\n"
|