feat: improve rag dataset document management
This commit is contained in:
@@ -16,7 +16,10 @@ from fastapi_modules.fastapi_leaudit.domian.Dto.ragChatDto import (
|
|||||||
RagChatSendMessageDTO,
|
RagChatSendMessageDTO,
|
||||||
RagMessageFeedbackDTO,
|
RagMessageFeedbackDTO,
|
||||||
)
|
)
|
||||||
from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import RagDatasetUpdateDTO
|
from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import (
|
||||||
|
RagDatasetBatchDocumentDeleteDTO,
|
||||||
|
RagDatasetUpdateDTO,
|
||||||
|
)
|
||||||
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import (
|
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import (
|
||||||
RagAppParametersVO,
|
RagAppParametersVO,
|
||||||
RagChatAppListVO,
|
RagChatAppListVO,
|
||||||
@@ -27,6 +30,7 @@ from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import (
|
|||||||
RagOperationResultVO,
|
RagOperationResultVO,
|
||||||
)
|
)
|
||||||
from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import (
|
from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import (
|
||||||
|
RagDatasetBatchDeleteResultVO,
|
||||||
RagDatasetDetailVO,
|
RagDatasetDetailVO,
|
||||||
RagDatasetDocumentItemVO,
|
RagDatasetDocumentItemVO,
|
||||||
RagDatasetDocumentPageVO,
|
RagDatasetDocumentPageVO,
|
||||||
@@ -347,6 +351,23 @@ class RagChatController(BaseController):
|
|||||||
)
|
)
|
||||||
return Result.success(data=result)
|
return Result.success(data=result)
|
||||||
|
|
||||||
|
@self.router.post("/datasets/{DatasetId}/documents/batch-delete", response_model=Result[RagDatasetBatchDeleteResultVO])
|
||||||
|
async def BatchDeleteDatasetDocuments(
|
||||||
|
DatasetId: int,
|
||||||
|
Body: RagDatasetBatchDocumentDeleteDTO,
|
||||||
|
payload: dict[str, Any] = Depends(verify_access_token),
|
||||||
|
):
|
||||||
|
if not await self._check_permission(int(payload["user_id"]), [self._PERMISSIONS["dataset_delete"]]):
|
||||||
|
return JSONResponse(status_code=403, content={"code": 403, "msg": "当前用户没有删除知识库文档权限", "data": None})
|
||||||
|
result = await self.RagDatasetService.BatchDeleteDatasetDocuments(
|
||||||
|
CurrentUserId=int(payload["user_id"]),
|
||||||
|
UserArea=payload.get("area"),
|
||||||
|
UserRole=payload.get("user_role"),
|
||||||
|
DatasetId=DatasetId,
|
||||||
|
DocumentIds=Body.document_ids,
|
||||||
|
)
|
||||||
|
return Result.success(data=result)
|
||||||
|
|
||||||
@self.router.post("/datasets/{DatasetId}/retrieve", response_model=Result[RagDatasetRetrieveResponseVO])
|
@self.router.post("/datasets/{DatasetId}/retrieve", response_model=Result[RagDatasetRetrieveResponseVO])
|
||||||
async def RetrieveDataset(
|
async def RetrieveDataset(
|
||||||
DatasetId: int,
|
DatasetId: int,
|
||||||
|
|||||||
@@ -4,3 +4,7 @@ from pydantic import BaseModel, Field
|
|||||||
class RagDatasetUpdateDTO(BaseModel):
|
class RagDatasetUpdateDTO(BaseModel):
|
||||||
name: str | None = Field(None, min_length=1, max_length=255)
|
name: str | None = Field(None, min_length=1, max_length=255)
|
||||||
retrieval_model: dict | None = Field(None)
|
retrieval_model: dict | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
|
class RagDatasetBatchDocumentDeleteDTO(BaseModel):
|
||||||
|
document_ids: list[int] = Field(..., min_length=1)
|
||||||
|
|||||||
@@ -74,6 +74,21 @@ class RagDatasetUploadDocumentVO(BaseModel):
|
|||||||
batch: str = Field("")
|
batch: str = Field("")
|
||||||
|
|
||||||
|
|
||||||
|
class RagDatasetBatchDeleteFailedItemVO(BaseModel):
|
||||||
|
id: int = Field(...)
|
||||||
|
name: str = Field("")
|
||||||
|
reason: str = Field("")
|
||||||
|
|
||||||
|
|
||||||
|
class RagDatasetBatchDeleteResultVO(BaseModel):
|
||||||
|
result: str = Field("success")
|
||||||
|
requestedCount: int = Field(0)
|
||||||
|
deletedCount: int = Field(0)
|
||||||
|
skippedCount: int = Field(0)
|
||||||
|
deletedIds: list[int] = Field(default_factory=list)
|
||||||
|
skipped: list[RagDatasetBatchDeleteFailedItemVO] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class RagDatasetSegmentItemVO(BaseModel):
|
class RagDatasetSegmentItemVO(BaseModel):
|
||||||
id: str = Field(...)
|
id: str = Field(...)
|
||||||
position: int = Field(0)
|
position: int = Field(0)
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
import mimetypes
|
import mimetypes
|
||||||
import os
|
import os
|
||||||
@@ -19,6 +20,7 @@ from fastapi_common.fastapi_common_web.exception.LeauditException import Leaudit
|
|||||||
|
|
||||||
from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import RagDatasetUpdateDTO
|
from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import RagDatasetUpdateDTO
|
||||||
from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import (
|
from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import (
|
||||||
|
RagDatasetBatchDeleteResultVO,
|
||||||
RagDatasetDetailVO,
|
RagDatasetDetailVO,
|
||||||
RagDatasetDocumentItemVO,
|
RagDatasetDocumentItemVO,
|
||||||
RagDatasetDocumentPageVO,
|
RagDatasetDocumentPageVO,
|
||||||
@@ -39,6 +41,8 @@ from fastapi_modules.fastapi_leaudit.services.ragDatasetService import IRagDatas
|
|||||||
|
|
||||||
|
|
||||||
class RagDatasetServiceImpl(IRagDatasetService):
|
class RagDatasetServiceImpl(IRagDatasetService):
|
||||||
|
_ACTIVE_INDEXING_STATUSES = {"waiting", "parsing", "cleaning", "splitting", "indexing"}
|
||||||
|
_DELETABLE_DOCUMENT_STATUSES = {"completed", "error", "paused"}
|
||||||
_APP_LINK_SQL = """
|
_APP_LINK_SQL = """
|
||||||
LEFT JOIN (
|
LEFT JOIN (
|
||||||
SELECT DISTINCT ON (dataset_id)
|
SELECT DISTINCT ON (dataset_id)
|
||||||
@@ -735,8 +739,8 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
|||||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
|
||||||
|
|
||||||
suffix = os.path.splitext(FileName)[1].lower()
|
suffix = os.path.splitext(FileName)[1].lower()
|
||||||
if suffix not in {".pdf", ".docx", ".txt", ".md"}:
|
if suffix not in {".pdf", ".docx", ".txt", ".md", ".json"}:
|
||||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD")
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD、JSON")
|
||||||
|
|
||||||
object_key = f"rag/{DatasetId}/{datetime.now().strftime('%Y/%m/%d')}/{uuid.uuid4().hex}_{FileName}"
|
object_key = f"rag/{DatasetId}/{datetime.now().strftime('%Y/%m/%d')}/{uuid.uuid4().hex}_{FileName}"
|
||||||
content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream"
|
content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream"
|
||||||
@@ -771,78 +775,23 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
|||||||
).mappings().first()
|
).mappings().first()
|
||||||
document_id = int(inserted["id"])
|
document_id = int(inserted["id"])
|
||||||
|
|
||||||
try:
|
asyncio.create_task(
|
||||||
page_texts = self._extract_page_texts(FileName=FileName, Content=Content)
|
self._run_document_indexing_task(
|
||||||
processed = self._build_chunks(
|
|
||||||
file_name=FileName,
|
|
||||||
page_texts=page_texts,
|
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
process_config=ProcessConfig or {},
|
dataset_id=DatasetId,
|
||||||
document_id=document_id,
|
document_id=document_id,
|
||||||
|
file_name=FileName,
|
||||||
|
content=Content,
|
||||||
|
process_config=ProcessConfig or {},
|
||||||
)
|
)
|
||||||
if not processed:
|
)
|
||||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文档未提取到可入库文本")
|
|
||||||
|
|
||||||
embeddings = await self._embed_texts([item["text"] for item in processed], dataset.get("embedding_model") or "")
|
|
||||||
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
|
||||||
collection.add(
|
|
||||||
ids=[item["id"] for item in processed],
|
|
||||||
documents=[item["text"] for item in processed],
|
|
||||||
embeddings=embeddings,
|
|
||||||
metadatas=[item["metadata"] for item in processed],
|
|
||||||
)
|
|
||||||
|
|
||||||
async with GetAsyncSession() as session:
|
|
||||||
await session.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
UPDATE rag_document
|
|
||||||
SET chunk_count = :chunk_count,
|
|
||||||
indexing_status = 'completed',
|
|
||||||
indexing_error = NULL,
|
|
||||||
indexing_started_at = COALESCE(indexing_started_at, NOW()),
|
|
||||||
indexing_completed_at = NOW()
|
|
||||||
WHERE id = :document_id
|
|
||||||
"""
|
|
||||||
),
|
|
||||||
{"chunk_count": len(processed), "document_id": document_id},
|
|
||||||
)
|
|
||||||
await session.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
UPDATE rag_dataset
|
|
||||||
SET document_count = (
|
|
||||||
SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
|
|
||||||
),
|
|
||||||
total_chunks = COALESCE((
|
|
||||||
SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
|
|
||||||
), 0)
|
|
||||||
WHERE id = :dataset_id
|
|
||||||
"""
|
|
||||||
),
|
|
||||||
{"dataset_id": DatasetId},
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
async with GetAsyncSession() as session:
|
|
||||||
await session.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
UPDATE rag_document
|
|
||||||
SET indexing_status = 'error',
|
|
||||||
indexing_error = :error
|
|
||||||
WHERE id = :document_id
|
|
||||||
"""
|
|
||||||
),
|
|
||||||
{"document_id": document_id, "error": str(exc)[:2000]},
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
return RagDatasetUploadDocumentVO(
|
return RagDatasetUploadDocumentVO(
|
||||||
document={
|
document={
|
||||||
"id": str(document_id),
|
"id": str(document_id),
|
||||||
"name": FileName,
|
"name": FileName,
|
||||||
"indexing_status": "completed",
|
"indexing_status": "indexing",
|
||||||
"word_count": len(processed),
|
"word_count": 0,
|
||||||
"hit_count": 0,
|
"hit_count": 0,
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
},
|
},
|
||||||
@@ -1013,7 +962,7 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
|||||||
DatasetId: int,
|
DatasetId: int,
|
||||||
DocumentId: int,
|
DocumentId: int,
|
||||||
SegmentId: str,
|
SegmentId: str,
|
||||||
) -> RagOperationResultVO:
|
) -> RagDatasetBatchDeleteResultVO:
|
||||||
dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
|
dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
|
||||||
@@ -1068,6 +1017,26 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
|||||||
UserRole: str | None,
|
UserRole: str | None,
|
||||||
DatasetId: int,
|
DatasetId: int,
|
||||||
DocumentId: int,
|
DocumentId: int,
|
||||||
|
) -> RagOperationResultVO:
|
||||||
|
result = await self.BatchDeleteDatasetDocuments(
|
||||||
|
CurrentUserId=CurrentUserId,
|
||||||
|
UserArea=UserArea,
|
||||||
|
UserRole=UserRole,
|
||||||
|
DatasetId=DatasetId,
|
||||||
|
DocumentIds=[DocumentId],
|
||||||
|
)
|
||||||
|
if result.deletedCount <= 0:
|
||||||
|
reason = result.skipped[0].reason if result.skipped else "文档删除失败"
|
||||||
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, reason)
|
||||||
|
return RagOperationResultVO(result="success")
|
||||||
|
|
||||||
|
async def BatchDeleteDatasetDocuments(
|
||||||
|
self,
|
||||||
|
CurrentUserId: int,
|
||||||
|
UserArea: str | None,
|
||||||
|
UserRole: str | None,
|
||||||
|
DatasetId: int,
|
||||||
|
DocumentIds: list[int],
|
||||||
) -> RagOperationResultVO:
|
) -> RagOperationResultVO:
|
||||||
dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
|
dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
|
||||||
if not dataset:
|
if not dataset:
|
||||||
@@ -1076,33 +1045,67 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
|||||||
if UserRole not in ("provincial_admin", "admin", "super_admin"):
|
if UserRole not in ("provincial_admin", "admin", "super_admin"):
|
||||||
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有删除知识库文档权限")
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有删除知识库文档权限")
|
||||||
|
|
||||||
|
normalized_ids = [int(item) for item in DocumentIds if item]
|
||||||
|
if not normalized_ids:
|
||||||
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "缺少待删除文档")
|
||||||
|
|
||||||
async with GetAsyncSession() as session:
|
async with GetAsyncSession() as session:
|
||||||
document_row = (
|
document_rows = (
|
||||||
await session.execute(
|
await session.execute(
|
||||||
text(
|
text(
|
||||||
"""
|
"""
|
||||||
SELECT id, dataset_id, minio_path
|
SELECT id, dataset_id, minio_path, original_name, indexing_status
|
||||||
FROM rag_document
|
FROM rag_document
|
||||||
WHERE id = :document_id
|
WHERE id = ANY(:document_ids)
|
||||||
AND dataset_id = :dataset_id
|
AND dataset_id = :dataset_id
|
||||||
AND deleted_at IS NULL
|
AND deleted_at IS NULL
|
||||||
LIMIT 1
|
|
||||||
"""
|
"""
|
||||||
),
|
),
|
||||||
{"document_id": DocumentId, "dataset_id": DatasetId},
|
{"document_ids": normalized_ids, "dataset_id": DatasetId},
|
||||||
)
|
)
|
||||||
).mappings().first()
|
).mappings().all()
|
||||||
|
|
||||||
if not document_row:
|
if not document_rows:
|
||||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在")
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在")
|
||||||
|
|
||||||
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
deletable_rows = []
|
||||||
raw = collection.get(where={"document_id": DocumentId}, include=[])
|
skipped = []
|
||||||
ids = raw.get("ids") or []
|
for row in document_rows:
|
||||||
if ids:
|
if row.get("indexing_status") not in self._DELETABLE_DOCUMENT_STATUSES:
|
||||||
collection.delete(ids=ids)
|
skipped.append(
|
||||||
|
{
|
||||||
|
"id": int(row["id"]),
|
||||||
|
"name": str(row.get("original_name") or ""),
|
||||||
|
"reason": "文档仍在处理中,暂不允许删除",
|
||||||
|
}
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
deletable_rows.append(row)
|
||||||
|
|
||||||
self._delete_oss_object(document_row.get("minio_path"))
|
if not deletable_rows:
|
||||||
|
return RagDatasetBatchDeleteResultVO(
|
||||||
|
result="success",
|
||||||
|
requestedCount=len(normalized_ids),
|
||||||
|
deletedCount=0,
|
||||||
|
skippedCount=len(skipped),
|
||||||
|
deletedIds=[],
|
||||||
|
skipped=skipped,
|
||||||
|
)
|
||||||
|
|
||||||
|
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
||||||
|
chunk_ids: list[str] = []
|
||||||
|
for document_row in deletable_rows:
|
||||||
|
raw = collection.get(where={"document_id": int(document_row["id"])}, include=[])
|
||||||
|
ids = raw.get("ids") or []
|
||||||
|
if ids:
|
||||||
|
chunk_ids.extend(ids)
|
||||||
|
if chunk_ids:
|
||||||
|
collection.delete(ids=chunk_ids)
|
||||||
|
|
||||||
|
for document_row in deletable_rows:
|
||||||
|
self._delete_oss_object(document_row.get("minio_path"))
|
||||||
|
|
||||||
|
existing_ids = [int(row["id"]) for row in deletable_rows]
|
||||||
|
|
||||||
async with GetAsyncSession() as session:
|
async with GetAsyncSession() as session:
|
||||||
await session.execute(
|
await session.execute(
|
||||||
@@ -1111,10 +1114,10 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
|||||||
UPDATE rag_document
|
UPDATE rag_document
|
||||||
SET deleted_at = NOW(),
|
SET deleted_at = NOW(),
|
||||||
updated_at = NOW()
|
updated_at = NOW()
|
||||||
WHERE id = :document_id
|
WHERE id = ANY(:document_ids)
|
||||||
"""
|
"""
|
||||||
),
|
),
|
||||||
{"document_id": DocumentId},
|
{"document_ids": existing_ids},
|
||||||
)
|
)
|
||||||
await session.execute(
|
await session.execute(
|
||||||
text(
|
text(
|
||||||
@@ -1133,7 +1136,14 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
|||||||
{"dataset_id": DatasetId},
|
{"dataset_id": DatasetId},
|
||||||
)
|
)
|
||||||
|
|
||||||
return RagOperationResultVO(result="success")
|
return RagDatasetBatchDeleteResultVO(
|
||||||
|
result="success",
|
||||||
|
requestedCount=len(normalized_ids),
|
||||||
|
deletedCount=len(existing_ids),
|
||||||
|
skippedCount=len(skipped),
|
||||||
|
deletedIds=existing_ids,
|
||||||
|
skipped=skipped,
|
||||||
|
)
|
||||||
|
|
||||||
async def RetrieveDataset(
|
async def RetrieveDataset(
|
||||||
self,
|
self,
|
||||||
@@ -1290,8 +1300,8 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
|||||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在")
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在")
|
||||||
|
|
||||||
suffix = os.path.splitext(FileName)[1].lower()
|
suffix = os.path.splitext(FileName)[1].lower()
|
||||||
if suffix not in {".pdf", ".docx", ".txt", ".md"}:
|
if suffix not in {".pdf", ".docx", ".txt", ".md", ".json"}:
|
||||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD")
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD、JSON")
|
||||||
|
|
||||||
content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream"
|
content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream"
|
||||||
object_key = current.get("minio_path") or f"rag/{DatasetId}/{datetime.now().strftime('%Y/%m/%d')}/{uuid.uuid4().hex}_{FileName}"
|
object_key = current.get("minio_path") or f"rag/{DatasetId}/{datetime.now().strftime('%Y/%m/%d')}/{uuid.uuid4().hex}_{FileName}"
|
||||||
@@ -1328,81 +1338,23 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
|||||||
"file_size": len(Content),
|
"file_size": len(Content),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
asyncio.create_task(
|
||||||
try:
|
self._run_document_indexing_task(
|
||||||
page_texts = self._extract_page_texts(FileName=FileName, Content=Content)
|
|
||||||
processed = self._build_chunks(
|
|
||||||
file_name=FileName,
|
|
||||||
page_texts=page_texts,
|
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
process_config=ProcessConfig or {},
|
dataset_id=DatasetId,
|
||||||
document_id=DocumentId,
|
document_id=DocumentId,
|
||||||
|
file_name=FileName,
|
||||||
|
content=Content,
|
||||||
|
process_config=ProcessConfig or {},
|
||||||
)
|
)
|
||||||
if not processed:
|
)
|
||||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文档未提取到可入库文本")
|
|
||||||
|
|
||||||
embeddings = await self._embed_texts([item["text"] for item in processed], dataset.get("embedding_model") or "")
|
|
||||||
collection.add(
|
|
||||||
ids=[item["id"] for item in processed],
|
|
||||||
documents=[item["text"] for item in processed],
|
|
||||||
embeddings=embeddings,
|
|
||||||
metadatas=[item["metadata"] for item in processed],
|
|
||||||
)
|
|
||||||
|
|
||||||
async with GetAsyncSession() as session:
|
|
||||||
await session.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
UPDATE rag_document
|
|
||||||
SET chunk_count = :chunk_count,
|
|
||||||
indexing_status = 'completed',
|
|
||||||
indexing_error = NULL,
|
|
||||||
indexing_started_at = COALESCE(indexing_started_at, NOW()),
|
|
||||||
indexing_completed_at = NOW(),
|
|
||||||
updated_at = NOW()
|
|
||||||
WHERE id = :document_id
|
|
||||||
"""
|
|
||||||
),
|
|
||||||
{"chunk_count": len(processed), "document_id": DocumentId},
|
|
||||||
)
|
|
||||||
await session.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
UPDATE rag_dataset
|
|
||||||
SET document_count = (
|
|
||||||
SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
|
|
||||||
),
|
|
||||||
total_chunks = COALESCE((
|
|
||||||
SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
|
|
||||||
), 0),
|
|
||||||
updated_at = NOW()
|
|
||||||
WHERE id = :dataset_id
|
|
||||||
"""
|
|
||||||
),
|
|
||||||
{"dataset_id": DatasetId},
|
|
||||||
)
|
|
||||||
except Exception as exc:
|
|
||||||
async with GetAsyncSession() as session:
|
|
||||||
await session.execute(
|
|
||||||
text(
|
|
||||||
"""
|
|
||||||
UPDATE rag_document
|
|
||||||
SET indexing_status = 'error',
|
|
||||||
indexing_error = :error,
|
|
||||||
updated_at = NOW()
|
|
||||||
WHERE id = :document_id
|
|
||||||
"""
|
|
||||||
),
|
|
||||||
{"document_id": DocumentId, "error": str(exc)[:2000]},
|
|
||||||
)
|
|
||||||
raise
|
|
||||||
|
|
||||||
return RagDatasetUploadDocumentVO(
|
return RagDatasetUploadDocumentVO(
|
||||||
document={
|
document={
|
||||||
"id": str(DocumentId),
|
"id": str(DocumentId),
|
||||||
"name": FileName,
|
"name": FileName,
|
||||||
"indexing_status": "completed",
|
"indexing_status": "indexing",
|
||||||
"word_count": len(processed),
|
"word_count": 0,
|
||||||
"hit_count": 0,
|
"hit_count": 0,
|
||||||
"enabled": True,
|
"enabled": True,
|
||||||
},
|
},
|
||||||
@@ -1460,6 +1412,8 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
|||||||
return _extract_page_texts_from_pdf(Path(temp_path))
|
return _extract_page_texts_from_pdf(Path(temp_path))
|
||||||
if suffix == ".docx":
|
if suffix == ".docx":
|
||||||
return _extract_page_texts_from_docx(Path(temp_path))
|
return _extract_page_texts_from_docx(Path(temp_path))
|
||||||
|
if suffix == ".json":
|
||||||
|
return self._extract_page_texts_from_json(Content)
|
||||||
|
|
||||||
text_value = Content.decode("utf-8", errors="ignore").strip()
|
text_value = Content.decode("utf-8", errors="ignore").strip()
|
||||||
return [(1, text_value)] if text_value else []
|
return [(1, text_value)] if text_value else []
|
||||||
@@ -1469,6 +1423,48 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
|||||||
except OSError:
|
except OSError:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
def _extract_page_texts_from_json(self, content: bytes) -> list[tuple[int, str]]:
|
||||||
|
text_value = content.decode("utf-8", errors="ignore").strip()
|
||||||
|
if not text_value:
|
||||||
|
return []
|
||||||
|
|
||||||
|
try:
|
||||||
|
payload = json.loads(text_value)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
return [(1, text_value)]
|
||||||
|
|
||||||
|
flattened = self._json_to_text(payload)
|
||||||
|
flattened = flattened.strip()
|
||||||
|
return [(1, flattened)] if flattened else []
|
||||||
|
|
||||||
|
def _json_to_text(self, value: object, prefix: str = "") -> str:
|
||||||
|
if value is None:
|
||||||
|
return ""
|
||||||
|
if isinstance(value, str):
|
||||||
|
return f"{prefix}: {value}" if prefix else value
|
||||||
|
if isinstance(value, (int, float, bool)):
|
||||||
|
return f"{prefix}: {value}" if prefix else str(value)
|
||||||
|
if isinstance(value, list):
|
||||||
|
parts: list[str] = []
|
||||||
|
for index, item in enumerate(value, start=1):
|
||||||
|
item_prefix = f"{prefix}[{index}]" if prefix else f"item_{index}"
|
||||||
|
chunk = self._json_to_text(item, item_prefix)
|
||||||
|
if chunk:
|
||||||
|
parts.append(chunk)
|
||||||
|
return "\n".join(parts)
|
||||||
|
if isinstance(value, dict):
|
||||||
|
parts = []
|
||||||
|
for key, item in value.items():
|
||||||
|
item_prefix = f"{prefix}.{key}" if prefix else str(key)
|
||||||
|
chunk = self._json_to_text(item, item_prefix)
|
||||||
|
if chunk:
|
||||||
|
parts.append(chunk)
|
||||||
|
return "\n".join(parts)
|
||||||
|
rendered = str(value).strip()
|
||||||
|
if not rendered:
|
||||||
|
return ""
|
||||||
|
return f"{prefix}: {rendered}" if prefix else rendered
|
||||||
|
|
||||||
def _build_chunks(self, *, file_name: str, page_texts: list[tuple[int, str]], dataset: dict, process_config: dict, document_id: int) -> list[dict]:
|
def _build_chunks(self, *, file_name: str, page_texts: list[tuple[int, str]], dataset: dict, process_config: dict, document_id: int) -> list[dict]:
|
||||||
rules = ((process_config or {}).get("process_rule") or {}).get("rules") or {}
|
rules = ((process_config or {}).get("process_rule") or {}).get("rules") or {}
|
||||||
segmentation = rules.get("segmentation") or {}
|
segmentation = rules.get("segmentation") or {}
|
||||||
@@ -1510,26 +1506,165 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
|||||||
embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}/embeddings"
|
embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}/embeddings"
|
||||||
embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"]
|
embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"]
|
||||||
embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4"
|
embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4"
|
||||||
|
batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10))
|
||||||
if not embed_url or not embed_key:
|
if not embed_url or not embed_key:
|
||||||
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务")
|
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务")
|
||||||
|
|
||||||
|
embeddings: list[list[float]] = []
|
||||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||||
response = await client.post(
|
for start in range(0, len(texts), batch_size):
|
||||||
embed_url,
|
batch_texts = texts[start:start + batch_size]
|
||||||
headers={
|
try:
|
||||||
"Content-Type": "application/json",
|
response = await client.post(
|
||||||
"Authorization": f"Bearer {embed_key}",
|
embed_url,
|
||||||
},
|
headers={
|
||||||
json={"model": embed_model, "input": texts},
|
"Content-Type": "application/json",
|
||||||
)
|
"Authorization": f"Bearer {embed_key}",
|
||||||
response.raise_for_status()
|
},
|
||||||
payload = response.json()
|
json={"model": embed_model, "input": batch_texts},
|
||||||
rows = payload.get("data") or []
|
)
|
||||||
embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")]
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}"
|
||||||
|
raise LeauditException(
|
||||||
|
StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
f"向量化服务调用失败: {error_message[:300]}",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
payload = response.json()
|
||||||
|
rows = payload.get("data") or []
|
||||||
|
batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")]
|
||||||
|
if len(batch_embeddings) != len(batch_texts):
|
||||||
|
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
|
||||||
|
embeddings.extend(batch_embeddings)
|
||||||
|
|
||||||
if len(embeddings) != len(texts):
|
if len(embeddings) != len(texts):
|
||||||
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
|
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
|
||||||
return embeddings
|
return embeddings
|
||||||
|
|
||||||
|
async def _run_document_indexing_task(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
dataset: dict,
|
||||||
|
dataset_id: int,
|
||||||
|
document_id: int,
|
||||||
|
file_name: str,
|
||||||
|
content: bytes,
|
||||||
|
process_config: dict,
|
||||||
|
) -> None:
|
||||||
|
try:
|
||||||
|
await self._update_document_processing_state(document_id=document_id, status="parsing")
|
||||||
|
page_texts = self._extract_page_texts(FileName=file_name, Content=content)
|
||||||
|
|
||||||
|
await self._update_document_processing_state(document_id=document_id, status="cleaning")
|
||||||
|
processed = self._build_chunks(
|
||||||
|
file_name=file_name,
|
||||||
|
page_texts=page_texts,
|
||||||
|
dataset=dataset,
|
||||||
|
process_config=process_config,
|
||||||
|
document_id=document_id,
|
||||||
|
)
|
||||||
|
if not processed:
|
||||||
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文档未提取到可入库文本")
|
||||||
|
|
||||||
|
await self._update_document_processing_state(
|
||||||
|
document_id=document_id,
|
||||||
|
status="splitting",
|
||||||
|
chunk_count=len(processed),
|
||||||
|
)
|
||||||
|
embeddings = await self._embed_texts([item["text"] for item in processed], dataset.get("embedding_model") or "")
|
||||||
|
|
||||||
|
await self._update_document_processing_state(
|
||||||
|
document_id=document_id,
|
||||||
|
status="indexing",
|
||||||
|
chunk_count=len(processed),
|
||||||
|
)
|
||||||
|
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
||||||
|
collection.add(
|
||||||
|
ids=[item["id"] for item in processed],
|
||||||
|
documents=[item["text"] for item in processed],
|
||||||
|
embeddings=embeddings,
|
||||||
|
metadatas=[item["metadata"] for item in processed],
|
||||||
|
)
|
||||||
|
|
||||||
|
async with GetAsyncSession() as session:
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
UPDATE rag_document
|
||||||
|
SET chunk_count = :chunk_count,
|
||||||
|
indexing_status = 'completed',
|
||||||
|
indexing_error = NULL,
|
||||||
|
indexing_started_at = COALESCE(indexing_started_at, NOW()),
|
||||||
|
indexing_completed_at = NOW(),
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = :document_id
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"chunk_count": len(processed), "document_id": document_id},
|
||||||
|
)
|
||||||
|
await self._sync_dataset_counts(dataset_id)
|
||||||
|
except Exception as exc:
|
||||||
|
async with GetAsyncSession() as session:
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
UPDATE rag_document
|
||||||
|
SET indexing_status = 'error',
|
||||||
|
indexing_error = :error,
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = :document_id
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"document_id": document_id, "error": str(exc)[:2000]},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _update_document_processing_state(self, *, document_id: int, status: str, chunk_count: int | None = None) -> None:
|
||||||
|
fields = [
|
||||||
|
"indexing_status = :status",
|
||||||
|
"indexing_error = NULL",
|
||||||
|
"indexing_started_at = COALESCE(indexing_started_at, NOW())",
|
||||||
|
"updated_at = NOW()",
|
||||||
|
]
|
||||||
|
params: dict[str, object] = {
|
||||||
|
"document_id": document_id,
|
||||||
|
"status": status,
|
||||||
|
}
|
||||||
|
if chunk_count is not None:
|
||||||
|
fields.append("chunk_count = :chunk_count")
|
||||||
|
params["chunk_count"] = chunk_count
|
||||||
|
|
||||||
|
async with GetAsyncSession() as session:
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
f"""
|
||||||
|
UPDATE rag_document
|
||||||
|
SET {", ".join(fields)}
|
||||||
|
WHERE id = :document_id
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
params,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _sync_dataset_counts(self, dataset_id: int) -> None:
|
||||||
|
async with GetAsyncSession() as session:
|
||||||
|
await session.execute(
|
||||||
|
text(
|
||||||
|
"""
|
||||||
|
UPDATE rag_dataset
|
||||||
|
SET document_count = (
|
||||||
|
SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
|
||||||
|
),
|
||||||
|
total_chunks = COALESCE((
|
||||||
|
SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
|
||||||
|
), 0),
|
||||||
|
updated_at = NOW()
|
||||||
|
WHERE id = :dataset_id
|
||||||
|
"""
|
||||||
|
),
|
||||||
|
{"dataset_id": dataset_id},
|
||||||
|
)
|
||||||
|
|
||||||
def _preprocess_text(self, text_value: str, *, remove_spaces: bool, remove_urls: bool) -> str:
|
def _preprocess_text(self, text_value: str, *, remove_spaces: bool, remove_urls: bool) -> str:
|
||||||
result = text_value or ""
|
result = text_value or ""
|
||||||
if remove_urls:
|
if remove_urls:
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ from abc import ABC, abstractmethod
|
|||||||
|
|
||||||
from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import RagDatasetUpdateDTO
|
from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import RagDatasetUpdateDTO
|
||||||
from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import (
|
from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import (
|
||||||
|
RagDatasetBatchDeleteResultVO,
|
||||||
RagDatasetDetailVO,
|
RagDatasetDetailVO,
|
||||||
RagDatasetDocumentItemVO,
|
RagDatasetDocumentItemVO,
|
||||||
RagDatasetDocumentPageVO,
|
RagDatasetDocumentPageVO,
|
||||||
@@ -123,6 +124,16 @@ class IRagDatasetService(ABC):
|
|||||||
DocumentId: int,
|
DocumentId: int,
|
||||||
) -> RagOperationResultVO: ...
|
) -> RagOperationResultVO: ...
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def BatchDeleteDatasetDocuments(
|
||||||
|
self,
|
||||||
|
CurrentUserId: int,
|
||||||
|
UserArea: str | None,
|
||||||
|
UserRole: str | None,
|
||||||
|
DatasetId: int,
|
||||||
|
DocumentIds: list[int],
|
||||||
|
) -> RagDatasetBatchDeleteResultVO: ...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def RetrieveDataset(
|
async def RetrieveDataset(
|
||||||
self,
|
self,
|
||||||
|
|||||||
Reference in New Issue
Block a user