feat: add rag backend and review access fixes

This commit is contained in:
wren
2026-05-08 10:58:24 +08:00
parent 1c84209f38
commit 9c86bf59e5
32 changed files with 3877 additions and 23 deletions
@@ -0,0 +1,173 @@
from __future__ import annotations
from typing import Any
from fastapi import Depends, Query
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi_common.fastapi_common_security.security import verify_access_token
from fastapi_common.fastapi_common_web.controller import BaseController
from fastapi_common.fastapi_common_web.domain.responses import Result
from fastapi_modules.fastapi_leaudit.domian.Dto.ragChatDto import (
RagConversationRenameDTO,
RagChatSendMessageDTO,
RagMessageFeedbackDTO,
)
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import (
RagAppParametersVO,
RagChatAppListVO,
RagChatAppVO,
RagConversationPageVO,
RagConversationRenameVO,
RagMessagePageVO,
RagOperationResultVO,
)
from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import RagDatasetPageVO
from fastapi_modules.fastapi_leaudit.services.impl.permissionServiceImpl import PermissionServiceImpl
from fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl import RagChatServiceImpl
from fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl import RagDatasetServiceImpl
from fastapi_modules.fastapi_leaudit.services.permissionService import IPermissionService
from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService
from fastapi_modules.fastapi_leaudit.services.ragDatasetService import IRagDatasetService
class RagChatController(BaseController):
_PERMISSIONS = {
"chat_use": "rag:chat:use",
"conversation_read": "rag:conversation:read",
"conversation_update": "rag:conversation:update",
"conversation_delete": "rag:conversation:delete",
"message_feedback": "rag:message:feedback",
"app_read": "rag:app:read",
"dataset_read": "rag:dataset:read",
}
def __init__(self):
super().__init__(prefix="/v3/rag", tags=["RAG 聊天"])
self.RagChatService: IRagChatService = RagChatServiceImpl()
self.RagDatasetService: IRagDatasetService = RagDatasetServiceImpl()
self.PermissionService: IPermissionService = PermissionServiceImpl()
@self.router.get("/apps", response_model=Result[RagChatAppListVO])
async def GetApps(payload: dict[str, Any] = Depends(verify_access_token)):
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["app_read"]]):
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看聊天应用权限", "data": None})
data = await self.RagChatService.GetApps(
CurrentUserId=int(payload["user_id"]),
UserArea=payload.get("area"),
UserRole=payload.get("user_role"),
)
return Result.success(data=data)
@self.router.get("/apps/default", response_model=Result[RagChatAppVO | None])
async def GetDefaultApp(payload: dict[str, Any] = Depends(verify_access_token)):
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["app_read"]]):
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看默认聊天应用权限", "data": None})
data = await self.RagChatService.GetDefaultApp(
CurrentUserId=int(payload["user_id"]),
UserArea=payload.get("area"),
UserRole=payload.get("user_role"),
)
return Result.success(data=data)
@self.router.get("/datasets/my", response_model=Result[RagDatasetPageVO])
async def GetMyDatasets(payload: dict[str, Any] = Depends(verify_access_token)):
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_read"]]):
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看知识库权限", "data": None})
data = await self.RagDatasetService.GetMyDatasets(
CurrentUserId=int(payload["user_id"]),
UserArea=payload.get("area"),
UserRole=payload.get("user_role"),
)
return Result.success(data=data)
@self.router.get("/chat/parameters", response_model=Result[RagAppParametersVO])
async def GetAppParameters(
appId: int | None = Query(None, description="聊天应用ID"),
payload: dict[str, Any] = Depends(verify_access_token),
):
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["chat_use"], self._PERMISSIONS["app_read"]]):
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看聊天配置权限", "data": None})
data = await self.RagChatService.GetAppParameters(
CurrentUserId=int(payload["user_id"]),
UserArea=payload.get("area"),
UserRole=payload.get("user_role"),
AppId=appId,
)
return Result.success(data=data)
@self.router.post("/chat/messages")
async def SendMessage(Body: RagChatSendMessageDTO, payload: dict[str, Any] = Depends(verify_access_token)):
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["chat_use"]]):
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有使用 RAG 对话权限", "data": None})
stream = self.RagChatService.SendMessage(
CurrentUserId=int(payload["user_id"]),
UserName=payload.get("username") or str(payload.get("user_id")),
UserArea=payload.get("area"),
UserRole=payload.get("user_role"),
Query=Body.query,
ConversationId=Body.conversationId,
AppId=Body.appId,
)
return StreamingResponse(
stream,
media_type="text/event-stream",
headers={"Cache-Control": "no-cache", "X-Accel-Buffering": "no", "Connection": "keep-alive"},
)
@self.router.get("/chat/conversations", response_model=Result[RagConversationPageVO])
async def GetConversations(
appId: int | None = Query(None, description="聊天应用ID"),
page: int = Query(1, ge=1),
pageSize: int = Query(20, ge=1, le=100),
payload: dict[str, Any] = Depends(verify_access_token),
):
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["conversation_read"]]):
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看聊天会话权限", "data": None})
data = await self.RagChatService.GetConversations(int(payload["user_id"]), appId, page, pageSize)
return Result.success(data=data)
@self.router.get("/chat/conversations/{ConversationId}/messages", response_model=Result[RagMessagePageVO])
async def GetConversationMessages(
ConversationId: str,
page: int = Query(1, ge=1),
pageSize: int = Query(20, ge=1, le=100),
payload: dict[str, Any] = Depends(verify_access_token),
):
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["conversation_read"]]):
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有查看聊天消息权限", "data": None})
data = await self.RagChatService.GetConversationMessages(int(payload["user_id"]), ConversationId, page, pageSize)
return Result.success(data=data)
@self.router.patch("/chat/conversations/{ConversationId}", response_model=Result[RagConversationRenameVO])
async def RenameConversation(
ConversationId: str,
Body: RagConversationRenameDTO,
payload: dict[str, Any] = Depends(verify_access_token),
):
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["conversation_update"]]):
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有修改聊天会话权限", "data": None})
data = await self.RagChatService.RenameConversation(int(payload["user_id"]), ConversationId, Body)
return Result.success(data=data)
@self.router.delete("/chat/conversations/{ConversationId}", response_model=Result[RagOperationResultVO])
async def DeleteConversation(ConversationId: str, payload: dict[str, Any] = Depends(verify_access_token)):
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["conversation_delete"]]):
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有删除聊天会话权限", "data": None})
data = await self.RagChatService.DeleteConversation(int(payload["user_id"]), ConversationId)
return Result.success(data=data)
@self.router.post("/chat/messages/{MessageId}/feedback", response_model=Result[RagOperationResultVO])
async def UpdateFeedback(
MessageId: str,
Body: RagMessageFeedbackDTO,
payload: dict[str, Any] = Depends(verify_access_token),
):
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["message_feedback"]]):
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有反馈聊天消息权限", "data": None})
data = await self.RagChatService.UpdateFeedback(int(payload["user_id"]), MessageId, Body)
return Result.success(data=data)
async def _check_permission(self, user_id: int, permission_keys: list[str]) -> bool:
return await self.PermissionService.HasAnyPermission(UserId=user_id, PermissionKeys=permission_keys)
@@ -0,0 +1,15 @@
from pydantic import BaseModel, Field
class RagChatSendMessageDTO(BaseModel):
query: str = Field(..., min_length=1, description="用户问题")
conversationId: str | None = Field(None, description="会话ID")
appId: int | None = Field(None, description="聊天应用ID")
class RagConversationRenameDTO(BaseModel):
name: str = Field(..., min_length=1, max_length=500, description="新会话名称")
class RagMessageFeedbackDTO(BaseModel):
rating: str | None = Field(None, description="反馈: like/dislike/None")
@@ -0,0 +1,59 @@
from pydantic import BaseModel, Field
class RagChatAppVO(BaseModel):
appId: str = Field(..., description="应用ID")
appName: str = Field(..., description="应用名称")
description: str = Field("", description="应用描述")
isDefault: bool = Field(False, description="是否默认应用")
class RagChatAppListVO(BaseModel):
data: list[RagChatAppVO] = Field(default_factory=list)
total: int = Field(0)
class RagConversationItemVO(BaseModel):
id: str = Field(..., description="会话ID")
name: str = Field(..., description="会话名称")
introduction: str = Field("", description="会话简介")
createdAt: int = Field(0, description="创建时间戳")
updatedAt: int = Field(0, description="更新时间戳")
class RagConversationPageVO(BaseModel):
data: list[RagConversationItemVO] = Field(default_factory=list)
hasMore: bool = Field(False)
limit: int = Field(20)
class RagMessageItemVO(BaseModel):
id: str = Field(...)
conversationId: str = Field(...)
query: str = Field(...)
answer: str = Field(...)
feedback: dict | None = Field(None)
retrieverResources: list[dict] | None = Field(None)
createdAt: int = Field(0)
class RagMessagePageVO(BaseModel):
data: list[RagMessageItemVO] = Field(default_factory=list)
hasMore: bool = Field(False)
limit: int = Field(20)
class RagConversationRenameVO(BaseModel):
result: str = Field("success")
name: str = Field(...)
class RagOperationResultVO(BaseModel):
result: str = Field("success")
class RagAppParametersVO(BaseModel):
openingStatement: str = Field("", description="开场白")
suggestedQuestions: list[str] = Field(default_factory=list)
userInputForm: list[dict] = Field(default_factory=list)
fileUpload: dict = Field(default_factory=lambda: {"image": {"enabled": False}})
@@ -0,0 +1,18 @@
from pydantic import BaseModel, Field
class RagDatasetItemVO(BaseModel):
id: int = Field(...)
name: str = Field(...)
description: str = Field("")
area: str = Field("")
isPublic: bool = Field(False)
isDefault: bool = Field(False)
documentCount: int = Field(0)
totalChunks: int = Field(0)
status: int = Field(1)
class RagDatasetPageVO(BaseModel):
data: list[RagDatasetItemVO] = Field(default_factory=list)
total: int = Field(0)
@@ -8,6 +8,11 @@ from fastapi_modules.fastapi_leaudit.models.leauditCrossReviewTask import Leaudi
from fastapi_modules.fastapi_leaudit.models.leauditCrossReviewTaskDocument import LeauditCrossReviewTaskDocument
from fastapi_modules.fastapi_leaudit.models.leauditCrossReviewTaskMember import LeauditCrossReviewTaskMember
from fastapi_modules.fastapi_leaudit.models.leauditCrossReviewVote import LeauditCrossReviewVote
from fastapi_modules.fastapi_leaudit.models.leauditRagDataset import LeauditRagDataset
from fastapi_modules.fastapi_leaudit.models.leauditRagDocument import LeauditRagDocument
from fastapi_modules.fastapi_leaudit.models.leauditRagChatApp import LeauditRagChatApp
from fastapi_modules.fastapi_leaudit.models.leauditRagConversation import LeauditRagConversation
from fastapi_modules.fastapi_leaudit.models.leauditRagMessage import LeauditRagMessage
__all__ = [
"LeauditDocument",
@@ -18,4 +23,9 @@ __all__ = [
"LeauditCrossReviewTaskDocument",
"LeauditCrossReviewProposal",
"LeauditCrossReviewVote",
"LeauditRagDataset",
"LeauditRagDocument",
"LeauditRagChatApp",
"LeauditRagConversation",
"LeauditRagMessage",
]
@@ -0,0 +1,27 @@
from __future__ import annotations
from sqlalchemy import BigInteger, Boolean, Float, Integer, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from fastapi_common.fastapi_common_web.models import BaseModel
class LeauditRagChatApp(BaseModel):
__tablename__ = "rag_chat_app"
Id: Mapped[int] = mapped_column("id", BigInteger, primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(255))
description: Mapped[str] = mapped_column(Text, default="")
area: Mapped[str] = mapped_column(String(50), default="")
datasetId: Mapped[int | None] = mapped_column("dataset_id", BigInteger)
systemPrompt: Mapped[str] = mapped_column("system_prompt", Text, default="")
llmModel: Mapped[str] = mapped_column("llm_model", String(100), default="")
temperature: Mapped[float] = mapped_column(Float, default=0.3)
maxTokens: Mapped[int] = mapped_column("max_tokens", Integer, default=2048)
openingStatement: Mapped[str] = mapped_column("opening_statement", Text, default="")
suggestedQuestions: Mapped[str] = mapped_column("suggested_questions", Text, default="[]")
isDefault: Mapped[bool] = mapped_column("is_default", Boolean, default=False)
sortOrder: Mapped[int] = mapped_column("sort_order", Integer, default=0)
status: Mapped[int] = mapped_column(Integer, default=1)
createdBy: Mapped[int | None] = mapped_column("created_by", BigInteger)
updatedBy: Mapped[int | None] = mapped_column("updated_by", BigInteger)
@@ -0,0 +1,17 @@
from __future__ import annotations
from sqlalchemy import BigInteger, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from fastapi_common.fastapi_common_web.models import BaseModel
class LeauditRagConversation(BaseModel):
__tablename__ = "rag_conversation"
Id: Mapped[int] = mapped_column("id", BigInteger, primary_key=True, autoincrement=True)
conversationId: Mapped[str] = mapped_column("conversation_id", String(100), unique=True)
userId: Mapped[int] = mapped_column("user_id", BigInteger)
appId: Mapped[int | None] = mapped_column("app_id", BigInteger)
name: Mapped[str] = mapped_column(String(500), default="新对话")
introduction: Mapped[str] = mapped_column(Text, default="")
@@ -0,0 +1,30 @@
from __future__ import annotations
from sqlalchemy import BigInteger, Boolean, Integer, String, Text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from fastapi_common.fastapi_common_web.models import BaseModel
class LeauditRagDataset(BaseModel):
__tablename__ = "rag_dataset"
Id: Mapped[int] = mapped_column("id", BigInteger, primary_key=True, autoincrement=True)
name: Mapped[str] = mapped_column(String(255), comment="知识库名称")
description: Mapped[str] = mapped_column(Text, default="", comment="知识库描述")
area: Mapped[str] = mapped_column(String(50), default="", comment="地区")
isPublic: Mapped[bool] = mapped_column("is_public", Boolean, default=False)
isDefault: Mapped[bool] = mapped_column("is_default", Boolean, default=False)
collectionName: Mapped[str] = mapped_column("collection_name", String(100), unique=True)
embeddingModel: Mapped[str] = mapped_column("embedding_model", String(100), default="text-embedding-v4")
embeddingDim: Mapped[int] = mapped_column("embedding_dim", Integer, default=1024)
chunkMaxSize: Mapped[int] = mapped_column("chunk_max_size", Integer, default=800)
chunkMinSize: Mapped[int] = mapped_column("chunk_min_size", Integer, default=20)
documentCount: Mapped[int] = mapped_column("document_count", Integer, default=0)
totalChunks: Mapped[int] = mapped_column("total_chunks", Integer, default=0)
retrievalModel: Mapped[dict] = mapped_column("retrieval_model", JSONB, default=dict)
sortOrder: Mapped[int] = mapped_column("sort_order", Integer, default=0)
status: Mapped[int] = mapped_column(Integer, default=1)
createdBy: Mapped[int | None] = mapped_column("created_by", BigInteger)
updatedBy: Mapped[int | None] = mapped_column("updated_by", BigInteger)
@@ -0,0 +1,24 @@
from __future__ import annotations
from sqlalchemy import BigInteger, Boolean, Integer, String, Text
from sqlalchemy.orm import Mapped, mapped_column
from fastapi_common.fastapi_common_web.models import BaseModel
class LeauditRagDocument(BaseModel):
__tablename__ = "rag_document"
Id: Mapped[int] = mapped_column("id", BigInteger, primary_key=True, autoincrement=True)
datasetId: Mapped[int] = mapped_column("dataset_id", BigInteger)
filename: Mapped[str] = mapped_column(String(500))
originalName: Mapped[str] = mapped_column("original_name", String(500))
minioPath: Mapped[str] = mapped_column("minio_path", String(1000))
fileType: Mapped[str] = mapped_column("file_type", String(20))
fileSize: Mapped[int] = mapped_column("file_size", BigInteger, default=0)
chunkCount: Mapped[int] = mapped_column("chunk_count", Integer, default=0)
indexingStatus: Mapped[str] = mapped_column("indexing_status", String(20), default="pending")
indexingError: Mapped[str | None] = mapped_column("indexing_error", Text)
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
hitCount: Mapped[int] = mapped_column("hit_count", Integer, default=0)
createdBy: Mapped[int | None] = mapped_column("created_by", BigInteger)
@@ -0,0 +1,20 @@
from __future__ import annotations
from sqlalchemy import BigInteger, String, Text
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy.orm import Mapped, mapped_column
from fastapi_common.fastapi_common_web.models import BaseModel
class LeauditRagMessage(BaseModel):
__tablename__ = "rag_message"
Id: Mapped[int] = mapped_column("id", BigInteger, primary_key=True, autoincrement=True)
messageId: Mapped[str] = mapped_column("message_id", String(100), unique=True)
conversationId: Mapped[str] = mapped_column("conversation_id", String(100))
role: Mapped[str] = mapped_column(String(20))
content: Mapped[str] = mapped_column(Text, default="")
sources: Mapped[list] = mapped_column(JSONB, default=list)
metadataJson: Mapped[dict] = mapped_column("metadata", JSONB, default=dict)
feedback: Mapped[str | None] = mapped_column(String(20))
@@ -0,0 +1 @@
"""RAG 聊天内核兼容层。"""
@@ -0,0 +1,40 @@
from __future__ import annotations
from typing import Any
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG
_instance: Any | None = None
def init_chroma() -> Any:
global _instance
if _instance is not None:
return _instance
import chromadb # lazy import to avoid hard failure before feature is enabled
import chromadb.config
host = RAG_CONFIG["CHROMA_HOST"]
if host:
token = RAG_CONFIG.get("CHROMA_TOKEN", "")
header = RAG_CONFIG.get("CHROMA_AUTH_HEADER", "X-Chroma-Token")
settings = (
chromadb.config.Settings(
chroma_client_auth_provider="chromadb.auth.token_authn.TokenAuthClientProvider",
chroma_client_auth_credentials=token,
chroma_auth_token_transport_header=header,
)
if token
else chromadb.config.Settings()
)
_instance = chromadb.HttpClient(host=host, port=RAG_CONFIG["CHROMA_PORT"], settings=settings)
else:
_instance = chromadb.PersistentClient(path=RAG_CONFIG["CHROMA_PERSIST_DIR"])
return _instance
def get_chroma() -> Any:
if _instance is None:
return init_chroma()
return _instance
@@ -0,0 +1,60 @@
from __future__ import annotations
from fastapi_admin.config._settings import llm
def _get_str(name: str, default: str = "") -> str:
import os
return os.getenv(name, default)
def _get_bool(name: str, default: bool = False) -> bool:
import os
return os.getenv(name, str(default).lower()).lower() == "true"
def _get_int(name: str, default: int) -> int:
import os
try:
return int(os.getenv(name, str(default)))
except ValueError:
return default
def _get_float(name: str, default: float) -> float:
import os
try:
return float(os.getenv(name, str(default)))
except ValueError:
return default
RAG_CONFIG = {
"USE_SELF_HOSTED": True,
"CHROMA_PERSIST_DIR": _get_str("RAG_CHROMA_PERSIST_DIR", ".chromadb_rag"),
"CHROMA_HOST": _get_str("RAG_CHROMA_HOST", ""),
"CHROMA_PORT": _get_int("RAG_CHROMA_PORT", 8010),
"CHROMA_TOKEN": _get_str("RAG_CHROMA_TOKEN", ""),
"CHROMA_AUTH_HEADER": _get_str("RAG_CHROMA_AUTH_HEADER", "X-Chroma-Token"),
"EMBED_URL": _get_str("RAG_EMBED_URL", _get_str("GRAPH_RAG_EMBED_URL", "")),
"EMBED_KEY": _get_str("RAG_EMBED_KEY", _get_str("GRAPH_RAG_EMBED_KEY", "")),
"EMBED_MODEL": _get_str("RAG_EMBED_MODEL", _get_str("GRAPH_RAG_EMBED_MODEL", "")),
"EMBED_DIM": _get_int("RAG_EMBED_DIM", 1024),
"EMBED_BATCH_SIZE": _get_int("RAG_EMBED_BATCH_SIZE", 10),
"RERANKER_URL": _get_str("RAG_RERANKER_URL", _get_str("GRAPH_RAG_RERANKER_URL", "")),
"RERANKER_KEY": _get_str("RAG_RERANKER_KEY", _get_str("GRAPH_RAG_RERANKER_KEY", "")),
"RERANKER_MODEL": _get_str("RAG_RERANKER_MODEL", _get_str("GRAPH_RAG_RERANKER_MODEL", "")),
"LLM_BASE_URL": _get_str("LLM_BASE_URL", llm.LLM_BASE_URL),
"LLM_MODEL": _get_str("LLM_MODEL", llm.LLM_MODEL),
"LLM_API_KEY": _get_str("LLM_API_KEY", llm.LLM_API_KEY),
"VECTOR_TOP_K": _get_int("RAG_VECTOR_TOP_K", 15),
"RERANK_TOP_K": _get_int("RAG_RERANK_TOP_K", 5),
"BM25_TOP_K": _get_int("RAG_BM25_TOP_K", 15),
"RRF_K": _get_int("RAG_RRF_K", 60),
"LLM_TEMPERATURE": _get_float("RAG_LLM_TEMPERATURE", 0.3),
"LLM_MAX_TOKENS": _get_int("RAG_LLM_MAX_TOKENS", 2048),
"LLM_TIMEOUT": _get_int("RAG_LLM_TIMEOUT", 120),
"QUERY_REWRITING": _get_bool("RAG_QUERY_REWRITING", False),
"HYBRID_SEARCH": _get_bool("RAG_HYBRID_SEARCH", True),
"RERANKING": _get_bool("RAG_RERANKING", True),
}
@@ -0,0 +1,144 @@
from __future__ import annotations
import json
import time
import uuid
from typing import AsyncGenerator
import httpx
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG
DEFAULT_SYSTEM_PROMPT = """你是烟草行业智慧法务小助手,专注于烟草专卖法规、合同管理、行政处罚等相关法律法规。\n\n回答要求:\n- 先用一句话直接回答,再展开详细说明\n- 多个要点用编号列表\n- 关键法条和数字用 **加粗**\n- 分类信息用表格\n- 层级结构用缩进子列表\n- 不要加标题,直接输出正文"""
async def generate_stream(
query: str,
context_chunks: list[dict],
conversation_id: str,
message_id: str,
system_prompt: str = "",
model: str = "",
temperature: float | None = None,
max_tokens: int | None = None,
dataset_name: str = "",
) -> AsyncGenerator[str, None]:
task_id = str(uuid.uuid4())
created_at = int(time.time())
_model = model or RAG_CONFIG["LLM_MODEL"]
_temp = temperature if temperature is not None else RAG_CONFIG["LLM_TEMPERATURE"]
_max_tok = max_tokens or RAG_CONFIG["LLM_MAX_TOKENS"]
_prompt = system_prompt or DEFAULT_SYSTEM_PROMPT
max_context_chars = 8000
if context_chunks:
parts: list[str] = []
total_len = 0
for chunk in context_chunks:
part = f"[来源: {chunk.get('source', '未知')}]\\n{chunk.get('text', '')}"
if total_len + len(part) > max_context_chars:
break
parts.append(part)
total_len += len(part)
context_text = "\\n\\n---\\n\\n".join(parts)
user_content = f"知识库内容:\\n{context_text}\\n\\n用户问题: {query}"
else:
user_content = query
messages = [
{"role": "system", "content": _prompt},
{"role": "user", "content": user_content},
]
total_tokens = 0
try:
async with httpx.AsyncClient(timeout=RAG_CONFIG["LLM_TIMEOUT"]) as client:
async with client.stream(
"POST",
f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}" + "/chat/completions",
json={
"model": _model,
"messages": messages,
"temperature": _temp,
"max_tokens": _max_tok,
"stream": True,
"chat_template_kwargs": {"enable_thinking": False},
},
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {RAG_CONFIG['LLM_API_KEY']}",
},
) as resp:
resp.raise_for_status()
async for line in resp.aiter_lines():
if not line.startswith("data: "):
continue
payload = line[6:].strip()
if payload == "[DONE]":
break
chunk = json.loads(payload)
delta = chunk.get("choices", [{}])[0].get("delta", {})
text = delta.get("content", "")
if text:
yield _sse_line(
{
"event": "message",
"task_id": task_id,
"message_id": message_id,
"conversation_id": conversation_id,
"answer": text,
"created_at": created_at,
}
)
usage = chunk.get("usage")
if usage:
total_tokens = usage.get("total_tokens", total_tokens)
except Exception as exc:
yield _sse_line(
{
"event": "error",
"task_id": task_id,
"message_id": message_id,
"code": "llm_error",
"message": str(exc),
}
)
return
retriever_resources = [
{
"position": i + 1,
"dataset_id": "",
"dataset_name": dataset_name,
"document_id": "",
"document_name": chunk.get("source", ""),
"data_source_type": "upload_file",
"segment_id": chunk.get("id", ""),
"retriever_from": "rag",
"score": round(chunk.get("score", 0.0), 4),
"hit_count": 0,
"word_count": len(chunk.get("text", "")),
"segment_position": i + 1,
"index_node_hash": "",
"content": chunk.get("text", "")[:500],
"page": None,
}
for i, chunk in enumerate(context_chunks)
]
yield _sse_line(
{
"event": "message_end",
"task_id": task_id,
"message_id": message_id,
"conversation_id": conversation_id,
"metadata": {
"usage": {"total_tokens": total_tokens},
"retriever_resources": retriever_resources,
},
}
)
def _sse_line(data: dict) -> str:
return f"data: {json.dumps(data, ensure_ascii=False)}\\n\\n"
@@ -0,0 +1,39 @@
from __future__ import annotations
import json
import httpx
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG
async def generate_followups(query: str, answer: str) -> list[str]:
prompt = (
"基于用户问题和已有回答,生成 3 个适合继续追问的简短问题。"
"仅返回 JSON 数组字符串,例如 [\"问题1\", \"问题2\"]。\\n"
f"用户问题: {query}\\n回答: {answer[:1200]}"
)
async with httpx.AsyncClient(timeout=30.0) as client:
resp = await client.post(
f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}" + "/chat/completions",
json={
"model": RAG_CONFIG["LLM_MODEL"],
"messages": [{"role": "user", "content": prompt}],
"temperature": 0.5,
"max_tokens": 256,
"chat_template_kwargs": {"enable_thinking": False},
},
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {RAG_CONFIG['LLM_API_KEY']}",
},
)
resp.raise_for_status()
content = resp.json()["choices"][0]["message"]["content"]
try:
parsed = json.loads(content)
if isinstance(parsed, list):
return [str(item).strip() for item in parsed if str(item).strip()][:3]
except Exception:
pass
return [line.strip("- 1234567890.\t") for line in content.splitlines() if line.strip()][:3]
@@ -14,6 +14,8 @@ from fastapi_modules.fastapi_leaudit.services.rbacAdminService import IRbacAdmin
from fastapi_modules.fastapi_leaudit.services.rbacService import IRbacService
from fastapi_modules.fastapi_leaudit.services.ruleConfigService import IRuleConfigService
from fastapi_modules.fastapi_leaudit.services.ruleService import IRuleService
from fastapi_modules.fastapi_leaudit.services.ragDatasetService import IRagDatasetService
from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService
__all__ = [
"IAuditService",
@@ -30,4 +32,6 @@ __all__ = [
"IRbacService",
"IRuleConfigService",
"IRuleService",
"IRagDatasetService",
"IRagChatService",
]
@@ -618,6 +618,15 @@ class DocumentServiceImpl(IDocumentService):
currentUser = await self._getCurrentUserContext(CurrentUserId)
documentColumns = await self._loadDocumentColumns(Session)
detail = await self._getDocumentDetail(Session, DocumentId, CurrentUserId, currentUser, documentColumns)
if not detail and await self._hasCrossReviewDocumentAccess(Session, DocumentId, CurrentUserId):
detail = await self._getDocumentDetail(
Session,
DocumentId,
CurrentUserId,
currentUser,
documentColumns,
BypassScopeCheck=True,
)
if not detail:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在或无权访问")
@@ -679,6 +688,15 @@ class DocumentServiceImpl(IDocumentService):
currentUser = await self._getCurrentUserContext(CurrentUserId)
documentColumns = await self._loadDocumentColumns(Session)
detail = await self._getDocumentDetail(Session, documentId, CurrentUserId, currentUser, documentColumns)
if not detail and await self._hasCrossReviewDocumentAccess(Session, documentId, CurrentUserId):
detail = await self._getDocumentDetail(
Session,
documentId,
CurrentUserId,
currentUser,
documentColumns,
BypassScopeCheck=True,
)
if not detail:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在或无权访问")
@@ -742,6 +760,15 @@ class DocumentServiceImpl(IDocumentService):
currentUser = await self._getCurrentUserContext(CurrentUserId)
documentColumns = await self._loadDocumentColumns(Session)
detail = await self._getDocumentDetail(Session, DocumentId, CurrentUserId, currentUser, documentColumns)
if not detail and await self._hasCrossReviewDocumentAccess(Session, DocumentId, CurrentUserId):
detail = await self._getDocumentDetail(
Session,
DocumentId,
CurrentUserId,
currentUser,
documentColumns,
BypassScopeCheck=True,
)
if not detail:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在或无权访问")
@@ -1601,19 +1628,21 @@ class DocumentServiceImpl(IDocumentService):
CurrentUserId: int,
CurrentUser: dict[str, Any],
DocumentColumns: set[str],
BypassScopeCheck: bool = False,
) -> DocumentDetailVO | None:
"""查询单文档详情,并附带历史版本。"""
params: dict[str, object] = {"id": DocumentId}
filters = ["d.id = :id", "d.deleted_at IS NULL", "f.is_active = true", "f.file_role = 'primary'"]
filters.extend(
self._buildDocumentScopeFilters(
CurrentUserId=CurrentUserId,
CurrentUser=CurrentUser,
Params=params,
DocumentAlias="d",
FileAlias="f",
if not BypassScopeCheck:
filters.extend(
self._buildDocumentScopeFilters(
CurrentUserId=CurrentUserId,
CurrentUser=CurrentUser,
Params=params,
DocumentAlias="d",
FileAlias="f",
)
)
)
whereClause = " AND ".join(filters)
groupIdSelectExpr = "d.group_id" if "group_id" in DocumentColumns else "NULL::bigint"
@@ -1832,6 +1861,38 @@ class DocumentServiceImpl(IDocumentService):
attachments=attachments,
)
async def _hasCrossReviewDocumentAccess(self, Session, DocumentId: int, CurrentUserId: int) -> bool:
"""判断当前用户是否作为交叉评查任务成员拥有文档访问权。"""
if not await self._tableExists(Session, "leaudit_cross_review_task_documents"):
return False
if not await self._tableExists(Session, "leaudit_cross_review_task_members"):
return False
if not await self._tableExists(Session, "leaudit_cross_review_tasks"):
return False
row = (
await Session.execute(
text(
"""
SELECT 1
FROM leaudit_cross_review_task_documents td
JOIN leaudit_cross_review_task_members tm
ON tm.task_id = td.task_id
JOIN leaudit_cross_review_tasks t
ON t.id = td.task_id
WHERE td.document_id = :document_id
AND tm.user_id = :user_id
AND td.delete_time IS NULL
AND tm.delete_time IS NULL
AND t.delete_time IS NULL
LIMIT 1
"""
),
{"document_id": DocumentId, "user_id": CurrentUserId},
)
).first()
return bool(row)
def _buildDocumentScopeFilters(
self,
CurrentUserId: int,
@@ -0,0 +1,589 @@
from __future__ import annotations
import json
import uuid
from typing import AsyncGenerator
from sqlalchemy import text
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
from fastapi_modules.fastapi_leaudit.domian.Dto.ragChatDto import (
RagConversationRenameDTO,
RagMessageFeedbackDTO,
)
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import (
RagAppParametersVO,
RagChatAppListVO,
RagChatAppVO,
RagConversationItemVO,
RagConversationPageVO,
RagConversationRenameVO,
RagMessageItemVO,
RagMessagePageVO,
RagOperationResultVO,
)
from fastapi_modules.fastapi_leaudit.rag_engine.generator import generate_stream
from fastapi_modules.fastapi_leaudit.rag_engine.question_chains import generate_followups
from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService
class RagChatServiceImpl(IRagChatService):
async def GetApps(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagChatAppListVO:
apps = await self._load_apps(UserArea, UserRole, only_default=False)
return RagChatAppListVO(data=apps, total=len(apps))
async def GetDefaultApp(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagChatAppVO | None:
apps = await self._load_apps(UserArea, UserRole, only_default=True)
if apps:
return apps[0]
all_apps = await self._load_apps(UserArea, UserRole, only_default=False)
return all_apps[0] if all_apps else None
async def SendMessage(
self,
CurrentUserId: int,
UserName: str,
UserArea: str | None,
UserRole: str | None,
Query: str,
ConversationId: str | None,
AppId: int | None,
) -> AsyncGenerator[bytes, None]:
if not Query.strip():
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "问题不能为空")
app = await self._resolve_app(AppId, UserArea, UserRole)
if not app:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "未配置可用聊天应用")
conversationId = await self._ensure_conversation(CurrentUserId, ConversationId, app["id"])
messageId = str(uuid.uuid4())
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata)
VALUES (:message_id, :conversation_id, 'user', :content, '[]'::jsonb, '{}'::jsonb)
"""
),
{
"message_id": str(uuid.uuid4()),
"conversation_id": conversationId,
"content": Query,
},
)
context_chunks, dataset_name = await self._retrieve_context(app.get("dataset_id"), Query)
collected_answer = ""
held_message_end: bytes | None = None
async for chunk in generate_stream(
query=Query,
context_chunks=context_chunks,
conversation_id=conversationId,
message_id=messageId,
system_prompt=app.get("system_prompt") or "",
model=app.get("llm_model") or "",
temperature=app.get("temperature"),
max_tokens=app.get("max_tokens"),
dataset_name=dataset_name,
):
chunk_bytes = chunk.encode("utf-8")
for line in chunk.strip().split("\n"):
if not line.startswith("data: "):
continue
data = json.loads(line[6:])
if data.get("event") == "message":
collected_answer += data.get("answer", "")
elif data.get("event") == "message_end":
held_message_end = chunk_bytes
continue
if held_message_end is None:
yield chunk_bytes
followups: list[str] = []
try:
followups = await generate_followups(Query, collected_answer)
except Exception:
followups = []
if held_message_end:
try:
for line in held_message_end.decode("utf-8").strip().split("\n"):
if not line.startswith("data: "):
continue
end_data = json.loads(line[6:])
if end_data.get("event") == "message_end":
end_data.setdefault("metadata", {})["suggested_questions"] = followups
yield f"data: {json.dumps(end_data, ensure_ascii=False)}\\n\\n".encode("utf-8")
except Exception:
yield held_message_end
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
INSERT INTO rag_message (message_id, conversation_id, role, content, sources, metadata)
VALUES (:message_id, :conversation_id, 'assistant', :content, CAST(:sources AS jsonb), CAST(:metadata AS jsonb))
"""
),
{
"message_id": messageId,
"conversation_id": conversationId,
"content": collected_answer,
"sources": json.dumps(self._build_sources(context_chunks, dataset_name), ensure_ascii=False),
"metadata": json.dumps({"suggested_questions": followups}, ensure_ascii=False),
},
)
await session.execute(
text(
"UPDATE rag_conversation SET updated_at = NOW() WHERE conversation_id = :conversation_id"
),
{"conversation_id": conversationId},
)
async def GetConversations(self, CurrentUserId: int, AppId: int | None, Page: int, PageSize: int) -> RagConversationPageVO:
async with GetAsyncSession() as session:
rows = (
await session.execute(
text(
"""
SELECT conversation_id, name, introduction, created_at, updated_at
FROM rag_conversation
WHERE user_id = :user_id
AND deleted_at IS NULL
AND (:app_id IS NULL OR app_id = :app_id)
ORDER BY updated_at DESC
OFFSET :offset LIMIT :limit
"""
),
{
"user_id": CurrentUserId,
"app_id": AppId,
"offset": max(Page - 1, 0) * PageSize,
"limit": PageSize + 1,
},
)
).mappings().all()
has_more = len(rows) > PageSize
items = rows[:PageSize]
return RagConversationPageVO(
data=[
RagConversationItemVO(
id=row["conversation_id"],
name=row["name"],
introduction=row.get("introduction") or "",
createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0,
updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0,
)
for row in items
],
hasMore=has_more,
limit=PageSize,
)
async def GetConversationMessages(self, CurrentUserId: int, ConversationId: str, Page: int, PageSize: int) -> RagMessagePageVO:
await self._ensure_conversation_owner(CurrentUserId, ConversationId)
async with GetAsyncSession() as session:
rows = (
await session.execute(
text(
"""
SELECT message_id, role, content, sources, feedback, created_at
FROM rag_message
WHERE conversation_id = :conversation_id
ORDER BY created_at ASC
OFFSET :offset LIMIT :limit
"""
),
{
"conversation_id": ConversationId,
"offset": max(Page - 1, 0) * PageSize,
"limit": PageSize + 1,
},
)
).mappings().all()
has_more = len(rows) > PageSize
items = rows[:PageSize]
data: list[RagMessageItemVO] = []
idx = 0
while idx < len(items):
row = items[idx]
if row["role"] == "user":
answer = items[idx + 1] if idx + 1 < len(items) and items[idx + 1]["role"] == "assistant" else None
data.append(
RagMessageItemVO(
id=(answer["message_id"] if answer else row["message_id"]),
conversationId=ConversationId,
query=row["content"],
answer=answer["content"] if answer else "",
feedback=({"rating": answer["feedback"]} if answer and answer.get("feedback") else None),
retrieverResources=(answer.get("sources") if answer else None),
createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0,
)
)
idx += 2 if answer else 1
else:
idx += 1
return RagMessagePageVO(data=data, hasMore=has_more, limit=PageSize)
async def RenameConversation(self, CurrentUserId: int, ConversationId: str, Body: RagConversationRenameDTO) -> RagConversationRenameVO:
await self._ensure_conversation_owner(CurrentUserId, ConversationId)
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"UPDATE rag_conversation SET name = :name, updated_at = NOW() WHERE conversation_id = :conversation_id"
),
{"name": Body.name, "conversation_id": ConversationId},
)
return RagConversationRenameVO(result="success", name=Body.name)
async def DeleteConversation(self, CurrentUserId: int, ConversationId: str) -> RagOperationResultVO:
await self._ensure_conversation_owner(CurrentUserId, ConversationId)
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"UPDATE rag_conversation SET deleted_at = NOW(), updated_at = NOW() WHERE conversation_id = :conversation_id"
),
{"conversation_id": ConversationId},
)
return RagOperationResultVO(result="success")
async def UpdateFeedback(self, CurrentUserId: int, MessageId: str, Body: RagMessageFeedbackDTO) -> RagOperationResultVO:
async with GetAsyncSession() as session:
owner = (
await session.execute(
text(
"""
SELECT c.user_id
FROM rag_message m
JOIN rag_conversation c ON c.conversation_id = m.conversation_id
WHERE m.message_id = :message_id AND c.deleted_at IS NULL
LIMIT 1
"""
),
{"message_id": MessageId},
)
).scalar_one_or_none()
if owner is None:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "消息不存在")
if int(owner) != CurrentUserId:
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权修改该消息反馈")
async with session.begin():
await session.execute(
text("UPDATE rag_message SET feedback = :feedback WHERE message_id = :message_id"),
{"feedback": Body.rating, "message_id": MessageId},
)
return RagOperationResultVO(result="success")
async def GetAppParameters(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
AppId: int | None,
) -> RagAppParametersVO:
app = await self._resolve_app(AppId, UserArea, UserRole)
if not app:
return RagAppParametersVO()
try:
suggested = json.loads(app.get("suggested_questions") or "[]")
if not isinstance(suggested, list):
suggested = []
except Exception:
suggested = []
return RagAppParametersVO(
openingStatement=app.get("opening_statement") or "",
suggestedQuestions=[str(item) for item in suggested[:6]],
userInputForm=[],
fileUpload={"image": {"enabled": False}},
)
async def _load_apps(self, user_area: str | None, user_role: str | None, only_default: bool) -> list[RagChatAppVO]:
async with GetAsyncSession() as session:
sql = (
"""
SELECT a.id, a.name, a.description, a.is_default
FROM rag_chat_app a
LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL
WHERE a.deleted_at IS NULL
AND a.status = 1
AND (:only_default = FALSE OR a.is_default = TRUE)
AND (
:is_provincial = TRUE
OR a.area IN (:user_area, '省级', '')
OR COALESCE(d.is_public, FALSE) = TRUE
)
ORDER BY a.sort_order ASC, a.created_at DESC
"""
)
rows = (
await session.execute(
text(sql),
{
"only_default": only_default,
"is_provincial": user_role == "provincial_admin",
"user_area": user_area or "",
},
)
).mappings().all()
return [
RagChatAppVO(
appId=str(row["id"]),
appName=row["name"],
description=row.get("description") or "",
isDefault=bool(row.get("is_default")),
)
for row in rows
]
async def _resolve_app(self, app_id: int | None, user_area: str | None, user_role: str | None) -> dict | None:
async with GetAsyncSession() as session:
params = {
"app_id": app_id,
"user_area": user_area or "",
"is_provincial": user_role == "provincial_admin",
}
base_sql = (
"""
SELECT a.id, a.name, a.description, a.area, a.dataset_id, a.system_prompt,
a.llm_model, a.temperature, a.max_tokens, a.opening_statement,
a.suggested_questions, a.is_default, COALESCE(d.is_public, FALSE) AS dataset_public,
COALESCE(d.name, '') AS dataset_name
FROM rag_chat_app a
LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL
WHERE a.deleted_at IS NULL AND a.status = 1
"""
)
if app_id is not None:
row = (
await session.execute(
text(base_sql + " AND a.id = :app_id LIMIT 1"),
params,
)
).mappings().first()
if row and self._app_visible(row, user_area, user_role):
return dict(row)
row = (
await session.execute(
text(base_sql + " AND a.is_default = TRUE ORDER BY a.sort_order ASC, a.created_at DESC LIMIT 1"),
params,
)
).mappings().first()
if row and self._app_visible(row, user_area, user_role):
return dict(row)
row = (
await session.execute(
text(base_sql + " ORDER BY a.sort_order ASC, a.created_at DESC LIMIT 1"),
params,
)
).mappings().first()
return dict(row) if row and self._app_visible(row, user_area, user_role) else None
def _app_visible(self, row: dict, user_area: str | None, user_role: str | None) -> bool:
if user_role == "provincial_admin":
return True
area = row.get("area") or ""
return area in ("", "省级", user_area or "") or bool(row.get("dataset_public"))
async def _ensure_conversation(self, user_id: int, conversation_id: str | None, app_id: int | None) -> str:
if conversation_id and conversation_id != "-1":
async with GetAsyncSession() as session:
row = (
await session.execute(
text(
"""
SELECT conversation_id, user_id
FROM rag_conversation
WHERE conversation_id = :conversation_id
AND deleted_at IS NULL
LIMIT 1
"""
),
{"conversation_id": conversation_id},
)
).mappings().first()
if row:
if int(row["user_id"]) != user_id:
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权使用该会话")
return str(row["conversation_id"])
conversation_id = str(uuid.uuid4())
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
INSERT INTO rag_conversation (conversation_id, user_id, app_id, name, introduction)
VALUES (:conversation_id, :user_id, :app_id, '新对话', '')
"""
),
{"conversation_id": conversation_id, "user_id": user_id, "app_id": app_id},
)
return conversation_id
async def _ensure_conversation_owner(self, user_id: int, conversation_id: str) -> None:
async with GetAsyncSession() as session:
owner = (
await session.execute(
text(
"SELECT user_id FROM rag_conversation WHERE conversation_id = :conversation_id AND deleted_at IS NULL LIMIT 1"
),
{"conversation_id": conversation_id},
)
).scalar_one_or_none()
if owner is None:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "会话不存在")
if int(owner) != user_id:
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权访问该会话")
async def _retrieve_context(self, dataset_id: int | None, query: str) -> tuple[list[dict], str]:
if not dataset_id:
return [], ""
async with GetAsyncSession() as session:
dataset = (
await session.execute(
text(
"""
SELECT id, name, collection_name, retrieval_model
FROM rag_dataset
WHERE id = :dataset_id AND deleted_at IS NULL
LIMIT 1
"""
),
{"dataset_id": dataset_id},
)
).mappings().first()
if not dataset:
return [], ""
retrieval_model = dataset.get("retrieval_model") or {}
top_k = int(retrieval_model.get("top_k") or 5)
score_threshold = None
if retrieval_model.get("score_threshold_enabled"):
try:
score_threshold = float(retrieval_model.get("score_threshold"))
except (TypeError, ValueError):
score_threshold = None
try:
from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma
except Exception:
return [], dataset.get("name") or ""
try:
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
result = collection.query(query_texts=[query], n_results=max(top_k, 1))
docs = (result.get("documents") or [[]])[0]
metas = (result.get("metadatas") or [[]])[0]
distances = (result.get("distances") or [[]])[0]
chunks: list[dict] = []
for idx, doc in enumerate(docs):
meta = metas[idx] if idx < len(metas) else {}
dist = distances[idx] if idx < len(distances) else 0.0
score = 1 - float(dist or 0.0)
if score_threshold is not None and score < score_threshold:
continue
chunks.append(
{
"id": str(meta.get("id") or idx),
"text": doc,
"source": meta.get("source") or meta.get("document_name") or dataset.get("name") or "",
"score": score,
"chunk_index": idx,
"document_name": meta.get("document_name") or meta.get("source") or "",
}
)
chunks = await self._hydrate_document_hits(dataset_id, chunks)
return chunks[:top_k], dataset.get("name") or ""
except Exception:
return [], dataset.get("name") or ""
def _build_sources(self, context_chunks: list[dict], dataset_name: str) -> list[dict]:
return [
{
"position": index + 1,
"dataset_id": str(chunk.get("dataset_id") or ""),
"dataset_name": dataset_name,
"document_id": str(chunk.get("document_id") or ""),
"document_name": chunk.get("document_name") or chunk.get("source", ""),
"data_source_type": "upload_file",
"segment_id": chunk.get("id", ""),
"retriever_from": "rag",
"score": round(chunk.get("score", 0.0), 4),
"hit_count": chunk.get("hit_count", 0),
"word_count": len(chunk.get("text", "")),
"segment_position": index + 1,
"index_node_hash": "",
"content": chunk.get("text", "")[:500],
"page": None,
}
for index, chunk in enumerate(context_chunks)
]
async def _hydrate_document_hits(self, dataset_id: int, chunks: list[dict]) -> list[dict]:
source_names = sorted(
{
str(chunk.get("document_name") or chunk.get("source") or "").strip()
for chunk in chunks
if str(chunk.get("document_name") or chunk.get("source") or "").strip()
}
)
if not source_names:
return chunks
async with GetAsyncSession() as session:
rows = (
await session.execute(
text(
"""
SELECT id, original_name, enabled, hit_count
FROM rag_document
WHERE dataset_id = :dataset_id
AND deleted_at IS NULL
AND original_name = ANY(:source_names)
"""
),
{
"dataset_id": dataset_id,
"source_names": source_names,
},
)
).mappings().all()
document_map = {str(row["original_name"]): row for row in rows}
visible_chunks: list[dict] = []
hit_document_ids: list[int] = []
for chunk in chunks:
source_name = str(chunk.get("document_name") or chunk.get("source") or "").strip()
document = document_map.get(source_name)
if document and not bool(document.get("enabled")):
continue
if document:
chunk["document_id"] = document["id"]
chunk["dataset_id"] = dataset_id
chunk["document_name"] = document["original_name"]
chunk["hit_count"] = document.get("hit_count") or 0
hit_document_ids.append(int(document["id"]))
visible_chunks.append(chunk)
if hit_document_ids:
async with GetAsyncSession() as session:
async with session.begin():
await session.execute(
text(
"""
UPDATE rag_document
SET hit_count = hit_count + 1,
updated_at = NOW()
WHERE id = ANY(:document_ids)
"""
),
{"document_ids": sorted(set(hit_document_ids))},
)
return visible_chunks
@@ -0,0 +1,52 @@
from __future__ import annotations
from sqlalchemy import text
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import RagDatasetItemVO, RagDatasetPageVO
from fastapi_modules.fastapi_leaudit.services.ragDatasetService import IRagDatasetService
class RagDatasetServiceImpl(IRagDatasetService):
async def GetMyDatasets(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagDatasetPageVO:
async with GetAsyncSession() as session:
rows = (
await session.execute(
text(
"""
SELECT id, name, description, area, is_public, is_default, document_count, total_chunks, status
FROM rag_dataset
WHERE deleted_at IS NULL
AND status = 1
AND (
:is_provincial = TRUE
OR area IN (:user_area, '省级', '')
OR is_public = TRUE
)
ORDER BY sort_order ASC, created_at DESC
"""
),
{
"is_provincial": UserRole == "provincial_admin",
"user_area": UserArea or "",
},
)
).mappings().all()
return RagDatasetPageVO(
data=[
RagDatasetItemVO(
id=row["id"],
name=row["name"],
description=row.get("description") or "",
area=row.get("area") or "",
isPublic=bool(row.get("is_public")),
isDefault=bool(row.get("is_default")),
documentCount=row.get("document_count") or 0,
totalChunks=row.get("total_chunks") or 0,
status=row.get("status") or 1,
)
for row in rows
],
total=len(rows),
)
@@ -223,6 +223,13 @@ class RbacAdminServiceImpl(IRbacAdminService):
{"permission_key": "rbac:user_roles:write", "display_name": "分配用户角色", "module": "rbac", "resource": "user_roles", "action": "write", "api_method": "POST", "api_path": "/api/v3/rbac/users/{user_id}/roles", "route_path": "/role-permissions"},
{"permission_key": "rbac:role_routes:write", "display_name": "配置角色菜单", "module": "rbac", "resource": "role_routes", "action": "write", "api_method": "PUT", "api_path": "/api/rbac/roles/{role_id}/routes", "route_path": "/role-permissions"},
{"permission_key": "rbac:role_permissions:write", "display_name": "配置角色权限", "module": "rbac", "resource": "role_permissions", "action": "write", "api_method": "POST", "api_path": "/api/v3/rbac/role-permissions", "route_path": "/role-permissions"},
{"permission_key": "rag:app:read", "display_name": "查看 RAG 应用", "module": "rag", "resource": "app", "action": "read", "api_method": "GET", "api_path": "/api/v3/rag/apps", "route_path": "/chat-with-llm"},
{"permission_key": "rag:chat:use", "display_name": "使用 RAG 对话", "module": "rag", "resource": "chat", "action": "use", "api_method": "POST", "api_path": "/api/v3/rag/chat/messages", "route_path": "/chat-with-llm"},
{"permission_key": "rag:conversation:read", "display_name": "查看 RAG 会话", "module": "rag", "resource": "conversation", "action": "read", "api_method": "GET", "api_path": "/api/v3/rag/chat/conversations", "route_path": "/chat-with-llm"},
{"permission_key": "rag:conversation:update", "display_name": "重命名 RAG 会话", "module": "rag", "resource": "conversation", "action": "update", "api_method": "PATCH", "api_path": "/api/v3/rag/chat/conversations/{ConversationId}", "route_path": "/chat-with-llm"},
{"permission_key": "rag:conversation:delete", "display_name": "删除 RAG 会话", "module": "rag", "resource": "conversation", "action": "delete", "api_method": "DELETE", "api_path": "/api/v3/rag/chat/conversations/{ConversationId}", "route_path": "/chat-with-llm"},
{"permission_key": "rag:message:feedback", "display_name": "反馈 RAG 消息", "module": "rag", "resource": "message", "action": "feedback", "api_method": "POST", "api_path": "/api/v3/rag/chat/messages/{MessageId}/feedback", "route_path": "/chat-with-llm"},
{"permission_key": "rag:dataset:read", "display_name": "查看 RAG 知识库", "module": "rag", "resource": "dataset", "action": "read", "api_method": "GET", "api_path": "/api/v3/rag/datasets/my", "route_path": "/chat-with-llm"},
]
async def ListRoles(self, CurrentUserId: int, Page: int, PageSize: int, RoleKey: str | None, RoleName: str | None, IncludeSystem: bool) -> RoleListVO:
@@ -0,0 +1,62 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import AsyncGenerator
from fastapi_modules.fastapi_leaudit.domian.Dto.ragChatDto import (
RagConversationRenameDTO,
RagMessageFeedbackDTO,
)
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import (
RagAppParametersVO,
RagChatAppListVO,
RagChatAppVO,
RagConversationPageVO,
RagConversationRenameVO,
RagMessagePageVO,
RagOperationResultVO,
)
class IRagChatService(ABC):
@abstractmethod
async def GetApps(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagChatAppListVO: ...
@abstractmethod
async def GetDefaultApp(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagChatAppVO | None: ...
@abstractmethod
async def SendMessage(
self,
CurrentUserId: int,
UserName: str,
UserArea: str | None,
UserRole: str | None,
Query: str,
ConversationId: str | None,
AppId: int | None,
) -> AsyncGenerator[bytes, None]: ...
@abstractmethod
async def GetConversations(self, CurrentUserId: int, AppId: int | None, Page: int, PageSize: int) -> RagConversationPageVO: ...
@abstractmethod
async def GetConversationMessages(self, CurrentUserId: int, ConversationId: str, Page: int, PageSize: int) -> RagMessagePageVO: ...
@abstractmethod
async def RenameConversation(self, CurrentUserId: int, ConversationId: str, Body: RagConversationRenameDTO) -> RagConversationRenameVO: ...
@abstractmethod
async def DeleteConversation(self, CurrentUserId: int, ConversationId: str) -> RagOperationResultVO: ...
@abstractmethod
async def UpdateFeedback(self, CurrentUserId: int, MessageId: str, Body: RagMessageFeedbackDTO) -> RagOperationResultVO: ...
@abstractmethod
async def GetAppParameters(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
AppId: int | None,
) -> RagAppParametersVO: ...
@@ -0,0 +1,10 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import RagDatasetPageVO
class IRagDatasetService(ABC):
@abstractmethod
async def GetMyDatasets(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagDatasetPageVO: ...