1713 lines
71 KiB
Python
1713 lines
71 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import json
|
|
import mimetypes
|
|
import os
|
|
import re
|
|
import tempfile
|
|
import uuid
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
import httpx
|
|
from sqlalchemy import text
|
|
|
|
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
|
from fastapi_common.fastapi_common_storage.oss_client import OssClient
|
|
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
|
|
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
|
|
|
|
from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import RagDatasetUpdateDTO
|
|
from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import (
|
|
RagDatasetBatchDeleteResultVO,
|
|
RagDatasetDetailVO,
|
|
RagDatasetDocumentItemVO,
|
|
RagDatasetDocumentPageVO,
|
|
RagDatasetItemVO,
|
|
RagDatasetPageVO,
|
|
RagDatasetRetrieveDocumentVO,
|
|
RagDatasetRetrieveRecordVO,
|
|
RagDatasetRetrieveResponseVO,
|
|
RagDatasetRetrieveSegmentVO,
|
|
RagDatasetSegmentItemVO,
|
|
RagDatasetSegmentPageVO,
|
|
RagDatasetUploadDocumentVO,
|
|
)
|
|
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import RagOperationResultVO
|
|
from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma
|
|
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG
|
|
from fastapi_modules.fastapi_leaudit.services.ragDatasetService import IRagDatasetService
|
|
|
|
|
|
class RagDatasetServiceImpl(IRagDatasetService):
|
|
_ACTIVE_INDEXING_STATUSES = {"waiting", "parsing", "cleaning", "splitting", "indexing"}
|
|
_DELETABLE_DOCUMENT_STATUSES = {"completed", "error", "paused"}
|
|
_APP_LINK_SQL = """
|
|
LEFT JOIN (
|
|
SELECT DISTINCT ON (dataset_id)
|
|
dataset_id,
|
|
id,
|
|
name,
|
|
is_default
|
|
FROM rag_chat_app
|
|
WHERE deleted_at IS NULL
|
|
AND status = 1
|
|
ORDER BY dataset_id, is_default DESC, sort_order ASC, id ASC
|
|
) a ON a.dataset_id = d.id
|
|
"""
|
|
|
|
async def GetAdminDatasets(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
Area: str | None,
|
|
OnlyEnabled: bool | None,
|
|
Page: int,
|
|
PageSize: int,
|
|
) -> RagDatasetPageVO:
|
|
if UserRole not in ("provincial_admin", "admin", "super_admin"):
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有管理知识库权限")
|
|
managed_area = self._resolve_managed_area(UserRole=UserRole, UserArea=UserArea)
|
|
|
|
filters = ["d.deleted_at IS NULL"]
|
|
params: dict = {
|
|
"offset": max(Page - 1, 0) * PageSize,
|
|
"limit": PageSize,
|
|
}
|
|
areas = [item.strip() for item in str(Area or "").split(",") if item.strip()]
|
|
if managed_area:
|
|
if areas and any(item != managed_area for item in areas):
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户只能查看本地区知识库配置")
|
|
filters.append("d.area = :managed_area")
|
|
params["managed_area"] = managed_area
|
|
elif len(areas) == 1:
|
|
filters.append("d.area = :area")
|
|
params["area"] = areas[0]
|
|
elif len(areas) > 1:
|
|
filters.append("d.area = ANY(:areas)")
|
|
params["areas"] = areas
|
|
if OnlyEnabled is not None:
|
|
filters.append("d.status = :status")
|
|
params["status"] = 1 if OnlyEnabled else 0
|
|
|
|
where_sql = " AND ".join(filters)
|
|
async with GetAsyncSession() as session:
|
|
total = (
|
|
await session.execute(
|
|
text(f"SELECT COUNT(1) FROM rag_dataset d WHERE {where_sql}"),
|
|
{k: v for k, v in params.items() if k not in ("offset", "limit")},
|
|
)
|
|
).scalar_one()
|
|
rows = (
|
|
await session.execute(
|
|
text(
|
|
f"""
|
|
SELECT d.id, d.name, d.description, d.area, d.is_public, d.is_default, d.document_count, d.total_chunks, d.status,
|
|
d.sort_order, d.created_at, d.updated_at,
|
|
a.id AS app_id, COALESCE(a.name, '') AS app_name, COALESCE(a.is_default, FALSE) AS app_is_default
|
|
FROM rag_dataset d
|
|
{self._APP_LINK_SQL}
|
|
WHERE {where_sql}
|
|
ORDER BY d.sort_order ASC, d.id ASC
|
|
OFFSET :offset LIMIT :limit
|
|
"""
|
|
),
|
|
params,
|
|
)
|
|
).mappings().all()
|
|
return RagDatasetPageVO(
|
|
data=[self._to_item_vo(dict(row)) for row in rows],
|
|
total=int(total or 0),
|
|
)
|
|
|
|
async def CreateAdminDataset(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
Body: dict,
|
|
) -> RagDatasetDetailVO:
|
|
if UserRole not in ("provincial_admin", "admin", "super_admin"):
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有创建知识库权限")
|
|
|
|
area = str(Body.get("area") or "").strip()
|
|
name = str(Body.get("dataset_name") or Body.get("name") or "").strip()
|
|
description = str(Body.get("dataset_description") or Body.get("description") or "").strip()
|
|
if not area or not name:
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "地区和知识库名称不能为空")
|
|
self._assert_manage_area_scope(UserRole=UserRole, UserArea=UserArea, DatasetArea=area)
|
|
|
|
collection_name = self._slugify_collection_name(area, name)
|
|
retrieval_model = {}
|
|
async with GetAsyncSession() as session:
|
|
base = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
SELECT embedding_model, embedding_dim, chunk_max_size, chunk_min_size, retrieval_model
|
|
FROM rag_dataset
|
|
WHERE deleted_at IS NULL
|
|
ORDER BY is_default DESC, id ASC
|
|
LIMIT 1
|
|
"""
|
|
)
|
|
)
|
|
).mappings().first()
|
|
if base:
|
|
retrieval_model = base.get("retrieval_model") or {}
|
|
if bool(Body.get("is_default")):
|
|
await self._clear_default_flags(session)
|
|
row = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO rag_dataset (
|
|
name, description, area, is_public, is_default, collection_name,
|
|
embedding_model, embedding_dim, chunk_max_size, chunk_min_size,
|
|
retrieval_model, sort_order, status, created_by, updated_by
|
|
) VALUES (
|
|
:name, :description, :area, :is_public, :is_default, :collection_name,
|
|
:embedding_model, :embedding_dim, :chunk_max_size, :chunk_min_size,
|
|
CAST(:retrieval_model AS jsonb), :sort_order, :status, :created_by, :updated_by
|
|
)
|
|
RETURNING id
|
|
"""
|
|
),
|
|
{
|
|
"name": name,
|
|
"description": description,
|
|
"area": area,
|
|
"is_public": bool(Body.get("is_public")),
|
|
"is_default": bool(Body.get("is_default")),
|
|
"collection_name": collection_name,
|
|
"embedding_model": (base.get("embedding_model") if base else "text-embedding-v4"),
|
|
"embedding_dim": (base.get("embedding_dim") if base else 1024),
|
|
"chunk_max_size": (base.get("chunk_max_size") if base else 800),
|
|
"chunk_min_size": (base.get("chunk_min_size") if base else 20),
|
|
"retrieval_model": json.dumps(retrieval_model, ensure_ascii=False),
|
|
"sort_order": int(Body.get("sort_order") or 0),
|
|
"status": int(Body.get("status") or 1),
|
|
"created_by": CurrentUserId,
|
|
"updated_by": CurrentUserId,
|
|
},
|
|
)
|
|
).mappings().first()
|
|
dataset_id = int(row["id"])
|
|
await self._ensure_linked_app(
|
|
session=session,
|
|
dataset_id=dataset_id,
|
|
dataset_name=name,
|
|
dataset_area=area,
|
|
current_user_id=CurrentUserId,
|
|
is_default=bool(Body.get("is_default")),
|
|
)
|
|
refreshed = await self._get_dataset_row(dataset_id)
|
|
return self._to_detail_vo(refreshed) if refreshed else None
|
|
|
|
async def UpdateAdminDataset(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
DatasetId: int,
|
|
Body: dict,
|
|
) -> RagDatasetDetailVO | None:
|
|
if UserRole not in ("provincial_admin", "admin", "super_admin"):
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有更新知识库权限")
|
|
existing = await self._get_dataset_row(DatasetId)
|
|
if not existing:
|
|
return None
|
|
self._assert_manage_area_scope(UserRole=UserRole, UserArea=UserArea, DatasetArea=str(existing.get("area") or ""))
|
|
|
|
area = str(Body.get("area") or existing.get("area") or "").strip()
|
|
self._assert_manage_area_scope(UserRole=UserRole, UserArea=UserArea, DatasetArea=area)
|
|
|
|
async with GetAsyncSession() as session:
|
|
target_is_default = bool(Body.get("is_default", existing.get("is_default")))
|
|
if target_is_default:
|
|
await self._clear_default_flags(session)
|
|
elif existing.get("is_default") and Body.get("is_default") is False:
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "默认知识库不能直接取消,请先将其他知识库设为默认")
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
UPDATE rag_dataset
|
|
SET name = :name,
|
|
description = :description,
|
|
area = :area,
|
|
is_public = :is_public,
|
|
is_default = :is_default,
|
|
sort_order = :sort_order,
|
|
status = :status,
|
|
updated_by = :updated_by,
|
|
updated_at = NOW()
|
|
WHERE id = :dataset_id
|
|
"""
|
|
),
|
|
{
|
|
"dataset_id": DatasetId,
|
|
"name": str(Body.get("dataset_name") or Body.get("name") or existing.get("name") or "").strip(),
|
|
"description": str(Body.get("dataset_description") or Body.get("description") or existing.get("description") or "").strip(),
|
|
"area": area,
|
|
"is_public": bool(Body.get("is_public", existing.get("is_public"))),
|
|
"is_default": target_is_default,
|
|
"sort_order": int(Body.get("sort_order") if Body.get("sort_order") is not None else (existing.get("sort_order") or 0)),
|
|
"status": int(Body.get("status") if Body.get("status") is not None else (existing.get("status") or 1)),
|
|
"updated_by": CurrentUserId,
|
|
},
|
|
)
|
|
await self._ensure_linked_app(
|
|
session=session,
|
|
dataset_id=DatasetId,
|
|
dataset_name=str(Body.get("dataset_name") or Body.get("name") or existing.get("name") or "").strip(),
|
|
dataset_area=area,
|
|
current_user_id=CurrentUserId,
|
|
is_default=target_is_default,
|
|
)
|
|
refreshed = await self._get_dataset_row(DatasetId)
|
|
return self._to_detail_vo(refreshed) if refreshed else None
|
|
|
|
async def DeleteAdminDataset(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
DatasetId: int,
|
|
) -> RagOperationResultVO:
|
|
if UserRole not in ("provincial_admin", "admin", "super_admin"):
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有删除知识库权限")
|
|
existing = await self._get_dataset_row(DatasetId)
|
|
if not existing:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
|
|
self._assert_manage_area_scope(UserRole=UserRole, UserArea=UserArea, DatasetArea=str(existing.get("area") or ""))
|
|
if bool(existing.get("is_default")):
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "默认知识库不允许删除,请先切换默认知识库")
|
|
async with GetAsyncSession() as session:
|
|
await session.execute(
|
|
text("UPDATE rag_dataset SET deleted_at = NOW(), updated_by = :updated_by, updated_at = NOW() WHERE id = :dataset_id"),
|
|
{"dataset_id": DatasetId, "updated_by": CurrentUserId},
|
|
)
|
|
await session.execute(
|
|
text("UPDATE rag_chat_app SET deleted_at = NOW(), updated_by = :updated_by, updated_at = NOW() WHERE dataset_id = :dataset_id AND deleted_at IS NULL"),
|
|
{"dataset_id": DatasetId, "updated_by": CurrentUserId},
|
|
)
|
|
return RagOperationResultVO(result="success")
|
|
|
|
async def GetMyDatasets(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagDatasetPageVO:
|
|
async with GetAsyncSession() as session:
|
|
rows = (
|
|
await session.execute(
|
|
text(
|
|
f"""
|
|
SELECT d.id, d.name, d.description, d.area, d.is_public, d.is_default, d.document_count, d.total_chunks, d.status,
|
|
d.sort_order, d.created_at, d.updated_at,
|
|
a.id AS app_id, COALESCE(a.name, '') AS app_name, COALESCE(a.is_default, FALSE) AS app_is_default
|
|
FROM rag_dataset d
|
|
{self._APP_LINK_SQL}
|
|
WHERE d.deleted_at IS NULL
|
|
AND d.status = 1
|
|
AND (
|
|
:is_provincial = TRUE
|
|
OR d.area IN (:user_area, '省级', '')
|
|
OR d.is_public = TRUE
|
|
)
|
|
ORDER BY d.sort_order ASC, d.created_at DESC
|
|
"""
|
|
),
|
|
{
|
|
"is_provincial": UserRole == "provincial_admin",
|
|
"user_area": UserArea or "",
|
|
},
|
|
)
|
|
).mappings().all()
|
|
return RagDatasetPageVO(
|
|
data=[
|
|
RagDatasetItemVO(
|
|
**self._to_item_vo(dict(row)).model_dump()
|
|
)
|
|
for row in rows
|
|
],
|
|
total=len(rows),
|
|
)
|
|
|
|
async def GetDatasetDetail(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int) -> RagDatasetDetailVO | None:
|
|
row = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
|
|
if not row:
|
|
return None
|
|
return self._to_detail_vo(row)
|
|
|
|
async def UpdateDataset(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
DatasetId: int,
|
|
Body: RagDatasetUpdateDTO,
|
|
) -> RagDatasetDetailVO | None:
|
|
row = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
|
|
if not row:
|
|
return None
|
|
|
|
if UserRole not in ("provincial_admin", "admin", "super_admin"):
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有修改知识库配置权限")
|
|
|
|
update_fields: list[str] = []
|
|
params: dict = {"dataset_id": DatasetId, "updated_by": CurrentUserId}
|
|
|
|
if Body.name is not None:
|
|
update_fields.append("name = :name")
|
|
params["name"] = Body.name.strip()
|
|
|
|
if Body.retrieval_model is not None:
|
|
update_fields.append("retrieval_model = CAST(:retrieval_model AS jsonb)")
|
|
params["retrieval_model"] = json.dumps(Body.retrieval_model, ensure_ascii=False)
|
|
|
|
if not update_fields:
|
|
return self._to_detail_vo(row)
|
|
|
|
update_fields.append("updated_by = :updated_by")
|
|
update_fields.append("updated_at = NOW()")
|
|
|
|
async with GetAsyncSession() as session:
|
|
await session.execute(
|
|
text(
|
|
f"""
|
|
UPDATE rag_dataset
|
|
SET {", ".join(update_fields)}
|
|
WHERE id = :dataset_id
|
|
"""
|
|
),
|
|
params,
|
|
)
|
|
|
|
refreshed = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
|
|
return self._to_detail_vo(refreshed) if refreshed else None
|
|
|
|
async def GetDatasetDocuments(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
DatasetId: int,
|
|
Page: int,
|
|
Limit: int,
|
|
Keyword: str | None,
|
|
) -> RagDatasetDocumentPageVO:
|
|
dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
|
|
if not dataset:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
|
|
|
|
where_sql = [
|
|
"dataset_id = :dataset_id",
|
|
"deleted_at IS NULL",
|
|
]
|
|
params: dict = {
|
|
"dataset_id": DatasetId,
|
|
"offset": max(Page - 1, 0) * Limit,
|
|
"limit": Limit + 1,
|
|
}
|
|
if Keyword and Keyword.strip():
|
|
where_sql.append("(original_name ILIKE :keyword OR filename ILIKE :keyword)")
|
|
params["keyword"] = f"%{Keyword.strip()}%"
|
|
|
|
async with GetAsyncSession() as session:
|
|
total = (
|
|
await session.execute(
|
|
text(
|
|
f"""
|
|
SELECT COUNT(1)
|
|
FROM rag_document
|
|
WHERE {" AND ".join(where_sql)}
|
|
"""
|
|
),
|
|
{key: value for key, value in params.items() if key not in ("offset", "limit")},
|
|
)
|
|
).scalar_one()
|
|
rows = (
|
|
await session.execute(
|
|
text(
|
|
f"""
|
|
SELECT id, dataset_id, original_name, file_type, file_size, chunk_count,
|
|
indexing_status, COALESCE(indexing_error, '') AS indexing_error,
|
|
enabled, hit_count, created_by, created_at, updated_at
|
|
FROM rag_document
|
|
WHERE {" AND ".join(where_sql)}
|
|
ORDER BY created_at DESC
|
|
OFFSET :offset LIMIT :limit
|
|
"""
|
|
),
|
|
params,
|
|
)
|
|
).mappings().all()
|
|
|
|
has_more = len(rows) > Limit
|
|
items = rows[:Limit]
|
|
return RagDatasetDocumentPageVO(
|
|
data=[
|
|
RagDatasetDocumentItemVO(
|
|
id=row["id"],
|
|
datasetId=row["dataset_id"],
|
|
name=row.get("original_name") or "",
|
|
fileType=row.get("file_type") or "",
|
|
fileSize=row.get("file_size") or 0,
|
|
chunkCount=row.get("chunk_count") or 0,
|
|
indexingStatus=row.get("indexing_status") or "waiting",
|
|
error=row.get("indexing_error") or "",
|
|
enabled=bool(row.get("enabled")),
|
|
hitCount=row.get("hit_count") or 0,
|
|
createdBy=row.get("created_by"),
|
|
createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0,
|
|
updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0,
|
|
)
|
|
for row in items
|
|
],
|
|
total=int(total or 0),
|
|
page=Page,
|
|
limit=Limit,
|
|
hasMore=has_more,
|
|
)
|
|
|
|
async def GetDatasetDocumentDetail(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
DatasetId: int,
|
|
DocumentId: int,
|
|
) -> RagDatasetDocumentItemVO | None:
|
|
dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
|
|
if not dataset:
|
|
return None
|
|
|
|
async with GetAsyncSession() as session:
|
|
row = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
SELECT id, dataset_id, original_name, file_type, file_size, chunk_count,
|
|
indexing_status, COALESCE(indexing_error, '') AS indexing_error,
|
|
enabled, hit_count, created_by, created_at, updated_at
|
|
FROM rag_document
|
|
WHERE id = :document_id
|
|
AND dataset_id = :dataset_id
|
|
AND deleted_at IS NULL
|
|
LIMIT 1
|
|
"""
|
|
),
|
|
{"document_id": DocumentId, "dataset_id": DatasetId},
|
|
)
|
|
).mappings().first()
|
|
|
|
if not row:
|
|
return None
|
|
|
|
return RagDatasetDocumentItemVO(
|
|
id=row["id"],
|
|
datasetId=row["dataset_id"],
|
|
name=row.get("original_name") or "",
|
|
fileType=row.get("file_type") or "",
|
|
fileSize=row.get("file_size") or 0,
|
|
chunkCount=row.get("chunk_count") or 0,
|
|
indexingStatus=row.get("indexing_status") or "waiting",
|
|
error=row.get("indexing_error") or "",
|
|
enabled=bool(row.get("enabled")),
|
|
hitCount=row.get("hit_count") or 0,
|
|
createdBy=row.get("created_by"),
|
|
createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0,
|
|
updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0,
|
|
)
|
|
|
|
async def _get_visible_dataset(self, user_area: str | None, user_role: str | None, dataset_id: int) -> dict | None:
|
|
async with GetAsyncSession() as session:
|
|
row = (
|
|
await session.execute(
|
|
text(
|
|
f"""
|
|
SELECT d.id, d.name, d.description, d.area, d.is_public, d.is_default, d.status,
|
|
d.document_count, d.total_chunks, d.chunk_max_size, d.chunk_min_size, d.sort_order,
|
|
d.retrieval_model, d.collection_name, d.embedding_model, d.created_at, d.updated_at,
|
|
a.id AS app_id, COALESCE(a.name, '') AS app_name, COALESCE(a.is_default, FALSE) AS app_is_default
|
|
FROM rag_dataset d
|
|
{self._APP_LINK_SQL}
|
|
WHERE d.id = :dataset_id
|
|
AND d.deleted_at IS NULL
|
|
AND d.status = 1
|
|
LIMIT 1
|
|
"""
|
|
),
|
|
{"dataset_id": dataset_id},
|
|
)
|
|
).mappings().first()
|
|
if not row:
|
|
return None
|
|
if user_role == "provincial_admin":
|
|
return dict(row)
|
|
area = row.get("area") or ""
|
|
if area in ("", "省级", user_area or "") or bool(row.get("is_public")):
|
|
return dict(row)
|
|
return None
|
|
|
|
def _to_detail_vo(self, row: dict) -> RagDatasetDetailVO:
|
|
return RagDatasetDetailVO(
|
|
id=row["id"],
|
|
name=row["name"],
|
|
description=row.get("description") or "",
|
|
area=row.get("area") or "",
|
|
isPublic=bool(row.get("is_public")),
|
|
isDefault=bool(row.get("is_default")),
|
|
status=row.get("status") or 1,
|
|
documentCount=row.get("document_count") or 0,
|
|
sortOrder=row.get("sort_order") or 0,
|
|
totalChunks=row.get("total_chunks") or 0,
|
|
chunkMaxSize=row.get("chunk_max_size") or 800,
|
|
chunkMinSize=row.get("chunk_min_size") or 20,
|
|
retrievalModel=row.get("retrieval_model") or {},
|
|
createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0,
|
|
updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0,
|
|
appId=row.get("app_id"),
|
|
appName=row.get("app_name") or "",
|
|
appIsDefault=bool(row.get("app_is_default")),
|
|
)
|
|
|
|
def _to_item_vo(self, row: dict) -> RagDatasetItemVO:
|
|
return RagDatasetItemVO(
|
|
id=row["id"],
|
|
name=row.get("name") or "",
|
|
description=row.get("description") or "",
|
|
area=row.get("area") or "",
|
|
isPublic=bool(row.get("is_public")),
|
|
isDefault=bool(row.get("is_default")),
|
|
documentCount=row.get("document_count") or 0,
|
|
totalChunks=row.get("total_chunks") or 0,
|
|
status=row.get("status") or 1,
|
|
sortOrder=row.get("sort_order") or 0,
|
|
createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0,
|
|
updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0,
|
|
appId=row.get("app_id"),
|
|
appName=row.get("app_name") or "",
|
|
appIsDefault=bool(row.get("app_is_default")),
|
|
)
|
|
|
|
async def _get_dataset_row(self, dataset_id: int) -> dict | None:
|
|
async with GetAsyncSession() as session:
|
|
row = (
|
|
await session.execute(
|
|
text(
|
|
f"""
|
|
SELECT d.id, d.name, d.description, d.area, d.is_public, d.is_default, d.status,
|
|
d.document_count, d.total_chunks, d.chunk_max_size, d.chunk_min_size, d.sort_order,
|
|
d.retrieval_model, d.collection_name, d.embedding_model, d.created_at, d.updated_at,
|
|
a.id AS app_id, COALESCE(a.name, '') AS app_name, COALESCE(a.is_default, FALSE) AS app_is_default
|
|
FROM rag_dataset d
|
|
{self._APP_LINK_SQL}
|
|
WHERE d.id = :dataset_id AND d.deleted_at IS NULL
|
|
LIMIT 1
|
|
"""
|
|
),
|
|
{"dataset_id": dataset_id},
|
|
)
|
|
).mappings().first()
|
|
return dict(row) if row else None
|
|
|
|
async def _clear_default_flags(self, session) -> None:
|
|
await session.execute(text("UPDATE rag_dataset SET is_default = FALSE WHERE deleted_at IS NULL"))
|
|
await session.execute(text("UPDATE rag_chat_app SET is_default = FALSE WHERE deleted_at IS NULL"))
|
|
|
|
async def _ensure_linked_app(
|
|
self,
|
|
session,
|
|
dataset_id: int,
|
|
dataset_name: str,
|
|
dataset_area: str,
|
|
current_user_id: int,
|
|
is_default: bool,
|
|
) -> None:
|
|
app_row = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
SELECT id
|
|
FROM rag_chat_app
|
|
WHERE dataset_id = :dataset_id
|
|
AND deleted_at IS NULL
|
|
ORDER BY is_default DESC, sort_order ASC, id ASC
|
|
LIMIT 1
|
|
"""
|
|
),
|
|
{"dataset_id": dataset_id},
|
|
)
|
|
).mappings().first()
|
|
|
|
app_name = self._build_app_name(dataset_area=dataset_area, dataset_name=dataset_name)
|
|
if app_row:
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
UPDATE rag_chat_app
|
|
SET name = :name,
|
|
is_default = :is_default,
|
|
status = 1,
|
|
updated_by = :updated_by,
|
|
updated_at = NOW()
|
|
WHERE id = :app_id
|
|
"""
|
|
),
|
|
{
|
|
"app_id": int(app_row["id"]),
|
|
"name": app_name,
|
|
"is_default": is_default,
|
|
"updated_by": current_user_id,
|
|
},
|
|
)
|
|
return
|
|
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO rag_chat_app (
|
|
name, description, area, dataset_id, suggested_questions,
|
|
opening_statement, sort_order, status, is_default, created_by, updated_by
|
|
) VALUES (
|
|
:name, :description, :area, :dataset_id, CAST(:suggested_questions AS jsonb),
|
|
:opening_statement, 0, 1, :is_default, :created_by, :updated_by
|
|
)
|
|
"""
|
|
),
|
|
{
|
|
"name": app_name,
|
|
"description": f"{dataset_area or '默认地区'}知识库问答助手",
|
|
"area": dataset_area or "",
|
|
"dataset_id": dataset_id,
|
|
"suggested_questions": json.dumps([], ensure_ascii=False),
|
|
"opening_statement": f"您好,我是{app_name}。",
|
|
"is_default": is_default,
|
|
"created_by": current_user_id,
|
|
"updated_by": current_user_id,
|
|
},
|
|
)
|
|
|
|
def _build_app_name(self, dataset_area: str, dataset_name: str) -> str:
|
|
cleaned = (dataset_name or "").strip()
|
|
if cleaned.endswith("知识库"):
|
|
cleaned = cleaned[:-3]
|
|
if cleaned.endswith("助手"):
|
|
return cleaned
|
|
if cleaned:
|
|
return f"{cleaned}助手"
|
|
return f"{dataset_area or '默认地区'}法务助手"
|
|
|
|
def _slugify_collection_name(self, area: str, name: str) -> str:
|
|
source = f"{area}_{name}".lower()
|
|
normalized = re.sub(r"[^a-z0-9]+", "_", source).strip("_")
|
|
if normalized:
|
|
return f"legal_kb_{normalized}"[:96]
|
|
return f"legal_kb_{uuid.uuid4().hex[:12]}"
|
|
|
|
def _resolve_managed_area(self, UserRole: str | None, UserArea: str | None) -> str | None:
|
|
if UserRole == "admin":
|
|
area = str(UserArea or "").strip()
|
|
if not area:
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前市级管理员未配置地区,无法管理知识库")
|
|
return area
|
|
return None
|
|
|
|
def _assert_manage_area_scope(self, UserRole: str | None, UserArea: str | None, DatasetArea: str) -> None:
|
|
if UserRole in ("provincial_admin", "super_admin"):
|
|
return
|
|
if UserRole != "admin":
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有管理知识库权限")
|
|
|
|
managed_area = self._resolve_managed_area(UserRole=UserRole, UserArea=UserArea)
|
|
if DatasetArea != managed_area:
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户只能管理本地区知识库")
|
|
|
|
async def UploadDatasetDocument(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
DatasetId: int,
|
|
FileName: str,
|
|
ContentType: str | None,
|
|
Content: bytes,
|
|
ProcessConfig: dict | None,
|
|
) -> RagDatasetUploadDocumentVO:
|
|
dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
|
|
if not dataset:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
|
|
|
|
suffix = os.path.splitext(FileName)[1].lower()
|
|
if suffix not in {".pdf", ".docx", ".txt", ".md", ".json"}:
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD、JSON")
|
|
|
|
object_key = f"rag/{DatasetId}/{datetime.now().strftime('%Y/%m/%d')}/{uuid.uuid4().hex}_{FileName}"
|
|
content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream"
|
|
OssClient().EnsureBucket()
|
|
stored_key = OssClient().UploadBytes(ObjectKey=object_key, Content=Content, ContentType=content_type)
|
|
|
|
async with GetAsyncSession() as session:
|
|
inserted = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO rag_document (
|
|
dataset_id, filename, original_name, minio_path, file_type, file_size,
|
|
chunk_count, indexing_status, enabled, hit_count, created_by
|
|
) VALUES (
|
|
:dataset_id, :filename, :original_name, :minio_path, :file_type, :file_size,
|
|
0, 'indexing', TRUE, 0, :created_by
|
|
)
|
|
RETURNING id, created_at, updated_at
|
|
"""
|
|
),
|
|
{
|
|
"dataset_id": DatasetId,
|
|
"filename": uuid.uuid4().hex + suffix,
|
|
"original_name": FileName,
|
|
"minio_path": stored_key,
|
|
"file_type": suffix.lstrip("."),
|
|
"file_size": len(Content),
|
|
"created_by": CurrentUserId,
|
|
},
|
|
)
|
|
).mappings().first()
|
|
document_id = int(inserted["id"])
|
|
|
|
asyncio.create_task(
|
|
self._run_document_indexing_task(
|
|
dataset=dataset,
|
|
dataset_id=DatasetId,
|
|
document_id=document_id,
|
|
file_name=FileName,
|
|
content=Content,
|
|
process_config=ProcessConfig or {},
|
|
)
|
|
)
|
|
|
|
return RagDatasetUploadDocumentVO(
|
|
document={
|
|
"id": str(document_id),
|
|
"name": FileName,
|
|
"indexing_status": "indexing",
|
|
"word_count": 0,
|
|
"hit_count": 0,
|
|
"enabled": True,
|
|
},
|
|
batch=str(document_id),
|
|
)
|
|
|
|
async def GetDatasetDocumentSegments(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
DatasetId: int,
|
|
DocumentId: int,
|
|
Page: int,
|
|
Limit: int,
|
|
Keyword: str | None,
|
|
) -> RagDatasetSegmentPageVO:
|
|
dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
|
|
if not dataset:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
|
|
|
|
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
|
raw = collection.get(where={"document_id": DocumentId}, include=["documents", "metadatas"])
|
|
ids = raw.get("ids") or []
|
|
docs = raw.get("documents") or []
|
|
metas = raw.get("metadatas") or []
|
|
|
|
items: list[dict] = []
|
|
for index, segment_id in enumerate(ids):
|
|
content = docs[index] if index < len(docs) else ""
|
|
meta = metas[index] if index < len(metas) and isinstance(metas[index], dict) else {}
|
|
if Keyword and Keyword.strip() and Keyword.strip() not in content:
|
|
continue
|
|
items.append(
|
|
{
|
|
"id": str(segment_id),
|
|
"position": index + 1,
|
|
"document_id": str(DocumentId),
|
|
"content": content or "",
|
|
"word_count": len(content or ""),
|
|
"hit_count": 0,
|
|
"enabled": True,
|
|
"status": "completed",
|
|
"created_at": 0,
|
|
}
|
|
)
|
|
|
|
offset = max(Page - 1, 0) * Limit
|
|
page_items = items[offset: offset + Limit]
|
|
has_more = offset + Limit < len(items)
|
|
return RagDatasetSegmentPageVO(
|
|
data=[
|
|
RagDatasetSegmentItemVO(
|
|
id=item["id"],
|
|
position=item["position"],
|
|
documentId=item["document_id"],
|
|
content=item["content"],
|
|
wordCount=item["word_count"],
|
|
hitCount=item["hit_count"],
|
|
enabled=item["enabled"],
|
|
status=item["status"],
|
|
createdAt=item["created_at"],
|
|
)
|
|
for item in page_items
|
|
],
|
|
total=len(items),
|
|
limit=Limit,
|
|
hasMore=has_more,
|
|
)
|
|
|
|
async def GetDatasetDocumentSegmentDetail(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
DatasetId: int,
|
|
DocumentId: int,
|
|
SegmentId: str,
|
|
) -> RagDatasetSegmentItemVO | None:
|
|
dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
|
|
if not dataset:
|
|
return None
|
|
|
|
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
|
raw = collection.get(ids=[SegmentId], include=["documents", "metadatas"])
|
|
ids = raw.get("ids") or []
|
|
if not ids:
|
|
return None
|
|
|
|
content = (raw.get("documents") or [""])[0] or ""
|
|
metadata = (raw.get("metadatas") or [{}])[0] or {}
|
|
if str(metadata.get("document_id") or "") != str(DocumentId):
|
|
return None
|
|
|
|
return RagDatasetSegmentItemVO(
|
|
id=str(SegmentId),
|
|
position=int(metadata.get("chunk_index") or 0) + 1,
|
|
documentId=str(DocumentId),
|
|
content=content,
|
|
wordCount=len(content),
|
|
hitCount=0,
|
|
enabled=True,
|
|
status="completed",
|
|
createdAt=0,
|
|
)
|
|
|
|
async def UpdateDatasetDocumentSegment(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
DatasetId: int,
|
|
DocumentId: int,
|
|
SegmentId: str,
|
|
Body: dict,
|
|
) -> RagDatasetSegmentItemVO | None:
|
|
dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
|
|
if not dataset:
|
|
return None
|
|
|
|
if UserRole not in ("provincial_admin", "admin", "super_admin"):
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有修改知识库分段权限")
|
|
|
|
current = await self.GetDatasetDocumentSegmentDetail(
|
|
CurrentUserId=CurrentUserId,
|
|
UserArea=UserArea,
|
|
UserRole=UserRole,
|
|
DatasetId=DatasetId,
|
|
DocumentId=DocumentId,
|
|
SegmentId=SegmentId,
|
|
)
|
|
if not current:
|
|
return None
|
|
|
|
segment_body = Body.get("segment") if isinstance(Body.get("segment"), dict) else Body
|
|
content = str(segment_body.get("content") or current.content).strip()
|
|
if not content:
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "分段内容不能为空")
|
|
|
|
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
|
raw = collection.get(ids=[SegmentId], include=["metadatas"])
|
|
metadata = (raw.get("metadatas") or [{}])[0] or {}
|
|
embeddings = await self._embed_texts([content], dataset.get("embedding_model") or "")
|
|
collection.update(
|
|
ids=[SegmentId],
|
|
documents=[content],
|
|
embeddings=embeddings,
|
|
metadatas=[metadata],
|
|
)
|
|
|
|
return RagDatasetSegmentItemVO(
|
|
id=str(SegmentId),
|
|
position=int(metadata.get("chunk_index") or 0) + 1,
|
|
documentId=str(DocumentId),
|
|
content=content,
|
|
wordCount=len(content),
|
|
hitCount=0,
|
|
enabled=True,
|
|
status="completed",
|
|
createdAt=0,
|
|
)
|
|
|
|
async def DeleteDatasetDocumentSegment(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
DatasetId: int,
|
|
DocumentId: int,
|
|
SegmentId: str,
|
|
) -> RagDatasetBatchDeleteResultVO:
|
|
dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
|
|
if not dataset:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
|
|
|
|
if UserRole not in ("provincial_admin", "admin", "super_admin"):
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有删除知识库分段权限")
|
|
|
|
current = await self.GetDatasetDocumentSegmentDetail(
|
|
CurrentUserId=CurrentUserId,
|
|
UserArea=UserArea,
|
|
UserRole=UserRole,
|
|
DatasetId=DatasetId,
|
|
DocumentId=DocumentId,
|
|
SegmentId=SegmentId,
|
|
)
|
|
if not current:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "分段不存在")
|
|
|
|
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
|
collection.delete(ids=[SegmentId])
|
|
|
|
async with GetAsyncSession() as session:
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
UPDATE rag_document
|
|
SET chunk_count = GREATEST(chunk_count - 1, 0),
|
|
updated_at = NOW()
|
|
WHERE id = :document_id
|
|
"""
|
|
),
|
|
{"document_id": DocumentId},
|
|
)
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
UPDATE rag_dataset
|
|
SET total_chunks = GREATEST(total_chunks - 1, 0),
|
|
updated_at = NOW()
|
|
WHERE id = :dataset_id
|
|
"""
|
|
),
|
|
{"dataset_id": DatasetId},
|
|
)
|
|
|
|
return RagOperationResultVO(result="success")
|
|
|
|
async def DeleteDatasetDocument(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
DatasetId: int,
|
|
DocumentId: int,
|
|
) -> RagOperationResultVO:
|
|
result = await self.BatchDeleteDatasetDocuments(
|
|
CurrentUserId=CurrentUserId,
|
|
UserArea=UserArea,
|
|
UserRole=UserRole,
|
|
DatasetId=DatasetId,
|
|
DocumentIds=[DocumentId],
|
|
)
|
|
if result.deletedCount <= 0:
|
|
reason = result.skipped[0].reason if result.skipped else "文档删除失败"
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, reason)
|
|
return RagOperationResultVO(result="success")
|
|
|
|
async def BatchDeleteDatasetDocuments(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
DatasetId: int,
|
|
DocumentIds: list[int],
|
|
) -> RagOperationResultVO:
|
|
dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
|
|
if not dataset:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
|
|
|
|
if UserRole not in ("provincial_admin", "admin", "super_admin"):
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有删除知识库文档权限")
|
|
|
|
normalized_ids = [int(item) for item in DocumentIds if item]
|
|
if not normalized_ids:
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "缺少待删除文档")
|
|
|
|
async with GetAsyncSession() as session:
|
|
document_rows = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
SELECT id, dataset_id, minio_path, original_name, indexing_status
|
|
FROM rag_document
|
|
WHERE id = ANY(:document_ids)
|
|
AND dataset_id = :dataset_id
|
|
AND deleted_at IS NULL
|
|
"""
|
|
),
|
|
{"document_ids": normalized_ids, "dataset_id": DatasetId},
|
|
)
|
|
).mappings().all()
|
|
|
|
if not document_rows:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在")
|
|
|
|
deletable_rows = []
|
|
skipped = []
|
|
for row in document_rows:
|
|
if row.get("indexing_status") not in self._DELETABLE_DOCUMENT_STATUSES:
|
|
skipped.append(
|
|
{
|
|
"id": int(row["id"]),
|
|
"name": str(row.get("original_name") or ""),
|
|
"reason": "文档仍在处理中,暂不允许删除",
|
|
}
|
|
)
|
|
continue
|
|
deletable_rows.append(row)
|
|
|
|
if not deletable_rows:
|
|
return RagDatasetBatchDeleteResultVO(
|
|
result="success",
|
|
requestedCount=len(normalized_ids),
|
|
deletedCount=0,
|
|
skippedCount=len(skipped),
|
|
deletedIds=[],
|
|
skipped=skipped,
|
|
)
|
|
|
|
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
|
chunk_ids: list[str] = []
|
|
for document_row in deletable_rows:
|
|
raw = collection.get(where={"document_id": int(document_row["id"])}, include=[])
|
|
ids = raw.get("ids") or []
|
|
if ids:
|
|
chunk_ids.extend(ids)
|
|
if chunk_ids:
|
|
collection.delete(ids=chunk_ids)
|
|
|
|
for document_row in deletable_rows:
|
|
self._delete_oss_object(document_row.get("minio_path"))
|
|
|
|
existing_ids = [int(row["id"]) for row in deletable_rows]
|
|
|
|
async with GetAsyncSession() as session:
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
UPDATE rag_document
|
|
SET deleted_at = NOW(),
|
|
updated_at = NOW()
|
|
WHERE id = ANY(:document_ids)
|
|
"""
|
|
),
|
|
{"document_ids": existing_ids},
|
|
)
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
UPDATE rag_dataset
|
|
SET document_count = (
|
|
SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
|
|
),
|
|
total_chunks = COALESCE((
|
|
SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
|
|
), 0),
|
|
updated_at = NOW()
|
|
WHERE id = :dataset_id
|
|
"""
|
|
),
|
|
{"dataset_id": DatasetId},
|
|
)
|
|
|
|
return RagDatasetBatchDeleteResultVO(
|
|
result="success",
|
|
requestedCount=len(normalized_ids),
|
|
deletedCount=len(existing_ids),
|
|
skippedCount=len(skipped),
|
|
deletedIds=existing_ids,
|
|
skipped=skipped,
|
|
)
|
|
|
|
async def RetrieveDataset(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
DatasetId: int,
|
|
Query: str,
|
|
RetrievalModel: dict | None,
|
|
) -> RagDatasetRetrieveResponseVO:
|
|
dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
|
|
if not dataset:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
|
|
|
|
query_text = (Query or "").strip()
|
|
if not query_text:
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "检索内容不能为空")
|
|
|
|
retrieval_model = RetrievalModel or {}
|
|
top_k = int(retrieval_model.get("top_k") or 5)
|
|
top_k = max(1, min(top_k, 20))
|
|
score_threshold_enabled = bool(retrieval_model.get("score_threshold_enabled"))
|
|
score_threshold = float(retrieval_model.get("score_threshold") or 0) if score_threshold_enabled else None
|
|
|
|
query_embedding = await self._embed_texts([query_text], dataset.get("embedding_model") or "")
|
|
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
|
raw = collection.query(
|
|
query_embeddings=query_embedding,
|
|
n_results=top_k,
|
|
include=["documents", "metadatas", "distances"],
|
|
)
|
|
|
|
ids = (raw.get("ids") or [[]])[0] if raw.get("ids") else []
|
|
documents = (raw.get("documents") or [[]])[0] if raw.get("documents") else []
|
|
metadatas = (raw.get("metadatas") or [[]])[0] if raw.get("metadatas") else []
|
|
distances = (raw.get("distances") or [[]])[0] if raw.get("distances") else []
|
|
|
|
records: list[RagDatasetRetrieveRecordVO] = []
|
|
for index, segment_id in enumerate(ids):
|
|
content = documents[index] if index < len(documents) else ""
|
|
metadata = metadatas[index] if index < len(metadatas) and isinstance(metadatas[index], dict) else {}
|
|
distance = float(distances[index]) if index < len(distances) and distances[index] is not None else 1.0
|
|
score = max(0.0, min(1.0, 1.0 - distance))
|
|
if score_threshold_enabled and score_threshold is not None and score < score_threshold:
|
|
continue
|
|
|
|
document_name = metadata.get("document_name") or metadata.get("source") or ""
|
|
document_id = str(metadata.get("document_id") or "")
|
|
chunk_index = int(metadata.get("chunk_index") or index)
|
|
records.append(
|
|
RagDatasetRetrieveRecordVO(
|
|
score=round(score, 6),
|
|
segment=RagDatasetRetrieveSegmentVO(
|
|
id=str(segment_id),
|
|
position=chunk_index + 1,
|
|
documentId=document_id,
|
|
content=content or "",
|
|
answer="",
|
|
wordCount=len(content or ""),
|
|
hitCount=0,
|
|
enabled=True,
|
|
status="completed",
|
|
createdAt=0,
|
|
document=RagDatasetRetrieveDocumentVO(
|
|
id=document_id,
|
|
dataSourceType="upload_file",
|
|
name=document_name,
|
|
docType=None,
|
|
),
|
|
),
|
|
)
|
|
)
|
|
|
|
return RagDatasetRetrieveResponseVO(
|
|
query={"content": query_text},
|
|
records=records,
|
|
)
|
|
|
|
async def GetDatasetDocumentIndexingStatus(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
DatasetId: int,
|
|
DocumentId: int,
|
|
) -> dict:
|
|
document = await self.GetDatasetDocumentDetail(
|
|
CurrentUserId=CurrentUserId,
|
|
UserArea=UserArea,
|
|
UserRole=UserRole,
|
|
DatasetId=DatasetId,
|
|
DocumentId=DocumentId,
|
|
)
|
|
if not document:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在")
|
|
|
|
completed_segments = document.chunkCount if document.indexingStatus == "completed" else 0
|
|
total_segments = document.chunkCount if document.chunkCount > 0 else 0
|
|
return {
|
|
"data": [
|
|
{
|
|
"id": str(document.id),
|
|
"indexing_status": document.indexingStatus,
|
|
"processing_started_at": document.updatedAt or document.createdAt or None,
|
|
"parsing_completed_at": document.updatedAt if document.indexingStatus in ("cleaning", "splitting", "indexing", "completed") else None,
|
|
"cleaning_completed_at": document.updatedAt if document.indexingStatus in ("splitting", "indexing", "completed") else None,
|
|
"splitting_completed_at": document.updatedAt if document.indexingStatus in ("indexing", "completed") else None,
|
|
"completed_at": document.updatedAt if document.indexingStatus == "completed" else None,
|
|
"paused_at": None,
|
|
"error": document.error or None,
|
|
"stopped_at": None,
|
|
"completed_segments": completed_segments,
|
|
"total_segments": total_segments,
|
|
}
|
|
]
|
|
}
|
|
|
|
async def UpdateDatasetDocumentByFile(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
DatasetId: int,
|
|
DocumentId: int,
|
|
FileName: str,
|
|
ContentType: str | None,
|
|
Content: bytes,
|
|
ProcessConfig: dict | None,
|
|
) -> RagDatasetUploadDocumentVO:
|
|
dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
|
|
if not dataset:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
|
|
|
|
if UserRole not in ("provincial_admin", "admin", "super_admin"):
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有重处理知识库文档权限")
|
|
|
|
async with GetAsyncSession() as session:
|
|
current = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
SELECT id, minio_path
|
|
FROM rag_document
|
|
WHERE id = :document_id
|
|
AND dataset_id = :dataset_id
|
|
AND deleted_at IS NULL
|
|
LIMIT 1
|
|
"""
|
|
),
|
|
{"document_id": DocumentId, "dataset_id": DatasetId},
|
|
)
|
|
).mappings().first()
|
|
if not current:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在")
|
|
|
|
suffix = os.path.splitext(FileName)[1].lower()
|
|
if suffix not in {".pdf", ".docx", ".txt", ".md", ".json"}:
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD、JSON")
|
|
|
|
content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream"
|
|
object_key = current.get("minio_path") or f"rag/{DatasetId}/{datetime.now().strftime('%Y/%m/%d')}/{uuid.uuid4().hex}_{FileName}"
|
|
OssClient().EnsureBucket()
|
|
stored_key = OssClient().UploadBytes(ObjectKey=object_key, Content=Content, ContentType=content_type)
|
|
|
|
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
|
raw = collection.get(where={"document_id": DocumentId}, include=[])
|
|
ids = raw.get("ids") or []
|
|
if ids:
|
|
collection.delete(ids=ids)
|
|
|
|
async with GetAsyncSession() as session:
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
UPDATE rag_document
|
|
SET original_name = :original_name,
|
|
minio_path = :minio_path,
|
|
file_type = :file_type,
|
|
file_size = :file_size,
|
|
chunk_count = 0,
|
|
indexing_status = 'indexing',
|
|
indexing_error = NULL,
|
|
updated_at = NOW()
|
|
WHERE id = :document_id
|
|
"""
|
|
),
|
|
{
|
|
"document_id": DocumentId,
|
|
"original_name": FileName,
|
|
"minio_path": stored_key,
|
|
"file_type": suffix.lstrip("."),
|
|
"file_size": len(Content),
|
|
},
|
|
)
|
|
asyncio.create_task(
|
|
self._run_document_indexing_task(
|
|
dataset=dataset,
|
|
dataset_id=DatasetId,
|
|
document_id=DocumentId,
|
|
file_name=FileName,
|
|
content=Content,
|
|
process_config=ProcessConfig or {},
|
|
)
|
|
)
|
|
|
|
return RagDatasetUploadDocumentVO(
|
|
document={
|
|
"id": str(DocumentId),
|
|
"name": FileName,
|
|
"indexing_status": "indexing",
|
|
"word_count": 0,
|
|
"hit_count": 0,
|
|
"enabled": True,
|
|
},
|
|
batch=str(DocumentId),
|
|
)
|
|
|
|
async def BatchUpdateDatasetDocumentStatus(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
DatasetId: int,
|
|
DocumentIds: list[int],
|
|
Enabled: bool,
|
|
) -> RagOperationResultVO:
|
|
dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
|
|
if not dataset:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
|
|
|
|
if UserRole not in ("provincial_admin", "admin", "super_admin"):
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有修改知识库文档状态权限")
|
|
|
|
ids = [int(doc_id) for doc_id in DocumentIds]
|
|
if not ids:
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "未传入文档ID")
|
|
|
|
async with GetAsyncSession() as session:
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
UPDATE rag_document
|
|
SET enabled = :enabled,
|
|
updated_at = NOW()
|
|
WHERE dataset_id = :dataset_id
|
|
AND id = ANY(:document_ids)
|
|
AND deleted_at IS NULL
|
|
"""
|
|
),
|
|
{"dataset_id": DatasetId, "document_ids": ids, "enabled": Enabled},
|
|
)
|
|
return RagOperationResultVO(result="success")
|
|
|
|
def _extract_page_texts(self, *, FileName: str, Content: bytes) -> list[tuple[int, str]]:
|
|
suffix = os.path.splitext(FileName)[1].lower()
|
|
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp_file:
|
|
temp_file.write(Content)
|
|
temp_path = temp_file.name
|
|
try:
|
|
from fastapi_modules.fastapi_leaudit.services.impl.documentServiceImpl import (
|
|
_extract_page_texts_from_docx,
|
|
_extract_page_texts_from_pdf,
|
|
)
|
|
|
|
if suffix == ".pdf":
|
|
return _extract_page_texts_from_pdf(Path(temp_path))
|
|
if suffix == ".docx":
|
|
return _extract_page_texts_from_docx(Path(temp_path))
|
|
if suffix == ".json":
|
|
return self._extract_page_texts_from_json(Content)
|
|
|
|
text_value = Content.decode("utf-8", errors="ignore").strip()
|
|
return [(1, text_value)] if text_value else []
|
|
finally:
|
|
try:
|
|
os.unlink(temp_path)
|
|
except OSError:
|
|
pass
|
|
|
|
def _extract_page_texts_from_json(self, content: bytes) -> list[tuple[int, str]]:
|
|
text_value = content.decode("utf-8", errors="ignore").strip()
|
|
if not text_value:
|
|
return []
|
|
|
|
try:
|
|
payload = json.loads(text_value)
|
|
except json.JSONDecodeError:
|
|
return [(1, text_value)]
|
|
|
|
flattened = self._json_to_text(payload)
|
|
flattened = flattened.strip()
|
|
return [(1, flattened)] if flattened else []
|
|
|
|
def _json_to_text(self, value: object, prefix: str = "") -> str:
|
|
if value is None:
|
|
return ""
|
|
if isinstance(value, str):
|
|
return f"{prefix}: {value}" if prefix else value
|
|
if isinstance(value, (int, float, bool)):
|
|
return f"{prefix}: {value}" if prefix else str(value)
|
|
if isinstance(value, list):
|
|
parts: list[str] = []
|
|
for index, item in enumerate(value, start=1):
|
|
item_prefix = f"{prefix}[{index}]" if prefix else f"item_{index}"
|
|
chunk = self._json_to_text(item, item_prefix)
|
|
if chunk:
|
|
parts.append(chunk)
|
|
return "\n".join(parts)
|
|
if isinstance(value, dict):
|
|
parts = []
|
|
for key, item in value.items():
|
|
item_prefix = f"{prefix}.{key}" if prefix else str(key)
|
|
chunk = self._json_to_text(item, item_prefix)
|
|
if chunk:
|
|
parts.append(chunk)
|
|
return "\n".join(parts)
|
|
rendered = str(value).strip()
|
|
if not rendered:
|
|
return ""
|
|
return f"{prefix}: {rendered}" if prefix else rendered
|
|
|
|
def _build_chunks(self, *, file_name: str, page_texts: list[tuple[int, str]], dataset: dict, process_config: dict, document_id: int) -> list[dict]:
|
|
rules = ((process_config or {}).get("process_rule") or {}).get("rules") or {}
|
|
segmentation = rules.get("segmentation") or {}
|
|
pre_rules = rules.get("pre_processing_rules") or []
|
|
remove_spaces = any(rule.get("id") == "remove_extra_spaces" and rule.get("enabled") for rule in pre_rules)
|
|
remove_urls = any(rule.get("id") == "remove_urls_emails" and rule.get("enabled") for rule in pre_rules)
|
|
|
|
separator = segmentation.get("separator") or "\n\n"
|
|
max_tokens = int(segmentation.get("max_tokens") or dataset.get("chunk_max_size") or 800)
|
|
chunk_overlap = int(segmentation.get("chunk_overlap") or 50)
|
|
chunk_overlap = max(0, min(chunk_overlap, max_tokens // 2 if max_tokens > 1 else 0))
|
|
|
|
chunks: list[dict] = []
|
|
for page_no, raw_text in page_texts:
|
|
text_value = self._preprocess_text(raw_text, remove_spaces=remove_spaces, remove_urls=remove_urls)
|
|
if not text_value:
|
|
continue
|
|
for index, chunk_text in enumerate(self._split_text(text_value, separator=separator, max_chars=max_tokens, overlap=chunk_overlap)):
|
|
if not chunk_text.strip():
|
|
continue
|
|
chunk_id = f"{document_id}:{page_no}:{index}"
|
|
chunks.append(
|
|
{
|
|
"id": chunk_id,
|
|
"text": chunk_text,
|
|
"metadata": {
|
|
"id": chunk_id,
|
|
"source": file_name,
|
|
"document_name": file_name,
|
|
"document_id": document_id,
|
|
"page": page_no,
|
|
"chunk_index": index,
|
|
},
|
|
}
|
|
)
|
|
return chunks
|
|
|
|
async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]:
|
|
embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}/embeddings"
|
|
embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"]
|
|
embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4"
|
|
batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10))
|
|
if not embed_url or not embed_key:
|
|
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务")
|
|
|
|
embeddings: list[list[float]] = []
|
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
|
for start in range(0, len(texts), batch_size):
|
|
batch_texts = texts[start:start + batch_size]
|
|
try:
|
|
response = await client.post(
|
|
embed_url,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
"Authorization": f"Bearer {embed_key}",
|
|
},
|
|
json={"model": embed_model, "input": batch_texts},
|
|
)
|
|
response.raise_for_status()
|
|
except httpx.HTTPStatusError as exc:
|
|
error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}"
|
|
raise LeauditException(
|
|
StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR,
|
|
f"向量化服务调用失败: {error_message[:300]}",
|
|
) from exc
|
|
|
|
payload = response.json()
|
|
rows = payload.get("data") or []
|
|
batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")]
|
|
if len(batch_embeddings) != len(batch_texts):
|
|
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
|
|
embeddings.extend(batch_embeddings)
|
|
|
|
if len(embeddings) != len(texts):
|
|
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
|
|
return embeddings
|
|
|
|
async def _run_document_indexing_task(
|
|
self,
|
|
*,
|
|
dataset: dict,
|
|
dataset_id: int,
|
|
document_id: int,
|
|
file_name: str,
|
|
content: bytes,
|
|
process_config: dict,
|
|
) -> None:
|
|
try:
|
|
await self._update_document_processing_state(document_id=document_id, status="parsing")
|
|
page_texts = self._extract_page_texts(FileName=file_name, Content=content)
|
|
|
|
await self._update_document_processing_state(document_id=document_id, status="cleaning")
|
|
processed = self._build_chunks(
|
|
file_name=file_name,
|
|
page_texts=page_texts,
|
|
dataset=dataset,
|
|
process_config=process_config,
|
|
document_id=document_id,
|
|
)
|
|
if not processed:
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文档未提取到可入库文本")
|
|
|
|
await self._update_document_processing_state(
|
|
document_id=document_id,
|
|
status="splitting",
|
|
chunk_count=len(processed),
|
|
)
|
|
embeddings = await self._embed_texts([item["text"] for item in processed], dataset.get("embedding_model") or "")
|
|
|
|
await self._update_document_processing_state(
|
|
document_id=document_id,
|
|
status="indexing",
|
|
chunk_count=len(processed),
|
|
)
|
|
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
|
collection.add(
|
|
ids=[item["id"] for item in processed],
|
|
documents=[item["text"] for item in processed],
|
|
embeddings=embeddings,
|
|
metadatas=[item["metadata"] for item in processed],
|
|
)
|
|
|
|
async with GetAsyncSession() as session:
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
UPDATE rag_document
|
|
SET chunk_count = :chunk_count,
|
|
indexing_status = 'completed',
|
|
indexing_error = NULL,
|
|
indexing_started_at = COALESCE(indexing_started_at, NOW()),
|
|
indexing_completed_at = NOW(),
|
|
updated_at = NOW()
|
|
WHERE id = :document_id
|
|
"""
|
|
),
|
|
{"chunk_count": len(processed), "document_id": document_id},
|
|
)
|
|
await self._sync_dataset_counts(dataset_id)
|
|
except Exception as exc:
|
|
async with GetAsyncSession() as session:
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
UPDATE rag_document
|
|
SET indexing_status = 'error',
|
|
indexing_error = :error,
|
|
updated_at = NOW()
|
|
WHERE id = :document_id
|
|
"""
|
|
),
|
|
{"document_id": document_id, "error": str(exc)[:2000]},
|
|
)
|
|
|
|
async def _update_document_processing_state(self, *, document_id: int, status: str, chunk_count: int | None = None) -> None:
|
|
fields = [
|
|
"indexing_status = :status",
|
|
"indexing_error = NULL",
|
|
"indexing_started_at = COALESCE(indexing_started_at, NOW())",
|
|
"updated_at = NOW()",
|
|
]
|
|
params: dict[str, object] = {
|
|
"document_id": document_id,
|
|
"status": status,
|
|
}
|
|
if chunk_count is not None:
|
|
fields.append("chunk_count = :chunk_count")
|
|
params["chunk_count"] = chunk_count
|
|
|
|
async with GetAsyncSession() as session:
|
|
await session.execute(
|
|
text(
|
|
f"""
|
|
UPDATE rag_document
|
|
SET {", ".join(fields)}
|
|
WHERE id = :document_id
|
|
"""
|
|
),
|
|
params,
|
|
)
|
|
|
|
async def _sync_dataset_counts(self, dataset_id: int) -> None:
|
|
async with GetAsyncSession() as session:
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
UPDATE rag_dataset
|
|
SET document_count = (
|
|
SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
|
|
),
|
|
total_chunks = COALESCE((
|
|
SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
|
|
), 0),
|
|
updated_at = NOW()
|
|
WHERE id = :dataset_id
|
|
"""
|
|
),
|
|
{"dataset_id": dataset_id},
|
|
)
|
|
|
|
def _preprocess_text(self, text_value: str, *, remove_spaces: bool, remove_urls: bool) -> str:
|
|
result = text_value or ""
|
|
if remove_urls:
|
|
result = re.sub(r"https?://\\S+|www\\.\\S+|[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}", " ", result)
|
|
if remove_spaces:
|
|
result = re.sub(r"[ \\t]+", " ", result)
|
|
result = re.sub(r"\n{3,}", "\n\n", result)
|
|
return result.strip()
|
|
|
|
def _split_text(self, text_value: str, *, separator: str, max_chars: int, overlap: int) -> list[str]:
|
|
parts = [part.strip() for part in text_value.split(separator) if part.strip()] if separator else [text_value]
|
|
chunks: list[str] = []
|
|
current = ""
|
|
for part in parts:
|
|
candidate = f"{current}{separator if current else ''}{part}" if separator else f"{current}{part}"
|
|
if len(candidate) <= max_chars:
|
|
current = candidate
|
|
continue
|
|
if current:
|
|
chunks.append(current)
|
|
if len(part) <= max_chars:
|
|
current = part
|
|
continue
|
|
start = 0
|
|
step = max(max_chars - overlap, 1)
|
|
while start < len(part):
|
|
chunks.append(part[start:start + max_chars])
|
|
start += step
|
|
current = ""
|
|
if current:
|
|
chunks.append(current)
|
|
return chunks
|
|
|
|
def _delete_oss_object(self, source: str | None) -> None:
|
|
if not source:
|
|
return
|
|
try:
|
|
oss = OssClient()
|
|
ref = oss.ResolveObjectRef(Source=source)
|
|
if ref.isDirectUrl or not ref.objectKey:
|
|
return
|
|
oss._GetMinioClient().remove_object(ref.bucket, ref.objectKey)
|
|
except Exception:
|
|
# 对象存储删除失败不阻塞业务主流程,避免历史脏数据导致文档无法删除。
|
|
return
|