feat: add tenant-scoped rule and permission management
This commit is contained in:
@@ -39,6 +39,7 @@ from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.generator import generate_stream
|
||||
from fastapi_modules.fastapi_leaudit.rag_engine.question_chains import generate_followups
|
||||
from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService
|
||||
from fastapi_modules.fastapi_leaudit.services.impl.tenantResolver import TenantResolver
|
||||
|
||||
|
||||
DEFAULT_CONVERSATION_NAME = "新对话"
|
||||
@@ -54,15 +55,40 @@ class RagChatServiceImpl(IRagChatService):
|
||||
_task_locks: dict[str, asyncio.Lock] = {}
|
||||
_title_tasks: dict[str, asyncio.Task] = {}
|
||||
|
||||
async def GetApps(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagChatAppListVO:
|
||||
apps = await self._load_apps(UserArea, UserRole, only_default=False)
|
||||
def __init__(self) -> None:
|
||||
self.TenantResolver = TenantResolver()
|
||||
|
||||
_APP_TENANT_NAME_SQL = (
|
||||
"CASE "
|
||||
"WHEN NULLIF(BTRIM(a.tenant_code), '') = 'PUBLIC' THEN '公共' "
|
||||
"WHEN NULLIF(BTRIM(a.tenant_code), '') = 'PROVINCIAL' THEN '省级' "
|
||||
"ELSE COALESCE(NULLIF(BTRIM(a.area), ''), '未分配地区') "
|
||||
"END"
|
||||
)
|
||||
|
||||
async def GetApps(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
TenantCode: str | None = None,
|
||||
TenantName: str | None = None,
|
||||
) -> RagChatAppListVO:
|
||||
apps = await self._load_apps(UserArea, UserRole, TenantCode, TenantName, only_default=False)
|
||||
return RagChatAppListVO(data=apps, total=len(apps))
|
||||
|
||||
async def GetDefaultApp(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagChatAppVO | None:
|
||||
apps = await self._load_apps(UserArea, UserRole, only_default=True)
|
||||
async def GetDefaultApp(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
TenantCode: str | None = None,
|
||||
TenantName: str | None = None,
|
||||
) -> RagChatAppVO | None:
|
||||
apps = await self._load_apps(UserArea, UserRole, TenantCode, TenantName, only_default=True)
|
||||
if apps:
|
||||
return apps[0]
|
||||
all_apps = await self._load_apps(UserArea, UserRole, only_default=False)
|
||||
all_apps = await self._load_apps(UserArea, UserRole, TenantCode, TenantName, only_default=False)
|
||||
return all_apps[0] if all_apps else None
|
||||
|
||||
async def SendMessage(
|
||||
@@ -74,15 +100,25 @@ class RagChatServiceImpl(IRagChatService):
|
||||
Query: str,
|
||||
ConversationId: str | None,
|
||||
AppId: int | None,
|
||||
TenantCode: str | None = None,
|
||||
TenantName: str | None = None,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
if not Query.strip():
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "问题不能为空")
|
||||
|
||||
app = await self._resolve_app(AppId, UserArea, UserRole)
|
||||
app = await self._resolve_app(AppId, UserArea, UserRole, TenantCode, TenantName)
|
||||
if not app:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "未配置可用聊天应用")
|
||||
|
||||
conversationId = await self._ensure_conversation(CurrentUserId, ConversationId, app["id"])
|
||||
conversationId = await self._ensure_conversation(
|
||||
user_id=CurrentUserId,
|
||||
conversation_id=ConversationId,
|
||||
app_id=app["id"],
|
||||
user_area=UserArea,
|
||||
user_role=UserRole,
|
||||
tenant_code=TenantCode,
|
||||
tenant_name=TenantName,
|
||||
)
|
||||
messageId = str(uuid.uuid4())
|
||||
taskId = str(uuid.uuid4())
|
||||
is_new_conversation = not ConversationId or ConversationId == "-1"
|
||||
@@ -161,7 +197,18 @@ class RagChatServiceImpl(IRagChatService):
|
||||
|
||||
await asyncio.sleep(0.05)
|
||||
|
||||
async def GetConversations(self, CurrentUserId: int, AppId: int | None, Page: int, PageSize: int) -> RagConversationPageVO:
|
||||
async def GetConversations(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
AppId: int | None,
|
||||
Page: int,
|
||||
PageSize: int,
|
||||
TenantCode: str | None = None,
|
||||
TenantName: str | None = None,
|
||||
) -> RagConversationPageVO:
|
||||
tenant_context = await self._resolve_tenant_context(UserArea, TenantCode, TenantName)
|
||||
async with GetAsyncSession() as session:
|
||||
rows = (
|
||||
await session.execute(
|
||||
@@ -186,8 +233,20 @@ class RagChatServiceImpl(IRagChatService):
|
||||
},
|
||||
)
|
||||
).mappings().all()
|
||||
has_more = len(rows) > PageSize
|
||||
items = rows[:PageSize]
|
||||
filtered_rows: list[dict] = []
|
||||
for row in rows:
|
||||
record = dict(row)
|
||||
if await self._conversation_accessible(
|
||||
conversation_id=str(record["conversation_id"]),
|
||||
expected_user_id=CurrentUserId,
|
||||
tenant_context=tenant_context,
|
||||
user_role=UserRole,
|
||||
app_id=AppId,
|
||||
session=session,
|
||||
):
|
||||
filtered_rows.append(record)
|
||||
has_more = len(filtered_rows) > PageSize
|
||||
items = filtered_rows[:PageSize]
|
||||
return RagConversationPageVO(
|
||||
data=[
|
||||
RagConversationItemVO(
|
||||
@@ -205,8 +264,25 @@ class RagChatServiceImpl(IRagChatService):
|
||||
limit=PageSize,
|
||||
)
|
||||
|
||||
async def GetConversationMessages(self, CurrentUserId: int, ConversationId: str, Page: int, PageSize: int) -> RagMessagePageVO:
|
||||
await self._ensure_conversation_owner(CurrentUserId, ConversationId)
|
||||
async def GetConversationMessages(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
ConversationId: str,
|
||||
Page: int,
|
||||
PageSize: int,
|
||||
TenantCode: str | None = None,
|
||||
TenantName: str | None = None,
|
||||
) -> RagMessagePageVO:
|
||||
await self._ensure_conversation_owner(
|
||||
user_id=CurrentUserId,
|
||||
conversation_id=ConversationId,
|
||||
user_area=UserArea,
|
||||
user_role=UserRole,
|
||||
tenant_code=TenantCode,
|
||||
tenant_name=TenantName,
|
||||
)
|
||||
async with GetAsyncSession() as session:
|
||||
rows = (
|
||||
await session.execute(
|
||||
@@ -333,8 +409,24 @@ class RagChatServiceImpl(IRagChatService):
|
||||
chunks.append(answer)
|
||||
return "".join(chunks)
|
||||
|
||||
async def RenameConversation(self, CurrentUserId: int, ConversationId: str, Body: RagConversationRenameDTO) -> RagConversationRenameVO:
|
||||
await self._ensure_conversation_owner(CurrentUserId, ConversationId)
|
||||
async def RenameConversation(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
ConversationId: str,
|
||||
Body: RagConversationRenameDTO,
|
||||
TenantCode: str | None = None,
|
||||
TenantName: str | None = None,
|
||||
) -> RagConversationRenameVO:
|
||||
await self._ensure_conversation_owner(
|
||||
user_id=CurrentUserId,
|
||||
conversation_id=ConversationId,
|
||||
user_area=UserArea,
|
||||
user_role=UserRole,
|
||||
tenant_code=TenantCode,
|
||||
tenant_name=TenantName,
|
||||
)
|
||||
final_name = Body.name.strip()
|
||||
if not final_name:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "会话名称不能为空")
|
||||
@@ -359,8 +451,23 @@ class RagChatServiceImpl(IRagChatService):
|
||||
)
|
||||
return RagConversationRenameVO(result="success", name=final_name)
|
||||
|
||||
async def DeleteConversation(self, CurrentUserId: int, ConversationId: str) -> RagOperationResultVO:
|
||||
await self._ensure_conversation_owner(CurrentUserId, ConversationId)
|
||||
async def DeleteConversation(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
ConversationId: str,
|
||||
TenantCode: str | None = None,
|
||||
TenantName: str | None = None,
|
||||
) -> RagOperationResultVO:
|
||||
await self._ensure_conversation_owner(
|
||||
user_id=CurrentUserId,
|
||||
conversation_id=ConversationId,
|
||||
user_area=UserArea,
|
||||
user_role=UserRole,
|
||||
tenant_code=TenantCode,
|
||||
tenant_name=TenantName,
|
||||
)
|
||||
async with GetAsyncSession() as session:
|
||||
async with session.begin():
|
||||
await session.execute(
|
||||
@@ -371,13 +478,23 @@ class RagChatServiceImpl(IRagChatService):
|
||||
)
|
||||
return RagOperationResultVO(result="success")
|
||||
|
||||
async def UpdateFeedback(self, CurrentUserId: int, MessageId: str, Body: RagMessageFeedbackDTO) -> RagOperationResultVO:
|
||||
async def UpdateFeedback(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
MessageId: str,
|
||||
Body: RagMessageFeedbackDTO,
|
||||
TenantCode: str | None = None,
|
||||
TenantName: str | None = None,
|
||||
) -> RagOperationResultVO:
|
||||
tenant_context = await self._resolve_tenant_context(UserArea, TenantCode, TenantName)
|
||||
async with GetAsyncSession() as session:
|
||||
owner = (
|
||||
row = (
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT c.user_id
|
||||
SELECT c.user_id, c.conversation_id
|
||||
FROM rag_message m
|
||||
JOIN rag_conversation c ON c.conversation_id = m.conversation_id
|
||||
WHERE m.message_id = :message_id AND c.deleted_at IS NULL
|
||||
@@ -386,10 +503,16 @@ class RagChatServiceImpl(IRagChatService):
|
||||
),
|
||||
{"message_id": MessageId},
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if owner is None:
|
||||
).mappings().first()
|
||||
if not row:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "消息不存在")
|
||||
if int(owner) != CurrentUserId:
|
||||
if not await self._conversation_accessible(
|
||||
conversation_id=str(row["conversation_id"]),
|
||||
expected_user_id=CurrentUserId,
|
||||
tenant_context=tenant_context,
|
||||
user_role=UserRole,
|
||||
session=session,
|
||||
):
|
||||
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权修改该消息反馈")
|
||||
await session.execute(
|
||||
text("UPDATE rag_message SET feedback = :feedback WHERE message_id = :message_id"),
|
||||
@@ -397,13 +520,23 @@ class RagChatServiceImpl(IRagChatService):
|
||||
)
|
||||
return RagOperationResultVO(result="success")
|
||||
|
||||
async def StopMessage(self, CurrentUserId: int, MessageId: str, Body: RagStopMessageDTO | None = None) -> RagOperationResultVO:
|
||||
async def StopMessage(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
MessageId: str,
|
||||
Body: RagStopMessageDTO | None = None,
|
||||
TenantCode: str | None = None,
|
||||
TenantName: str | None = None,
|
||||
) -> RagOperationResultVO:
|
||||
tenant_context = await self._resolve_tenant_context(UserArea, TenantCode, TenantName)
|
||||
async with GetAsyncSession() as session:
|
||||
row = (
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT m.metadata, c.user_id
|
||||
SELECT m.metadata, c.user_id, c.conversation_id
|
||||
FROM rag_message m
|
||||
JOIN rag_conversation c ON c.conversation_id = m.conversation_id
|
||||
WHERE m.message_id = :message_id
|
||||
@@ -416,7 +549,12 @@ class RagChatServiceImpl(IRagChatService):
|
||||
).mappings().first()
|
||||
if not row:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "消息不存在")
|
||||
if int(row["user_id"]) != CurrentUserId:
|
||||
if not await self._conversation_accessible(
|
||||
conversation_id=str(row["conversation_id"]),
|
||||
expected_user_id=CurrentUserId,
|
||||
tenant_context=tenant_context,
|
||||
user_role=UserRole,
|
||||
):
|
||||
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权停止该消息")
|
||||
|
||||
metadata = row.get("metadata") or {}
|
||||
@@ -432,8 +570,10 @@ class RagChatServiceImpl(IRagChatService):
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
AppId: int | None,
|
||||
TenantCode: str | None = None,
|
||||
TenantName: str | None = None,
|
||||
) -> RagAppParametersVO:
|
||||
app = await self._resolve_app(AppId, UserArea, UserRole)
|
||||
app = await self._resolve_app(AppId, UserArea, UserRole, TenantCode, TenantName)
|
||||
if not app:
|
||||
return RagAppParametersVO()
|
||||
try:
|
||||
@@ -449,21 +589,27 @@ class RagChatServiceImpl(IRagChatService):
|
||||
fileUpload={"image": {"enabled": False}},
|
||||
)
|
||||
|
||||
async def _load_apps(self, user_area: str | None, user_role: str | None, only_default: bool) -> list[RagChatAppVO]:
|
||||
async def _load_apps(
|
||||
self,
|
||||
user_area: str | None,
|
||||
user_role: str | None,
|
||||
tenant_code: str | None,
|
||||
tenant_name: str | None,
|
||||
only_default: bool,
|
||||
) -> list[RagChatAppVO]:
|
||||
async with GetAsyncSession() as session:
|
||||
await self._ensure_rag_chat_schema(session)
|
||||
sql = (
|
||||
"""
|
||||
SELECT a.id, a.name, a.description, a.is_default
|
||||
f"""
|
||||
SELECT a.id, a.name, a.description, a.is_default, a.area,
|
||||
COALESCE(NULLIF(BTRIM(a.tenant_code), ''), NULL) AS tenant_code,
|
||||
{self._APP_TENANT_NAME_SQL} AS tenant_name,
|
||||
COALESCE(d.is_public, FALSE) AS dataset_public
|
||||
FROM rag_chat_app a
|
||||
LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL
|
||||
WHERE a.deleted_at IS NULL
|
||||
AND a.status = 1
|
||||
AND (:only_default = FALSE OR a.is_default = TRUE)
|
||||
AND (
|
||||
:is_provincial = TRUE
|
||||
OR a.area IN (:user_area, '省级', '')
|
||||
OR COALESCE(d.is_public, FALSE) = TRUE
|
||||
)
|
||||
ORDER BY a.sort_order ASC, a.created_at DESC
|
||||
"""
|
||||
)
|
||||
@@ -472,31 +618,47 @@ class RagChatServiceImpl(IRagChatService):
|
||||
text(sql),
|
||||
{
|
||||
"only_default": only_default,
|
||||
"is_provincial": user_role == "provincial_admin",
|
||||
"user_area": user_area or "",
|
||||
},
|
||||
)
|
||||
).mappings().all()
|
||||
return [
|
||||
RagChatAppVO(
|
||||
appId=str(row["id"]),
|
||||
appName=row["name"],
|
||||
description=row.get("description") or "",
|
||||
isDefault=bool(row.get("is_default")),
|
||||
tenant_context = await self._resolve_tenant_context(user_area, tenant_code, tenant_name)
|
||||
data: list[RagChatAppVO] = []
|
||||
for row in rows:
|
||||
record = dict(row)
|
||||
if not await self._app_visible(record, tenant_context=tenant_context, user_role=user_role):
|
||||
continue
|
||||
data.append(
|
||||
RagChatAppVO(
|
||||
appId=str(record["id"]),
|
||||
appName=record["name"],
|
||||
description=record.get("description") or "",
|
||||
tenantCode=str(record.get("tenant_code") or ""),
|
||||
tenantName=str(record.get("tenant_name") or record.get("area") or ""),
|
||||
isDefault=bool(record.get("is_default")),
|
||||
)
|
||||
)
|
||||
for row in rows
|
||||
]
|
||||
return data
|
||||
|
||||
async def _resolve_app(self, app_id: int | None, user_area: str | None, user_role: str | None) -> dict | None:
|
||||
async def _resolve_app(
|
||||
self,
|
||||
app_id: int | None,
|
||||
user_area: str | None,
|
||||
user_role: str | None,
|
||||
tenant_code: str | None,
|
||||
tenant_name: str | None,
|
||||
) -> dict | None:
|
||||
tenant_context = await self._resolve_tenant_context(user_area, tenant_code, tenant_name)
|
||||
async with GetAsyncSession() as session:
|
||||
await self._ensure_rag_chat_schema(session)
|
||||
params = {
|
||||
"app_id": app_id,
|
||||
"user_area": user_area or "",
|
||||
"is_provincial": user_role == "provincial_admin",
|
||||
}
|
||||
base_sql = (
|
||||
"""
|
||||
SELECT a.id, a.name, a.description, a.area, a.dataset_id, a.system_prompt,
|
||||
f"""
|
||||
SELECT a.id, a.name, a.description, a.area,
|
||||
COALESCE(NULLIF(BTRIM(a.tenant_code), ''), NULL) AS tenant_code,
|
||||
{self._APP_TENANT_NAME_SQL} AS tenant_name,
|
||||
a.dataset_id, a.system_prompt,
|
||||
a.llm_model, a.temperature, a.max_tokens, a.opening_statement,
|
||||
a.suggested_questions, a.is_default, COALESCE(d.is_public, FALSE) AS dataset_public,
|
||||
COALESCE(d.name, '') AS dataset_name
|
||||
@@ -512,7 +674,7 @@ class RagChatServiceImpl(IRagChatService):
|
||||
params,
|
||||
)
|
||||
).mappings().first()
|
||||
if row and self._app_visible(row, user_area, user_role):
|
||||
if row and await self._app_visible(dict(row), tenant_context=tenant_context, user_role=user_role):
|
||||
return dict(row)
|
||||
row = (
|
||||
await session.execute(
|
||||
@@ -520,7 +682,7 @@ class RagChatServiceImpl(IRagChatService):
|
||||
params,
|
||||
)
|
||||
).mappings().first()
|
||||
if row and self._app_visible(row, user_area, user_role):
|
||||
if row and await self._app_visible(dict(row), tenant_context=tenant_context, user_role=user_role):
|
||||
return dict(row)
|
||||
row = (
|
||||
await session.execute(
|
||||
@@ -528,15 +690,83 @@ class RagChatServiceImpl(IRagChatService):
|
||||
params,
|
||||
)
|
||||
).mappings().first()
|
||||
return dict(row) if row and self._app_visible(row, user_area, user_role) else None
|
||||
return dict(row) if row and await self._app_visible(dict(row), tenant_context=tenant_context, user_role=user_role) else None
|
||||
|
||||
def _app_visible(self, row: dict, user_area: str | None, user_role: str | None) -> bool:
|
||||
if user_role == "provincial_admin":
|
||||
async def _app_visible(self, row: dict, tenant_context: dict, user_role: str | None) -> bool:
|
||||
if self._role_is_global(user_role):
|
||||
return True
|
||||
area = row.get("area") or ""
|
||||
return area in ("", "省级", user_area or "") or bool(row.get("dataset_public"))
|
||||
if bool(row.get("dataset_public")):
|
||||
return True
|
||||
if str(row.get("tenant_code") or "").strip().upper() == "PUBLIC":
|
||||
return True
|
||||
return self._row_matches_tenant_scope(
|
||||
row_tenant_code=row.get("tenant_code"),
|
||||
row_area=row.get("area"),
|
||||
tenant_context=tenant_context,
|
||||
)
|
||||
|
||||
async def _ensure_conversation(self, user_id: int, conversation_id: str | None, app_id: int | None) -> str:
|
||||
async def _resolve_tenant_context(
|
||||
self,
|
||||
user_area: str | None,
|
||||
tenant_code: str | None,
|
||||
tenant_name: str | None,
|
||||
) -> dict[str, str | None]:
|
||||
resolved = await self.TenantResolver.ResolveUserContext(
|
||||
Area=user_area,
|
||||
TenantCode=tenant_code,
|
||||
TenantName=tenant_name,
|
||||
Source="rag_chat_user",
|
||||
)
|
||||
return {
|
||||
"tenant_code": resolved.tenant_code,
|
||||
"tenant_name": resolved.tenant_name,
|
||||
"tenant_type": resolved.tenant_type,
|
||||
"area": user_area,
|
||||
}
|
||||
|
||||
async def _resolve_record_tenant(self, raw_value: str | None):
|
||||
return await self.TenantResolver.Resolve(
|
||||
RawValue=raw_value,
|
||||
Source="rag_chat_record",
|
||||
)
|
||||
|
||||
async def _ensure_rag_chat_schema(self, session) -> None:
|
||||
await session.execute(text("ALTER TABLE rag_chat_app ADD COLUMN IF NOT EXISTS tenant_code VARCHAR(64) NULL"))
|
||||
await session.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_chat_app_tenant_code ON rag_chat_app(tenant_code) WHERE deleted_at IS NULL"))
|
||||
|
||||
@staticmethod
|
||||
def _tenant_context_is_global(tenant_context: dict[str, str | None]) -> bool:
|
||||
tenant_code = str(tenant_context.get("tenant_code") or "").strip().upper()
|
||||
return tenant_code in {"PUBLIC", "PROVINCIAL"}
|
||||
|
||||
@staticmethod
|
||||
def _role_is_global(user_role: str | None) -> bool:
|
||||
normalized = str(user_role or "").strip()
|
||||
return normalized in {"super_admin", "provincial_admin"}
|
||||
|
||||
def _row_matches_tenant_scope(
|
||||
self,
|
||||
*,
|
||||
row_tenant_code: str | None,
|
||||
row_area: str | None,
|
||||
tenant_context: dict[str, str | None],
|
||||
) -> bool:
|
||||
user_tenant_code = str(tenant_context.get("tenant_code") or "").strip()
|
||||
if user_tenant_code:
|
||||
return str(row_tenant_code or "").strip() == user_tenant_code
|
||||
return str(row_area or "").strip() == str(tenant_context.get("area") or "").strip()
|
||||
|
||||
async def _ensure_conversation(
|
||||
self,
|
||||
user_id: int,
|
||||
conversation_id: str | None,
|
||||
app_id: int | None,
|
||||
user_area: str | None,
|
||||
user_role: str | None,
|
||||
tenant_code: str | None,
|
||||
tenant_name: str | None,
|
||||
) -> str:
|
||||
tenant_context = await self._resolve_tenant_context(user_area, tenant_code, tenant_name)
|
||||
if conversation_id and conversation_id != "-1":
|
||||
async with GetAsyncSession() as session:
|
||||
row = (
|
||||
@@ -554,7 +784,14 @@ class RagChatServiceImpl(IRagChatService):
|
||||
)
|
||||
).mappings().first()
|
||||
if row:
|
||||
if int(row["user_id"]) != user_id:
|
||||
if not await self._conversation_accessible(
|
||||
conversation_id=str(row["conversation_id"]),
|
||||
expected_user_id=user_id,
|
||||
tenant_context=tenant_context,
|
||||
user_role=user_role,
|
||||
app_id=app_id,
|
||||
session=session,
|
||||
):
|
||||
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权使用该会话")
|
||||
return str(row["conversation_id"])
|
||||
conversation_id = str(uuid.uuid4())
|
||||
@@ -576,21 +813,99 @@ class RagChatServiceImpl(IRagChatService):
|
||||
)
|
||||
return conversation_id
|
||||
|
||||
async def _ensure_conversation_owner(self, user_id: int, conversation_id: str) -> None:
|
||||
async with GetAsyncSession() as session:
|
||||
owner = (
|
||||
await session.execute(
|
||||
text(
|
||||
"SELECT user_id FROM rag_conversation WHERE conversation_id = :conversation_id AND deleted_at IS NULL LIMIT 1"
|
||||
),
|
||||
{"conversation_id": conversation_id},
|
||||
)
|
||||
).scalar_one_or_none()
|
||||
if owner is None:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "会话不存在")
|
||||
if int(owner) != user_id:
|
||||
async def _ensure_conversation_owner(
|
||||
self,
|
||||
*,
|
||||
user_id: int,
|
||||
conversation_id: str,
|
||||
user_area: str | None,
|
||||
user_role: str | None,
|
||||
tenant_code: str | None,
|
||||
tenant_name: str | None,
|
||||
) -> None:
|
||||
tenant_context = await self._resolve_tenant_context(user_area, tenant_code, tenant_name)
|
||||
if not await self._conversation_accessible(
|
||||
conversation_id=conversation_id,
|
||||
expected_user_id=user_id,
|
||||
tenant_context=tenant_context,
|
||||
user_role=user_role,
|
||||
):
|
||||
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权访问该会话")
|
||||
|
||||
async def _conversation_accessible(
|
||||
self,
|
||||
*,
|
||||
conversation_id: str,
|
||||
expected_user_id: int,
|
||||
tenant_context: dict[str, str | None],
|
||||
user_role: str | None,
|
||||
app_id: int | None = None,
|
||||
session=None,
|
||||
) -> bool:
|
||||
if session is not None:
|
||||
return await self._conversation_accessible_with_session(
|
||||
session=session,
|
||||
conversation_id=conversation_id,
|
||||
expected_user_id=expected_user_id,
|
||||
tenant_context=tenant_context,
|
||||
user_role=user_role,
|
||||
app_id=app_id,
|
||||
)
|
||||
async with GetAsyncSession() as owned_session:
|
||||
return await self._conversation_accessible_with_session(
|
||||
session=owned_session,
|
||||
conversation_id=conversation_id,
|
||||
expected_user_id=expected_user_id,
|
||||
tenant_context=tenant_context,
|
||||
user_role=user_role,
|
||||
app_id=app_id,
|
||||
)
|
||||
|
||||
async def _conversation_accessible_with_session(
|
||||
self,
|
||||
*,
|
||||
session,
|
||||
conversation_id: str,
|
||||
expected_user_id: int,
|
||||
tenant_context: dict[str, str | None],
|
||||
user_role: str | None,
|
||||
app_id: int | None = None,
|
||||
) -> bool:
|
||||
row = (
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT
|
||||
c.conversation_id,
|
||||
c.user_id,
|
||||
c.app_id,
|
||||
a.area,
|
||||
COALESCE(NULLIF(BTRIM(a.tenant_code), ''), NULL) AS tenant_code,
|
||||
COALESCE(d.is_public, FALSE) AS dataset_public
|
||||
FROM rag_conversation c
|
||||
LEFT JOIN rag_chat_app a ON a.id = c.app_id AND a.deleted_at IS NULL
|
||||
LEFT JOIN rag_dataset d ON d.id = a.dataset_id AND d.deleted_at IS NULL
|
||||
WHERE c.conversation_id = :conversation_id
|
||||
AND c.deleted_at IS NULL
|
||||
LIMIT 1
|
||||
"""
|
||||
),
|
||||
{"conversation_id": conversation_id},
|
||||
)
|
||||
).mappings().first()
|
||||
if not row:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "会话不存在")
|
||||
if int(row["user_id"]) != expected_user_id:
|
||||
return False
|
||||
if app_id is not None and row.get("app_id") is not None and int(row["app_id"]) != int(app_id):
|
||||
return False
|
||||
app_row = {
|
||||
"tenant_code": row.get("tenant_code"),
|
||||
"area": row.get("area"),
|
||||
"dataset_public": row.get("dataset_public"),
|
||||
}
|
||||
return await self._app_visible(app_row, tenant_context=tenant_context, user_role=user_role)
|
||||
|
||||
async def _retrieve_context(self, dataset_id: int | None, query: str) -> tuple[list[dict], str]:
|
||||
if not dataset_id:
|
||||
return [], ""
|
||||
|
||||
Reference in New Issue
Block a user