40 lines
1.5 KiB
Python
40 lines
1.5 KiB
Python
from __future__ import annotations
|
|
|
|
import json
|
|
|
|
import httpx
|
|
|
|
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG
|
|
|
|
|
|
async def generate_followups(query: str, answer: str) -> list[str]:
|
|
prompt = (
|
|
"基于用户问题和已有回答,生成 3 个适合继续追问的简短问题。"
|
|
"仅返回 JSON 数组字符串,例如 [\"问题1\", \"问题2\"]。\\n"
|
|
f"用户问题: {query}\\n回答: {answer[:1200]}"
|
|
)
|
|
async with httpx.AsyncClient(timeout=30.0) as client:
|
|
resp = await client.post(
|
|
f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}" + "/chat/completions",
|
|
json={
|
|
"model": RAG_CONFIG["LLM_MODEL"],
|
|
"messages": [{"role": "user", "content": prompt}],
|
|
"temperature": 0.5,
|
|
"max_tokens": 256,
|
|
"chat_template_kwargs": {"enable_thinking": False},
|
|
},
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {RAG_CONFIG['LLM_API_KEY']}",
|
|
},
|
|
)
|
|
resp.raise_for_status()
|
|
content = resp.json()["choices"][0]["message"]["content"]
|
|
try:
|
|
parsed = json.loads(content)
|
|
if isinstance(parsed, list):
|
|
return [str(item).strip() for item in parsed if str(item).strip()][:3]
|
|
except Exception:
|
|
pass
|
|
return [line.strip("- 1234567890.\t") for line in content.splitlines() if line.strip()][:3]
|