feat: add rag backend and review access fixes
This commit is contained in:
@@ -0,0 +1 @@
|
||||
"""RAG 聊天内核兼容层。"""
|
||||
@@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG
|
||||
|
||||
_instance: Any | None = None
|
||||
|
||||
|
||||
def init_chroma() -> Any:
|
||||
global _instance
|
||||
if _instance is not None:
|
||||
return _instance
|
||||
|
||||
import chromadb # lazy import to avoid hard failure before feature is enabled
|
||||
import chromadb.config
|
||||
|
||||
host = RAG_CONFIG["CHROMA_HOST"]
|
||||
if host:
|
||||
token = RAG_CONFIG.get("CHROMA_TOKEN", "")
|
||||
header = RAG_CONFIG.get("CHROMA_AUTH_HEADER", "X-Chroma-Token")
|
||||
settings = (
|
||||
chromadb.config.Settings(
|
||||
chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider",
|
||||
chroma_client_auth_credentials=token,
|
||||
chroma_auth_token_transport_header=header,
|
||||
)
|
||||
if token
|
||||
else chromadb.config.Settings()
|
||||
)
|
||||
_instance = chromadb.HttpClient(host=host, port=RAG_CONFIG["CHROMA_PORT"], settings=settings)
|
||||
else:
|
||||
_instance = chromadb.PersistentClient(path=RAG_CONFIG["CHROMA_PERSIST_DIR"])
|
||||
return _instance
|
||||
|
||||
|
||||
def get_chroma() -> Any:
|
||||
if _instance is None:
|
||||
return init_chroma()
|
||||
return _instance
|
||||
@@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi_admin.config._settings import llm
|
||||
|
||||
|
||||
def _get_str(name: str, default: str = "") -> str:
|
||||
import os
|
||||
return os.getenv(name, default)
|
||||
|
||||
|
||||
def _get_bool(name: str, default: bool = False) -> bool:
|
||||
import os
|
||||
return os.getenv(name, str(default).lower()).lower() == "true"
|
||||
|
||||
|
||||
def _get_int(name: str, default: int) -> int:
|
||||
import os
|
||||
try:
|
||||
return int(os.getenv(name, str(default)))
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
|
||||
def _get_float(name: str, default: float) -> float:
|
||||
import os
|
||||
try:
|
||||
return float(os.getenv(name, str(default)))
|
||||
except ValueError:
|
||||
return default
|
||||
|
||||
|
||||
RAG_CONFIG = {
|
||||
"USE_SELF_HOSTED": True,
|
||||
"CHROMA_PERSIST_DIR": _get_str("RAG_CHROMA_PERSIST_DIR", ".chromadb_rag"),
|
||||
"CHROMA_HOST": _get_str("RAG_CHROMA_HOST", ""),
|
||||
"CHROMA_PORT": _get_int("RAG_CHROMA_PORT", 8010),
|
||||
"CHROMA_TOKEN": _get_str("RAG_CHROMA_TOKEN", ""),
|
||||
"CHROMA_AUTH_HEADER": _get_str("RAG_CHROMA_AUTH_HEADER", "X-Chroma-Token"),
|
||||
"EMBED_URL": _get_str("RAG_EMBED_URL", _get_str("GRAPH_RAG_EMBED_URL", "")),
|
||||
"EMBED_KEY": _get_str("RAG_EMBED_KEY", _get_str("GRAPH_RAG_EMBED_KEY", "")),
|
||||
"EMBED_MODEL": _get_str("RAG_EMBED_MODEL", _get_str("GRAPH_RAG_EMBED_MODEL", "")),
|
||||
"EMBED_DIM": _get_int("RAG_EMBED_DIM", 1024),
|
||||
"EMBED_BATCH_SIZE": _get_int("RAG_EMBED_BATCH_SIZE", 10),
|
||||
"RERANKER_URL": _get_str("RAG_RERANKER_URL", _get_str("GRAPH_RAG_RERANKER_URL", "")),
|
||||
"RERANKER_KEY": _get_str("RAG_RERANKER_KEY", _get_str("GRAPH_RAG_RERANKER_KEY", "")),
|
||||
"RERANKER_MODEL": _get_str("RAG_RERANKER_MODEL", _get_str("GRAPH_RAG_RERANKER_MODEL", "")),
|
||||
"LLM_BASE_URL": _get_str("LLM_BASE_URL", llm.LLM_BASE_URL),
|
||||
"LLM_MODEL": _get_str("LLM_MODEL", llm.LLM_MODEL),
|
||||
"LLM_API_KEY": _get_str("LLM_API_KEY", llm.LLM_API_KEY),
|
||||
"VECTOR_TOP_K": _get_int("RAG_VECTOR_TOP_K", 15),
|
||||
"RERANK_TOP_K": _get_int("RAG_RERANK_TOP_K", 5),
|
||||
"BM25_TOP_K": _get_int("RAG_BM25_TOP_K", 15),
|
||||
"RRF_K": _get_int("RAG_RRF_K", 60),
|
||||
"LLM_TEMPERATURE": _get_float("RAG_LLM_TEMPERATURE", 0.3),
|
||||
"LLM_MAX_TOKENS": _get_int("RAG_LLM_MAX_TOKENS", 2048),
|
||||
"LLM_TIMEOUT": _get_int("RAG_LLM_TIMEOUT", 120),
|
||||
"QUERY_REWRITING": _get_bool("RAG_QUERY_REWRITING", False),
|
||||
"HYBRID_SEARCH": _get_bool("RAG_HYBRID_SEARCH", True),
|
||||
"RERANKING": _get_bool("RAG_RERANKING", True),
|
||||
}
|
||||
@@ -0,0 +1,144 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
import uuid
|
||||
from typing import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = """你是烟草行业智慧法务小助手,专注于烟草专卖法规、合同管理、行政处罚等相关法律法规。\n\n回答要求:\n- 先用一句话直接回答,再展开详细说明\n- 多个要点用编号列表\n- 关键法条和数字用 **加粗**\n- 分类信息用表格\n- 层级结构用缩进子列表\n- 不要加标题,直接输出正文"""
|
||||
|
||||
|
||||
async def generate_stream(
|
||||
query: str,
|
||||
context_chunks: list[dict],
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
system_prompt: str = "",
|
||||
model: str = "",
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
dataset_name: str = "",
|
||||
) -> AsyncGenerator[str, None]:
|
||||
task_id = str(uuid.uuid4())
|
||||
created_at = int(time.time())
|
||||
_model = model or RAG_CONFIG["LLM_MODEL"]
|
||||
_temp = temperature if temperature is not None else RAG_CONFIG["LLM_TEMPERATURE"]
|
||||
_max_tok = max_tokens or RAG_CONFIG["LLM_MAX_TOKENS"]
|
||||
_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
|
||||
|
||||
max_context_chars = 8000
|
||||
if context_chunks:
|
||||
parts: list[str] = []
|
||||
total_len = 0
|
||||
for chunk in context_chunks:
|
||||
part = f"[来源: {chunk.get('source', '未知')}]\\n{chunk.get('text', '')}"
|
||||
if total_len + len(part) > max_context_chars:
|
||||
break
|
||||
parts.append(part)
|
||||
total_len += len(part)
|
||||
context_text = "\\n\\n---\\n\\n".join(parts)
|
||||
user_content = f"知识库内容:\\n{context_text}\\n\\n用户问题: {query}"
|
||||
else:
|
||||
user_content = query
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": _prompt},
|
||||
{"role": "user", "content": user_content},
|
||||
]
|
||||
|
||||
total_tokens = 0
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=RAG_CONFIG["LLM_TIMEOUT"]) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}" + "/chat/completions",
|
||||
json={
|
||||
"model": _model,
|
||||
"messages": messages,
|
||||
"temperature": _temp,
|
||||
"max_tokens": _max_tok,
|
||||
"stream": True,
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {RAG_CONFIG['LLM_API_KEY']}",
|
||||
},
|
||||
) as resp:
|
||||
resp.raise_for_status()
|
||||
async for line in resp.aiter_lines():
|
||||
if not line.startswith("data: "):
|
||||
continue
|
||||
payload = line[6:].strip()
|
||||
if payload == "[DONE]":
|
||||
break
|
||||
chunk = json.loads(payload)
|
||||
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
||||
text = delta.get("content", "")
|
||||
if text:
|
||||
yield _sse_line(
|
||||
{
|
||||
"event": "message",
|
||||
"task_id": task_id,
|
||||
"message_id": message_id,
|
||||
"conversation_id": conversation_id,
|
||||
"answer": text,
|
||||
"created_at": created_at,
|
||||
}
|
||||
)
|
||||
usage = chunk.get("usage")
|
||||
if usage:
|
||||
total_tokens = usage.get("total_tokens", total_tokens)
|
||||
except Exception as exc:
|
||||
yield _sse_line(
|
||||
{
|
||||
"event": "error",
|
||||
"task_id": task_id,
|
||||
"message_id": message_id,
|
||||
"code": "llm_error",
|
||||
"message": str(exc),
|
||||
}
|
||||
)
|
||||
return
|
||||
|
||||
retriever_resources = [
|
||||
{
|
||||
"position": i + 1,
|
||||
"dataset_id": "",
|
||||
"dataset_name": dataset_name,
|
||||
"document_id": "",
|
||||
"document_name": chunk.get("source", ""),
|
||||
"data_source_type": "upload_file",
|
||||
"segment_id": chunk.get("id", ""),
|
||||
"retriever_from": "rag",
|
||||
"score": round(chunk.get("score", 0.0), 4),
|
||||
"hit_count": 0,
|
||||
"word_count": len(chunk.get("text", "")),
|
||||
"segment_position": i + 1,
|
||||
"index_node_hash": "",
|
||||
"content": chunk.get("text", "")[:500],
|
||||
"page": None,
|
||||
}
|
||||
for i, chunk in enumerate(context_chunks)
|
||||
]
|
||||
|
||||
yield _sse_line(
|
||||
{
|
||||
"event": "message_end",
|
||||
"task_id": task_id,
|
||||
"message_id": message_id,
|
||||
"conversation_id": conversation_id,
|
||||
"metadata": {
|
||||
"usage": {"total_tokens": total_tokens},
|
||||
"retriever_resources": retriever_resources,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _sse_line(data: dict) -> str:
|
||||
return f"data: {json.dumps(data, ensure_ascii=False)}\\n\\n"
|
||||
@@ -0,0 +1,39 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
|
||||
import httpx
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG
|
||||
|
||||
|
||||
async def generate_followups(query: str, answer: str) -> list[str]:
|
||||
prompt = (
|
||||
"基于用户问题和已有回答,生成 3 个适合继续追问的简短问题。"
|
||||
"仅返回 JSON 数组字符串,例如 [\"问题1\", \"问题2\"]。\\n"
|
||||
f"用户问题: {query}\\n回答: {answer[:1200]}"
|
||||
)
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}" + "/chat/completions",
|
||||
json={
|
||||
"model": RAG_CONFIG["LLM_MODEL"],
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
"temperature": 0.5,
|
||||
"max_tokens": 256,
|
||||
"chat_template_kwargs": {"enable_thinking": False},
|
||||
},
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {RAG_CONFIG['LLM_API_KEY']}",
|
||||
},
|
||||
)
|
||||
resp.raise_for_status()
|
||||
content = resp.json()["choices"][0]["message"]["content"]
|
||||
try:
|
||||
parsed = json.loads(content)
|
||||
if isinstance(parsed, list):
|
||||
return [str(item).strip() for item in parsed if str(item).strip()][:3]
|
||||
except Exception:
|
||||
pass
|
||||
return [line.strip("- 1234567890.\t") for line in content.splitlines() if line.strip()][:3]
|
||||
Reference in New Issue
Block a user