feat: improve rag dataset document management
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import mimetypes
|
||||
import os
|
||||
@@ -19,6 +20,7 @@ from fastapi_common.fastapi_common_web.exception.LeauditException import Leaudit
|
||||
|
||||
from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import RagDatasetUpdateDTO
|
||||
from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import (
|
||||
RagDatasetBatchDeleteResultVO,
|
||||
RagDatasetDetailVO,
|
||||
RagDatasetDocumentItemVO,
|
||||
RagDatasetDocumentPageVO,
|
||||
@@ -39,6 +41,8 @@ from fastapi_modules.fastapi_leaudit.services.ragDatasetService import IRagDatas
|
||||
|
||||
|
||||
class RagDatasetServiceImpl(IRagDatasetService):
|
||||
_ACTIVE_INDEXING_STATUSES = {"waiting", "parsing", "cleaning", "splitting", "indexing"}
|
||||
_DELETABLE_DOCUMENT_STATUSES = {"completed", "error", "paused"}
|
||||
_APP_LINK_SQL = """
|
||||
LEFT JOIN (
|
||||
SELECT DISTINCT ON (dataset_id)
|
||||
@@ -735,8 +739,8 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
|
||||
|
||||
suffix = os.path.splitext(FileName)[1].lower()
|
||||
if suffix not in {".pdf", ".docx", ".txt", ".md"}:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD")
|
||||
if suffix not in {".pdf", ".docx", ".txt", ".md", ".json"}:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD、JSON")
|
||||
|
||||
object_key = f"rag/{DatasetId}/{datetime.now().strftime('%Y/%m/%d')}/{uuid.uuid4().hex}_{FileName}"
|
||||
content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream"
|
||||
@@ -771,78 +775,23 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
||||
).mappings().first()
|
||||
document_id = int(inserted["id"])
|
||||
|
||||
try:
|
||||
page_texts = self._extract_page_texts(FileName=FileName, Content=Content)
|
||||
processed = self._build_chunks(
|
||||
file_name=FileName,
|
||||
page_texts=page_texts,
|
||||
asyncio.create_task(
|
||||
self._run_document_indexing_task(
|
||||
dataset=dataset,
|
||||
process_config=ProcessConfig or {},
|
||||
dataset_id=DatasetId,
|
||||
document_id=document_id,
|
||||
file_name=FileName,
|
||||
content=Content,
|
||||
process_config=ProcessConfig or {},
|
||||
)
|
||||
if not processed:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文档未提取到可入库文本")
|
||||
|
||||
embeddings = await self._embed_texts([item["text"] for item in processed], dataset.get("embedding_model") or "")
|
||||
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
||||
collection.add(
|
||||
ids=[item["id"] for item in processed],
|
||||
documents=[item["text"] for item in processed],
|
||||
embeddings=embeddings,
|
||||
metadatas=[item["metadata"] for item in processed],
|
||||
)
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE rag_document
|
||||
SET chunk_count = :chunk_count,
|
||||
indexing_status = 'completed',
|
||||
indexing_error = NULL,
|
||||
indexing_started_at = COALESCE(indexing_started_at, NOW()),
|
||||
indexing_completed_at = NOW()
|
||||
WHERE id = :document_id
|
||||
"""
|
||||
),
|
||||
{"chunk_count": len(processed), "document_id": document_id},
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE rag_dataset
|
||||
SET document_count = (
|
||||
SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
|
||||
),
|
||||
total_chunks = COALESCE((
|
||||
SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
|
||||
), 0)
|
||||
WHERE id = :dataset_id
|
||||
"""
|
||||
),
|
||||
{"dataset_id": DatasetId},
|
||||
)
|
||||
except Exception as exc:
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE rag_document
|
||||
SET indexing_status = 'error',
|
||||
indexing_error = :error
|
||||
WHERE id = :document_id
|
||||
"""
|
||||
),
|
||||
{"document_id": document_id, "error": str(exc)[:2000]},
|
||||
)
|
||||
raise
|
||||
)
|
||||
|
||||
return RagDatasetUploadDocumentVO(
|
||||
document={
|
||||
"id": str(document_id),
|
||||
"name": FileName,
|
||||
"indexing_status": "completed",
|
||||
"word_count": len(processed),
|
||||
"indexing_status": "indexing",
|
||||
"word_count": 0,
|
||||
"hit_count": 0,
|
||||
"enabled": True,
|
||||
},
|
||||
@@ -1013,7 +962,7 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
||||
DatasetId: int,
|
||||
DocumentId: int,
|
||||
SegmentId: str,
|
||||
) -> RagOperationResultVO:
|
||||
) -> RagDatasetBatchDeleteResultVO:
|
||||
dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
|
||||
if not dataset:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在")
|
||||
@@ -1068,6 +1017,26 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
||||
UserRole: str | None,
|
||||
DatasetId: int,
|
||||
DocumentId: int,
|
||||
) -> RagOperationResultVO:
|
||||
result = await self.BatchDeleteDatasetDocuments(
|
||||
CurrentUserId=CurrentUserId,
|
||||
UserArea=UserArea,
|
||||
UserRole=UserRole,
|
||||
DatasetId=DatasetId,
|
||||
DocumentIds=[DocumentId],
|
||||
)
|
||||
if result.deletedCount <= 0:
|
||||
reason = result.skipped[0].reason if result.skipped else "文档删除失败"
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, reason)
|
||||
return RagOperationResultVO(result="success")
|
||||
|
||||
async def BatchDeleteDatasetDocuments(
|
||||
self,
|
||||
CurrentUserId: int,
|
||||
UserArea: str | None,
|
||||
UserRole: str | None,
|
||||
DatasetId: int,
|
||||
DocumentIds: list[int],
|
||||
) -> RagOperationResultVO:
|
||||
dataset = await self._get_visible_dataset(UserArea, UserRole, DatasetId)
|
||||
if not dataset:
|
||||
@@ -1076,33 +1045,67 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
||||
if UserRole not in ("provincial_admin", "admin", "super_admin"):
|
||||
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户没有删除知识库文档权限")
|
||||
|
||||
normalized_ids = [int(item) for item in DocumentIds if item]
|
||||
if not normalized_ids:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "缺少待删除文档")
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
document_row = (
|
||||
document_rows = (
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT id, dataset_id, minio_path
|
||||
SELECT id, dataset_id, minio_path, original_name, indexing_status
|
||||
FROM rag_document
|
||||
WHERE id = :document_id
|
||||
WHERE id = ANY(:document_ids)
|
||||
AND dataset_id = :dataset_id
|
||||
AND deleted_at IS NULL
|
||||
LIMIT 1
|
||||
"""
|
||||
),
|
||||
{"document_id": DocumentId, "dataset_id": DatasetId},
|
||||
{"document_ids": normalized_ids, "dataset_id": DatasetId},
|
||||
)
|
||||
).mappings().first()
|
||||
).mappings().all()
|
||||
|
||||
if not document_row:
|
||||
if not document_rows:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在")
|
||||
|
||||
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
||||
raw = collection.get(where={"document_id": DocumentId}, include=[])
|
||||
ids = raw.get("ids") or []
|
||||
if ids:
|
||||
collection.delete(ids=ids)
|
||||
deletable_rows = []
|
||||
skipped = []
|
||||
for row in document_rows:
|
||||
if row.get("indexing_status") not in self._DELETABLE_DOCUMENT_STATUSES:
|
||||
skipped.append(
|
||||
{
|
||||
"id": int(row["id"]),
|
||||
"name": str(row.get("original_name") or ""),
|
||||
"reason": "文档仍在处理中,暂不允许删除",
|
||||
}
|
||||
)
|
||||
continue
|
||||
deletable_rows.append(row)
|
||||
|
||||
self._delete_oss_object(document_row.get("minio_path"))
|
||||
if not deletable_rows:
|
||||
return RagDatasetBatchDeleteResultVO(
|
||||
result="success",
|
||||
requestedCount=len(normalized_ids),
|
||||
deletedCount=0,
|
||||
skippedCount=len(skipped),
|
||||
deletedIds=[],
|
||||
skipped=skipped,
|
||||
)
|
||||
|
||||
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
||||
chunk_ids: list[str] = []
|
||||
for document_row in deletable_rows:
|
||||
raw = collection.get(where={"document_id": int(document_row["id"])}, include=[])
|
||||
ids = raw.get("ids") or []
|
||||
if ids:
|
||||
chunk_ids.extend(ids)
|
||||
if chunk_ids:
|
||||
collection.delete(ids=chunk_ids)
|
||||
|
||||
for document_row in deletable_rows:
|
||||
self._delete_oss_object(document_row.get("minio_path"))
|
||||
|
||||
existing_ids = [int(row["id"]) for row in deletable_rows]
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
@@ -1111,10 +1114,10 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
||||
UPDATE rag_document
|
||||
SET deleted_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE id = :document_id
|
||||
WHERE id = ANY(:document_ids)
|
||||
"""
|
||||
),
|
||||
{"document_id": DocumentId},
|
||||
{"document_ids": existing_ids},
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
@@ -1133,7 +1136,14 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
||||
{"dataset_id": DatasetId},
|
||||
)
|
||||
|
||||
return RagOperationResultVO(result="success")
|
||||
return RagDatasetBatchDeleteResultVO(
|
||||
result="success",
|
||||
requestedCount=len(normalized_ids),
|
||||
deletedCount=len(existing_ids),
|
||||
skippedCount=len(skipped),
|
||||
deletedIds=existing_ids,
|
||||
skipped=skipped,
|
||||
)
|
||||
|
||||
async def RetrieveDataset(
|
||||
self,
|
||||
@@ -1290,8 +1300,8 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
||||
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在")
|
||||
|
||||
suffix = os.path.splitext(FileName)[1].lower()
|
||||
if suffix not in {".pdf", ".docx", ".txt", ".md"}:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD")
|
||||
if suffix not in {".pdf", ".docx", ".txt", ".md", ".json"}:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD、JSON")
|
||||
|
||||
content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream"
|
||||
object_key = current.get("minio_path") or f"rag/{DatasetId}/{datetime.now().strftime('%Y/%m/%d')}/{uuid.uuid4().hex}_{FileName}"
|
||||
@@ -1328,81 +1338,23 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
||||
"file_size": len(Content),
|
||||
},
|
||||
)
|
||||
|
||||
try:
|
||||
page_texts = self._extract_page_texts(FileName=FileName, Content=Content)
|
||||
processed = self._build_chunks(
|
||||
file_name=FileName,
|
||||
page_texts=page_texts,
|
||||
asyncio.create_task(
|
||||
self._run_document_indexing_task(
|
||||
dataset=dataset,
|
||||
process_config=ProcessConfig or {},
|
||||
dataset_id=DatasetId,
|
||||
document_id=DocumentId,
|
||||
file_name=FileName,
|
||||
content=Content,
|
||||
process_config=ProcessConfig or {},
|
||||
)
|
||||
if not processed:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文档未提取到可入库文本")
|
||||
|
||||
embeddings = await self._embed_texts([item["text"] for item in processed], dataset.get("embedding_model") or "")
|
||||
collection.add(
|
||||
ids=[item["id"] for item in processed],
|
||||
documents=[item["text"] for item in processed],
|
||||
embeddings=embeddings,
|
||||
metadatas=[item["metadata"] for item in processed],
|
||||
)
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE rag_document
|
||||
SET chunk_count = :chunk_count,
|
||||
indexing_status = 'completed',
|
||||
indexing_error = NULL,
|
||||
indexing_started_at = COALESCE(indexing_started_at, NOW()),
|
||||
indexing_completed_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE id = :document_id
|
||||
"""
|
||||
),
|
||||
{"chunk_count": len(processed), "document_id": DocumentId},
|
||||
)
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE rag_dataset
|
||||
SET document_count = (
|
||||
SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
|
||||
),
|
||||
total_chunks = COALESCE((
|
||||
SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
|
||||
), 0),
|
||||
updated_at = NOW()
|
||||
WHERE id = :dataset_id
|
||||
"""
|
||||
),
|
||||
{"dataset_id": DatasetId},
|
||||
)
|
||||
except Exception as exc:
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE rag_document
|
||||
SET indexing_status = 'error',
|
||||
indexing_error = :error,
|
||||
updated_at = NOW()
|
||||
WHERE id = :document_id
|
||||
"""
|
||||
),
|
||||
{"document_id": DocumentId, "error": str(exc)[:2000]},
|
||||
)
|
||||
raise
|
||||
)
|
||||
|
||||
return RagDatasetUploadDocumentVO(
|
||||
document={
|
||||
"id": str(DocumentId),
|
||||
"name": FileName,
|
||||
"indexing_status": "completed",
|
||||
"word_count": len(processed),
|
||||
"indexing_status": "indexing",
|
||||
"word_count": 0,
|
||||
"hit_count": 0,
|
||||
"enabled": True,
|
||||
},
|
||||
@@ -1460,6 +1412,8 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
||||
return _extract_page_texts_from_pdf(Path(temp_path))
|
||||
if suffix == ".docx":
|
||||
return _extract_page_texts_from_docx(Path(temp_path))
|
||||
if suffix == ".json":
|
||||
return self._extract_page_texts_from_json(Content)
|
||||
|
||||
text_value = Content.decode("utf-8", errors="ignore").strip()
|
||||
return [(1, text_value)] if text_value else []
|
||||
@@ -1469,6 +1423,48 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
def _extract_page_texts_from_json(self, content: bytes) -> list[tuple[int, str]]:
|
||||
text_value = content.decode("utf-8", errors="ignore").strip()
|
||||
if not text_value:
|
||||
return []
|
||||
|
||||
try:
|
||||
payload = json.loads(text_value)
|
||||
except json.JSONDecodeError:
|
||||
return [(1, text_value)]
|
||||
|
||||
flattened = self._json_to_text(payload)
|
||||
flattened = flattened.strip()
|
||||
return [(1, flattened)] if flattened else []
|
||||
|
||||
def _json_to_text(self, value: object, prefix: str = "") -> str:
|
||||
if value is None:
|
||||
return ""
|
||||
if isinstance(value, str):
|
||||
return f"{prefix}: {value}" if prefix else value
|
||||
if isinstance(value, (int, float, bool)):
|
||||
return f"{prefix}: {value}" if prefix else str(value)
|
||||
if isinstance(value, list):
|
||||
parts: list[str] = []
|
||||
for index, item in enumerate(value, start=1):
|
||||
item_prefix = f"{prefix}[{index}]" if prefix else f"item_{index}"
|
||||
chunk = self._json_to_text(item, item_prefix)
|
||||
if chunk:
|
||||
parts.append(chunk)
|
||||
return "\n".join(parts)
|
||||
if isinstance(value, dict):
|
||||
parts = []
|
||||
for key, item in value.items():
|
||||
item_prefix = f"{prefix}.{key}" if prefix else str(key)
|
||||
chunk = self._json_to_text(item, item_prefix)
|
||||
if chunk:
|
||||
parts.append(chunk)
|
||||
return "\n".join(parts)
|
||||
rendered = str(value).strip()
|
||||
if not rendered:
|
||||
return ""
|
||||
return f"{prefix}: {rendered}" if prefix else rendered
|
||||
|
||||
def _build_chunks(self, *, file_name: str, page_texts: list[tuple[int, str]], dataset: dict, process_config: dict, document_id: int) -> list[dict]:
|
||||
rules = ((process_config or {}).get("process_rule") or {}).get("rules") or {}
|
||||
segmentation = rules.get("segmentation") or {}
|
||||
@@ -1510,26 +1506,165 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
||||
embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}/embeddings"
|
||||
embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"]
|
||||
embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4"
|
||||
batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10))
|
||||
if not embed_url or not embed_key:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务")
|
||||
|
||||
embeddings: list[list[float]] = []
|
||||
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||
response = await client.post(
|
||||
embed_url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {embed_key}",
|
||||
},
|
||||
json={"model": embed_model, "input": texts},
|
||||
)
|
||||
response.raise_for_status()
|
||||
payload = response.json()
|
||||
rows = payload.get("data") or []
|
||||
embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")]
|
||||
for start in range(0, len(texts), batch_size):
|
||||
batch_texts = texts[start:start + batch_size]
|
||||
try:
|
||||
response = await client.post(
|
||||
embed_url,
|
||||
headers={
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": f"Bearer {embed_key}",
|
||||
},
|
||||
json={"model": embed_model, "input": batch_texts},
|
||||
)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPStatusError as exc:
|
||||
error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}"
|
||||
raise LeauditException(
|
||||
StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
f"向量化服务调用失败: {error_message[:300]}",
|
||||
) from exc
|
||||
|
||||
payload = response.json()
|
||||
rows = payload.get("data") or []
|
||||
batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")]
|
||||
if len(batch_embeddings) != len(batch_texts):
|
||||
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
|
||||
embeddings.extend(batch_embeddings)
|
||||
|
||||
if len(embeddings) != len(texts):
|
||||
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
|
||||
return embeddings
|
||||
|
||||
async def _run_document_indexing_task(
|
||||
self,
|
||||
*,
|
||||
dataset: dict,
|
||||
dataset_id: int,
|
||||
document_id: int,
|
||||
file_name: str,
|
||||
content: bytes,
|
||||
process_config: dict,
|
||||
) -> None:
|
||||
try:
|
||||
await self._update_document_processing_state(document_id=document_id, status="parsing")
|
||||
page_texts = self._extract_page_texts(FileName=file_name, Content=content)
|
||||
|
||||
await self._update_document_processing_state(document_id=document_id, status="cleaning")
|
||||
processed = self._build_chunks(
|
||||
file_name=file_name,
|
||||
page_texts=page_texts,
|
||||
dataset=dataset,
|
||||
process_config=process_config,
|
||||
document_id=document_id,
|
||||
)
|
||||
if not processed:
|
||||
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文档未提取到可入库文本")
|
||||
|
||||
await self._update_document_processing_state(
|
||||
document_id=document_id,
|
||||
status="splitting",
|
||||
chunk_count=len(processed),
|
||||
)
|
||||
embeddings = await self._embed_texts([item["text"] for item in processed], dataset.get("embedding_model") or "")
|
||||
|
||||
await self._update_document_processing_state(
|
||||
document_id=document_id,
|
||||
status="indexing",
|
||||
chunk_count=len(processed),
|
||||
)
|
||||
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
||||
collection.add(
|
||||
ids=[item["id"] for item in processed],
|
||||
documents=[item["text"] for item in processed],
|
||||
embeddings=embeddings,
|
||||
metadatas=[item["metadata"] for item in processed],
|
||||
)
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE rag_document
|
||||
SET chunk_count = :chunk_count,
|
||||
indexing_status = 'completed',
|
||||
indexing_error = NULL,
|
||||
indexing_started_at = COALESCE(indexing_started_at, NOW()),
|
||||
indexing_completed_at = NOW(),
|
||||
updated_at = NOW()
|
||||
WHERE id = :document_id
|
||||
"""
|
||||
),
|
||||
{"chunk_count": len(processed), "document_id": document_id},
|
||||
)
|
||||
await self._sync_dataset_counts(dataset_id)
|
||||
except Exception as exc:
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE rag_document
|
||||
SET indexing_status = 'error',
|
||||
indexing_error = :error,
|
||||
updated_at = NOW()
|
||||
WHERE id = :document_id
|
||||
"""
|
||||
),
|
||||
{"document_id": document_id, "error": str(exc)[:2000]},
|
||||
)
|
||||
|
||||
async def _update_document_processing_state(self, *, document_id: int, status: str, chunk_count: int | None = None) -> None:
|
||||
fields = [
|
||||
"indexing_status = :status",
|
||||
"indexing_error = NULL",
|
||||
"indexing_started_at = COALESCE(indexing_started_at, NOW())",
|
||||
"updated_at = NOW()",
|
||||
]
|
||||
params: dict[str, object] = {
|
||||
"document_id": document_id,
|
||||
"status": status,
|
||||
}
|
||||
if chunk_count is not None:
|
||||
fields.append("chunk_count = :chunk_count")
|
||||
params["chunk_count"] = chunk_count
|
||||
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
f"""
|
||||
UPDATE rag_document
|
||||
SET {", ".join(fields)}
|
||||
WHERE id = :document_id
|
||||
"""
|
||||
),
|
||||
params,
|
||||
)
|
||||
|
||||
async def _sync_dataset_counts(self, dataset_id: int) -> None:
|
||||
async with GetAsyncSession() as session:
|
||||
await session.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE rag_dataset
|
||||
SET document_count = (
|
||||
SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
|
||||
),
|
||||
total_chunks = COALESCE((
|
||||
SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL
|
||||
), 0),
|
||||
updated_at = NOW()
|
||||
WHERE id = :dataset_id
|
||||
"""
|
||||
),
|
||||
{"dataset_id": dataset_id},
|
||||
)
|
||||
|
||||
def _preprocess_text(self, text_value: str, *, remove_spaces: bool, remove_urls: bool) -> str:
|
||||
result = text_value or ""
|
||||
if remove_urls:
|
||||
|
||||
Reference in New Issue
Block a user