From 8206ed7d43e63501d60d2284c6f47dac2517afe6 Mon Sep 17 00:00:00 2001 From: wren <“porlong@qq.com”> Date: Mon, 11 May 2026 19:25:50 +0800 Subject: [PATCH] feat: improve rag dataset document management --- .../controllers/ragChatController.py | 23 +- .../domian/Dto/ragDatasetDto.py | 4 + .../fastapi_leaudit/domian/vo/ragDatasetVo.py | 15 + .../services/impl/ragDatasetServiceImpl.py | 463 +++++++++++------- .../services/ragDatasetService.py | 11 + 5 files changed, 351 insertions(+), 165 deletions(-) diff --git a/fastapi_modules/fastapi_leaudit/controllers/ragChatController.py b/fastapi_modules/fastapi_leaudit/controllers/ragChatController.py index ca2c697..84c777a 100644 --- a/fastapi_modules/fastapi_leaudit/controllers/ragChatController.py +++ b/fastapi_modules/fastapi_leaudit/controllers/ragChatController.py @@ -16,7 +16,10 @@ from fastapi_modules.fastapi_leaudit.domian.Dto.ragChatDto import ( RagChatSendMessageDTO, RagMessageFeedbackDTO, ) -from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import RagDatasetUpdateDTO +from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import ( + RagDatasetBatchDocumentDeleteDTO, + RagDatasetUpdateDTO, +) from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import ( RagAppParametersVO, RagChatAppListVO, @@ -27,6 +30,7 @@ from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import ( RagOperationResultVO, ) from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import ( + RagDatasetBatchDeleteResultVO, RagDatasetDetailVO, RagDatasetDocumentItemVO, RagDatasetDocumentPageVO, @@ -347,6 +351,23 @@ class RagChatController(BaseController): ) return Result.success(data=result) + @self.router.post("/datasets/{DatasetId}/documents/batch-delete", response_model=Result[RagDatasetBatchDeleteResultVO]) + async def BatchDeleteDatasetDocuments( + DatasetId: int, + Body: RagDatasetBatchDocumentDeleteDTO, + payload: dict[str, Any] = Depends(verify_access_token), + ): + if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_delete"]]): + return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有删除知识库文档权限", "data": None}) + result = await self.RagDatasetService.BatchDeleteDatasetDocuments( + CurrentUserId=int(payload["user_id"]), + UserArea=payload.get("area"), + UserRole=payload.get("user_role"), + DatasetId=DatasetId, + DocumentIds=Body.document_ids, + ) + return Result.success(data=result) + @self.router.post("/datasets/{DatasetId}/retrieve", response_model=Result[RagDatasetRetrieveResponseVO]) async def RetrieveDataset( DatasetId: int, diff --git a/fastapi_modules/fastapi_leaudit/domian/Dto/ragDatasetDto.py b/fastapi_modules/fastapi_leaudit/domian/Dto/ragDatasetDto.py index 12301ee..d61413f 100644 --- a/fastapi_modules/fastapi_leaudit/domian/Dto/ragDatasetDto.py +++ b/fastapi_modules/fastapi_leaudit/domian/Dto/ragDatasetDto.py @@ -4,3 +4,7 @@ from pydantic import BaseModel, Field class RagDatasetUpdateDTO(BaseModel): name: str | None = Field(None, min_length=1, max_length=255) retrieval_model: dict | None = Field(None) + + +class RagDatasetBatchDocumentDeleteDTO(BaseModel): + document_ids: list[int] = Field(..., min_length=1) diff --git a/fastapi_modules/fastapi_leaudit/domian/vo/ragDatasetVo.py b/fastapi_modules/fastapi_leaudit/domian/vo/ragDatasetVo.py index 92dc2a1..02bbb97 100644 --- a/fastapi_modules/fastapi_leaudit/domian/vo/ragDatasetVo.py +++ b/fastapi_modules/fastapi_leaudit/domian/vo/ragDatasetVo.py @@ -74,6 +74,21 @@ class RagDatasetUploadDocumentVO(BaseModel): batch: str = Field("") +class RagDatasetBatchDeleteFailedItemVO(BaseModel): + id: int = Field(...) + name: str = Field("") + reason: str = Field("") + + +class RagDatasetBatchDeleteResultVO(BaseModel): + result: str = Field("success") + requestedCount: int = Field(0) + deletedCount: int = Field(0) + skippedCount: int = Field(0) + deletedIds: list[int] = Field(default_factory=list) + skipped: list[RagDatasetBatchDeleteFailedItemVO] = Field(default_factory=list) + + class RagDatasetSegmentItemVO(BaseModel): id: str = Field(...) position: int = Field(0) diff --git a/fastapi_modules/fastapi_leaudit/services/impl/ragDatasetServiceImpl.py b/fastapi_modules/fastapi_leaudit/services/impl/ragDatasetServiceImpl.py index a4d14f4..893290d 100644 --- a/fastapi_modules/fastapi_leaudit/services/impl/ragDatasetServiceImpl.py +++ b/fastapi_modules/fastapi_leaudit/services/impl/ragDatasetServiceImpl.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import json import mimetypes import os @@ -19,6 +20,7 @@ from fastapi_common.fastapi_common_web.exception.LeauditException import Leaudit from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import RagDatasetUpdateDTO from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import ( + RagDatasetBatchDeleteResultVO, RagDatasetDetailVO, RagDatasetDocumentItemVO, RagDatasetDocumentPageVO, @@ -39,6 +41,8 @@ from fastapi_modules.fastapi_leaudit.services.ragDatasetService import IRagDatas class RagDatasetServiceImpl(IRagDatasetService): + _ACTIVE_INDEXING_STATUSES = {"waiting", "parsing", "cleaning", "splitting", "indexing"} + _DELETABLE_DOCUMENT_STATUSES = {"completed", "error", "paused"} _APP_LINK_SQL = """ LEFT JOIN ( SELECT DISTINCT ON (dataset_id) @@ -735,8 +739,8 @@ class RagDatasetServiceImpl(IRagDatasetService): raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") suffix = os.path.splitext(FileName)[1].lower() - if suffix not in {".pdf", ".docx", ".txt", ".md"}: - raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD") + if suffix not in {".pdf", ".docx", ".txt", ".md", ".json"}: + raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD、JSON") object_key = f"rag/{DatasetId}/{datetime.now().strftime('%Y/%m/%d')}/{uuid.uuid4().hex}_{FileName}" content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream" @@ -771,78 +775,23 @@ class RagDatasetServiceImpl(IRagDatasetService): ).mappings().first() document_id = int(inserted["id"]) - try: - page_texts = self._extract_page_texts(FileName=FileName, Content=Content) - processed = self._build_chunks( - file_name=FileName, - page_texts=page_texts, + asyncio.create_task( + self._run_document_indexing_task( dataset=dataset, - process_config=ProcessConfig or {}, + dataset_id=DatasetId, document_id=document_id, + file_name=FileName, + content=Content, + process_config=ProcessConfig or {}, ) - if not processed: - raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文档未提取到可入库文本") - - embeddings = await self._embed_texts([item["text"] for item in processed], dataset.get("embedding_model") or "") - collection = get_chroma().get_or_create_collection(dataset["collection_name"]) - collection.add( - ids=[item["id"] for item in processed], - documents=[item["text"] for item in processed], - embeddings=embeddings, - metadatas=[item["metadata"] for item in processed], - ) - - async with GetAsyncSession() as session: - await session.execute( - text( - """ - UPDATE rag_document - SET chunk_count = :chunk_count, - indexing_status = 'completed', - indexing_error = NULL, - indexing_started_at = COALESCE(indexing_started_at, NOW()), - indexing_completed_at = NOW() - WHERE id = :document_id - """ - ), - {"chunk_count": len(processed), "document_id": document_id}, - ) - await session.execute( - text( - """ - UPDATE rag_dataset - SET document_count = ( - SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL - ), - total_chunks = COALESCE(( - SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL - ), 0) - WHERE id = :dataset_id - """ - ), - {"dataset_id": DatasetId}, - ) - except Exception as exc: - async with GetAsyncSession() as session: - await session.execute( - text( - """ - UPDATE rag_document - SET indexing_status = 'error', - indexing_error = :error - WHERE id = :document_id - """ - ), - {"document_id": document_id, "error": str(exc)[:2000]}, - ) - raise + ) return RagDatasetUploadDocumentVO( document={ "id": str(document_id), "name": FileName, - "indexing_status": "completed", - "word_count": len(processed), + "indexing_status": "indexing", + "word_count": 0, "hit_count": 0, "enabled": True, }, @@ -1013,7 +962,7 @@ class RagDatasetServiceImpl(IRagDatasetService): DatasetId: int, DocumentId: int, SegmentId: str, - ) -> RagOperationResultVO: + ) -> RagDatasetBatchDeleteResultVO: dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) if not dataset: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") @@ -1068,6 +1017,26 @@ class RagDatasetServiceImpl(IRagDatasetService): UserRole: str | None, DatasetId: int, DocumentId: int, + ) -> RagOperationResultVO: + result = await self.BatchDeleteDatasetDocuments( + CurrentUserId=CurrentUserId, + UserArea=UserArea, + UserRole=UserRole, + DatasetId=DatasetId, + DocumentIds=[DocumentId], + ) + if result.deletedCount <= 0: + reason = result.skipped[0].reason if result.skipped else "文档删除失败" + raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, reason) + return RagOperationResultVO(result="success") + + async def BatchDeleteDatasetDocuments( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + DocumentIds: list[int], ) -> RagOperationResultVO: dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId) if not dataset: @@ -1076,33 +1045,67 @@ class RagDatasetServiceImpl(IRagDatasetService): if UserRole not in ("provincial_admin", "admin", "super_admin"): raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有删除知识库文档权限") + normalized_ids = [int(item) for item in DocumentIds if item] + if not normalized_ids: + raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "缺少待删除文档") + async with GetAsyncSession() as session: - document_row = ( + document_rows = ( await session.execute( text( """ - SELECT id, dataset_id, minio_path + SELECT id, dataset_id, minio_path, original_name, indexing_status FROM rag_document - WHERE id = :document_id + WHERE id = ANY(:document_ids) AND dataset_id = :dataset_id AND deleted_at IS NULL - LIMIT 1 """ ), - {"document_id": DocumentId, "dataset_id": DatasetId}, + {"document_ids": normalized_ids, "dataset_id": DatasetId}, ) - ).mappings().first() + ).mappings().all() - if not document_row: + if not document_rows: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在") - collection = get_chroma().get_or_create_collection(dataset["collection_name"]) - raw = collection.get(where={"document_id": DocumentId}, include=[]) - ids = raw.get("ids") or [] - if ids: - collection.delete(ids=ids) + deletable_rows = [] + skipped = [] + for row in document_rows: + if row.get("indexing_status") not in self._DELETABLE_DOCUMENT_STATUSES: + skipped.append( + { + "id": int(row["id"]), + "name": str(row.get("original_name") or ""), + "reason": "文档仍在处理中,暂不允许删除", + } + ) + continue + deletable_rows.append(row) - self._delete_oss_object(document_row.get("minio_path")) + if not deletable_rows: + return RagDatasetBatchDeleteResultVO( + result="success", + requestedCount=len(normalized_ids), + deletedCount=0, + skippedCount=len(skipped), + deletedIds=[], + skipped=skipped, + ) + + collection = get_chroma().get_or_create_collection(dataset["collection_name"]) + chunk_ids: list[str] = [] + for document_row in deletable_rows: + raw = collection.get(where={"document_id": int(document_row["id"])}, include=[]) + ids = raw.get("ids") or [] + if ids: + chunk_ids.extend(ids) + if chunk_ids: + collection.delete(ids=chunk_ids) + + for document_row in deletable_rows: + self._delete_oss_object(document_row.get("minio_path")) + + existing_ids = [int(row["id"]) for row in deletable_rows] async with GetAsyncSession() as session: await session.execute( @@ -1111,10 +1114,10 @@ class RagDatasetServiceImpl(IRagDatasetService): UPDATE rag_document SET deleted_at = NOW(), updated_at = NOW() - WHERE id = :document_id + WHERE id = ANY(:document_ids) """ ), - {"document_id": DocumentId}, + {"document_ids": existing_ids}, ) await session.execute( text( @@ -1133,7 +1136,14 @@ class RagDatasetServiceImpl(IRagDatasetService): {"dataset_id": DatasetId}, ) - return RagOperationResultVO(result="success") + return RagDatasetBatchDeleteResultVO( + result="success", + requestedCount=len(normalized_ids), + deletedCount=len(existing_ids), + skippedCount=len(skipped), + deletedIds=existing_ids, + skipped=skipped, + ) async def RetrieveDataset( self, @@ -1290,8 +1300,8 @@ class RagDatasetServiceImpl(IRagDatasetService): raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在") suffix = os.path.splitext(FileName)[1].lower() - if suffix not in {".pdf", ".docx", ".txt", ".md"}: - raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD") + if suffix not in {".pdf", ".docx", ".txt", ".md", ".json"}: + raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD、JSON") content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream" object_key = current.get("minio_path") or f"rag/{DatasetId}/{datetime.now().strftime('%Y/%m/%d')}/{uuid.uuid4().hex}_{FileName}" @@ -1328,81 +1338,23 @@ class RagDatasetServiceImpl(IRagDatasetService): "file_size": len(Content), }, ) - - try: - page_texts = self._extract_page_texts(FileName=FileName, Content=Content) - processed = self._build_chunks( - file_name=FileName, - page_texts=page_texts, + asyncio.create_task( + self._run_document_indexing_task( dataset=dataset, - process_config=ProcessConfig or {}, + dataset_id=DatasetId, document_id=DocumentId, + file_name=FileName, + content=Content, + process_config=ProcessConfig or {}, ) - if not processed: - raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文档未提取到可入库文本") - - embeddings = await self._embed_texts([item["text"] for item in processed], dataset.get("embedding_model") or "") - collection.add( - ids=[item["id"] for item in processed], - documents=[item["text"] for item in processed], - embeddings=embeddings, - metadatas=[item["metadata"] for item in processed], - ) - - async with GetAsyncSession() as session: - await session.execute( - text( - """ - UPDATE rag_document - SET chunk_count = :chunk_count, - indexing_status = 'completed', - indexing_error = NULL, - indexing_started_at = COALESCE(indexing_started_at, NOW()), - indexing_completed_at = NOW(), - updated_at = NOW() - WHERE id = :document_id - """ - ), - {"chunk_count": len(processed), "document_id": DocumentId}, - ) - await session.execute( - text( - """ - UPDATE rag_dataset - SET document_count = ( - SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL - ), - total_chunks = COALESCE(( - SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL - ), 0), - updated_at = NOW() - WHERE id = :dataset_id - """ - ), - {"dataset_id": DatasetId}, - ) - except Exception as exc: - async with GetAsyncSession() as session: - await session.execute( - text( - """ - UPDATE rag_document - SET indexing_status = 'error', - indexing_error = :error, - updated_at = NOW() - WHERE id = :document_id - """ - ), - {"document_id": DocumentId, "error": str(exc)[:2000]}, - ) - raise + ) return RagDatasetUploadDocumentVO( document={ "id": str(DocumentId), "name": FileName, - "indexing_status": "completed", - "word_count": len(processed), + "indexing_status": "indexing", + "word_count": 0, "hit_count": 0, "enabled": True, }, @@ -1460,6 +1412,8 @@ class RagDatasetServiceImpl(IRagDatasetService): return _extract_page_texts_from_pdf(Path(temp_path)) if suffix == ".docx": return _extract_page_texts_from_docx(Path(temp_path)) + if suffix == ".json": + return self._extract_page_texts_from_json(Content) text_value = Content.decode("utf-8", errors="ignore").strip() return [(1, text_value)] if text_value else [] @@ -1469,6 +1423,48 @@ class RagDatasetServiceImpl(IRagDatasetService): except OSError: pass + def _extract_page_texts_from_json(self, content: bytes) -> list[tuple[int, str]]: + text_value = content.decode("utf-8", errors="ignore").strip() + if not text_value: + return [] + + try: + payload = json.loads(text_value) + except json.JSONDecodeError: + return [(1, text_value)] + + flattened = self._json_to_text(payload) + flattened = flattened.strip() + return [(1, flattened)] if flattened else [] + + def _json_to_text(self, value: object, prefix: str = "") -> str: + if value is None: + return "" + if isinstance(value, str): + return f"{prefix}: {value}" if prefix else value + if isinstance(value, (int, float, bool)): + return f"{prefix}: {value}" if prefix else str(value) + if isinstance(value, list): + parts: list[str] = [] + for index, item in enumerate(value, start=1): + item_prefix = f"{prefix}[{index}]" if prefix else f"item_{index}" + chunk = self._json_to_text(item, item_prefix) + if chunk: + parts.append(chunk) + return "\n".join(parts) + if isinstance(value, dict): + parts = [] + for key, item in value.items(): + item_prefix = f"{prefix}.{key}" if prefix else str(key) + chunk = self._json_to_text(item, item_prefix) + if chunk: + parts.append(chunk) + return "\n".join(parts) + rendered = str(value).strip() + if not rendered: + return "" + return f"{prefix}: {rendered}" if prefix else rendered + def _build_chunks(self, *, file_name: str, page_texts: list[tuple[int, str]], dataset: dict, process_config: dict, document_id: int) -> list[dict]: rules = ((process_config or {}).get("process_rule") or {}).get("rules") or {} segmentation = rules.get("segmentation") or {} @@ -1510,26 +1506,165 @@ class RagDatasetServiceImpl(IRagDatasetService): embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}/embeddings" embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"] embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4" + batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10)) if not embed_url or not embed_key: raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务") + embeddings: list[list[float]] = [] async with httpx.AsyncClient(timeout=120.0) as client: - response = await client.post( - embed_url, - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {embed_key}", - }, - json={"model": embed_model, "input": texts}, - ) - response.raise_for_status() - payload = response.json() - rows = payload.get("data") or [] - embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")] + for start in range(0, len(texts), batch_size): + batch_texts = texts[start:start + batch_size] + try: + response = await client.post( + embed_url, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {embed_key}", + }, + json={"model": embed_model, "input": batch_texts}, + ) + response.raise_for_status() + except httpx.HTTPStatusError as exc: + error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}" + raise LeauditException( + StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, + f"向量化服务调用失败: {error_message[:300]}", + ) from exc + + payload = response.json() + rows = payload.get("data") or [] + batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")] + if len(batch_embeddings) != len(batch_texts): + raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常") + embeddings.extend(batch_embeddings) + if len(embeddings) != len(texts): raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常") return embeddings + async def _run_document_indexing_task( + self, + *, + dataset: dict, + dataset_id: int, + document_id: int, + file_name: str, + content: bytes, + process_config: dict, + ) -> None: + try: + await self._update_document_processing_state(document_id=document_id, status="parsing") + page_texts = self._extract_page_texts(FileName=file_name, Content=content) + + await self._update_document_processing_state(document_id=document_id, status="cleaning") + processed = self._build_chunks( + file_name=file_name, + page_texts=page_texts, + dataset=dataset, + process_config=process_config, + document_id=document_id, + ) + if not processed: + raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文档未提取到可入库文本") + + await self._update_document_processing_state( + document_id=document_id, + status="splitting", + chunk_count=len(processed), + ) + embeddings = await self._embed_texts([item["text"] for item in processed], dataset.get("embedding_model") or "") + + await self._update_document_processing_state( + document_id=document_id, + status="indexing", + chunk_count=len(processed), + ) + collection = get_chroma().get_or_create_collection(dataset["collection_name"]) + collection.add( + ids=[item["id"] for item in processed], + documents=[item["text"] for item in processed], + embeddings=embeddings, + metadatas=[item["metadata"] for item in processed], + ) + + async with GetAsyncSession() as session: + await session.execute( + text( + """ + UPDATE rag_document + SET chunk_count = :chunk_count, + indexing_status = 'completed', + indexing_error = NULL, + indexing_started_at = COALESCE(indexing_started_at, NOW()), + indexing_completed_at = NOW(), + updated_at = NOW() + WHERE id = :document_id + """ + ), + {"chunk_count": len(processed), "document_id": document_id}, + ) + await self._sync_dataset_counts(dataset_id) + except Exception as exc: + async with GetAsyncSession() as session: + await session.execute( + text( + """ + UPDATE rag_document + SET indexing_status = 'error', + indexing_error = :error, + updated_at = NOW() + WHERE id = :document_id + """ + ), + {"document_id": document_id, "error": str(exc)[:2000]}, + ) + + async def _update_document_processing_state(self, *, document_id: int, status: str, chunk_count: int | None = None) -> None: + fields = [ + "indexing_status = :status", + "indexing_error = NULL", + "indexing_started_at = COALESCE(indexing_started_at, NOW())", + "updated_at = NOW()", + ] + params: dict[str, object] = { + "document_id": document_id, + "status": status, + } + if chunk_count is not None: + fields.append("chunk_count = :chunk_count") + params["chunk_count"] = chunk_count + + async with GetAsyncSession() as session: + await session.execute( + text( + f""" + UPDATE rag_document + SET {", ".join(fields)} + WHERE id = :document_id + """ + ), + params, + ) + + async def _sync_dataset_counts(self, dataset_id: int) -> None: + async with GetAsyncSession() as session: + await session.execute( + text( + """ + UPDATE rag_dataset + SET document_count = ( + SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL + ), + total_chunks = COALESCE(( + SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL + ), 0), + updated_at = NOW() + WHERE id = :dataset_id + """ + ), + {"dataset_id": dataset_id}, + ) + def _preprocess_text(self, text_value: str, *, remove_spaces: bool, remove_urls: bool) -> str: result = text_value or "" if remove_urls: diff --git a/fastapi_modules/fastapi_leaudit/services/ragDatasetService.py b/fastapi_modules/fastapi_leaudit/services/ragDatasetService.py index 09053e5..451b580 100644 --- a/fastapi_modules/fastapi_leaudit/services/ragDatasetService.py +++ b/fastapi_modules/fastapi_leaudit/services/ragDatasetService.py @@ -4,6 +4,7 @@ from abc import ABC, abstractmethod from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import RagDatasetUpdateDTO from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import ( + RagDatasetBatchDeleteResultVO, RagDatasetDetailVO, RagDatasetDocumentItemVO, RagDatasetDocumentPageVO, @@ -123,6 +124,16 @@ class IRagDatasetService(ABC): DocumentId: int, ) -> RagOperationResultVO: ... + @abstractmethod + async def BatchDeleteDatasetDocuments( + self, + CurrentUserId: int, + UserArea: str | None, + UserRole: str | None, + DatasetId: int, + DocumentIds: list[int], + ) -> RagDatasetBatchDeleteResultVO: ... + @abstractmethod async def RetrieveDataset( self,