from __future__ import annotations import asyncio import json import mimetypes import os import re import tempfile import uuid from datetime import datetime from pathlib import Path import httpx from sqlalchemy import text from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession from fastapi_common.fastapi_common_storage.oss_client import OssClient from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import RagDatasetUpdateDTO from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import ( RagDatasetBatchDeleteResultVO, RagDatasetDetailVO, RagDatasetDocumentItemVO, RagDatasetDocumentPageVO, RagDatasetItemVO, RagDatasetPageVO, RagDatasetRetrieveDocumentVO, RagDatasetRetrieveRecordVO, RagDatasetRetrieveResponseVO, RagDatasetRetrieveSegmentVO, RagDatasetSegmentItemVO, RagDatasetSegmentPageVO, RagDatasetUploadDocumentVO, ) from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import RagOperationResultVO from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG from fastapi_modules.fastapi_leaudit.services.ragDatasetService import IRagDatasetService class RagDatasetServiceImpl(IRagDatasetService): _ACTIVE_INDEXING_STATUSES = {"waiting", "parsing", "cleaning", "splitting", "indexing"} _DELETABLE_DOCUMENT_STATUSES = {"completed", "error", "paused"} _APP_LINK_SQL = """ LEFT JOIN ( SELECT DISTINCT ON (dataset_id) dataset_id, id, name, is_default FROM rag_chat_app WHERE deleted_at IS NULL AND status = 1 ORDER BY dataset_id, is_default DESC, sort_order ASC, id ASC ) a ON a.dataset_id = d.id """ async def GetAdminDatasets( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, Area: str | None, OnlyEnabled: bool | None, Page: int, PageSize: int, ) -> RagDatasetPageVO: if UserRole not in ("provincial_admin", "admin", "super_admin"): raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有管理知识库权限") managed_area = self._resolve_managed_area(UserRole=UserRole, UserArea=UserArea) filters = ["d.deleted_at IS NULL"] params: dict = { "offset": max(Page - 1, 0) * PageSize, "limit": PageSize, } areas = [item.strip() for item in str(Area or "").split(",") if item.strip()] if managed_area: if areas and any(item != managed_area for item in areas): raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户只能查看本地区知识库配置") filters.append("d.area = :managed_area") params["managed_area"] = managed_area elif len(areas) == 1: filters.append("d.area = :area") params["area"] = areas[0] elif len(areas) > 1: filters.append("d.area = ANY(:areas)") params["areas"] = areas if OnlyEnabled is not None: filters.append("d.status = :status") params["status"] = 1 if OnlyEnabled else 0 where_sql = " AND ".join(filters) async with GetAsyncSession() as session: total = ( await session.execute( text(f"SELECT COUNT(1) FROM rag_dataset d WHERE {where_sql}"), {k: v for k, v in params.items() if k not in ("offset", "limit")}, ) ).scalar_one() rows = ( await session.execute( text( f""" SELECT d.id, d.name, d.description, d.area, d.is_public, d.is_default, d.document_count, d.total_chunks, d.status, d.sort_order, d.created_at, d.updated_at, a.id AS app_id, COALESCE(a.name, '') AS app_name, COALESCE(a.is_default, FALSE) AS app_is_default FROM rag_dataset d {self._APP_LINK_SQL} WHERE {where_sql} ORDER BY d.sort_order ASC, d.id ASC OFFSET :offset LIMIT :limit """ ), params, ) ).mappings().all() return RagDatasetPageVO( data=[self._to_item_vo(dict(row)) for row in rows], total=int(total or 0), ) async def CreateAdminDataset( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, Body: dict, ) -> RagDatasetDetailVO: if UserRole not in ("provincial_admin", "admin", "super_admin"): raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有创建知识库权限") area = str(Body.get("area") or "").strip() name = str(Body.get("dataset_name") or Body.get("name") or "").strip() description = str(Body.get("dataset_description") or Body.get("description") or "").strip() if not area or not name: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "地区和知识库名称不能为空") self._assert_manage_area_scope(UserRole=UserRole, UserArea=UserArea, DatasetArea=area) collection_name = self._slugify_collection_name(area, name) retrieval_model = {} async with GetAsyncSession() as session: base = ( await session.execute( text( """ SELECT embedding_model, embedding_dim, chunk_max_size, chunk_min_size, retrieval_model FROM rag_dataset WHERE deleted_at IS NULL ORDER BY is_default DESC, id ASC LIMIT 1 """ ) ) ).mappings().first() if base: retrieval_model = base.get("retrieval_model") or {} if bool(Body.get("is_default")): await self._clear_default_flags(session) row = ( await session.execute( text( """ INSERT INTO rag_dataset ( name, description, area, is_public, is_default, collection_name, embedding_model, embedding_dim, chunk_max_size, chunk_min_size, retrieval_model, sort_order, status, created_by, updated_by ) VALUES ( :name, :description, :area, :is_public, :is_default, :collection_name, :embedding_model, :embedding_dim, :chunk_max_size, :chunk_min_size, CAST(:retrieval_model AS jsonb), :sort_order, :status, :created_by, :updated_by ) RETURNING id """ ), { "name": name, "description": description, "area": area, "is_public": bool(Body.get("is_public")), "is_default": bool(Body.get("is_default")), "collection_name": collection_name, "embedding_model": (base.get("embedding_model") if base else "text-embedding-v4"), "embedding_dim": (base.get("embedding_dim") if base else 1024), "chunk_max_size": (base.get("chunk_max_size") if base else 800), "chunk_min_size": (base.get("chunk_min_size") if base else 20), "retrieval_model": json.dumps(retrieval_model, ensure_ascii=False), "sort_order": int(Body.get("sort_order") or 0), "status": int(Body.get("status") or 1), "created_by": CurrentUserId, "updated_by": CurrentUserId, }, ) ).mappings().first() dataset_id = int(row["id"]) await self._ensure_linked_app( session=session, dataset_id=dataset_id, dataset_name=name, dataset_area=area, current_user_id=CurrentUserId, is_default=bool(Body.get("is_default")), ) refreshed = await self._get_dataset_row(dataset_id) return self._to_detail_vo(refreshed) if refreshed else None async def UpdateAdminDataset( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int, Body: dict, ) -> RagDatasetDetailVO | None: if UserRole not in ("provincial_admin", "admin", "super_admin"): raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有更新知识库权限") existing = await self._get_dataset_row(DatasetId) if not existing: return None self._assert_manage_area_scope(UserRole=UserRole, UserArea=UserArea, DatasetArea=str(existing.get("area") or "")) area = str(Body.get("area") or existing.get("area") or "").strip() self._assert_manage_area_scope(UserRole=UserRole, UserArea=UserArea, DatasetArea=area) async with GetAsyncSession() as session: target_is_default = bool(Body.get("is_default", existing.get("is_default"))) if target_is_default: await self._clear_default_flags(session) elif existing.get("is_default") and Body.get("is_default") is False: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "默认知识库不能直接取消,请先将其他知识库设为默认") await session.execute( text( """ UPDATE rag_dataset SET name = :name, description = :description, area = :area, is_public = :is_public, is_default = :is_default, sort_order = :sort_order, status = :status, updated_by = :updated_by, updated_at = NOW() WHERE id = :dataset_id """ ), { "dataset_id": DatasetId, "name": str(Body.get("dataset_name") or Body.get("name") or existing.get("name") or "").strip(), "description": str(Body.get("dataset_description") or Body.get("description") or existing.get("description") or "").strip(), "area": area, "is_public": bool(Body.get("is_public", existing.get("is_public"))), "is_default": target_is_default, "sort_order": int(Body.get("sort_order") if Body.get("sort_order") is not None else (existing.get("sort_order") or 0)), "status": int(Body.get("status") if Body.get("status") is not None else (existing.get("status") or 1)), "updated_by": CurrentUserId, }, ) await self._ensure_linked_app( session=session, dataset_id=DatasetId, dataset_name=str(Body.get("dataset_name") or Body.get("name") or existing.get("name") or "").strip(), dataset_area=area, current_user_id=CurrentUserId, is_default=target_is_default, ) refreshed = await self._get_dataset_row(DatasetId) return self._to_detail_vo(refreshed) if refreshed else None async def DeleteAdminDataset( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int, ) -> RagOperationResultVO: if UserRole not in ("provincial_admin", "admin", "super_admin"): raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有删除知识库权限") existing = await self._get_dataset_row(DatasetId) if not existing: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") self._assert_manage_area_scope(UserRole=UserRole, UserArea=UserArea, DatasetArea=str(existing.get("area") or "")) if bool(existing.get("is_default")): raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "默认知识库不允许删除,请先切换默认知识库") async with GetAsyncSession() as session: await session.execute( text("UPDATE rag_dataset SET deleted_at = NOW(), updated_by = :updated_by, updated_at = NOW() WHERE id = :dataset_id"), {"dataset_id": DatasetId, "updated_by": CurrentUserId}, ) await session.execute( text("UPDATE rag_chat_app SET deleted_at = NOW(), updated_by = :updated_by, updated_at = NOW() WHERE dataset_id = :dataset_id AND deleted_at IS NULL"), {"dataset_id": DatasetId, "updated_by": CurrentUserId}, ) return RagOperationResultVO(result="success") async def GetMyDatasets(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagDatasetPageVO: async with GetAsyncSession() as session: rows = ( await session.execute( text( f""" SELECT d.id, d.name, d.description, d.area, d.is_public, d.is_default, d.document_count, d.total_chunks, d.status, d.sort_order, d.created_at, d.updated_at, a.id AS app_id, COALESCE(a.name, '') AS app_name, COALESCE(a.is_default, FALSE) AS app_is_default FROM rag_dataset d {self._APP_LINK_SQL} WHERE d.deleted_at IS NULL AND d.status = 1 AND ( :is_provincial = TRUE OR d.area IN (:user_area, '省级', '') OR d.is_public = TRUE ) ORDER BY d.sort_order ASC, d.created_at DESC """ ), { "is_provincial": UserRole == "provincial_admin", "user_area": UserArea or "", }, ) ).mappings().all() return RagDatasetPageVO( data=[ RagDatasetItemVO( **self._to_item_vo(dict(row)).model_dump() ) for row in rows ], total=len(rows), ) async def GetDatasetDetail(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int) -> RagDatasetDetailVO | None: row = await self._get_visible_dataset(UserArea, UserRole, DatasetId) if not row: return None return self._to_detail_vo(row) async def UpdateDataset( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int, Body: RagDatasetUpdateDTO, ) -> RagDatasetDetailVO | None: row = await self._get_visible_dataset(UserArea, UserRole, DatasetId) if not row: return None if UserRole not in ("provincial_admin", "admin", "super_admin"): raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有修改知识库配置权限") update_fields: list[str] = [] params: dict = {"dataset_id": DatasetId, "updated_by": CurrentUserId} if Body.name is not None: update_fields.append("name = :name") params["name"] = Body.name.strip() if Body.retrieval_model is not None: update_fields.append("retrieval_model = CAST(:retrieval_model AS jsonb)") params["retrieval_model"] = json.dumps(Body.retrieval_model, ensure_ascii=False) if not update_fields: return self._to_detail_vo(row) update_fields.append("updated_by = :updated_by") update_fields.append("updated_at = NOW()") async with GetAsyncSession() as session: await session.execute( text( f""" UPDATE rag_dataset SET {", ".join(update_fields)} WHERE id = :dataset_id """ ), params, ) refreshed = await self._get_visible_dataset(UserArea, UserRole, DatasetId) return self._to_detail_vo(refreshed) if refreshed else None async def GetDatasetDocuments( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int, Page: int, Limit: int, Keyword: str | None, ) -> RagDatasetDocumentPageVO: dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) if not dataset: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") where_sql = [ "dataset_id = :dataset_id", "deleted_at IS NULL", ] params: dict = { "dataset_id": DatasetId, "offset": max(Page - 1, 0) * Limit, "limit": Limit + 1, } if Keyword and Keyword.strip(): where_sql.append("(original_name ILIKE :keyword OR filename ILIKE :keyword)") params["keyword"] = f"%{Keyword.strip()}%" async with GetAsyncSession() as session: total = ( await session.execute( text( f""" SELECT COUNT(1) FROM rag_document WHERE {" AND ".join(where_sql)} """ ), {key: value for key, value in params.items() if key not in ("offset", "limit")}, ) ).scalar_one() rows = ( await session.execute( text( f""" SELECT id, dataset_id, original_name, file_type, file_size, chunk_count, indexing_status, COALESCE(indexing_error, '') AS indexing_error, enabled, hit_count, created_by, created_at, updated_at FROM rag_document WHERE {" AND ".join(where_sql)} ORDER BY created_at DESC OFFSET :offset LIMIT :limit """ ), params, ) ).mappings().all() has_more = len(rows) > Limit items = rows[:Limit] return RagDatasetDocumentPageVO( data=[ RagDatasetDocumentItemVO( id=row["id"], datasetId=row["dataset_id"], name=row.get("original_name") or "", fileType=row.get("file_type") or "", fileSize=row.get("file_size") or 0, chunkCount=row.get("chunk_count") or 0, indexingStatus=row.get("indexing_status") or "waiting", error=row.get("indexing_error") or "", enabled=bool(row.get("enabled")), hitCount=row.get("hit_count") or 0, createdBy=row.get("created_by"), createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0, ) for row in items ], total=int(total or 0), page=Page, limit=Limit, hasMore=has_more, ) async def GetDatasetDocumentDetail( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int, DocumentId: int, ) -> RagDatasetDocumentItemVO | None: dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) if not dataset: return None async with GetAsyncSession() as session: row = ( await session.execute( text( """ SELECT id, dataset_id, original_name, file_type, file_size, chunk_count, indexing_status, COALESCE(indexing_error, '') AS indexing_error, enabled, hit_count, created_by, created_at, updated_at FROM rag_document WHERE id = :document_id AND dataset_id = :dataset_id AND deleted_at IS NULL LIMIT 1 """ ), {"document_id": DocumentId, "dataset_id": DatasetId}, ) ).mappings().first() if not row: return None return RagDatasetDocumentItemVO( id=row["id"], datasetId=row["dataset_id"], name=row.get("original_name") or "", fileType=row.get("file_type") or "", fileSize=row.get("file_size") or 0, chunkCount=row.get("chunk_count") or 0, indexingStatus=row.get("indexing_status") or "waiting", error=row.get("indexing_error") or "", enabled=bool(row.get("enabled")), hitCount=row.get("hit_count") or 0, createdBy=row.get("created_by"), createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0, ) async def _get_visible_dataset(self, user_area: str | None, user_role: str | None, dataset_id: int) -> dict | None: async with GetAsyncSession() as session: row = ( await session.execute( text( f""" SELECT d.id, d.name, d.description, d.area, d.is_public, d.is_default, d.status, d.document_count, d.total_chunks, d.chunk_max_size, d.chunk_min_size, d.sort_order, d.retrieval_model, d.collection_name, d.embedding_model, d.created_at, d.updated_at, a.id AS app_id, COALESCE(a.name, '') AS app_name, COALESCE(a.is_default, FALSE) AS app_is_default FROM rag_dataset d {self._APP_LINK_SQL} WHERE d.id = :dataset_id AND d.deleted_at IS NULL AND d.status = 1 LIMIT 1 """ ), {"dataset_id": dataset_id}, ) ).mappings().first() if not row: return None if user_role == "provincial_admin": return dict(row) area = row.get("area") or "" if area in ("", "省级", user_area or "") or bool(row.get("is_public")): return dict(row) return None def _to_detail_vo(self, row: dict) -> RagDatasetDetailVO: return RagDatasetDetailVO( id=row["id"], name=row["name"], description=row.get("description") or "", area=row.get("area") or "", isPublic=bool(row.get("is_public")), isDefault=bool(row.get("is_default")), status=row.get("status") or 1, documentCount=row.get("document_count") or 0, sortOrder=row.get("sort_order") or 0, totalChunks=row.get("total_chunks") or 0, chunkMaxSize=row.get("chunk_max_size") or 800, chunkMinSize=row.get("chunk_min_size") or 20, retrievalModel=row.get("retrieval_model") or {}, createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0, appId=row.get("app_id"), appName=row.get("app_name") or "", appIsDefault=bool(row.get("app_is_default")), ) def _to_item_vo(self, row: dict) -> RagDatasetItemVO: return RagDatasetItemVO( id=row["id"], name=row.get("name") or "", description=row.get("description") or "", area=row.get("area") or "", isPublic=bool(row.get("is_public")), isDefault=bool(row.get("is_default")), documentCount=row.get("document_count") or 0, totalChunks=row.get("total_chunks") or 0, status=row.get("status") or 1, sortOrder=row.get("sort_order") or 0, createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0, appId=row.get("app_id"), appName=row.get("app_name") or "", appIsDefault=bool(row.get("app_is_default")), ) async def _get_dataset_row(self, dataset_id: int) -> dict | None: async with GetAsyncSession() as session: row = ( await session.execute( text( f""" SELECT d.id, d.name, d.description, d.area, d.is_public, d.is_default, d.status, d.document_count, d.total_chunks, d.chunk_max_size, d.chunk_min_size, d.sort_order, d.retrieval_model, d.collection_name, d.embedding_model, d.created_at, d.updated_at, a.id AS app_id, COALESCE(a.name, '') AS app_name, COALESCE(a.is_default, FALSE) AS app_is_default FROM rag_dataset d {self._APP_LINK_SQL} WHERE d.id = :dataset_id AND d.deleted_at IS NULL LIMIT 1 """ ), {"dataset_id": dataset_id}, ) ).mappings().first() return dict(row) if row else None async def _clear_default_flags(self, session) -> None: await session.execute(text("UPDATE rag_dataset SET is_default = FALSE WHERE deleted_at IS NULL")) await session.execute(text("UPDATE rag_chat_app SET is_default = FALSE WHERE deleted_at IS NULL")) async def _ensure_linked_app( self, session, dataset_id: int, dataset_name: str, dataset_area: str, current_user_id: int, is_default: bool, ) -> None: app_row = ( await session.execute( text( """ SELECT id FROM rag_chat_app WHERE dataset_id = :dataset_id AND deleted_at IS NULL ORDER BY is_default DESC, sort_order ASC, id ASC LIMIT 1 """ ), {"dataset_id": dataset_id}, ) ).mappings().first() app_name = self._build_app_name(dataset_area=dataset_area, dataset_name=dataset_name) if app_row: await session.execute( text( """ UPDATE rag_chat_app SET name = :name, is_default = :is_default, status = 1, updated_by = :updated_by, updated_at = NOW() WHERE id = :app_id """ ), { "app_id": int(app_row["id"]), "name": app_name, "is_default": is_default, "updated_by": current_user_id, }, ) return await session.execute( text( """ INSERT INTO rag_chat_app ( name, description, area, dataset_id, suggested_questions, opening_statement, sort_order, status, is_default, created_by, updated_by ) VALUES ( :name, :description, :area, :dataset_id, CAST(:suggested_questions AS jsonb), :opening_statement, 0, 1, :is_default, :created_by, :updated_by ) """ ), { "name": app_name, "description": f"{dataset_area or '默认地区'}知识库问答助手", "area": dataset_area or "", "dataset_id": dataset_id, "suggested_questions": json.dumps([], ensure_ascii=False), "opening_statement": f"您好,我是{app_name}。", "is_default": is_default, "created_by": current_user_id, "updated_by": current_user_id, }, ) def _build_app_name(self, dataset_area: str, dataset_name: str) -> str: cleaned = (dataset_name or "").strip() if cleaned.endswith("知识库"): cleaned = cleaned[:-3] if cleaned.endswith("助手"): return cleaned if cleaned: return f"{cleaned}助手" return f"{dataset_area or '默认地区'}法务助手" def _slugify_collection_name(self, area: str, name: str) -> str: source = f"{area}_{name}".lower() normalized = re.sub(r"[^a-z0-9]+", "_", source).strip("_") if normalized: return f"legal_kb_{normalized}"[:96] return f"legal_kb_{uuid.uuid4().hex[:12]}" def _resolve_managed_area(self, UserRole: str | None, UserArea: str | None) -> str | None: if UserRole == "admin": area = str(UserArea or "").strip() if not area: raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前市级管理员未配置地区,无法管理知识库") return area return None def _assert_manage_area_scope(self, UserRole: str | None, UserArea: str | None, DatasetArea: str) -> None: if UserRole in ("provincial_admin", "super_admin"): return if UserRole != "admin": raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有管理知识库权限") managed_area = self._resolve_managed_area(UserRole=UserRole, UserArea=UserArea) if DatasetArea != managed_area: raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户只能管理本地区知识库") async def UploadDatasetDocument( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int, FileName: str, ContentType: str | None, Content: bytes, ProcessConfig: dict | None, ) -> RagDatasetUploadDocumentVO: dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) if not dataset: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") suffix = os.path.splitext(FileName)[1].lower() if suffix not in {".pdf", ".docx", ".txt", ".md", ".json"}: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD、JSON") object_key = f"rag/{DatasetId}/{datetime.now().strftime('%Y/%m/%d')}/{uuid.uuid4().hex}_{FileName}" content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream" OssClient().EnsureBucket() stored_key = OssClient().UploadBytes(ObjectKey=object_key, Content=Content, ContentType=content_type) async with GetAsyncSession() as session: inserted = ( await session.execute( text( """ INSERT INTO rag_document ( dataset_id, filename, original_name, minio_path, file_type, file_size, chunk_count, indexing_status, enabled, hit_count, created_by ) VALUES ( :dataset_id, :filename, :original_name, :minio_path, :file_type, :file_size, 0, 'indexing', TRUE, 0, :created_by ) RETURNING id, created_at, updated_at """ ), { "dataset_id": DatasetId, "filename": uuid.uuid4().hex + suffix, "original_name": FileName, "minio_path": stored_key, "file_type": suffix.lstrip("."), "file_size": len(Content), "created_by": CurrentUserId, }, ) ).mappings().first() document_id = int(inserted["id"]) asyncio.create_task( self._run_document_indexing_task( dataset=dataset, dataset_id=DatasetId, document_id=document_id, file_name=FileName, content=Content, process_config=ProcessConfig or {}, ) ) return RagDatasetUploadDocumentVO( document={ "id": str(document_id), "name": FileName, "indexing_status": "indexing", "word_count": 0, "hit_count": 0, "enabled": True, }, batch=str(document_id), ) async def GetDatasetDocumentSegments( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int, DocumentId: int, Page: int, Limit: int, Keyword: str | None, ) -> RagDatasetSegmentPageVO: dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) if not dataset: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") collection = get_chroma().get_or_create_collection(dataset["collection_name"]) raw = collection.get(where={"document_id": DocumentId}, include=["documents", "metadatas"]) ids = raw.get("ids") or [] docs = raw.get("documents") or [] metas = raw.get("metadatas") or [] items: list[dict] = [] for index, segment_id in enumerate(ids): content = docs[index] if index < len(docs) else "" meta = metas[index] if index < len(metas) and isinstance(metas[index], dict) else {} if Keyword and Keyword.strip() and Keyword.strip() not in content: continue items.append( { "id": str(segment_id), "position": index + 1, "document_id": str(DocumentId), "content": content or "", "word_count": len(content or ""), "hit_count": 0, "enabled": True, "status": "completed", "created_at": 0, } ) offset = max(Page - 1, 0) * Limit page_items = items[offset: offset + Limit] has_more = offset + Limit < len(items) return RagDatasetSegmentPageVO( data=[ RagDatasetSegmentItemVO( id=item["id"], position=item["position"], documentId=item["document_id"], content=item["content"], wordCount=item["word_count"], hitCount=item["hit_count"], enabled=item["enabled"], status=item["status"], createdAt=item["created_at"], ) for item in page_items ], total=len(items), limit=Limit, hasMore=has_more, ) async def GetDatasetDocumentSegmentDetail( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int, DocumentId: int, SegmentId: str, ) -> RagDatasetSegmentItemVO | None: dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) if not dataset: return None collection = get_chroma().get_or_create_collection(dataset["collection_name"]) raw = collection.get(ids=[SegmentId], include=["documents", "metadatas"]) ids = raw.get("ids") or [] if not ids: return None content = (raw.get("documents") or [""])[0] or "" metadata = (raw.get("metadatas") or [{}])[0] or {} if str(metadata.get("document_id") or "") != str(DocumentId): return None return RagDatasetSegmentItemVO( id=str(SegmentId), position=int(metadata.get("chunk_index") or 0) + 1, documentId=str(DocumentId), content=content, wordCount=len(content), hitCount=0, enabled=True, status="completed", createdAt=0, ) async def UpdateDatasetDocumentSegment( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int, DocumentId: int, SegmentId: str, Body: dict, ) -> RagDatasetSegmentItemVO | None: dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) if not dataset: return None if UserRole not in ("provincial_admin", "admin", "super_admin"): raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有修改知识库分段权限") current = await self.GetDatasetDocumentSegmentDetail( CurrentUserId=CurrentUserId, UserArea=UserArea, UserRole=UserRole, DatasetId=DatasetId, DocumentId=DocumentId, SegmentId=SegmentId, ) if not current: return None segment_body = Body.get("segment") if isinstance(Body.get("segment"), dict) else Body content = str(segment_body.get("content") or current.content).strip() if not content: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "分段内容不能为空") collection = get_chroma().get_or_create_collection(dataset["collection_name"]) raw = collection.get(ids=[SegmentId], include=["metadatas"]) metadata = (raw.get("metadatas") or [{}])[0] or {} embeddings = await self._embed_texts([content], dataset.get("embedding_model") or "") collection.update( ids=[SegmentId], documents=[content], embeddings=embeddings, metadatas=[metadata], ) return RagDatasetSegmentItemVO( id=str(SegmentId), position=int(metadata.get("chunk_index") or 0) + 1, documentId=str(DocumentId), content=content, wordCount=len(content), hitCount=0, enabled=True, status="completed", createdAt=0, ) async def DeleteDatasetDocumentSegment( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int, DocumentId: int, SegmentId: str, ) -> RagDatasetBatchDeleteResultVO: dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) if not dataset: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") if UserRole not in ("provincial_admin", "admin", "super_admin"): raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有删除知识库分段权限") current = await self.GetDatasetDocumentSegmentDetail( CurrentUserId=CurrentUserId, UserArea=UserArea, UserRole=UserRole, DatasetId=DatasetId, DocumentId=DocumentId, SegmentId=SegmentId, ) if not current: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "分段不存在") collection = get_chroma().get_or_create_collection(dataset["collection_name"]) collection.delete(ids=[SegmentId]) async with GetAsyncSession() as session: await session.execute( text( """ UPDATE rag_document SET chunk_count = GREATEST(chunk_count - 1, 0), updated_at = NOW() WHERE id = :document_id """ ), {"document_id": DocumentId}, ) await session.execute( text( """ UPDATE rag_dataset SET total_chunks = GREATEST(total_chunks - 1, 0), updated_at = NOW() WHERE id = :dataset_id """ ), {"dataset_id": DatasetId}, ) return RagOperationResultVO(result="success") async def DeleteDatasetDocument( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int, DocumentId: int, ) -> RagOperationResultVO: result = await self.BatchDeleteDatasetDocuments( CurrentUserId=CurrentUserId, UserArea=UserArea, UserRole=UserRole, DatasetId=DatasetId, DocumentIds=[DocumentId], ) if result.deletedCount <= 0: reason = result.skipped[0].reason if result.skipped else "文档删除失败" raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, reason) return RagOperationResultVO(result="success") async def BatchDeleteDatasetDocuments( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int, DocumentIds: list[int], ) -> RagOperationResultVO: dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) if not dataset: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") if UserRole not in ("provincial_admin", "admin", "super_admin"): raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有删除知识库文档权限") normalized_ids = [int(item) for item in DocumentIds if item] if not normalized_ids: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "缺少待删除文档") async with GetAsyncSession() as session: document_rows = ( await session.execute( text( """ SELECT id, dataset_id, minio_path, original_name, indexing_status FROM rag_document WHERE id = ANY(:document_ids) AND dataset_id = :dataset_id AND deleted_at IS NULL """ ), {"document_ids": normalized_ids, "dataset_id": DatasetId}, ) ).mappings().all() if not document_rows: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在") deletable_rows = [] skipped = [] for row in document_rows: if row.get("indexing_status") not in self._DELETABLE_DOCUMENT_STATUSES: skipped.append( { "id": int(row["id"]), "name": str(row.get("original_name") or ""), "reason": "文档仍在处理中,暂不允许删除", } ) continue deletable_rows.append(row) if not deletable_rows: return RagDatasetBatchDeleteResultVO( result="success", requestedCount=len(normalized_ids), deletedCount=0, skippedCount=len(skipped), deletedIds=[], skipped=skipped, ) collection = get_chroma().get_or_create_collection(dataset["collection_name"]) chunk_ids: list[str] = [] for document_row in deletable_rows: raw = collection.get(where={"document_id": int(document_row["id"])}, include=[]) ids = raw.get("ids") or [] if ids: chunk_ids.extend(ids) if chunk_ids: collection.delete(ids=chunk_ids) for document_row in deletable_rows: self._delete_oss_object(document_row.get("minio_path")) existing_ids = [int(row["id"]) for row in deletable_rows] async with GetAsyncSession() as session: await session.execute( text( """ UPDATE rag_document SET deleted_at = NOW(), updated_at = NOW() WHERE id = ANY(:document_ids) """ ), {"document_ids": existing_ids}, ) await session.execute( text( """ UPDATE rag_dataset SET document_count = ( SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL ), total_chunks = COALESCE(( SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL ), 0), updated_at = NOW() WHERE id = :dataset_id """ ), {"dataset_id": DatasetId}, ) return RagDatasetBatchDeleteResultVO( result="success", requestedCount=len(normalized_ids), deletedCount=len(existing_ids), skippedCount=len(skipped), deletedIds=existing_ids, skipped=skipped, ) async def RetrieveDataset( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int, Query: str, RetrievalModel: dict | None, ) -> RagDatasetRetrieveResponseVO: dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) if not dataset: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") query_text = (Query or "").strip() if not query_text: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "检索内容不能为空") retrieval_model = RetrievalModel or {} top_k = int(retrieval_model.get("top_k") or 5) top_k = max(1, min(top_k, 20)) score_threshold_enabled = bool(retrieval_model.get("score_threshold_enabled")) score_threshold = float(retrieval_model.get("score_threshold") or 0) if score_threshold_enabled else None query_embedding = await self._embed_texts([query_text], dataset.get("embedding_model") or "") collection = get_chroma().get_or_create_collection(dataset["collection_name"]) raw = collection.query( query_embeddings=query_embedding, n_results=top_k, include=["documents", "metadatas", "distances"], ) ids = (raw.get("ids") or [[]])[0] if raw.get("ids") else [] documents = (raw.get("documents") or [[]])[0] if raw.get("documents") else [] metadatas = (raw.get("metadatas") or [[]])[0] if raw.get("metadatas") else [] distances = (raw.get("distances") or [[]])[0] if raw.get("distances") else [] records: list[RagDatasetRetrieveRecordVO] = [] for index, segment_id in enumerate(ids): content = documents[index] if index < len(documents) else "" metadata = metadatas[index] if index < len(metadatas) and isinstance(metadatas[index], dict) else {} distance = float(distances[index]) if index < len(distances) and distances[index] is not None else 1.0 score = max(0.0, min(1.0, 1.0 - distance)) if score_threshold_enabled and score_threshold is not None and score < score_threshold: continue document_name = metadata.get("document_name") or metadata.get("source") or "" document_id = str(metadata.get("document_id") or "") chunk_index = int(metadata.get("chunk_index") or index) records.append( RagDatasetRetrieveRecordVO( score=round(score, 6), segment=RagDatasetRetrieveSegmentVO( id=str(segment_id), position=chunk_index + 1, documentId=document_id, content=content or "", answer="", wordCount=len(content or ""), hitCount=0, enabled=True, status="completed", createdAt=0, document=RagDatasetRetrieveDocumentVO( id=document_id, dataSourceType="upload_file", name=document_name, docType=None, ), ), ) ) return RagDatasetRetrieveResponseVO( query={"content": query_text}, records=records, ) async def GetDatasetDocumentIndexingStatus( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int, DocumentId: int, ) -> dict: document = await self.GetDatasetDocumentDetail( CurrentUserId=CurrentUserId, UserArea=UserArea, UserRole=UserRole, DatasetId=DatasetId, DocumentId=DocumentId, ) if not document: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在") completed_segments = document.chunkCount if document.indexingStatus == "completed" else 0 total_segments = document.chunkCount if document.chunkCount > 0 else 0 return { "data": [ { "id": str(document.id), "indexing_status": document.indexingStatus, "processing_started_at": document.updatedAt or document.createdAt or None, "parsing_completed_at": document.updatedAt if document.indexingStatus in ("cleaning", "splitting", "indexing", "completed") else None, "cleaning_completed_at": document.updatedAt if document.indexingStatus in ("splitting", "indexing", "completed") else None, "splitting_completed_at": document.updatedAt if document.indexingStatus in ("indexing", "completed") else None, "completed_at": document.updatedAt if document.indexingStatus == "completed" else None, "paused_at": None, "error": document.error or None, "stopped_at": None, "completed_segments": completed_segments, "total_segments": total_segments, } ] } async def UpdateDatasetDocumentByFile( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int, DocumentId: int, FileName: str, ContentType: str | None, Content: bytes, ProcessConfig: dict | None, ) -> RagDatasetUploadDocumentVO: dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) if not dataset: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") if UserRole not in ("provincial_admin", "admin", "super_admin"): raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有重处理知识库文档权限") async with GetAsyncSession() as session: current = ( await session.execute( text( """ SELECT id, minio_path FROM rag_document WHERE id = :document_id AND dataset_id = :dataset_id AND deleted_at IS NULL LIMIT 1 """ ), {"document_id": DocumentId, "dataset_id": DatasetId}, ) ).mappings().first() if not current: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在") suffix = os.path.splitext(FileName)[1].lower() if suffix not in {".pdf", ".docx", ".txt", ".md", ".json"}: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD、JSON") content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream" object_key = current.get("minio_path") or f"rag/{DatasetId}/{datetime.now().strftime('%Y/%m/%d')}/{uuid.uuid4().hex}_{FileName}" OssClient().EnsureBucket() stored_key = OssClient().UploadBytes(ObjectKey=object_key, Content=Content, ContentType=content_type) collection = get_chroma().get_or_create_collection(dataset["collection_name"]) raw = collection.get(where={"document_id": DocumentId}, include=[]) ids = raw.get("ids") or [] if ids: collection.delete(ids=ids) async with GetAsyncSession() as session: await session.execute( text( """ UPDATE rag_document SET original_name = :original_name, minio_path = :minio_path, file_type = :file_type, file_size = :file_size, chunk_count = 0, indexing_status = 'indexing', indexing_error = NULL, updated_at = NOW() WHERE id = :document_id """ ), { "document_id": DocumentId, "original_name": FileName, "minio_path": stored_key, "file_type": suffix.lstrip("."), "file_size": len(Content), }, ) asyncio.create_task( self._run_document_indexing_task( dataset=dataset, dataset_id=DatasetId, document_id=DocumentId, file_name=FileName, content=Content, process_config=ProcessConfig or {}, ) ) return RagDatasetUploadDocumentVO( document={ "id": str(DocumentId), "name": FileName, "indexing_status": "indexing", "word_count": 0, "hit_count": 0, "enabled": True, }, batch=str(DocumentId), ) async def BatchUpdateDatasetDocumentStatus( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, DatasetId: int, DocumentIds: list[int], Enabled: bool, ) -> RagOperationResultVO: dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) if not dataset: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") if UserRole not in ("provincial_admin", "admin", "super_admin"): raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有修改知识库文档状态权限") ids = [int(doc_id) for doc_id in DocumentIds] if not ids: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "未传入文档ID") async with GetAsyncSession() as session: await session.execute( text( """ UPDATE rag_document SET enabled = :enabled, updated_at = NOW() WHERE dataset_id = :dataset_id AND id = ANY(:document_ids) AND deleted_at IS NULL """ ), {"dataset_id": DatasetId, "document_ids": ids, "enabled": Enabled}, ) return RagOperationResultVO(result="success") def _extract_page_texts(self, *, FileName: str, Content: bytes) -> list[tuple[int, str]]: suffix = os.path.splitext(FileName)[1].lower() with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp_file: temp_file.write(Content) temp_path = temp_file.name try: from fastapi_modules.fastapi_leaudit.services.impl.documentServiceImpl import ( _extract_page_texts_from_docx, _extract_page_texts_from_pdf, ) if suffix == ".pdf": return _extract_page_texts_from_pdf(Path(temp_path)) if suffix == ".docx": return _extract_page_texts_from_docx(Path(temp_path)) if suffix == ".json": return self._extract_page_texts_from_json(Content) text_value = Content.decode("utf-8", errors="ignore").strip() return [(1, text_value)] if text_value else [] finally: try: os.unlink(temp_path) except OSError: pass def _extract_page_texts_from_json(self, content: bytes) -> list[tuple[int, str]]: text_value = content.decode("utf-8", errors="ignore").strip() if not text_value: return [] try: payload = json.loads(text_value) except json.JSONDecodeError: return [(1, text_value)] flattened = self._json_to_text(payload) flattened = flattened.strip() return [(1, flattened)] if flattened else [] def _json_to_text(self, value: object, prefix: str = "") -> str: if value is None: return "" if isinstance(value, str): return f"{prefix}: {value}" if prefix else value if isinstance(value, (int, float, bool)): return f"{prefix}: {value}" if prefix else str(value) if isinstance(value, list): parts: list[str] = [] for index, item in enumerate(value, start=1): item_prefix = f"{prefix}[{index}]" if prefix else f"item_{index}" chunk = self._json_to_text(item, item_prefix) if chunk: parts.append(chunk) return "\n".join(parts) if isinstance(value, dict): parts = [] for key, item in value.items(): item_prefix = f"{prefix}.{key}" if prefix else str(key) chunk = self._json_to_text(item, item_prefix) if chunk: parts.append(chunk) return "\n".join(parts) rendered = str(value).strip() if not rendered: return "" return f"{prefix}: {rendered}" if prefix else rendered def _build_chunks(self, *, file_name: str, page_texts: list[tuple[int, str]], dataset: dict, process_config: dict, document_id: int) -> list[dict]: rules = ((process_config or {}).get("process_rule") or {}).get("rules") or {} segmentation = rules.get("segmentation") or {} pre_rules = rules.get("pre_processing_rules") or [] remove_spaces = any(rule.get("id") == "remove_extra_spaces" and rule.get("enabled") for rule in pre_rules) remove_urls = any(rule.get("id") == "remove_urls_emails" and rule.get("enabled") for rule in pre_rules) separator = segmentation.get("separator") or "\n\n" max_tokens = int(segmentation.get("max_tokens") or dataset.get("chunk_max_size") or 800) chunk_overlap = int(segmentation.get("chunk_overlap") or 50) chunk_overlap = max(0, min(chunk_overlap, max_tokens // 2 if max_tokens > 1 else 0)) chunks: list[dict] = [] for page_no, raw_text in page_texts: text_value = self._preprocess_text(raw_text, remove_spaces=remove_spaces, remove_urls=remove_urls) if not text_value: continue for index, chunk_text in enumerate(self._split_text(text_value, separator=separator, max_chars=max_tokens, overlap=chunk_overlap)): if not chunk_text.strip(): continue chunk_id = f"{document_id}:{page_no}:{index}" chunks.append( { "id": chunk_id, "text": chunk_text, "metadata": { "id": chunk_id, "source": file_name, "document_name": file_name, "document_id": document_id, "page": page_no, "chunk_index": index, }, } ) return chunks async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]: embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}/embeddings" embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"] embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4" batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10)) if not embed_url or not embed_key: raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务") embeddings: list[list[float]] = [] async with httpx.AsyncClient(timeout=120.0) as client: for start in range(0, len(texts), batch_size): batch_texts = texts[start:start + batch_size] try: response = await client.post( embed_url, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {embed_key}", }, json={"model": embed_model, "input": batch_texts}, ) response.raise_for_status() except httpx.HTTPStatusError as exc: error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}" raise LeauditException( StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, f"向量化服务调用失败: {error_message[:300]}", ) from exc payload = response.json() rows = payload.get("data") or [] batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")] if len(batch_embeddings) != len(batch_texts): raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常") embeddings.extend(batch_embeddings) if len(embeddings) != len(texts): raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常") return embeddings async def _run_document_indexing_task( self, *, dataset: dict, dataset_id: int, document_id: int, file_name: str, content: bytes, process_config: dict, ) -> None: try: await self._update_document_processing_state(document_id=document_id, status="parsing") page_texts = self._extract_page_texts(FileName=file_name, Content=content) await self._update_document_processing_state(document_id=document_id, status="cleaning") processed = self._build_chunks( file_name=file_name, page_texts=page_texts, dataset=dataset, process_config=process_config, document_id=document_id, ) if not processed: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文档未提取到可入库文本") await self._update_document_processing_state( document_id=document_id, status="splitting", chunk_count=len(processed), ) embeddings = await self._embed_texts([item["text"] for item in processed], dataset.get("embedding_model") or "") await self._update_document_processing_state( document_id=document_id, status="indexing", chunk_count=len(processed), ) collection = get_chroma().get_or_create_collection(dataset["collection_name"]) collection.add( ids=[item["id"] for item in processed], documents=[item["text"] for item in processed], embeddings=embeddings, metadatas=[item["metadata"] for item in processed], ) async with GetAsyncSession() as session: await session.execute( text( """ UPDATE rag_document SET chunk_count = :chunk_count, indexing_status = 'completed', indexing_error = NULL, indexing_started_at = COALESCE(indexing_started_at, NOW()), indexing_completed_at = NOW(), updated_at = NOW() WHERE id = :document_id """ ), {"chunk_count": len(processed), "document_id": document_id}, ) await self._sync_dataset_counts(dataset_id) except Exception as exc: async with GetAsyncSession() as session: await session.execute( text( """ UPDATE rag_document SET indexing_status = 'error', indexing_error = :error, updated_at = NOW() WHERE id = :document_id """ ), {"document_id": document_id, "error": str(exc)[:2000]}, ) async def _update_document_processing_state(self, *, document_id: int, status: str, chunk_count: int | None = None) -> None: fields = [ "indexing_status = :status", "indexing_error = NULL", "indexing_started_at = COALESCE(indexing_started_at, NOW())", "updated_at = NOW()", ] params: dict[str, object] = { "document_id": document_id, "status": status, } if chunk_count is not None: fields.append("chunk_count = :chunk_count") params["chunk_count"] = chunk_count async with GetAsyncSession() as session: await session.execute( text( f""" UPDATE rag_document SET {", ".join(fields)} WHERE id = :document_id """ ), params, ) async def _sync_dataset_counts(self, dataset_id: int) -> None: async with GetAsyncSession() as session: await session.execute( text( """ UPDATE rag_dataset SET document_count = ( SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL ), total_chunks = COALESCE(( SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL ), 0), updated_at = NOW() WHERE id = :dataset_id """ ), {"dataset_id": dataset_id}, ) def _preprocess_text(self, text_value: str, *, remove_spaces: bool, remove_urls: bool) -> str: result = text_value or "" if remove_urls: result = re.sub(r"https?://\\S+|www\\.\\S+|[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}", " ", result) if remove_spaces: result = re.sub(r"[ \\t]+", " ", result) result = re.sub(r"\n{3,}", "\n\n", result) return result.strip() def _split_text(self, text_value: str, *, separator: str, max_chars: int, overlap: int) -> list[str]: parts = [part.strip() for part in text_value.split(separator) if part.strip()] if separator else [text_value] chunks: list[str] = [] current = "" for part in parts: candidate = f"{current}{separator if current else ''}{part}" if separator else f"{current}{part}" if len(candidate) <= max_chars: current = candidate continue if current: chunks.append(current) if len(part) <= max_chars: current = part continue start = 0 step = max(max_chars - overlap, 1) while start < len(part): chunks.append(part[start:start + max_chars]) start += step current = "" if current: chunks.append(current) return chunks def _delete_oss_object(self, source: str | None) -> None: if not source: return try: oss = OssClient() ref = oss.ResolveObjectRef(Source=source) if ref.isDirectUrl or not ref.objectKey: return oss._GetMinioClient().remove_object(ref.bucket, ref.objectKey) except Exception: # 对象存储删除失败不阻塞业务主流程,避免历史脏数据导致文档无法删除。 return