feat: improve rag dataset document management

This commit is contained in:
wren
2026-05-11 19:25:50 +08:00
parent 2aa5a6d1d6
commit 8206ed7d43
5 changed files with 351 additions and 165 deletions
@@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
import json
import mimetypes
import os
@@ -19,6 +20,7 @@ from fastapi_common.fastapi_common_web.exception.LeauditException import Leaudit
from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import RagDatasetUpdateDTO
from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import (
RagDatasetBatchDeleteResultVO,
RagDatasetDetailVO,
RagDatasetDocumentItemVO,
RagDatasetDocumentPageVO,
@@ -39,6 +41,8 @@ from fastapi_modules.fastapi_leaudit.services.ragDatasetService import IRagDatas
class RagDatasetServiceImpl(IRagDatasetService):
_ACTIVE_INDEXING_STATUSES = {"waiting", "parsing", "cleaning", "splitting", "indexing"}
_DELETABLE_DOCUMENT_STATUSES = {"completed", "error", "paused"}
_APP_LINK_SQL = """
LEFT JOIN (
SELECT DISTINCT ON (dataset_id)
@@ -735,8 +739,8 @@ class RagDatasetServiceImpl(IRagDatasetService):
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
suffix = os.path.splitext(FileName)[1].lower()
if suffix not in {".pdf", ".docx", ".txt", ".md"}:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD")
if suffix not in {".pdf", ".docx", ".txt", ".md", ".json"}:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD、JSON")
object_key = f"rag/{DatasetId}/{datetime.now().strftime('%Y/%m/%d')}/{uuid.uuid4().hex}_{FileName}"
content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream"
@@ -771,78 +775,23 @@ class RagDatasetServiceImpl(IRagDatasetService):
).mappings().first()
document_id = int(inserted["id"])
try:
page_texts = self._extract_page_texts(FileName=FileName, Content=Content)
processed = self._build_chunks(
file_name=FileName,
page_texts=page_texts,
asyncio.create_task(
self._run_document_indexing_task(
dataset=dataset,
process_config=ProcessConfig or {},
dataset_id=DatasetId,
document_id=document_id,
file_name=FileName,
content=Content,
process_config=ProcessConfig or {},
)
if not processed:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文档未提取到可入库文本")
embeddings = await self._embed_texts([item["text"] for item in processed], dataset.get("embedding_model") or "")
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
collection.add(
ids=[item["id"] for item in processed],
documents=[item["text"] for item in processed],
embeddings=embeddings,
metadatas=[item["metadata"] for item in processed],
)
async with GetAsyncSession() as session:
await session.execute(
text(
"""
UPDATE rag_document
SET chunk_count = :chunk_count,
indexing_status = 'completed',
indexing_error = NULL,
indexing_started_at = COALESCE(indexing_started_at, NOW()),
indexing_completed_at = NOW()
WHERE id = :document_id
"""
),
{"chunk_count": len(processed), "document_id": document_id},
)
await session.execute(
text(
"""
UPDATE rag_dataset
SET document_count = (
SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
),
total_chunks = COALESCE((
SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
), 0)
WHERE id = :dataset_id
"""
),
{"dataset_id": DatasetId},
)
except Exception as exc:
async with GetAsyncSession() as session:
await session.execute(
text(
"""
UPDATE rag_document
SET indexing_status = 'error',
indexing_error = :error
WHERE id = :document_id
"""
),
{"document_id": document_id, "error": str(exc)[:2000]},
)
raise
)
return RagDatasetUploadDocumentVO(
document={
"id": str(document_id),
"name": FileName,
"indexing_status": "completed",
"word_count": len(processed),
"indexing_status": "indexing",
"word_count": 0,
"hit_count": 0,
"enabled": True,
},
@@ -1013,7 +962,7 @@ class RagDatasetServiceImpl(IRagDatasetService):
DatasetId: int,
DocumentId: int,
SegmentId: str,
) -> RagOperationResultVO:
) -> RagDatasetBatchDeleteResultVO:
dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
if not dataset:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
@@ -1068,6 +1017,26 @@ class RagDatasetServiceImpl(IRagDatasetService):
UserRole: str | None,
DatasetId: int,
DocumentId: int,
) -> RagOperationResultVO:
result = await self.BatchDeleteDatasetDocuments(
CurrentUserId=CurrentUserId,
UserArea=UserArea,
UserRole=UserRole,
DatasetId=DatasetId,
DocumentIds=[DocumentId],
)
if result.deletedCount <= 0:
reason = result.skipped[0].reason if result.skipped else "文档删除失败"
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, reason)
return RagOperationResultVO(result="success")
async def BatchDeleteDatasetDocuments(
self,
CurrentUserId: int,
UserArea: str | None,
UserRole: str | None,
DatasetId: int,
DocumentIds: list[int],
) -> RagOperationResultVO:
dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
if not dataset:
@@ -1076,33 +1045,67 @@ class RagDatasetServiceImpl(IRagDatasetService):
if UserRole not in ("provincial_admin", "admin", "super_admin"):
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有删除知识库文档权限")
normalized_ids = [int(item) for item in DocumentIds if item]
if not normalized_ids:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "缺少待删除文档")
async with GetAsyncSession() as session:
document_row = (
document_rows = (
await session.execute(
text(
"""
SELECT id, dataset_id, minio_path
SELECT id, dataset_id, minio_path, original_name, indexing_status
FROM rag_document
WHERE id = :document_id
WHERE id = ANY(:document_ids)
AND dataset_id = :dataset_id
AND deleted_at IS NULL
LIMIT 1
"""
),
{"document_id": DocumentId, "dataset_id": DatasetId},
{"document_ids": normalized_ids, "dataset_id": DatasetId},
)
).mappings().first()
).mappings().all()
if not document_row:
if not document_rows:
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在")
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
raw = collection.get(where={"document_id": DocumentId}, include=[])
ids = raw.get("ids") or []
if ids:
collection.delete(ids=ids)
deletable_rows = []
skipped = []
for row in document_rows:
if row.get("indexing_status") not in self._DELETABLE_DOCUMENT_STATUSES:
skipped.append(
{
"id": int(row["id"]),
"name": str(row.get("original_name") or ""),
"reason": "文档仍在处理中,暂不允许删除",
}
)
continue
deletable_rows.append(row)
self._delete_oss_object(document_row.get("minio_path"))
if not deletable_rows:
return RagDatasetBatchDeleteResultVO(
result="success",
requestedCount=len(normalized_ids),
deletedCount=0,
skippedCount=len(skipped),
deletedIds=[],
skipped=skipped,
)
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
chunk_ids: list[str] = []
for document_row in deletable_rows:
raw = collection.get(where={"document_id": int(document_row["id"])}, include=[])
ids = raw.get("ids") or []
if ids:
chunk_ids.extend(ids)
if chunk_ids:
collection.delete(ids=chunk_ids)
for document_row in deletable_rows:
self._delete_oss_object(document_row.get("minio_path"))
existing_ids = [int(row["id"]) for row in deletable_rows]
async with GetAsyncSession() as session:
await session.execute(
@@ -1111,10 +1114,10 @@ class RagDatasetServiceImpl(IRagDatasetService):
UPDATE rag_document
SET deleted_at = NOW(),
updated_at = NOW()
WHERE id = :document_id
WHERE id = ANY(:document_ids)
"""
),
{"document_id": DocumentId},
{"document_ids": existing_ids},
)
await session.execute(
text(
@@ -1133,7 +1136,14 @@ class RagDatasetServiceImpl(IRagDatasetService):
{"dataset_id": DatasetId},
)
return RagOperationResultVO(result="success")
return RagDatasetBatchDeleteResultVO(
result="success",
requestedCount=len(normalized_ids),
deletedCount=len(existing_ids),
skippedCount=len(skipped),
deletedIds=existing_ids,
skipped=skipped,
)
async def RetrieveDataset(
self,
@@ -1290,8 +1300,8 @@ class RagDatasetServiceImpl(IRagDatasetService):
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在")
suffix = os.path.splitext(FileName)[1].lower()
if suffix not in {".pdf", ".docx", ".txt", ".md"}:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD")
if suffix not in {".pdf", ".docx", ".txt", ".md", ".json"}:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD、JSON")
content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream"
object_key = current.get("minio_path") or f"rag/{DatasetId}/{datetime.now().strftime('%Y/%m/%d')}/{uuid.uuid4().hex}_{FileName}"
@@ -1328,81 +1338,23 @@ class RagDatasetServiceImpl(IRagDatasetService):
"file_size": len(Content),
},
)
try:
page_texts = self._extract_page_texts(FileName=FileName, Content=Content)
processed = self._build_chunks(
file_name=FileName,
page_texts=page_texts,
asyncio.create_task(
self._run_document_indexing_task(
dataset=dataset,
process_config=ProcessConfig or {},
dataset_id=DatasetId,
document_id=DocumentId,
file_name=FileName,
content=Content,
process_config=ProcessConfig or {},
)
if not processed:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文档未提取到可入库文本")
embeddings = await self._embed_texts([item["text"] for item in processed], dataset.get("embedding_model") or "")
collection.add(
ids=[item["id"] for item in processed],
documents=[item["text"] for item in processed],
embeddings=embeddings,
metadatas=[item["metadata"] for item in processed],
)
async with GetAsyncSession() as session:
await session.execute(
text(
"""
UPDATE rag_document
SET chunk_count = :chunk_count,
indexing_status = 'completed',
indexing_error = NULL,
indexing_started_at = COALESCE(indexing_started_at, NOW()),
indexing_completed_at = NOW(),
updated_at = NOW()
WHERE id = :document_id
"""
),
{"chunk_count": len(processed), "document_id": DocumentId},
)
await session.execute(
text(
"""
UPDATE rag_dataset
SET document_count = (
SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
),
total_chunks = COALESCE((
SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
), 0),
updated_at = NOW()
WHERE id = :dataset_id
"""
),
{"dataset_id": DatasetId},
)
except Exception as exc:
async with GetAsyncSession() as session:
await session.execute(
text(
"""
UPDATE rag_document
SET indexing_status = 'error',
indexing_error = :error,
updated_at = NOW()
WHERE id = :document_id
"""
),
{"document_id": DocumentId, "error": str(exc)[:2000]},
)
raise
)
return RagDatasetUploadDocumentVO(
document={
"id": str(DocumentId),
"name": FileName,
"indexing_status": "completed",
"word_count": len(processed),
"indexing_status": "indexing",
"word_count": 0,
"hit_count": 0,
"enabled": True,
},
@@ -1460,6 +1412,8 @@ class RagDatasetServiceImpl(IRagDatasetService):
return _extract_page_texts_from_pdf(Path(temp_path))
if suffix == ".docx":
return _extract_page_texts_from_docx(Path(temp_path))
if suffix == ".json":
return self._extract_page_texts_from_json(Content)
text_value = Content.decode("utf-8", errors="ignore").strip()
return [(1, text_value)] if text_value else []
@@ -1469,6 +1423,48 @@ class RagDatasetServiceImpl(IRagDatasetService):
except OSError:
pass
def _extract_page_texts_from_json(self, content: bytes) -> list[tuple[int, str]]:
text_value = content.decode("utf-8", errors="ignore").strip()
if not text_value:
return []
try:
payload = json.loads(text_value)
except json.JSONDecodeError:
return [(1, text_value)]
flattened = self._json_to_text(payload)
flattened = flattened.strip()
return [(1, flattened)] if flattened else []
def _json_to_text(self, value: object, prefix: str = "") -> str:
if value is None:
return ""
if isinstance(value, str):
return f"{prefix}: {value}" if prefix else value
if isinstance(value, (int, float, bool)):
return f"{prefix}: {value}" if prefix else str(value)
if isinstance(value, list):
parts: list[str] = []
for index, item in enumerate(value, start=1):
item_prefix = f"{prefix}[{index}]" if prefix else f"item_{index}"
chunk = self._json_to_text(item, item_prefix)
if chunk:
parts.append(chunk)
return "\n".join(parts)
if isinstance(value, dict):
parts = []
for key, item in value.items():
item_prefix = f"{prefix}.{key}" if prefix else str(key)
chunk = self._json_to_text(item, item_prefix)
if chunk:
parts.append(chunk)
return "\n".join(parts)
rendered = str(value).strip()
if not rendered:
return ""
return f"{prefix}: {rendered}" if prefix else rendered
def _build_chunks(self, *, file_name: str, page_texts: list[tuple[int, str]], dataset: dict, process_config: dict, document_id: int) -> list[dict]:
rules = ((process_config or {}).get("process_rule") or {}).get("rules") or {}
segmentation = rules.get("segmentation") or {}
@@ -1510,26 +1506,165 @@ class RagDatasetServiceImpl(IRagDatasetService):
embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}/embeddings"
embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"]
embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4"
batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10))
if not embed_url or not embed_key:
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务")
embeddings: list[list[float]] = []
async with httpx.AsyncClient(timeout=120.0) as client:
response = await client.post(
embed_url,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {embed_key}",
},
json={"model": embed_model, "input": texts},
)
response.raise_for_status()
payload = response.json()
rows = payload.get("data") or []
embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")]
for start in range(0, len(texts), batch_size):
batch_texts = texts[start:start + batch_size]
try:
response = await client.post(
embed_url,
headers={
"Content-Type": "application/json",
"Authorization": f"Bearer {embed_key}",
},
json={"model": embed_model, "input": batch_texts},
)
response.raise_for_status()
except httpx.HTTPStatusError as exc:
error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}"
raise LeauditException(
StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR,
f"向量化服务调用失败: {error_message[:300]}",
) from exc
payload = response.json()
rows = payload.get("data") or []
batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")]
if len(batch_embeddings) != len(batch_texts):
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
embeddings.extend(batch_embeddings)
if len(embeddings) != len(texts):
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
return embeddings
async def _run_document_indexing_task(
self,
*,
dataset: dict,
dataset_id: int,
document_id: int,
file_name: str,
content: bytes,
process_config: dict,
) -> None:
try:
await self._update_document_processing_state(document_id=document_id, status="parsing")
page_texts = self._extract_page_texts(FileName=file_name, Content=content)
await self._update_document_processing_state(document_id=document_id, status="cleaning")
processed = self._build_chunks(
file_name=file_name,
page_texts=page_texts,
dataset=dataset,
process_config=process_config,
document_id=document_id,
)
if not processed:
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文档未提取到可入库文本")
await self._update_document_processing_state(
document_id=document_id,
status="splitting",
chunk_count=len(processed),
)
embeddings = await self._embed_texts([item["text"] for item in processed], dataset.get("embedding_model") or "")
await self._update_document_processing_state(
document_id=document_id,
status="indexing",
chunk_count=len(processed),
)
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
collection.add(
ids=[item["id"] for item in processed],
documents=[item["text"] for item in processed],
embeddings=embeddings,
metadatas=[item["metadata"] for item in processed],
)
async with GetAsyncSession() as session:
await session.execute(
text(
"""
UPDATE rag_document
SET chunk_count = :chunk_count,
indexing_status = 'completed',
indexing_error = NULL,
indexing_started_at = COALESCE(indexing_started_at, NOW()),
indexing_completed_at = NOW(),
updated_at = NOW()
WHERE id = :document_id
"""
),
{"chunk_count": len(processed), "document_id": document_id},
)
await self._sync_dataset_counts(dataset_id)
except Exception as exc:
async with GetAsyncSession() as session:
await session.execute(
text(
"""
UPDATE rag_document
SET indexing_status = 'error',
indexing_error = :error,
updated_at = NOW()
WHERE id = :document_id
"""
),
{"document_id": document_id, "error": str(exc)[:2000]},
)
async def _update_document_processing_state(self, *, document_id: int, status: str, chunk_count: int | None = None) -> None:
fields = [
"indexing_status = :status",
"indexing_error = NULL",
"indexing_started_at = COALESCE(indexing_started_at, NOW())",
"updated_at = NOW()",
]
params: dict[str, object] = {
"document_id": document_id,
"status": status,
}
if chunk_count is not None:
fields.append("chunk_count = :chunk_count")
params["chunk_count"] = chunk_count
async with GetAsyncSession() as session:
await session.execute(
text(
f"""
UPDATE rag_document
SET {", ".join(fields)}
WHERE id = :document_id
"""
),
params,
)
async def _sync_dataset_counts(self, dataset_id: int) -> None:
async with GetAsyncSession() as session:
await session.execute(
text(
"""
UPDATE rag_dataset
SET document_count = (
SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
),
total_chunks = COALESCE((
SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
), 0),
updated_at = NOW()
WHERE id = :dataset_id
"""
),
{"dataset_id": dataset_id},
)
def _preprocess_text(self, text_value: str, *, remove_spaces: bool, remove_urls: bool) -> str:
result = text_value or ""
if remove_urls: