Files
leaudit-platform-backend/fastapi_modules/fastapi_leaudit/services/impl/ragDatasetServiceImpl.py
T
2026-05-22 15:36:13 +08:00

2149 lines
88 KiB
Python

"""RAG 知识库服务实现。"""
from __future__ import annotations
import asyncio
import json
import mimetypes
import os
import re
import tempfile
import uuid
from datetime import datetime
from pathlib import Path
import httpx
from sqlalchemy import text
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
from fastapi_common.fastapi_common_storage.oss_client import OssClient
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import RagDatasetUpdateDTO
from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import (
RagDatasetBatchDeleteResultVO,
RagDatasetDetailVO,
RagDatasetDocumentItemVO,
RagDatasetDocumentPageVO,
RagDatasetItemVO,
RagDatasetPageVO,
RagDatasetRetrieveDocumentVO,
RagDatasetRetrieveRecordVO,
RagDatasetRetrieveResponseVO,
RagDatasetRetrieveSegmentVO,
RagDatasetSegmentItemVO,
RagDatasetSegmentPageVO,
RagDatasetUploadDocumentVO,
)
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import RagOperationResultVO
from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG, build_openai_embeddings_url
from fastapi_modules.fastapi_leaudit.services.ragDatasetService import IRagDatasetService
from fastapi_modules.fastapi_leaudit.services.impl.tenantResolver import TenantResolver
class RagDatasetServiceImpl(IRagDatasetService):
"""RAG 知识库服务实现。"""
_ACTIVE_INDEXING_STATUSES = {"waiting", "parsing", "cleaning", "splitting", "indexing"}
_DELETABLE_DOCUMENT_STATUSES = {"completed", "error", "paused"}
_APP_LINK_SQL = """
LEFT JOIN (
SELECT DISTINCT ON (dataset_id)
dataset_id,
id,
name,
is_default
FROM rag_chat_app
WHERE deleted_at IS NULL
AND status = 1
ORDER BY dataset_id, is_default DESC, sort_order ASC, id ASC
) a ON a.dataset_id = d.id
"""
_tenant_schema_checked = False
_tenant_schema_lock = asyncio.Lock()
def __init__(self) -> None:
self.TenantResolver = TenantResolver()
_DATASET_TENANT_NAME_SQL = (
"CASE "
"WHEN NULLIF(BTRIM(d.tenant_code), '') = 'PUBLIC' THEN '公共' "
"WHEN NULLIF(BTRIM(d.tenant_code), '') = 'PROVINCIAL' THEN '省级' "
"ELSE COALESCE(NULLIF(BTRIM(d.area), ''), '未分配地区') "
"END"
)
async def GetAdminDatasets(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
Area: str | None,
TenantFilterCode: str | None,
OnlyEnabled: bool | None,
Page: int,
PageSize: int,
) -> RagDatasetPageVO:
tenant_context = await self._resolve_tenant_context(UserArea=UserArea, TenantCode=TenantCode, TenantName=TenantName)
async with GetAsyncSession() as session:
await self._ensure_rag_tenant_schema(session)
managed_area = await self._resolve_managed_area(UserRole=UserRole, UserArea=UserArea, TenantContext=tenant_context)
managed_tenant_code = self._resolve_managed_tenant_code(UserRole=UserRole, TenantContext=tenant_context)
filters = ["d.deleted_at IS NULL"]
params: dict = {
"offset": max(Page - 1, 0) * PageSize,
"limit": PageSize,
}
requested_tenants = await self._normalize_requested_tenants(Area, TenantFilterCode)
if managed_tenant_code:
if requested_tenants and any(str(item.get("tenant_code") or "").strip() != managed_tenant_code for item in requested_tenants):
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户只能查看本地区知识库配置")
filters.extend(self._dataset_tenant_filter_sql(managed_tenant_code, managed_area, prefix="managed", alias="d", params=params))
elif len(requested_tenants) == 1:
filters.extend(
self._dataset_tenant_filter_sql(
requested_tenants[0].get("tenant_code"),
requested_tenants[0].get("normalized_area"),
prefix="requested",
alias="d",
params=params,
)
)
elif len(requested_tenants) > 1:
tenant_conditions: list[str] = []
for index, item in enumerate(requested_tenants):
condition_parts = self._dataset_tenant_filter_sql(
item.get("tenant_code"),
item.get("normalized_area"),
prefix=f"requested_{index}",
alias="d",
params=params,
)
tenant_conditions.append("(" + " AND ".join(condition_parts) + ")")
filters.append("(" + " OR ".join(tenant_conditions) + ")")
if OnlyEnabled is not None:
filters.append("d.status = :status")
params["status"] = 1 if OnlyEnabled else 0
where_sql = " AND ".join(filters)
async with GetAsyncSession() as session:
await self._ensure_rag_tenant_schema(session)
total = (
await session.execute(
text(f"SELECT COUNT(1) FROM rag_dataset d WHERE {where_sql}"),
{k: v for k, v in params.items() if k not in ("offset", "limit")},
)
).scalar_one()
rows = (
await session.execute(
text(
f"""
SELECT d.id, d.name, d.description, d.area,
COALESCE(NULLIF(BTRIM(d.tenant_code), ''), NULL) AS tenant_code,
{self._DATASET_TENANT_NAME_SQL} AS tenant_name,
d.is_public, d.is_default, d.document_count, d.total_chunks, d.status,
d.sort_order, d.created_at, d.updated_at,
a.id AS app_id, COALESCE(a.name, '') AS app_name, COALESCE(a.is_default, FALSE) AS app_is_default
FROM rag_dataset d
{self._APP_LINK_SQL}
WHERE {where_sql}
ORDER BY d.sort_order ASC, d.id ASC
OFFSET :offset LIMIT :limit
"""
),
params,
)
).mappings().all()
return RagDatasetPageVO(data=[await self._to_item_vo(dict(row)) for row in rows], total=int(total or 0))
async def CreateAdminDataset(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
Body: dict,
) -> RagDatasetDetailVO:
tenant_context = await self._resolve_tenant_context(UserArea=UserArea, TenantCode=TenantCode, TenantName=TenantName)
area, resolved_tenant_code, tenant_resolution = await self._resolve_dataset_area_input(
RawArea=Body.get("area"),
TenantCode=Body.get("tenant_code"),
TenantName=Body.get("tenant_name"),
)
name = str(Body.get("dataset_name") or Body.get("name") or "").strip()
description = str(Body.get("dataset_description") or Body.get("description") or "").strip()
if not area or not name:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "租户/地区和知识库名称不能为空")
await self._assert_manage_area_scope(
UserRole=UserRole,
UserArea=UserArea,
TenantContext=tenant_context,
DatasetArea=area,
DatasetTenantCode=resolved_tenant_code,
)
collection_name = self._slugify_collection_name(area, name)
retrieval_model = {}
async with GetAsyncSession() as session:
await self._ensure_rag_tenant_schema(session)
base = (
await session.execute(
text(
"""
SELECT embedding_model, embedding_dim, chunk_max_size, chunk_min_size, retrieval_model
FROM rag_dataset
WHERE deleted_at IS NULL
ORDER BY is_default DESC, id ASC
LIMIT 1
"""
)
)
).mappings().first()
if base:
retrieval_model = base.get("retrieval_model") or {}
if bool(Body.get("is_default")):
await self._clear_default_flags(session, tenant_code=resolved_tenant_code)
row = (
await session.execute(
text(
"""
INSERT INTO rag_dataset (
name, description, area, tenant_code, is_public, is_default, collection_name,
embedding_model, embedding_dim, chunk_max_size, chunk_min_size,
retrieval_model, sort_order, status, created_by, updated_by
) VALUES (
:name, :description, :area, :tenant_code, :is_public, :is_default, :collection_name,
:embedding_model, :embedding_dim, :chunk_max_size, :chunk_min_size,
CAST(:retrieval_model AS jsonb), :sort_order, :status, :created_by, :updated_by
)
RETURNING id
"""
),
{
"name": name,
"description": description,
"area": area,
"tenant_code": resolved_tenant_code,
"is_public": bool(Body.get("is_public")),
"is_default": bool(Body.get("is_default")),
"collection_name": collection_name,
"embedding_model": (base.get("embedding_model") if base else "text-embedding-v4"),
"embedding_dim": (base.get("embedding_dim") if base else 1024),
"chunk_max_size": (base.get("chunk_max_size") if base else 800),
"chunk_min_size": (base.get("chunk_min_size") if base else 20),
"retrieval_model": json.dumps(retrieval_model, ensure_ascii=False),
"sort_order": int(Body.get("sort_order") or 0),
"status": int(Body.get("status") or 1),
"created_by": CurrentUserId,
"updated_by": CurrentUserId,
},
)
).mappings().first()
dataset_id = int(row["id"])
await self._ensure_linked_app(
session=session,
dataset_id=dataset_id,
dataset_name=name,
dataset_area=area,
dataset_tenant_code=resolved_tenant_code,
current_user_id=CurrentUserId,
is_default=bool(Body.get("is_default")),
)
refreshed = await self._get_dataset_row(dataset_id)
return await self._to_detail_vo(refreshed) if refreshed else None
async def UpdateAdminDataset(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
DatasetId: int,
Body: dict,
) -> RagDatasetDetailVO | None:
tenant_context = await self._resolve_tenant_context(UserArea=UserArea, TenantCode=TenantCode, TenantName=TenantName)
existing = await self._get_dataset_row(DatasetId)
if not existing:
return None
await self._assert_manage_area_scope(
UserRole=UserRole,
UserArea=UserArea,
TenantContext=tenant_context,
DatasetArea=str(existing.get("area") or ""),
DatasetTenantCode=str(existing.get("tenant_code") or "") or None,
)
area, resolved_tenant_code, _ = await self._resolve_dataset_area_input(
RawArea=Body.get("area") or existing.get("area"),
TenantCode=Body.get("tenant_code") or existing.get("tenant_code"),
TenantName=Body.get("tenant_name"),
)
await self._assert_manage_area_scope(
UserRole=UserRole,
UserArea=UserArea,
TenantContext=tenant_context,
DatasetArea=area,
DatasetTenantCode=resolved_tenant_code,
)
async with GetAsyncSession() as session:
await self._ensure_rag_tenant_schema(session)
target_is_default = bool(Body.get("is_default", existing.get("is_default")))
if target_is_default:
await self._clear_default_flags(session, tenant_code=resolved_tenant_code)
elif existing.get("is_default") and Body.get("is_default") is False:
normalized_tenant_code = str(resolved_tenant_code or "").strip()
default_filters = [
"deleted_at IS NULL",
"is_default = TRUE",
"id <> :dataset_id",
]
default_params = {"dataset_id": DatasetId}
if normalized_tenant_code:
default_filters.append("tenant_code = :tenant_code")
default_params["tenant_code"] = normalized_tenant_code
else:
default_filters.append("(tenant_code IS NULL OR BTRIM(tenant_code) = '')")
other_default_count = (
await session.execute(
text(
f"""
SELECT COUNT(1)
FROM rag_dataset
WHERE {" AND ".join(default_filters)}
"""
),
default_params,
)
).scalar_one()
if int(other_default_count or 0) <= 0:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "默认知识库不能直接取消,请先将其他知识库设为默认")
await session.execute(
text(
"""
UPDATE rag_dataset
SET name = :name,
description = :description,
area = :area,
tenant_code = :tenant_code,
is_public = :is_public,
is_default = :is_default,
sort_order = :sort_order,
status = :status,
updated_by = :updated_by,
updated_at = NOW()
WHERE id = :dataset_id
"""
),
{
"dataset_id": DatasetId,
"name": str(Body.get("dataset_name") or Body.get("name") or existing.get("name") or "").strip(),
"description": str(Body.get("dataset_description") or Body.get("description") or existing.get("description") or "").strip(),
"area": area,
"tenant_code": resolved_tenant_code,
"is_public": bool(Body.get("is_public", existing.get("is_public"))),
"is_default": target_is_default,
"sort_order": int(Body.get("sort_order") if Body.get("sort_order") is not None else (existing.get("sort_order") or 0)),
"status": int(Body.get("status") if Body.get("status") is not None else (existing.get("status") or 1)),
"updated_by": CurrentUserId,
},
)
await self._ensure_linked_app(
session=session,
dataset_id=DatasetId,
dataset_name=str(Body.get("dataset_name") or Body.get("name") or existing.get("name") or "").strip(),
dataset_area=area,
dataset_tenant_code=resolved_tenant_code,
current_user_id=CurrentUserId,
is_default=target_is_default,
)
refreshed = await self._get_dataset_row(DatasetId)
return await self._to_detail_vo(refreshed) if refreshed else None
async def DeleteAdminDataset(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
DatasetId: int,
) -> RagOperationResultVO:
tenant_context = await self._resolve_tenant_context(UserArea=UserArea, TenantCode=TenantCode, TenantName=TenantName)
existing = await self._get_dataset_row(DatasetId)
if not existing:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
await self._assert_manage_area_scope(
UserRole=UserRole,
UserArea=UserArea,
TenantContext=tenant_context,
DatasetArea=str(existing.get("area") or ""),
DatasetTenantCode=str(existing.get("tenant_code") or "") or None,
)
if bool(existing.get("is_default")):
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "默认知识库不允许删除,请先切换默认知识库")
async with GetAsyncSession() as session:
await self._ensure_rag_tenant_schema(session)
await session.execute(
text("UPDATE rag_dataset SET deleted_at = NOW(), updated_by = :updated_by, updated_at = NOW() WHERE id = :dataset_id"),
{"dataset_id": DatasetId, "updated_by": CurrentUserId},
)
await session.execute(
text("UPDATE rag_chat_app SET deleted_at = NOW(), updated_by = :updated_by, updated_at = NOW() WHERE dataset_id = :dataset_id AND deleted_at IS NULL"),
{"dataset_id": DatasetId, "updated_by": CurrentUserId},
)
return RagOperationResultVO(result="success")
async def GetMyDatasets(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
) -> RagDatasetPageVO:
tenant_context = await self._resolve_tenant_context(UserArea=UserArea, TenantCode=TenantCode, TenantName=TenantName)
async with GetAsyncSession() as session:
await self._ensure_rag_tenant_schema(session)
rows = (
await session.execute(
text(
f"""
SELECT d.id, d.name, d.description, d.area,
COALESCE(NULLIF(BTRIM(d.tenant_code), ''), NULL) AS tenant_code,
{self._DATASET_TENANT_NAME_SQL} AS tenant_name,
d.is_public, d.is_default, d.document_count, d.total_chunks, d.status,
d.sort_order, d.created_at, d.updated_at,
a.id AS app_id, COALESCE(a.name, '') AS app_name, COALESCE(a.is_default, FALSE) AS app_is_default
FROM rag_dataset d
{self._APP_LINK_SQL}
WHERE d.deleted_at IS NULL
AND d.status = 1
AND d.status = 1
ORDER BY d.sort_order ASC, d.created_at DESC
"""
),
{},
)
).mappings().all()
return RagDatasetPageVO(
data=[item for item in [await self._build_visible_item_vo(dict(row), tenant_context, UserRole) for row in rows] if item],
total=len([item for item in [await self._build_visible_item_vo(dict(row), tenant_context, UserRole) for row in rows] if item]),
)
async def GetDatasetDetail(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
DatasetId: int,
) -> RagDatasetDetailVO | None:
row = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId)
if not row:
return None
return await self._to_detail_vo(row)
async def UpdateDataset(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
DatasetId: int,
Body: RagDatasetUpdateDTO,
) -> RagDatasetDetailVO | None:
row = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId)
if not row:
return None
update_fields: list[str] = []
params: dict = {"dataset_id": DatasetId, "updated_by": CurrentUserId}
if Body.name is not None:
update_fields.append("name = :name")
params["name"] = Body.name.strip()
if Body.retrieval_model is not None:
update_fields.append("retrieval_model = CAST(:retrieval_model AS jsonb)")
params["retrieval_model"] = json.dumps(Body.retrieval_model, ensure_ascii=False)
if not update_fields:
return self._to_detail_vo(row)
update_fields.append("updated_by = :updated_by")
update_fields.append("updated_at = NOW()")
async with GetAsyncSession() as session:
await session.execute(
text(
f"""
UPDATE rag_dataset
SET {", ".join(update_fields)}
WHERE id = :dataset_id
"""
),
params,
)
refreshed = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId)
return await self._to_detail_vo(refreshed) if refreshed else None
async def GetDatasetDocuments(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
DatasetId: int,
Page: int,
Limit: int,
Keyword: str | None,
) -> RagDatasetDocumentPageVO:
dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId)
if not dataset:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
where_sql = [
"dataset_id = :dataset_id",
"deleted_at IS NULL",
]
params: dict = {
"dataset_id": DatasetId,
"offset": max(Page - 1, 0) * Limit,
"limit": Limit + 1,
}
if Keyword and Keyword.strip():
where_sql.append("(original_name ILIKE :keyword OR filename ILIKE :keyword)")
params["keyword"] = f"%{Keyword.strip()}%"
async with GetAsyncSession() as session:
total = (
await session.execute(
text(
f"""
SELECT COUNT(1)
FROM rag_document
WHERE {" AND ".join(where_sql)}
"""
),
{key: value for key, value in params.items() if key not in ("offset", "limit")},
)
).scalar_one()
rows = (
await session.execute(
text(
f"""
SELECT id, dataset_id, original_name, file_type, file_size, chunk_count,
indexing_status, COALESCE(indexing_error, '') AS indexing_error,
enabled, hit_count, created_by, created_at, updated_at
FROM rag_document
WHERE {" AND ".join(where_sql)}
ORDER BY created_at DESC
OFFSET :offset LIMIT :limit
"""
),
params,
)
).mappings().all()
has_more = len(rows) > Limit
items = rows[:Limit]
return RagDatasetDocumentPageVO(
data=[
RagDatasetDocumentItemVO(
id=row["id"],
datasetId=row["dataset_id"],
name=row.get("original_name") or "",
fileType=row.get("file_type") or "",
fileSize=row.get("file_size") or 0,
chunkCount=row.get("chunk_count") or 0,
indexingStatus=row.get("indexing_status") or "waiting",
error=row.get("indexing_error") or "",
enabled=bool(row.get("enabled")),
hitCount=row.get("hit_count") or 0,
createdBy=row.get("created_by"),
createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0,
updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0,
)
for row in items
],
total=int(total or 0),
page=Page,
limit=Limit,
hasMore=has_more,
)
async def GetDatasetDocumentDetail(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
DatasetId: int,
DocumentId: int,
) -> RagDatasetDocumentItemVO | None:
dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId)
if not dataset:
return None
async with GetAsyncSession() as session:
row = (
await session.execute(
text(
"""
SELECT id, dataset_id, original_name, file_type, file_size, chunk_count,
indexing_status, COALESCE(indexing_error, '') AS indexing_error,
enabled, hit_count, created_by, created_at, updated_at
FROM rag_document
WHERE id = :document_id
AND dataset_id = :dataset_id
AND deleted_at IS NULL
LIMIT 1
"""
),
{"document_id": DocumentId, "dataset_id": DatasetId},
)
).mappings().first()
if not row:
return None
return RagDatasetDocumentItemVO(
id=row["id"],
datasetId=row["dataset_id"],
name=row.get("original_name") or "",
fileType=row.get("file_type") or "",
fileSize=row.get("file_size") or 0,
chunkCount=row.get("chunk_count") or 0,
indexingStatus=row.get("indexing_status") or "waiting",
error=row.get("indexing_error") or "",
enabled=bool(row.get("enabled")),
hitCount=row.get("hit_count") or 0,
createdBy=row.get("created_by"),
createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0,
updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0,
)
async def _get_visible_dataset(
self,
user_area: str | None,
user_role: str | None,
tenant_code: str | None,
tenant_name: str | None,
dataset_id: int,
) -> dict | None:
async with GetAsyncSession() as session:
await self._ensure_rag_tenant_schema(session)
row = (
await session.execute(
text(
f"""
SELECT d.id, d.name, d.description, d.area,
COALESCE(NULLIF(BTRIM(d.tenant_code), ''), NULL) AS tenant_code,
{self._DATASET_TENANT_NAME_SQL} AS tenant_name,
d.is_public, d.is_default, d.status,
d.document_count, d.total_chunks, d.chunk_max_size, d.chunk_min_size, d.sort_order,
d.retrieval_model, d.collection_name, d.embedding_model, d.created_at, d.updated_at,
a.id AS app_id, COALESCE(a.name, '') AS app_name, COALESCE(a.is_default, FALSE) AS app_is_default
FROM rag_dataset d
{self._APP_LINK_SQL}
WHERE d.id = :dataset_id
AND d.deleted_at IS NULL
AND d.status = 1
LIMIT 1
"""
),
{"dataset_id": dataset_id},
)
).mappings().first()
if not row:
return None
if await self._dataset_visible(dict(row), UserArea=user_area, UserRole=user_role, TenantCode=tenant_code, TenantName=tenant_name):
return dict(row)
return None
async def _to_detail_vo(self, row: dict) -> RagDatasetDetailVO:
return RagDatasetDetailVO(
id=row["id"],
name=row["name"],
description=row.get("description") or "",
area=row.get("area") or "",
tenantCode=str(row.get("tenant_code") or ""),
tenantName=str(row.get("tenant_name") or row.get("area") or ""),
isPublic=bool(row.get("is_public")),
isDefault=bool(row.get("is_default")),
status=row.get("status") or 1,
documentCount=row.get("document_count") or 0,
sortOrder=row.get("sort_order") or 0,
totalChunks=row.get("total_chunks") or 0,
chunkMaxSize=row.get("chunk_max_size") or 800,
chunkMinSize=row.get("chunk_min_size") or 20,
retrievalModel=row.get("retrieval_model") or {},
createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0,
updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0,
appId=row.get("app_id"),
appName=row.get("app_name") or "",
appIsDefault=bool(row.get("app_is_default")),
)
async def _to_item_vo(self, row: dict) -> RagDatasetItemVO:
return RagDatasetItemVO(
id=row["id"],
name=row.get("name") or "",
description=row.get("description") or "",
area=row.get("area") or "",
tenantCode=str(row.get("tenant_code") or ""),
tenantName=str(row.get("tenant_name") or row.get("area") or ""),
isPublic=bool(row.get("is_public")),
isDefault=bool(row.get("is_default")),
documentCount=row.get("document_count") or 0,
totalChunks=row.get("total_chunks") or 0,
status=row.get("status") or 1,
sortOrder=row.get("sort_order") or 0,
createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0,
updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0,
appId=row.get("app_id"),
appName=row.get("app_name") or "",
appIsDefault=bool(row.get("app_is_default")),
)
async def _get_dataset_row(self, dataset_id: int) -> dict | None:
async with GetAsyncSession() as session:
await self._ensure_rag_tenant_schema(session)
row = (
await session.execute(
text(
f"""
SELECT d.id, d.name, d.description, d.area,
COALESCE(NULLIF(BTRIM(d.tenant_code), ''), NULL) AS tenant_code,
{self._DATASET_TENANT_NAME_SQL} AS tenant_name,
d.is_public, d.is_default, d.status,
d.document_count, d.total_chunks, d.chunk_max_size, d.chunk_min_size, d.sort_order,
d.retrieval_model, d.collection_name, d.embedding_model, d.created_at, d.updated_at,
a.id AS app_id, COALESCE(a.name, '') AS app_name, COALESCE(a.is_default, FALSE) AS app_is_default
FROM rag_dataset d
{self._APP_LINK_SQL}
WHERE d.id = :dataset_id AND d.deleted_at IS NULL
LIMIT 1
"""
),
{"dataset_id": dataset_id},
)
).mappings().first()
return dict(row) if row else None
async def _resolve_tenant_context(
self,
*,
UserArea: str | None,
TenantCode: str | None,
TenantName: str | None,
) -> dict[str, str | None]:
resolved = await self.TenantResolver.ResolveUserContext(
Area=UserArea,
TenantCode=TenantCode,
TenantName=TenantName,
Source="rag_dataset_user",
)
return {
"tenant_code": resolved.tenant_code,
"tenant_name": resolved.tenant_name,
"tenant_type": resolved.tenant_type,
"area": resolved.normalized_value or UserArea,
}
async def _resolve_record_tenant(self, raw_value: str | None):
return await self.TenantResolver.Resolve(
RawValue=raw_value,
Source="rag_dataset_record",
)
async def _dataset_visible(
self,
row: dict,
*,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
) -> bool:
if self._role_is_global(UserRole):
return True
if bool(row.get("is_public")):
return True
tenant_context = await self._resolve_tenant_context(UserArea=UserArea, TenantCode=TenantCode, TenantName=TenantName)
return self._row_matches_tenant_scope(
row_tenant_code=row.get("tenant_code"),
row_area=row.get("area"),
tenant_context=tenant_context,
)
async def _build_visible_item_vo(
self,
row: dict,
tenant_context: dict[str, str | None],
user_role: str | None,
) -> RagDatasetItemVO | None:
if not await self._dataset_visible(
row,
UserArea=tenant_context.get("area"),
UserRole=user_role,
TenantCode=tenant_context.get("tenant_code"),
TenantName=tenant_context.get("tenant_name"),
):
return None
return await self._to_item_vo(row)
async def _normalize_requested_tenants(
self,
area_filter: str | None,
tenant_code_filter: str | None = None,
) -> list[dict[str, str]]:
area_items = [item.strip() for item in str(area_filter or "").split(",") if item.strip()]
code_items = [item.strip() for item in str(tenant_code_filter or "").split(",") if item.strip()]
normalized: list[dict[str, str]] = []
seen_keys: set[str] = set()
for item in area_items:
resolution = await self.TenantResolver.Resolve(
RawValue=item,
Source="rag_dataset_filter",
)
normalized_area = resolution.tenant_name or resolution.normalized_value or item
tenant_code = resolution.tenant_code or ""
dedupe_key = tenant_code or normalized_area
if dedupe_key in seen_keys:
continue
seen_keys.add(dedupe_key)
normalized.append(
{
"tenant_code": tenant_code,
"tenant_name": resolution.tenant_name or normalized_area,
"normalized_area": normalized_area,
}
)
for tenant_code in code_items:
resolution = await self.TenantResolver.Resolve(
RawValue=None,
Source="rag_dataset_filter_code",
PreferredTenantCode=tenant_code,
)
normalized_area = (
resolution.tenant_name
or self._fallback_area_from_tenant_code(tenant_code)
or resolution.normalized_value
or tenant_code
)
resolved_code = resolution.tenant_code or tenant_code
dedupe_key = resolved_code or normalized_area
if dedupe_key in seen_keys:
continue
seen_keys.add(dedupe_key)
normalized.append(
{
"tenant_code": resolved_code,
"tenant_name": resolution.tenant_name or normalized_area,
"normalized_area": normalized_area,
}
)
return normalized
async def _resolve_dataset_area_input(
self,
*,
RawArea: str | None,
TenantCode: str | None,
TenantName: str | None,
) -> tuple[str, str | None, object]:
resolution = await self.TenantResolver.Resolve(
RawValue=RawArea,
Source="rag_dataset_input",
PreferredTenantCode=str(TenantCode or "").strip() or None,
FallbackTenantName=TenantName,
)
normalized_area = resolution.tenant_name or resolution.normalized_value or str(RawArea or "").strip()
return normalized_area, resolution.tenant_code, resolution
async def _clear_default_flags(self, session, tenant_code: str | None = None) -> None:
await self._ensure_rag_tenant_schema(session)
normalized_tenant_code = str(tenant_code or "").strip()
if normalized_tenant_code:
await session.execute(
text(
"""
UPDATE rag_dataset
SET is_default = FALSE
WHERE deleted_at IS NULL
AND tenant_code = :tenant_code
"""
),
{"tenant_code": normalized_tenant_code},
)
await session.execute(
text(
"""
UPDATE rag_chat_app
SET is_default = FALSE
WHERE deleted_at IS NULL
AND tenant_code = :tenant_code
"""
),
{"tenant_code": normalized_tenant_code},
)
return
await session.execute(
text(
"""
UPDATE rag_dataset
SET is_default = FALSE
WHERE deleted_at IS NULL
AND (tenant_code IS NULL OR BTRIM(tenant_code) = '')
"""
)
)
await session.execute(
text(
"""
UPDATE rag_chat_app
SET is_default = FALSE
WHERE deleted_at IS NULL
AND (tenant_code IS NULL OR BTRIM(tenant_code) = '')
"""
)
)
async def _ensure_linked_app(
self,
session,
dataset_id: int,
dataset_name: str,
dataset_area: str,
dataset_tenant_code: str | None,
current_user_id: int,
is_default: bool,
) -> None:
await self._ensure_rag_tenant_schema(session)
app_row = (
await session.execute(
text(
"""
SELECT id
FROM rag_chat_app
WHERE dataset_id = :dataset_id
AND deleted_at IS NULL
ORDER BY is_default DESC, sort_order ASC, id ASC
LIMIT 1
"""
),
{"dataset_id": dataset_id},
)
).mappings().first()
app_name = self._build_app_name(dataset_area=dataset_area, dataset_name=dataset_name)
if app_row:
await session.execute(
text(
"""
UPDATE rag_chat_app
SET name = :name,
description = :description,
area = :area,
tenant_code = :tenant_code,
is_default = :is_default,
status = 1,
updated_by = :updated_by,
updated_at = NOW()
WHERE id = :app_id
"""
),
{
"app_id": int(app_row["id"]),
"name": app_name,
"description": f"{dataset_area or '默认地区'}知识库问答助手",
"area": dataset_area or "",
"tenant_code": dataset_tenant_code,
"is_default": is_default,
"updated_by": current_user_id,
},
)
return
await session.execute(
text(
"""
INSERT INTO rag_chat_app (
name, description, area, tenant_code, dataset_id, suggested_questions,
opening_statement, sort_order, status, is_default, created_by, updated_by
) VALUES (
:name, :description, :area, :tenant_code, :dataset_id, CAST(:suggested_questions AS jsonb),
:opening_statement, 0, 1, :is_default, :created_by, :updated_by
)
"""
),
{
"name": app_name,
"description": f"{dataset_area or '默认地区'}知识库问答助手",
"area": dataset_area or "",
"tenant_code": dataset_tenant_code,
"dataset_id": dataset_id,
"suggested_questions": json.dumps([], ensure_ascii=False),
"opening_statement": f"您好,我是{app_name}",
"is_default": is_default,
"created_by": current_user_id,
"updated_by": current_user_id,
},
)
def _build_app_name(self, dataset_area: str, dataset_name: str) -> str:
cleaned = (dataset_name or "").strip()
if cleaned.endswith("知识库"):
cleaned = cleaned[:-3]
if cleaned.endswith("助手"):
return cleaned
if cleaned:
return f"{cleaned}助手"
return f"{dataset_area or '默认地区'}法务助手"
def _slugify_collection_name(self, area: str, name: str) -> str:
source = f"{area}_{name}".lower()
normalized = re.sub(r"[^a-z0-9]+", "_", source).strip("_")
if normalized:
return f"legal_kb_{normalized}"[:96]
return f"legal_kb_{uuid.uuid4().hex[:12]}"
@staticmethod
def _fallback_area_from_tenant_code(tenant_code: str | None) -> str | None:
normalized = str(tenant_code or "").strip().upper()
if normalized == "PUBLIC":
return "公共"
if normalized == "PROVINCIAL":
return "省级"
return None
def _resolve_managed_tenant_code(self, UserRole: str | None, TenantContext: dict[str, str | None]) -> str | None:
if self._role_is_global(UserRole):
return None
tenant_code = str(TenantContext.get("tenant_code") or "").strip()
return tenant_code or None
async def _resolve_managed_area(self, UserRole: str | None, UserArea: str | None, TenantContext: dict[str, str | None]) -> str | None:
if self._role_is_global(UserRole):
return None
area = str(TenantContext.get("area") or UserArea or "").strip()
if not area:
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户未配置租户/地区,无法管理知识库")
return area
async def _assert_manage_area_scope(
self,
UserRole: str | None,
UserArea: str | None,
TenantContext: dict[str, str | None],
DatasetArea: str,
DatasetTenantCode: str | None = None,
) -> None:
if self._role_is_global(UserRole):
return
managed_tenant_code = self._resolve_managed_tenant_code(UserRole=UserRole, TenantContext=TenantContext)
if managed_tenant_code:
if str(DatasetTenantCode or "").strip() != managed_tenant_code:
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户只能管理本租户知识库")
return
managed_area = await self._resolve_managed_area(UserRole=UserRole, UserArea=UserArea, TenantContext=TenantContext)
if DatasetArea != managed_area:
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户只能管理本地区知识库")
async def _ensure_rag_tenant_schema(self, session) -> None:
if self.__class__._tenant_schema_checked:
return
async with self.__class__._tenant_schema_lock:
if self.__class__._tenant_schema_checked:
return
columns = (
await session.execute(
text(
"""
SELECT table_name
FROM information_schema.columns
WHERE table_schema = current_schema()
AND table_name IN ('rag_dataset', 'rag_chat_app')
AND column_name = 'tenant_code'
"""
)
)
).scalars().all()
existing = set(columns)
if existing == {"rag_dataset", "rag_chat_app"}:
self.__class__._tenant_schema_checked = True
return
await session.execute(text("SET LOCAL lock_timeout = '1000ms'"))
if "rag_dataset" not in existing:
await session.execute(text("ALTER TABLE rag_dataset ADD COLUMN tenant_code VARCHAR(64) NULL"))
if "rag_chat_app" not in existing:
await session.execute(text("ALTER TABLE rag_chat_app ADD COLUMN tenant_code VARCHAR(64) NULL"))
await session.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_dataset_tenant_code ON rag_dataset(tenant_code) WHERE deleted_at IS NULL"))
await session.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_chat_app_tenant_code ON rag_chat_app(tenant_code) WHERE deleted_at IS NULL"))
self.__class__._tenant_schema_checked = True
def _dataset_tenant_filter_sql(
self,
tenant_code: str | None,
area: str | None,
*,
prefix: str,
alias: str,
params: dict,
) -> list[str]:
normalized_tenant_code = str(tenant_code or "").strip()
normalized_area = str(area or "").strip()
if normalized_tenant_code:
params[f"{prefix}_tenant_code"] = normalized_tenant_code
return [f"{alias}.tenant_code = :{prefix}_tenant_code"]
if normalized_area:
params[f"{prefix}_area"] = normalized_area
return [f"COALESCE({alias}.area, '') = :{prefix}_area"]
return ["1 = 0"]
@staticmethod
def _tenant_context_is_global(tenant_context: dict[str, str | None]) -> bool:
tenant_code = str(tenant_context.get("tenant_code") or "").strip().upper()
return tenant_code in {"PUBLIC", "PROVINCIAL"}
@staticmethod
def _role_is_global(user_role: str | None) -> bool:
normalized = str(user_role or "").strip()
return normalized in {"super_admin", "provincial_admin"}
def _row_matches_tenant_scope(
self,
*,
row_tenant_code: str | None,
row_area: str | None,
tenant_context: dict[str, str | None],
) -> bool:
user_tenant_code = str(tenant_context.get("tenant_code") or "").strip()
if user_tenant_code:
return str(row_tenant_code or "").strip() == user_tenant_code
return str(row_area or "").strip() == str(tenant_context.get("area") or "").strip()
async def UploadDatasetDocument(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
DatasetId: int,
FileName: str,
ContentType: str | None,
Content: bytes,
ProcessConfig: dict | None,
) -> RagDatasetUploadDocumentVO:
dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId)
if not dataset:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
suffix = os.path.splitext(FileName)[1].lower()
if suffix not in {".pdf", ".docx", ".txt", ".md", ".json"}:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD、JSON")
object_key = f"rag/{DatasetId}/{datetime.now().strftime('%Y/%m/%d')}/{uuid.uuid4().hex}_{FileName}"
content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream"
OssClient().EnsureBucket()
stored_key = OssClient().UploadBytes(ObjectKey=object_key, Content=Content, ContentType=content_type)
async with GetAsyncSession() as session:
inserted = (
await session.execute(
text(
"""
INSERT INTO rag_document (
dataset_id, filename, original_name, minio_path, file_type, file_size,
chunk_count, indexing_status, enabled, hit_count, created_by
) VALUES (
:dataset_id, :filename, :original_name, :minio_path, :file_type, :file_size,
0, 'indexing', TRUE, 0, :created_by
)
RETURNING id, created_at, updated_at
"""
),
{
"dataset_id": DatasetId,
"filename": uuid.uuid4().hex + suffix,
"original_name": FileName,
"minio_path": stored_key,
"file_type": suffix.lstrip("."),
"file_size": len(Content),
"created_by": CurrentUserId,
},
)
).mappings().first()
document_id = int(inserted["id"])
asyncio.create_task(
self._run_document_indexing_task(
dataset=dataset,
dataset_id=DatasetId,
document_id=document_id,
file_name=FileName,
content=Content,
process_config=ProcessConfig or {},
)
)
return RagDatasetUploadDocumentVO(
document={
"id": str(document_id),
"name": FileName,
"indexing_status": "indexing",
"word_count": 0,
"hit_count": 0,
"enabled": True,
},
batch=str(document_id),
)
async def GetDatasetDocumentSegments(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
DatasetId: int,
DocumentId: int,
Page: int,
Limit: int,
Keyword: str | None,
) -> RagDatasetSegmentPageVO:
dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId)
if not dataset:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
raw = collection.get(where={"document_id": DocumentId}, include=["documents", "metadatas"])
ids = raw.get("ids") or []
docs = raw.get("documents") or []
metas = raw.get("metadatas") or []
items: list[dict] = []
for index, segment_id in enumerate(ids):
content = docs[index] if index < len(docs) else ""
meta = metas[index] if index < len(metas) and isinstance(metas[index], dict) else {}
if Keyword and Keyword.strip() and Keyword.strip() not in content:
continue
items.append(
{
"id": str(segment_id),
"position": index + 1,
"document_id": str(DocumentId),
"content": content or "",
"word_count": len(content or ""),
"hit_count": 0,
"enabled": True,
"status": "completed",
"created_at": 0,
}
)
offset = max(Page - 1, 0) * Limit
page_items = items[offset: offset + Limit]
has_more = offset + Limit < len(items)
return RagDatasetSegmentPageVO(
data=[
RagDatasetSegmentItemVO(
id=item["id"],
position=item["position"],
documentId=item["document_id"],
content=item["content"],
wordCount=item["word_count"],
hitCount=item["hit_count"],
enabled=item["enabled"],
status=item["status"],
createdAt=item["created_at"],
)
for item in page_items
],
total=len(items),
limit=Limit,
hasMore=has_more,
)
async def GetDatasetDocumentSegmentDetail(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
DatasetId: int,
DocumentId: int,
SegmentId: str,
) -> RagDatasetSegmentItemVO | None:
dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId)
if not dataset:
return None
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
raw = collection.get(ids=[SegmentId], include=["documents", "metadatas"])
ids = raw.get("ids") or []
if not ids:
return None
content = (raw.get("documents") or [""])[0] or ""
metadata = (raw.get("metadatas") or [{}])[0] or {}
if str(metadata.get("document_id") or "") != str(DocumentId):
return None
return RagDatasetSegmentItemVO(
id=str(SegmentId),
position=int(metadata.get("chunk_index") or 0) + 1,
documentId=str(DocumentId),
content=content,
wordCount=len(content),
hitCount=0,
enabled=True,
status="completed",
createdAt=0,
)
async def UpdateDatasetDocumentSegment(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
DatasetId: int,
DocumentId: int,
SegmentId: str,
Body: dict,
) -> RagDatasetSegmentItemVO | None:
dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId)
if not dataset:
return None
current = await self.GetDatasetDocumentSegmentDetail(
CurrentUserId=CurrentUserId,
UserArea=UserArea,
UserRole=UserRole,
TenantCode=TenantCode,
TenantName=TenantName,
DatasetId=DatasetId,
DocumentId=DocumentId,
SegmentId=SegmentId,
)
if not current:
return None
segment_body = Body.get("segment") if isinstance(Body.get("segment"), dict) else Body
content = str(segment_body.get("content") or current.content).strip()
if not content:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "分段内容不能为空")
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
raw = collection.get(ids=[SegmentId], include=["metadatas"])
metadata = (raw.get("metadatas") or [{}])[0] or {}
embeddings = await self._embed_texts([content], dataset.get("embedding_model") or "")
collection.update(
ids=[SegmentId],
documents=[content],
embeddings=embeddings,
metadatas=[metadata],
)
return RagDatasetSegmentItemVO(
id=str(SegmentId),
position=int(metadata.get("chunk_index") or 0) + 1,
documentId=str(DocumentId),
content=content,
wordCount=len(content),
hitCount=0,
enabled=True,
status="completed",
createdAt=0,
)
async def DeleteDatasetDocumentSegment(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
DatasetId: int,
DocumentId: int,
SegmentId: str,
) -> RagDatasetBatchDeleteResultVO:
dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId)
if not dataset:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
current = await self.GetDatasetDocumentSegmentDetail(
CurrentUserId=CurrentUserId,
UserArea=UserArea,
UserRole=UserRole,
TenantCode=TenantCode,
TenantName=TenantName,
DatasetId=DatasetId,
DocumentId=DocumentId,
SegmentId=SegmentId,
)
if not current:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "分段不存在")
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
collection.delete(ids=[SegmentId])
async with GetAsyncSession() as session:
await session.execute(
text(
"""
UPDATE rag_document
SET chunk_count = GREATEST(chunk_count - 1, 0),
updated_at = NOW()
WHERE id = :document_id
"""
),
{"document_id": DocumentId},
)
await session.execute(
text(
"""
UPDATE rag_dataset
SET total_chunks = GREATEST(total_chunks - 1, 0),
updated_at = NOW()
WHERE id = :dataset_id
"""
),
{"dataset_id": DatasetId},
)
return RagOperationResultVO(result="success")
async def DeleteDatasetDocument(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
DatasetId: int,
DocumentId: int,
) -> RagOperationResultVO:
result = await self.BatchDeleteDatasetDocuments(
CurrentUserId=CurrentUserId,
UserArea=UserArea,
UserRole=UserRole,
TenantCode=TenantCode,
TenantName=TenantName,
DatasetId=DatasetId,
DocumentIds=[DocumentId],
)
if result.deletedCount <= 0:
reason = result.skipped[0].reason if result.skipped else "文档删除失败"
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, reason)
return RagOperationResultVO(result="success")
async def BatchDeleteDatasetDocuments(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
DatasetId: int,
DocumentIds: list[int],
) -> RagOperationResultVO:
dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId)
if not dataset:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
normalized_ids = [int(item) for item in DocumentIds if item]
if not normalized_ids:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "缺少待删除文档")
async with GetAsyncSession() as session:
document_rows = (
await session.execute(
text(
"""
SELECT id, dataset_id, minio_path, original_name, indexing_status
FROM rag_document
WHERE id = ANY(:document_ids)
AND dataset_id = :dataset_id
AND deleted_at IS NULL
"""
),
{"document_ids": normalized_ids, "dataset_id": DatasetId},
)
).mappings().all()
if not document_rows:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在")
deletable_rows = []
skipped = []
for row in document_rows:
if row.get("indexing_status") not in self._DELETABLE_DOCUMENT_STATUSES:
skipped.append(
{
"id": int(row["id"]),
"name": str(row.get("original_name") or ""),
"reason": "文档仍在处理中,暂不允许删除",
}
)
continue
deletable_rows.append(row)
if not deletable_rows:
return RagDatasetBatchDeleteResultVO(
result="success",
requestedCount=len(normalized_ids),
deletedCount=0,
skippedCount=len(skipped),
deletedIds=[],
skipped=skipped,
)
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
chunk_ids: list[str] = []
for document_row in deletable_rows:
raw = collection.get(where={"document_id": int(document_row["id"])}, include=[])
ids = raw.get("ids") or []
if ids:
chunk_ids.extend(ids)
if chunk_ids:
collection.delete(ids=chunk_ids)
for document_row in deletable_rows:
self._delete_oss_object(document_row.get("minio_path"))
existing_ids = [int(row["id"]) for row in deletable_rows]
async with GetAsyncSession() as session:
await session.execute(
text(
"""
UPDATE rag_document
SET deleted_at = NOW(),
updated_at = NOW()
WHERE id = ANY(:document_ids)
"""
),
{"document_ids": existing_ids},
)
await session.execute(
text(
"""
UPDATE rag_dataset
SET document_count = (
SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
),
total_chunks = COALESCE((
SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
), 0),
updated_at = NOW()
WHERE id = :dataset_id
"""
),
{"dataset_id": DatasetId},
)
return RagDatasetBatchDeleteResultVO(
result="success",
requestedCount=len(normalized_ids),
deletedCount=len(existing_ids),
skippedCount=len(skipped),
deletedIds=existing_ids,
skipped=skipped,
)
async def RetrieveDataset(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
DatasetId: int,
Query: str,
RetrievalModel: dict | None,
) -> RagDatasetRetrieveResponseVO:
dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId)
if not dataset:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
query_text = (Query or "").strip()
if not query_text:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "检索内容不能为空")
retrieval_model = RetrievalModel or {}
top_k = int(retrieval_model.get("top_k") or 5)
top_k = max(1, min(top_k, 20))
score_threshold_enabled = bool(retrieval_model.get("score_threshold_enabled"))
score_threshold = float(retrieval_model.get("score_threshold") or 0) if score_threshold_enabled else None
query_embedding = await self._embed_texts([query_text], dataset.get("embedding_model") or "")
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
raw = collection.query(
query_embeddings=query_embedding,
n_results=top_k,
include=["documents", "metadatas", "distances"],
)
ids = (raw.get("ids") or [[]])[0] if raw.get("ids") else []
documents = (raw.get("documents") or [[]])[0] if raw.get("documents") else []
metadatas = (raw.get("metadatas") or [[]])[0] if raw.get("metadatas") else []
distances = (raw.get("distances") or [[]])[0] if raw.get("distances") else []
records: list[RagDatasetRetrieveRecordVO] = []
for index, segment_id in enumerate(ids):
content = documents[index] if index < len(documents) else ""
metadata = metadatas[index] if index < len(metadatas) and isinstance(metadatas[index], dict) else {}
distance = float(distances[index]) if index < len(distances) and distances[index] is not None else 1.0
score = max(0.0, min(1.0, 1.0 / (1.0 + max(0.0, distance))))
if score_threshold_enabled and score_threshold is not None and score < score_threshold:
continue
document_name = metadata.get("document_name") or metadata.get("source") or ""
document_id = str(metadata.get("document_id") or "")
chunk_index = int(metadata.get("chunk_index") or index)
records.append(
RagDatasetRetrieveRecordVO(
score=round(score, 6),
segment=RagDatasetRetrieveSegmentVO(
id=str(segment_id),
position=chunk_index + 1,
documentId=document_id,
content=content or "",
answer="",
wordCount=len(content or ""),
hitCount=0,
enabled=True,
status="completed",
createdAt=0,
document=RagDatasetRetrieveDocumentVO(
id=document_id,
dataSourceType="upload_file",
name=document_name,
docType=None,
),
),
)
)
return RagDatasetRetrieveResponseVO(
query={"content": query_text},
records=records,
)
async def GetDatasetDocumentIndexingStatus(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
DatasetId: int,
DocumentId: int,
) -> dict:
document = await self.GetDatasetDocumentDetail(
CurrentUserId=CurrentUserId,
UserArea=UserArea,
UserRole=UserRole,
TenantCode=TenantCode,
TenantName=TenantName,
DatasetId=DatasetId,
DocumentId=DocumentId,
)
if not document:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在")
completed_segments = document.chunkCount if document.indexingStatus == "completed" else 0
total_segments = document.chunkCount if document.chunkCount > 0 else 0
return {
"data": [
{
"id": str(document.id),
"indexing_status": document.indexingStatus,
"processing_started_at": document.updatedAt or document.createdAt or None,
"parsing_completed_at": document.updatedAt if document.indexingStatus in ("cleaning", "splitting", "indexing", "completed") else None,
"cleaning_completed_at": document.updatedAt if document.indexingStatus in ("splitting", "indexing", "completed") else None,
"splitting_completed_at": document.updatedAt if document.indexingStatus in ("indexing", "completed") else None,
"completed_at": document.updatedAt if document.indexingStatus == "completed" else None,
"paused_at": None,
"error": document.error or None,
"stopped_at": None,
"completed_segments": completed_segments,
"total_segments": total_segments,
}
]
}
async def UpdateDatasetDocumentByFile(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
DatasetId: int,
DocumentId: int,
FileName: str,
ContentType: str | None,
Content: bytes,
ProcessConfig: dict | None,
) -> RagDatasetUploadDocumentVO:
dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId)
if not dataset:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
async with GetAsyncSession() as session:
current = (
await session.execute(
text(
"""
SELECT id, minio_path
FROM rag_document
WHERE id = :document_id
AND dataset_id = :dataset_id
AND deleted_at IS NULL
LIMIT 1
"""
),
{"document_id": DocumentId, "dataset_id": DatasetId},
)
).mappings().first()
if not current:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在")
suffix = os.path.splitext(FileName)[1].lower()
if suffix not in {".pdf", ".docx", ".txt", ".md", ".json"}:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD、JSON")
content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream"
object_key = current.get("minio_path") or f"rag/{DatasetId}/{datetime.now().strftime('%Y/%m/%d')}/{uuid.uuid4().hex}_{FileName}"
OssClient().EnsureBucket()
stored_key = OssClient().UploadBytes(ObjectKey=object_key, Content=Content, ContentType=content_type)
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
raw = collection.get(where={"document_id": DocumentId}, include=[])
ids = raw.get("ids") or []
if ids:
collection.delete(ids=ids)
async with GetAsyncSession() as session:
await session.execute(
text(
"""
UPDATE rag_document
SET original_name = :original_name,
minio_path = :minio_path,
file_type = :file_type,
file_size = :file_size,
chunk_count = 0,
indexing_status = 'indexing',
indexing_error = NULL,
updated_at = NOW()
WHERE id = :document_id
"""
),
{
"document_id": DocumentId,
"original_name": FileName,
"minio_path": stored_key,
"file_type": suffix.lstrip("."),
"file_size": len(Content),
},
)
asyncio.create_task(
self._run_document_indexing_task(
dataset=dataset,
dataset_id=DatasetId,
document_id=DocumentId,
file_name=FileName,
content=Content,
process_config=ProcessConfig or {},
)
)
return RagDatasetUploadDocumentVO(
document={
"id": str(DocumentId),
"name": FileName,
"indexing_status": "indexing",
"word_count": 0,
"hit_count": 0,
"enabled": True,
},
batch=str(DocumentId),
)
async def BatchUpdateDatasetDocumentStatus(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
TenantCode: str | None,
TenantName: str | None,
DatasetId: int,
DocumentIds: list[int],
Enabled: bool,
) -> RagOperationResultVO:
dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId)
if not dataset:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
ids = [int(doc_id) for doc_id in DocumentIds]
if not ids:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "未传入文档ID")
async with GetAsyncSession() as session:
await session.execute(
text(
"""
UPDATE rag_document
SET enabled = :enabled,
updated_at = NOW()
WHERE dataset_id = :dataset_id
AND id = ANY(:document_ids)
AND deleted_at IS NULL
"""
),
{"dataset_id": DatasetId, "document_ids": ids, "enabled": Enabled},
)
return RagOperationResultVO(result="success")
def _extract_page_texts(self, *, FileName: str, Content: bytes) -> list[tuple[int, str]]:
suffix = os.path.splitext(FileName)[1].lower()
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp_file:
temp_file.write(Content)
temp_path = temp_file.name
try:
from fastapi_modules.fastapi_leaudit.services.impl.documentServiceImpl import (
_extract_page_texts_from_docx,
_extract_page_texts_from_pdf,
)
if suffix == ".pdf":
return _extract_page_texts_from_pdf(Path(temp_path))
if suffix == ".docx":
return _extract_page_texts_from_docx(Path(temp_path))
if suffix == ".json":
return self._extract_page_texts_from_json(Content)
text_value = Content.decode("utf-8", errors="ignore").strip()
return [(1, text_value)] if text_value else []
finally:
try:
os.unlink(temp_path)
except OSError:
pass
def _extract_page_texts_from_json(self, content: bytes) -> list[tuple[int, str]]:
text_value = content.decode("utf-8", errors="ignore").strip()
if not text_value:
return []
try:
payload = json.loads(text_value)
except json.JSONDecodeError:
return [(1, text_value)]
flattened = self._json_to_text(payload)
flattened = flattened.strip()
return [(1, flattened)] if flattened else []
def _json_to_text(self, value: object, prefix: str = "") -> str:
if value is None:
return ""
if isinstance(value, str):
return f"{prefix}: {value}" if prefix else value
if isinstance(value, (int, float, bool)):
return f"{prefix}: {value}" if prefix else str(value)
if isinstance(value, list):
parts: list[str] = []
for index, item in enumerate(value, start=1):
item_prefix = f"{prefix}[{index}]" if prefix else f"item_{index}"
chunk = self._json_to_text(item, item_prefix)
if chunk:
parts.append(chunk)
return "\n".join(parts)
if isinstance(value, dict):
parts = []
for key, item in value.items():
item_prefix = f"{prefix}.{key}" if prefix else str(key)
chunk = self._json_to_text(item, item_prefix)
if chunk:
parts.append(chunk)
return "\n".join(parts)
rendered = str(value).strip()
if not rendered:
return ""
return f"{prefix}: {rendered}" if prefix else rendered
def _build_chunks(self, *, file_name: str, page_texts: list[tuple[int, str]], dataset: dict, process_config: dict, document_id: int) -> list[dict]:
rules = ((process_config or {}).get("process_rule") or {}).get("rules") or {}
segmentation = rules.get("segmentation") or {}
pre_rules = rules.get("pre_processing_rules") or []
remove_spaces = any(rule.get("id") == "remove_extra_spaces" and rule.get("enabled") for rule in pre_rules)
remove_urls = any(rule.get("id") == "remove_urls_emails" and rule.get("enabled") for rule in pre_rules)
separator = segmentation.get("separator") or "\n\n"
max_tokens = int(segmentation.get("max_tokens") or dataset.get("chunk_max_size") or 800)
chunk_overlap = int(segmentation.get("chunk_overlap") or 50)
chunk_overlap = max(0, min(chunk_overlap, max_tokens // 2 if max_tokens > 1 else 0))
chunks: list[dict] = []
for page_no, raw_text in page_texts:
text_value = self._preprocess_text(raw_text, remove_spaces=remove_spaces, remove_urls=remove_urls)
if not text_value:
continue
for index, chunk_text in enumerate(self._split_text(text_value, separator=separator, max_chars=max_tokens, overlap=chunk_overlap)):
if not chunk_text.strip():
continue
chunk_id = f"{document_id}:{page_no}:{index}"
chunks.append(
{
"id": chunk_id,
"text": chunk_text,
"metadata": {
"id": chunk_id,
"source": file_name,
"document_name": file_name,
"document_id": document_id,
"page": page_no,
"chunk_index": index,
},
}
)
return chunks
async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]:
embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or build_openai_embeddings_url(RAG_CONFIG["LLM_BASE_URL"])
embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"]
embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4"
batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10))
if not embed_url or not embed_key:
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务")
embeddings: list[list[float]] = []
async with httpx.AsyncClient(timeout=120.0) as client:
for start in range(0, len(texts), batch_size):
batch_texts = texts[start:start + batch_size]
try:
response = await client.post(
embed_url,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {embed_key}",
},
json={"model": embed_model, "input": batch_texts},
)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}"
raise LeauditException(
StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR,
f"向量化服务调用失败: {error_message[:300]}",
) from exc
payload = response.json()
rows = payload.get("data") or []
batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")]
if len(batch_embeddings) != len(batch_texts):
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
embeddings.extend(batch_embeddings)
if len(embeddings) != len(texts):
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
return embeddings
async def _run_document_indexing_task(
self,
*,
dataset: dict,
dataset_id: int,
document_id: int,
file_name: str,
content: bytes,
process_config: dict,
) -> None:
try:
await self._update_document_processing_state(document_id=document_id, status="parsing")
page_texts = self._extract_page_texts(FileName=file_name, Content=content)
await self._update_document_processing_state(document_id=document_id, status="cleaning")
processed = self._build_chunks(
file_name=file_name,
page_texts=page_texts,
dataset=dataset,
process_config=process_config,
document_id=document_id,
)
if not processed:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文档未提取到可入库文本")
await self._update_document_processing_state(
document_id=document_id,
status="splitting",
chunk_count=len(processed),
)
embeddings = await self._embed_texts([item["text"] for item in processed], dataset.get("embedding_model") or "")
await self._update_document_processing_state(
document_id=document_id,
status="indexing",
chunk_count=len(processed),
)
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
collection.add(
ids=[item["id"] for item in processed],
documents=[item["text"] for item in processed],
embeddings=embeddings,
metadatas=[item["metadata"] for item in processed],
)
async with GetAsyncSession() as session:
await session.execute(
text(
"""
UPDATE rag_document
SET chunk_count = :chunk_count,
indexing_status = 'completed',
indexing_error = NULL,
indexing_started_at = COALESCE(indexing_started_at, NOW()),
indexing_completed_at = NOW(),
updated_at = NOW()
WHERE id = :document_id
"""
),
{"chunk_count": len(processed), "document_id": document_id},
)
await self._sync_dataset_counts(dataset_id)
except Exception as exc:
async with GetAsyncSession() as session:
await session.execute(
text(
"""
UPDATE rag_document
SET indexing_status = 'error',
indexing_error = :error,
updated_at = NOW()
WHERE id = :document_id
"""
),
{"document_id": document_id, "error": str(exc)[:2000]},
)
async def _update_document_processing_state(self, *, document_id: int, status: str, chunk_count: int | None = None) -> None:
fields = [
"indexing_status = :status",
"indexing_error = NULL",
"indexing_started_at = COALESCE(indexing_started_at, NOW())",
"updated_at = NOW()",
]
params: dict[str, object] = {
"document_id": document_id,
"status": status,
}
if chunk_count is not None:
fields.append("chunk_count = :chunk_count")
params["chunk_count"] = chunk_count
async with GetAsyncSession() as session:
await session.execute(
text(
f"""
UPDATE rag_document
SET {", ".join(fields)}
WHERE id = :document_id
"""
),
params,
)
async def _sync_dataset_counts(self, dataset_id: int) -> None:
async with GetAsyncSession() as session:
await session.execute(
text(
"""
UPDATE rag_dataset
SET document_count = (
SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
),
total_chunks = COALESCE((
SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
), 0),
updated_at = NOW()
WHERE id = :dataset_id
"""
),
{"dataset_id": dataset_id},
)
def _preprocess_text(self, text_value: str, *, remove_spaces: bool, remove_urls: bool) -> str:
result = text_value or ""
if remove_urls:
result = re.sub(r"https?://\\S+|www\\.\\S+|[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}", " ", result)
if remove_spaces:
result = re.sub(r"[ \\t]+", " ", result)
result = re.sub(r"\n{3,}", "\n\n", result)
return result.strip()
def _split_text(self, text_value: str, *, separator: str, max_chars: int, overlap: int) -> list[str]:
parts = [part.strip() for part in text_value.split(separator) if part.strip()] if separator else [text_value]
chunks: list[str] = []
current = ""
for part in parts:
candidate = f"{current}{separator if current else ''}{part}" if separator else f"{current}{part}"
if len(candidate) <= max_chars:
current = candidate
continue
if current:
chunks.append(current)
if len(part) <= max_chars:
current = part
continue
start = 0
step = max(max_chars - overlap, 1)
while start < len(part):
chunks.append(part[start:start + max_chars])
start += step
current = ""
if current:
chunks.append(current)
return chunks
def _delete_oss_object(self, source: str | None) -> None:
if not source:
return
try:
oss = OssClient()
ref = oss.ResolveObjectRef(Source=source)
if ref.isDirectUrl or not ref.objectKey:
return
oss._GetMinioClient().remove_object(ref.bucket, ref.objectKey)
except Exception:
# 对象存储删除失败不阻塞业务主流程,避免历史脏数据导致文档无法删除。
return