feat: add rag backend and review access fixes
This commit is contained in:
@@ -0,0 +1,144 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = """你是烟草行业智慧法务小助手,专注于烟草专卖法规、合同管理、行政处罚等相关法律法规。\n\n回答要求:\n- 先用一句话直接回答,再展开详细说明\n- 多个要点用编号列表\n- 关键法条和数字用 **加粗**\n- 分类信息用表格\n- 层级结构用缩进子列表\n- 不要加标题,直接输出正文"""
|
||||
|
||||
|
||||
async def generate_stream(
|
||||
query: str,
|
||||
context_chunks: list[dict],
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
system_prompt: str = "",
|
||||
model: str = "",
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
dataset_name: str = "",
|
||||
) -> AsyncGenerator[str, None]:
|
||||
task_id = str(uuid.uuid4())
|
||||
created_at = int(time.time())
|
||||
_model = model or RAG_CONFIG["LLM_MODEL"]
|
||||
_temp = temperature if temperature is not None else RAG_CONFIG["LLM_TEMPERATURE"]
|
||||
_max_tok = max_tokens or RAG_CONFIG["LLM_MAX_TOKENS"]
|
||||
_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
|
||||
|
||||
max_context_chars = 8000
|
||||
if context_chunks:
|
||||
parts: list[str] = []
|
||||
total_len = 0
|
||||
for chunk in context_chunks:
|
||||
part = f"[来源: {chunk.get('source', '未知')}]\\n{chunk.get('text', '')}"
|
||||
if total_len + len(part) > max_context_chars:
|
||||
break
|
||||
parts.append(part)
|
||||
total_len += len(part)
|
||||
context_text = "\\n\\n---\\n\\n".join(parts)
|
||||
user_content = f"知识库内容:\\n{context_text}\\n\\n用户问题: {query}"
|
||||
else:
|
||||
user_content = query
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": _prompt},
|
||||
{"role": "user", "content": user_content},
|
||||
]
|
||||
|
||||
total_tokens = 0
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=RAG_CONFIG["LLM_TIMEOUT"]) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}" + "/chat/completions",
|
||||
json={
|
||||
"model": _model,
|
||||
"messages": messages,
|
||||
"temperature": _temp,
|
||||
"max_tokens": _max_tok,
|
||||
"stream": True,
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {RAG_CONFIG['LLM_API_KEY']}",
|
||||
},
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
payload = line[6:].strip()
|
||||
if payload == "[DONE]":
|
||||
break
|
||||
chunk = json.loads(payload)
|
||||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||
text = delta.get("content", "")
|
||||
if text:
|
||||
yield _sse_line(
|
||||
{
|
||||
"event": "message",
|
||||
"task_id": task_id,
|
||||
"message_id": message_id,
|
||||
"conversation_id": conversation_id,
|
||||
"answer": text,
|
||||
"created_at": created_at,
|
||||
}
|
||||
)
|
||||
usage = chunk.get("usage")
|
||||
if usage:
|
||||
total_tokens = usage.get("total_tokens", total_tokens)
|
||||
except Exception as exc:
|
||||
yield _sse_line(
|
||||
{
|
||||
"event": "error",
|
||||
"task_id": task_id,
|
||||
"message_id": message_id,
|
||||
"code": "llm_error",
|
||||
"message": str(exc),
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
retriever_resources = [
|
||||
{
|
||||
"position": i + 1,
|
||||
"dataset_id": "",
|
||||
"dataset_name": dataset_name,
|
||||
"document_id": "",
|
||||
"document_name": chunk.get("source", ""),
|
||||
"data_source_type": "upload_file",
|
||||
"segment_id": chunk.get("id", ""),
|
||||
"retriever_from": "rag",
|
||||
"score": round(chunk.get("score", 0.0), 4),
|
||||
"hit_count": 0,
|
||||
"word_count": len(chunk.get("text", "")),
|
||||
"segment_position": i + 1,
|
||||
"index_node_hash": "",
|
||||
"content": chunk.get("text", "")[:500],
|
||||
"page": None,
|
||||
}
|
||||
for i, chunk in enumerate(context_chunks)
|
||||
]
|
||||
|
||||
yield _sse_line(
|
||||
{
|
||||
"event": "message_end",
|
||||
"task_id": task_id,
|
||||
"message_id": message_id,
|
||||
"conversation_id": conversation_id,
|
||||
"metadata": {
|
||||
"usage": {"total_tokens": total_tokens},
|
||||
"retriever_resources": retriever_resources,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _sse_line(data: dict) -> str:
|
||||
return f"data: {json.dumps(data, ensure_ascii=False)}\\n\\n"
|
||||
Reference in New Issue
Block a user