feat: stabilize rag chat conversations and auto title sync
This commit is contained in:
@@ -15,6 +15,7 @@ from fastapi_modules.fastapi_leaudit.domian.Dto.ragChatDto import (
|
||||
RagConversationRenameDTO,
|
||||
RagChatSendMessageDTO,
|
||||
RagMessageFeedbackDTO,
|
||||
RagStopMessageDTO,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import (
|
||||
RagDatasetBatchDocumentDeleteDTO,
|
||||
@@ -479,6 +480,17 @@ class RagChatController(BaseController):
|
||||
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no", "Connection": "keep-alive"},
|
||||
)
|
||||
|
||||
@self.router.post("/chat/messages/{MessageId}/stop", response_model=Result[RagOperationResultVO])
|
||||
async def StopMessage(
|
||||
MessageId: str,
|
||||
Body: RagStopMessageDTO | None = None,
|
||||
payload: dict[str, Any] = Depends(verify_access_token),
|
||||
):
|
||||
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["chat_use"]]):
|
||||
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有停止 RAG 对话权限", "data": None})
|
||||
data = await self.RagChatService.StopMessage(int(payload["user_id"]), MessageId, Body)
|
||||
return Result.success(data=data)
|
||||
|
||||
@self.router.get("/chat/conversations", response_model=Result[RagConversationPageVO])
|
||||
async def GetConversations(
|
||||
appId: int | None = Query(None, description="聊天应用ID"),
|
||||
|
||||
@@ -13,3 +13,7 @@ class RagConversationRenameDTO(BaseModel):
|
||||
|
||||
class RagMessageFeedbackDTO(BaseModel):
|
||||
rating: str | None = Field(None, description="反馈: like/dislike/None")
|
||||
|
||||
|
||||
class RagStopMessageDTO(BaseModel):
|
||||
taskId: str | None = Field(None, description="流式任务ID")
|
||||
|
||||
@@ -17,8 +17,10 @@ class RagConversationItemVO(BaseModel):
|
||||
id: str = Field(..., description="会话ID")
|
||||
name: str = Field(..., description="会话名称")
|
||||
introduction: str = Field("", description="会话简介")
|
||||
titleSource: str = Field("default", description="标题来源: default/auto/manual")
|
||||
createdAt: int = Field(0, description="创建时间戳")
|
||||
updatedAt: int = Field(0, description="更新时间戳")
|
||||
lastMessageAt: int = Field(0, description="最后一条消息完成时间戳")
|
||||
|
||||
|
||||
class RagConversationPageVO(BaseModel):
|
||||
@@ -34,6 +36,8 @@ class RagMessageItemVO(BaseModel):
|
||||
answer: str = Field(...)
|
||||
feedback: dict | None = Field(None)
|
||||
retrieverResources: list[dict] | None = Field(None)
|
||||
suggestedQuestions: list[str] = Field(default_factory=list)
|
||||
status: str = Field("completed")
|
||||
createdAt: int = Field(0)
|
||||
|
||||
|
||||
|
||||
@@ -38,6 +38,7 @@ from fastapi_admin.config import (
|
||||
LEAUDIT_LLM_RETRY_BACKOFF_BASE_SECONDS,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.govdoc_engine.llm.cache import LlmCache, make_key
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.config import normalize_openai_base_url
|
||||
|
||||
_log = logging.getLogger(__name__)
|
||||
|
||||
@@ -153,8 +154,9 @@ class LlmClient:
|
||||
"LLM_API_KEY is not configured. Set LLM_API_KEY in platform config."
|
||||
)
|
||||
else:
|
||||
self._client = OpenAI(api_key=key, base_url=base_url or LLM_BASE_URL)
|
||||
self._aclient = AsyncOpenAI(api_key=key, base_url=base_url or LLM_BASE_URL)
|
||||
normalized_base_url = normalize_openai_base_url(base_url or LLM_BASE_URL)
|
||||
self._client = OpenAI(api_key=key, base_url=normalized_base_url)
|
||||
self._aclient = AsyncOpenAI(api_key=key, base_url=normalized_base_url)
|
||||
self.model = model or LLM_MODEL
|
||||
self.timeout = timeout_seconds if timeout_seconds is not None else LEAUDIT_LLM_REQUEST_TIMEOUT
|
||||
self.max_retries = max_retries if max_retries is not None else LEAUDIT_LLM_RETRY_MAX_ATTEMPTS
|
||||
|
||||
@@ -29,6 +29,7 @@ from fastapi_modules.fastapi_leaudit.leaudit_bridge.resilient_clients import (
|
||||
ResilientOpenAICompatibleClient,
|
||||
ResilientQwenVLMClient,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.config import normalize_openai_base_url
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from leaudit.llm.base import BaseLLMClient
|
||||
@@ -68,7 +69,7 @@ def create_ocr_client() -> BaseOCRClient:
|
||||
|
||||
def create_llm_client() -> BaseLLMClient:
|
||||
"""Create a leaudit OpenAICompatibleClient from docauditai's LLM config."""
|
||||
base_url = LLM_BASE_URL
|
||||
base_url = normalize_openai_base_url(LLM_BASE_URL)
|
||||
model = LLM_MODEL
|
||||
api_key = LLM_API_KEY or "no-key"
|
||||
|
||||
@@ -93,7 +94,7 @@ def create_llm_client() -> BaseLLMClient:
|
||||
|
||||
def create_vlm_client() -> BaseVLMClient | None:
|
||||
"""Create a leaudit QwenVLMClient from docauditai's VLM config."""
|
||||
base_url = VLM_BASE_URL
|
||||
base_url = normalize_openai_base_url(VLM_BASE_URL)
|
||||
model = VLM_MODEL
|
||||
api_key = VLM_API_KEY or LLM_API_KEY or "no-key"
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi_admin.config._settings import llm
|
||||
from fastapi_admin.config._settings import embedding, llm
|
||||
|
||||
|
||||
def _get_str(name: str, default: str = "") -> str:
|
||||
@@ -36,11 +36,23 @@ RAG_CONFIG = {
|
||||
"CHROMA_PORT": _get_int("RAG_CHROMA_PORT", 8010),
|
||||
"CHROMA_TOKEN": _get_str("RAG_CHROMA_TOKEN", ""),
|
||||
"CHROMA_AUTH_HEADER": _get_str("RAG_CHROMA_AUTH_HEADER", "X-Chroma-Token"),
|
||||
"EMBED_URL": _get_str("RAG_EMBED_URL", _get_str("GRAPH_RAG_EMBED_URL", "")),
|
||||
"EMBED_KEY": _get_str("RAG_EMBED_KEY", _get_str("GRAPH_RAG_EMBED_KEY", "")),
|
||||
"EMBED_MODEL": _get_str("RAG_EMBED_MODEL", _get_str("GRAPH_RAG_EMBED_MODEL", "")),
|
||||
"EMBED_DIM": _get_int("RAG_EMBED_DIM", 1024),
|
||||
"EMBED_BATCH_SIZE": _get_int("RAG_EMBED_BATCH_SIZE", 10),
|
||||
"EMBED_URL": _get_str(
|
||||
"RAG_EMBED_URL",
|
||||
_get_str("GRAPH_RAG_EMBED_URL", _get_str("EMBEDDING_BASE_URL", embedding.EMBEDDING_BASE_URL)),
|
||||
),
|
||||
"EMBED_KEY": _get_str(
|
||||
"RAG_EMBED_KEY",
|
||||
_get_str("GRAPH_RAG_EMBED_KEY", _get_str("EMBEDDING_API_KEY", embedding.EMBEDDING_API_KEY)),
|
||||
),
|
||||
"EMBED_MODEL": _get_str(
|
||||
"RAG_EMBED_MODEL",
|
||||
_get_str("GRAPH_RAG_EMBED_MODEL", _get_str("EMBEDDING_MODEL", embedding.EMBEDDING_MODEL)),
|
||||
),
|
||||
"EMBED_DIM": _get_int("RAG_EMBED_DIM", _get_int("EMBEDDING_DIM", embedding.EMBEDDING_DIM)),
|
||||
"EMBED_BATCH_SIZE": _get_int(
|
||||
"RAG_EMBED_BATCH_SIZE",
|
||||
_get_int("EMBEDDING_BATCH_SIZE", embedding.EMBEDDING_BATCH_SIZE),
|
||||
),
|
||||
"RERANKER_URL": _get_str("RAG_RERANKER_URL", _get_str("GRAPH_RAG_RERANKER_URL", "")),
|
||||
"RERANKER_KEY": _get_str("RAG_RERANKER_KEY", _get_str("GRAPH_RAG_RERANKER_KEY", "")),
|
||||
"RERANKER_MODEL": _get_str("RAG_RERANKER_MODEL", _get_str("GRAPH_RAG_RERANKER_MODEL", "")),
|
||||
@@ -58,3 +70,34 @@ RAG_CONFIG = {
|
||||
"HYBRID_SEARCH": _get_bool("RAG_HYBRID_SEARCH", True),
|
||||
"RERANKING": _get_bool("RAG_RERANKING", True),
|
||||
}
|
||||
|
||||
|
||||
def build_openai_chat_completions_url(base_url: str) -> str:
|
||||
normalized = (base_url or "").strip().rstrip("/")
|
||||
if not normalized:
|
||||
return "/chat/completions"
|
||||
if normalized.endswith("/chat/completions"):
|
||||
return normalized
|
||||
return f"{normalized}/chat/completions"
|
||||
|
||||
|
||||
def build_openai_embeddings_url(base_url: str) -> str:
|
||||
normalized = (base_url or "").strip().rstrip("/")
|
||||
if not normalized:
|
||||
return "/embeddings"
|
||||
if normalized.endswith("/chat/completions"):
|
||||
normalized = normalized[:-len("/chat/completions")]
|
||||
if normalized.endswith("/embeddings"):
|
||||
return normalized
|
||||
return f"{normalized}/embeddings"
|
||||
|
||||
|
||||
def normalize_openai_base_url(base_url: str) -> str:
|
||||
normalized = (base_url or "").strip().rstrip("/")
|
||||
if not normalized:
|
||||
return ""
|
||||
if normalized.endswith("/chat/completions"):
|
||||
return normalized[:-len("/chat/completions")]
|
||||
if normalized.endswith("/embeddings"):
|
||||
return normalized[:-len("/embeddings")]
|
||||
return normalized
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import AsyncGenerator
|
||||
|
||||
import httpx
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG, build_openai_chat_completions_url
|
||||
|
||||
DEFAULT_SYSTEM_PROMPT = """你是烟草行业智慧法务小助手,专注于烟草专卖法规、合同管理、行政处罚等相关法律法规。\n\n回答要求:\n- 先用一句话直接回答,再展开详细说明\n- 多个要点用编号列表\n- 关键法条和数字用 **加粗**\n- 分类信息用表格\n- 层级结构用缩进子列表\n- 不要加标题,直接输出正文"""
|
||||
|
||||
@@ -17,13 +17,14 @@ async def generate_stream(
|
||||
context_chunks: list[dict],
|
||||
conversation_id: str,
|
||||
message_id: str,
|
||||
task_id: str | None = None,
|
||||
system_prompt: str = "",
|
||||
model: str = "",
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
dataset_name: str = "",
|
||||
) -> AsyncGenerator[str, None]:
|
||||
task_id = str(uuid.uuid4())
|
||||
task_id = task_id or str(uuid.uuid4())
|
||||
created_at = int(time.time())
|
||||
_model = model or RAG_CONFIG["LLM_MODEL"]
|
||||
_temp = temperature if temperature is not None else RAG_CONFIG["LLM_TEMPERATURE"]
|
||||
@@ -55,7 +56,7 @@ async def generate_stream(
|
||||
async with httpx.AsyncClient(timeout=RAG_CONFIG["LLM_TIMEOUT"]) as client:
|
||||
async with client.stream(
|
||||
"POST",
|
||||
f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}" + "/chat/completions",
|
||||
build_openai_chat_completions_url(RAG_CONFIG["LLM_BASE_URL"]),
|
||||
json={
|
||||
"model": _model,
|
||||
"messages": messages,
|
||||
|
||||
@@ -4,7 +4,7 @@ import json
|
||||
|
||||
import httpx
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG, build_openai_chat_completions_url
|
||||
|
||||
|
||||
async def generate_followups(query: str, answer: str) -> list[str]:
|
||||
@@ -15,7 +15,7 @@ async def generate_followups(query: str, answer: str) -> list[str]:
|
||||
)
|
||||
async with httpx.AsyncClient(timeout=30.0) as client:
|
||||
resp = await client.post(
|
||||
f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}" + "/chat/completions",
|
||||
build_openai_chat_completions_url(RAG_CONFIG["LLM_BASE_URL"]),
|
||||
json={
|
||||
"model": RAG_CONFIG["LLM_MODEL"],
|
||||
"messages": [{"role": "user", "content": prompt}],
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -36,7 +36,7 @@ from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import (
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import RagOperationResultVO
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG, build_openai_embeddings_url
|
||||
from fastapi_modules.fastapi_leaudit.services.ragDatasetService import IRagDatasetService
|
||||
|
||||
|
||||
@@ -1503,7 +1503,7 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
||||
return chunks
|
||||
|
||||
async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]:
|
||||
embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}/embeddings"
|
||||
embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or build_openai_embeddings_url(RAG_CONFIG["LLM_BASE_URL"])
|
||||
embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"]
|
||||
embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4"
|
||||
batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10))
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import AsyncGenerator
|
||||
from fastapi_modules.fastapi_leaudit.domian.Dto.ragChatDto import (
|
||||
RagConversationRenameDTO,
|
||||
RagMessageFeedbackDTO,
|
||||
RagStopMessageDTO,
|
||||
)
|
||||
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import (
|
||||
RagAppParametersVO,
|
||||
@@ -52,6 +53,9 @@ class IRagChatService(ABC):
|
||||
@abstractmethod
|
||||
async def UpdateFeedback(self, CurrentUserId: int, MessageId: str, Body: RagMessageFeedbackDTO) -> RagOperationResultVO: ...
|
||||
|
||||
@abstractmethod
|
||||
async def StopMessage(self, CurrentUserId: int, MessageId: str, Body: RagStopMessageDTO | None = None) -> RagOperationResultVO: ...
|
||||
|
||||
@abstractmethod
|
||||
async def GetAppParameters(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user