feat: add rag backend and review access fixes

This commit is contained in:
wren
2026-05-08 10:58:24 +08:00
parent 1c84209f38
commit 9c86bf59e5
32 changed files with 3877 additions and 23 deletions
@@ -0,0 +1 @@
"""RAG 聊天内核兼容层。"""
@@ -0,0 +1,40 @@
from __future__ import annotations
from typing import Any
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG
_instance: Any | None = None
def init_chroma() -> Any:
global _instance
if _instance is not None:
return _instance
import chromadb # lazy import to avoid hard failure before feature is enabled
import chromadb.config
host = RAG_CONFIG["CHROMA_HOST"]
if host:
token = RAG_CONFIG.get("CHROMA_TOKEN", "")
header = RAG_CONFIG.get("CHROMA_AUTH_HEADER", "X-Chroma-Token")
settings = (
chromadb.config.Settings(
chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider",
chroma_client_auth_credentials=token,
chroma_auth_token_transport_header=header,
)
if token
else chromadb.config.Settings()
)
_instance = chromadb.HttpClient(host=host, port=RAG_CONFIG["CHROMA_PORT"], settings=settings)
else:
_instance = chromadb.PersistentClient(path=RAG_CONFIG["CHROMA_PERSIST_DIR"])
return _instance
def get_chroma() -> Any:
if _instance is None:
return init_chroma()
return _instance
@@ -0,0 +1,60 @@
from __future__ import annotations
from fastapi_admin.config._settings import llm
def _get_str(name: str, default: str = "") -> str:
import os
return os.getenv(name, default)
def _get_bool(name: str, default: bool = False) -> bool:
import os
return os.getenv(name, str(default).lower()).lower() == "true"
def _get_int(name: str, default: int) -> int:
import os
try:
return int(os.getenv(name, str(default)))
except ValueError:
return default
def _get_float(name: str, default: float) -> float:
import os
try:
return float(os.getenv(name, str(default)))
except ValueError:
return default
RAG_CONFIG = {
"USE_SELF_HOSTED": True,
"CHROMA_PERSIST_DIR": _get_str("RAG_CHROMA_PERSIST_DIR", ".chromadb_rag"),
"CHROMA_HOST": _get_str("RAG_CHROMA_HOST", ""),
"CHROMA_PORT": _get_int("RAG_CHROMA_PORT", 8010),
"CHROMA_TOKEN": _get_str("RAG_CHROMA_TOKEN", ""),
"CHROMA_AUTH_HEADER": _get_str("RAG_CHROMA_AUTH_HEADER", "X-Chroma-Token"),
"EMBED_URL": _get_str("RAG_EMBED_URL", _get_str("GRAPH_RAG_EMBED_URL", "")),
"EMBED_KEY": _get_str("RAG_EMBED_KEY", _get_str("GRAPH_RAG_EMBED_KEY", "")),
"EMBED_MODEL": _get_str("RAG_EMBED_MODEL", _get_str("GRAPH_RAG_EMBED_MODEL", "")),
"EMBED_DIM": _get_int("RAG_EMBED_DIM", 1024),
"EMBED_BATCH_SIZE": _get_int("RAG_EMBED_BATCH_SIZE", 10),
"RERANKER_URL": _get_str("RAG_RERANKER_URL", _get_str("GRAPH_RAG_RERANKER_URL", "")),
"RERANKER_KEY": _get_str("RAG_RERANKER_KEY", _get_str("GRAPH_RAG_RERANKER_KEY", "")),
"RERANKER_MODEL": _get_str("RAG_RERANKER_MODEL", _get_str("GRAPH_RAG_RERANKER_MODEL", "")),
"LLM_BASE_URL": _get_str("LLM_BASE_URL", llm.LLM_BASE_URL),
"LLM_MODEL": _get_str("LLM_MODEL", llm.LLM_MODEL),
"LLM_API_KEY": _get_str("LLM_API_KEY", llm.LLM_API_KEY),
"VECTOR_TOP_K": _get_int("RAG_VECTOR_TOP_K", 15),
"RERANK_TOP_K": _get_int("RAG_RERANK_TOP_K", 5),
"BM25_TOP_K": _get_int("RAG_BM25_TOP_K", 15),
"RRF_K": _get_int("RAG_RRF_K", 60),
"LLM_TEMPERATURE": _get_float("RAG_LLM_TEMPERATURE", 0.3),
"LLM_MAX_TOKENS": _get_int("RAG_LLM_MAX_TOKENS", 2048),
"LLM_TIMEOUT": _get_int("RAG_LLM_TIMEOUT", 120),
"QUERY_REWRITING": _get_bool("RAG_QUERY_REWRITING", False),
"HYBRID_SEARCH": _get_bool("RAG_HYBRID_SEARCH", True),
"RERANKING": _get_bool("RAG_RERANKING", True),
}
@@ -0,0 +1,144 @@
from __future__ import annotations
import json
import time
import uuid
from typing import AsyncGenerator
import httpx
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG
DEFAULT_SYSTEM_PROMPT = """你是烟草行业智慧法务小助手,专注于烟草专卖法规、合同管理、行政处罚等相关法律法规。\n\n回答要求:\n- 先用一句话直接回答,再展开详细说明\n- 多个要点用编号列表\n- 关键法条和数字用 **加粗**\n- 分类信息用表格\n- 层级结构用缩进子列表\n- 不要加标题,直接输出正文"""
async def generate_stream(
query: str,
context_chunks: list[dict],
conversation_id: str,
message_id: str,
system_prompt: str = "",
model: str = "",
temperature: float | None = None,
max_tokens: int | None = None,
dataset_name: str = "",
) -> AsyncGenerator[str, None]:
task_id = str(uuid.uuid4())
created_at = int(time.time())
_model = model or RAG_CONFIG["LLM_MODEL"]
_temp = temperature if temperature is not None else RAG_CONFIG["LLM_TEMPERATURE"]
_max_tok = max_tokens or RAG_CONFIG["LLM_MAX_TOKENS"]
_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
max_context_chars = 8000
if context_chunks:
parts: list[str] = []
total_len = 0
for chunk in context_chunks:
part = f"[来源: {chunk.get('source', '未知')}]\\n{chunk.get('text', '')}"
if total_len + len(part) > max_context_chars:
break
parts.append(part)
total_len += len(part)
context_text = "\\n\\n---\\n\\n".join(parts)
user_content = f"知识库内容:\\n{context_text}\\n\\n用户问题: {query}"
else:
user_content = query
messages = [
{"role": "system", "content": _prompt},
{"role": "user", "content": user_content},
]
total_tokens = 0
try:
async with httpx.AsyncClient(timeout=RAG_CONFIG["LLM_TIMEOUT"]) as client:
async with client.stream(
"POST",
f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}" + "/chat/completions",
json={
"model": _model,
"messages": messages,
"temperature": _temp,
"max_tokens": _max_tok,
"stream": True,
"chat_template_kwargs": {"enable_thinking": False},
},
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {RAG_CONFIG['LLM_API_KEY']}",
},
) as resp:
resp.raise_for_status()
async for line in resp.aiter_lines():
if not line.startswith("data: "):
continue
payload = line[6:].strip()
if payload == "[DONE]":
break
chunk = json.loads(payload)
delta = chunk.get("choices", [{}])[0].get("delta", {})
text = delta.get("content", "")
if text:
yield _sse_line(
{
"event": "message",
"task_id": task_id,
"message_id": message_id,
"conversation_id": conversation_id,
"answer": text,
"created_at": created_at,
}
)
usage = chunk.get("usage")
if usage:
total_tokens = usage.get("total_tokens", total_tokens)
except Exception as exc:
yield _sse_line(
{
"event": "error",
"task_id": task_id,
"message_id": message_id,
"code": "llm_error",
"message": str(exc),
}
)
return
retriever_resources = [
{
"position": i + 1,
"dataset_id": "",
"dataset_name": dataset_name,
"document_id": "",
"document_name": chunk.get("source", ""),
"data_source_type": "upload_file",
"segment_id": chunk.get("id", ""),
"retriever_from": "rag",
"score": round(chunk.get("score", 0.0), 4),
"hit_count": 0,
"word_count": len(chunk.get("text", "")),
"segment_position": i + 1,
"index_node_hash": "",
"content": chunk.get("text", "")[:500],
"page": None,
}
for i, chunk in enumerate(context_chunks)
]
yield _sse_line(
{
"event": "message_end",
"task_id": task_id,
"message_id": message_id,
"conversation_id": conversation_id,
"metadata": {
"usage": {"total_tokens": total_tokens},
"retriever_resources": retriever_resources,
},
}
)
def _sse_line(data: dict) -> str:
return f"data: {json.dumps(data, ensure_ascii=False)}\\n\\n"
@@ -0,0 +1,39 @@
from __future__ import annotations
import json
import httpx
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG
async def generate_followups(query: str, answer: str) -> list[str]:
prompt = (
"基于用户问题和已有回答,生成 3 个适合继续追问的简短问题。"
"仅返回 JSON 数组字符串,例如 [\"问题1\", \"问题2\"]。\\n"
f"用户问题: {query}\\n回答: {answer[:1200]}"
)
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}" + "/chat/completions",
json={
"model": RAG_CONFIG["LLM_MODEL"],
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
"max_tokens": 256,
"chat_template_kwargs": {"enable_thinking": False},
},
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {RAG_CONFIG['LLM_API_KEY']}",
},
)
resp.raise_for_status()
content = resp.json()["choices"][0]["message"]["content"]
try:
parsed = json.loads(content)
if isinstance(parsed, list):
return [str(item).strip() for item in parsed if str(item).strip()][:3]
except Exception:
pass
return [line.strip("- 1234567890.\t") for line in content.splitlines() if line.strip()][:3]