from __future__ import annotations import asyncio import json import mimetypes import os import re import tempfile import uuid from datetime import datetime from pathlib import Path import httpx from sqlalchemy import text from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession from fastapi_common.fastapi_common_storage.oss_client import OssClient from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException from fastapi_modules.fastapi_leaudit.domian.Dto.ragDatasetDto import RagDatasetUpdateDTO from fastapi_modules.fastapi_leaudit.domian.vo.ragDatasetVo import ( RagDatasetBatchDeleteResultVO, RagDatasetDetailVO, RagDatasetDocumentItemVO, RagDatasetDocumentPageVO, RagDatasetItemVO, RagDatasetPageVO, RagDatasetRetrieveDocumentVO, RagDatasetRetrieveRecordVO, RagDatasetRetrieveResponseVO, RagDatasetRetrieveSegmentVO, RagDatasetSegmentItemVO, RagDatasetSegmentPageVO, RagDatasetUploadDocumentVO, ) from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import RagOperationResultVO from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG, build_openai_embeddings_url from fastapi_modules.fastapi_leaudit.services.ragDatasetService import IRagDatasetService from fastapi_modules.fastapi_leaudit.services.impl.tenantResolver import TenantResolver class RagDatasetServiceImpl(IRagDatasetService): _ACTIVE_INDEXING_STATUSES = {"waiting", "parsing", "cleaning", "splitting", "indexing"} _DELETABLE_DOCUMENT_STATUSES = {"completed", "error", "paused"} _APP_LINK_SQL = """ LEFT JOIN ( SELECT DISTINCT ON (dataset_id) dataset_id, id, name, is_default FROM rag_chat_app WHERE deleted_at IS NULL AND status = 1 ORDER BY dataset_id, is_default DESC, sort_order ASC, id ASC ) a ON a.dataset_id = d.id """ _tenant_schema_checked = False _tenant_schema_lock = asyncio.Lock() def __init__(self) -> None: self.TenantResolver = TenantResolver() _DATASET_TENANT_NAME_SQL = ( "CASE " "WHEN NULLIF(BTRIM(d.tenant_code), '') = 'PUBLIC' THEN '公共' " "WHEN NULLIF(BTRIM(d.tenant_code), '') = 'PROVINCIAL' THEN '省级' " "ELSE COALESCE(NULLIF(BTRIM(d.area), ''), '未分配地区') " "END" ) async def GetAdminDatasets( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, Area: str | None, TenantFilterCode: str | None, OnlyEnabled: bool | None, Page: int, PageSize: int, ) -> RagDatasetPageVO: tenant_context = await self._resolve_tenant_context(UserArea=UserArea, TenantCode=TenantCode, TenantName=TenantName) async with GetAsyncSession() as session: await self._ensure_rag_tenant_schema(session) managed_area = await self._resolve_managed_area(UserRole=UserRole, UserArea=UserArea, TenantContext=tenant_context) managed_tenant_code = self._resolve_managed_tenant_code(UserRole=UserRole, TenantContext=tenant_context) filters = ["d.deleted_at IS NULL"] params: dict = { "offset": max(Page - 1, 0) * PageSize, "limit": PageSize, } requested_tenants = await self._normalize_requested_tenants(Area, TenantFilterCode) if managed_tenant_code: if requested_tenants and any(str(item.get("tenant_code") or "").strip() != managed_tenant_code for item in requested_tenants): raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户只能查看本地区知识库配置") filters.extend(self._dataset_tenant_filter_sql(managed_tenant_code, managed_area, prefix="managed", alias="d", params=params)) elif len(requested_tenants) == 1: filters.extend( self._dataset_tenant_filter_sql( requested_tenants[0].get("tenant_code"), requested_tenants[0].get("normalized_area"), prefix="requested", alias="d", params=params, ) ) elif len(requested_tenants) > 1: tenant_conditions: list[str] = [] for index, item in enumerate(requested_tenants): condition_parts = self._dataset_tenant_filter_sql( item.get("tenant_code"), item.get("normalized_area"), prefix=f"requested_{index}", alias="d", params=params, ) tenant_conditions.append("(" + " AND ".join(condition_parts) + ")") filters.append("(" + " OR ".join(tenant_conditions) + ")") if OnlyEnabled is not None: filters.append("d.status = :status") params["status"] = 1 if OnlyEnabled else 0 where_sql = " AND ".join(filters) async with GetAsyncSession() as session: await self._ensure_rag_tenant_schema(session) total = ( await session.execute( text(f"SELECT COUNT(1) FROM rag_dataset d WHERE {where_sql}"), {k: v for k, v in params.items() if k not in ("offset", "limit")}, ) ).scalar_one() rows = ( await session.execute( text( f""" SELECT d.id, d.name, d.description, d.area, COALESCE(NULLIF(BTRIM(d.tenant_code), ''), NULL) AS tenant_code, {self._DATASET_TENANT_NAME_SQL} AS tenant_name, d.is_public, d.is_default, d.document_count, d.total_chunks, d.status, d.sort_order, d.created_at, d.updated_at, a.id AS app_id, COALESCE(a.name, '') AS app_name, COALESCE(a.is_default, FALSE) AS app_is_default FROM rag_dataset d {self._APP_LINK_SQL} WHERE {where_sql} ORDER BY d.sort_order ASC, d.id ASC OFFSET :offset LIMIT :limit """ ), params, ) ).mappings().all() return RagDatasetPageVO(data=[await self._to_item_vo(dict(row)) for row in rows], total=int(total or 0)) async def CreateAdminDataset( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, Body: dict, ) -> RagDatasetDetailVO: tenant_context = await self._resolve_tenant_context(UserArea=UserArea, TenantCode=TenantCode, TenantName=TenantName) area, resolved_tenant_code, tenant_resolution = await self._resolve_dataset_area_input( RawArea=Body.get("area"), TenantCode=Body.get("tenant_code"), TenantName=Body.get("tenant_name"), ) name = str(Body.get("dataset_name") or Body.get("name") or "").strip() description = str(Body.get("dataset_description") or Body.get("description") or "").strip() if not area or not name: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "租户/地区和知识库名称不能为空") await self._assert_manage_area_scope( UserRole=UserRole, UserArea=UserArea, TenantContext=tenant_context, DatasetArea=area, DatasetTenantCode=resolved_tenant_code, ) collection_name = self._slugify_collection_name(area, name) retrieval_model = {} async with GetAsyncSession() as session: await self._ensure_rag_tenant_schema(session) base = ( await session.execute( text( """ SELECT embedding_model, embedding_dim, chunk_max_size, chunk_min_size, retrieval_model FROM rag_dataset WHERE deleted_at IS NULL ORDER BY is_default DESC, id ASC LIMIT 1 """ ) ) ).mappings().first() if base: retrieval_model = base.get("retrieval_model") or {} if bool(Body.get("is_default")): await self._clear_default_flags(session, tenant_code=resolved_tenant_code) row = ( await session.execute( text( """ INSERT INTO rag_dataset ( name, description, area, tenant_code, is_public, is_default, collection_name, embedding_model, embedding_dim, chunk_max_size, chunk_min_size, retrieval_model, sort_order, status, created_by, updated_by ) VALUES ( :name, :description, :area, :tenant_code, :is_public, :is_default, :collection_name, :embedding_model, :embedding_dim, :chunk_max_size, :chunk_min_size, CAST(:retrieval_model AS jsonb), :sort_order, :status, :created_by, :updated_by ) RETURNING id """ ), { "name": name, "description": description, "area": area, "tenant_code": resolved_tenant_code, "is_public": bool(Body.get("is_public")), "is_default": bool(Body.get("is_default")), "collection_name": collection_name, "embedding_model": (base.get("embedding_model") if base else "text-embedding-v4"), "embedding_dim": (base.get("embedding_dim") if base else 1024), "chunk_max_size": (base.get("chunk_max_size") if base else 800), "chunk_min_size": (base.get("chunk_min_size") if base else 20), "retrieval_model": json.dumps(retrieval_model, ensure_ascii=False), "sort_order": int(Body.get("sort_order") or 0), "status": int(Body.get("status") or 1), "created_by": CurrentUserId, "updated_by": CurrentUserId, }, ) ).mappings().first() dataset_id = int(row["id"]) await self._ensure_linked_app( session=session, dataset_id=dataset_id, dataset_name=name, dataset_area=area, dataset_tenant_code=resolved_tenant_code, current_user_id=CurrentUserId, is_default=bool(Body.get("is_default")), ) refreshed = await self._get_dataset_row(dataset_id) return await self._to_detail_vo(refreshed) if refreshed else None async def UpdateAdminDataset( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, DatasetId: int, Body: dict, ) -> RagDatasetDetailVO | None: tenant_context = await self._resolve_tenant_context(UserArea=UserArea, TenantCode=TenantCode, TenantName=TenantName) existing = await self._get_dataset_row(DatasetId) if not existing: return None await self._assert_manage_area_scope( UserRole=UserRole, UserArea=UserArea, TenantContext=tenant_context, DatasetArea=str(existing.get("area") or ""), DatasetTenantCode=str(existing.get("tenant_code") or "") or None, ) area, resolved_tenant_code, _ = await self._resolve_dataset_area_input( RawArea=Body.get("area") or existing.get("area"), TenantCode=Body.get("tenant_code") or existing.get("tenant_code"), TenantName=Body.get("tenant_name"), ) await self._assert_manage_area_scope( UserRole=UserRole, UserArea=UserArea, TenantContext=tenant_context, DatasetArea=area, DatasetTenantCode=resolved_tenant_code, ) async with GetAsyncSession() as session: await self._ensure_rag_tenant_schema(session) target_is_default = bool(Body.get("is_default", existing.get("is_default"))) if target_is_default: await self._clear_default_flags(session, tenant_code=resolved_tenant_code) elif existing.get("is_default") and Body.get("is_default") is False: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "默认知识库不能直接取消,请先将其他知识库设为默认") await session.execute( text( """ UPDATE rag_dataset SET name = :name, description = :description, area = :area, tenant_code = :tenant_code, is_public = :is_public, is_default = :is_default, sort_order = :sort_order, status = :status, updated_by = :updated_by, updated_at = NOW() WHERE id = :dataset_id """ ), { "dataset_id": DatasetId, "name": str(Body.get("dataset_name") or Body.get("name") or existing.get("name") or "").strip(), "description": str(Body.get("dataset_description") or Body.get("description") or existing.get("description") or "").strip(), "area": area, "tenant_code": resolved_tenant_code, "is_public": bool(Body.get("is_public", existing.get("is_public"))), "is_default": target_is_default, "sort_order": int(Body.get("sort_order") if Body.get("sort_order") is not None else (existing.get("sort_order") or 0)), "status": int(Body.get("status") if Body.get("status") is not None else (existing.get("status") or 1)), "updated_by": CurrentUserId, }, ) await self._ensure_linked_app( session=session, dataset_id=DatasetId, dataset_name=str(Body.get("dataset_name") or Body.get("name") or existing.get("name") or "").strip(), dataset_area=area, dataset_tenant_code=resolved_tenant_code, current_user_id=CurrentUserId, is_default=target_is_default, ) refreshed = await self._get_dataset_row(DatasetId) return await self._to_detail_vo(refreshed) if refreshed else None async def DeleteAdminDataset( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, DatasetId: int, ) -> RagOperationResultVO: tenant_context = await self._resolve_tenant_context(UserArea=UserArea, TenantCode=TenantCode, TenantName=TenantName) existing = await self._get_dataset_row(DatasetId) if not existing: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") await self._assert_manage_area_scope( UserRole=UserRole, UserArea=UserArea, TenantContext=tenant_context, DatasetArea=str(existing.get("area") or ""), DatasetTenantCode=str(existing.get("tenant_code") or "") or None, ) if bool(existing.get("is_default")): raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "默认知识库不允许删除,请先切换默认知识库") async with GetAsyncSession() as session: await self._ensure_rag_tenant_schema(session) await session.execute( text("UPDATE rag_dataset SET deleted_at = NOW(), updated_by = :updated_by, updated_at = NOW() WHERE id = :dataset_id"), {"dataset_id": DatasetId, "updated_by": CurrentUserId}, ) await session.execute( text("UPDATE rag_chat_app SET deleted_at = NOW(), updated_by = :updated_by, updated_at = NOW() WHERE dataset_id = :dataset_id AND deleted_at IS NULL"), {"dataset_id": DatasetId, "updated_by": CurrentUserId}, ) return RagOperationResultVO(result="success") async def GetMyDatasets( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, ) -> RagDatasetPageVO: tenant_context = await self._resolve_tenant_context(UserArea=UserArea, TenantCode=TenantCode, TenantName=TenantName) async with GetAsyncSession() as session: await self._ensure_rag_tenant_schema(session) rows = ( await session.execute( text( f""" SELECT d.id, d.name, d.description, d.area, COALESCE(NULLIF(BTRIM(d.tenant_code), ''), NULL) AS tenant_code, {self._DATASET_TENANT_NAME_SQL} AS tenant_name, d.is_public, d.is_default, d.document_count, d.total_chunks, d.status, d.sort_order, d.created_at, d.updated_at, a.id AS app_id, COALESCE(a.name, '') AS app_name, COALESCE(a.is_default, FALSE) AS app_is_default FROM rag_dataset d {self._APP_LINK_SQL} WHERE d.deleted_at IS NULL AND d.status = 1 AND d.status = 1 ORDER BY d.sort_order ASC, d.created_at DESC """ ), {}, ) ).mappings().all() return RagDatasetPageVO( data=[item for item in [await self._build_visible_item_vo(dict(row), tenant_context, UserRole) for row in rows] if item], total=len([item for item in [await self._build_visible_item_vo(dict(row), tenant_context, UserRole) for row in rows] if item]), ) async def GetDatasetDetail( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, DatasetId: int, ) -> RagDatasetDetailVO | None: row = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId) if not row: return None return await self._to_detail_vo(row) async def UpdateDataset( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, DatasetId: int, Body: RagDatasetUpdateDTO, ) -> RagDatasetDetailVO | None: row = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId) if not row: return None update_fields: list[str] = [] params: dict = {"dataset_id": DatasetId, "updated_by": CurrentUserId} if Body.name is not None: update_fields.append("name = :name") params["name"] = Body.name.strip() if Body.retrieval_model is not None: update_fields.append("retrieval_model = CAST(:retrieval_model AS jsonb)") params["retrieval_model"] = json.dumps(Body.retrieval_model, ensure_ascii=False) if not update_fields: return self._to_detail_vo(row) update_fields.append("updated_by = :updated_by") update_fields.append("updated_at = NOW()") async with GetAsyncSession() as session: await session.execute( text( f""" UPDATE rag_dataset SET {", ".join(update_fields)} WHERE id = :dataset_id """ ), params, ) refreshed = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId) return await self._to_detail_vo(refreshed) if refreshed else None async def GetDatasetDocuments( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, DatasetId: int, Page: int, Limit: int, Keyword: str | None, ) -> RagDatasetDocumentPageVO: dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId) if not dataset: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") where_sql = [ "dataset_id = :dataset_id", "deleted_at IS NULL", ] params: dict = { "dataset_id": DatasetId, "offset": max(Page - 1, 0) * Limit, "limit": Limit + 1, } if Keyword and Keyword.strip(): where_sql.append("(original_name ILIKE :keyword OR filename ILIKE :keyword)") params["keyword"] = f"%{Keyword.strip()}%" async with GetAsyncSession() as session: total = ( await session.execute( text( f""" SELECT COUNT(1) FROM rag_document WHERE {" AND ".join(where_sql)} """ ), {key: value for key, value in params.items() if key not in ("offset", "limit")}, ) ).scalar_one() rows = ( await session.execute( text( f""" SELECT id, dataset_id, original_name, file_type, file_size, chunk_count, indexing_status, COALESCE(indexing_error, '') AS indexing_error, enabled, hit_count, created_by, created_at, updated_at FROM rag_document WHERE {" AND ".join(where_sql)} ORDER BY created_at DESC OFFSET :offset LIMIT :limit """ ), params, ) ).mappings().all() has_more = len(rows) > Limit items = rows[:Limit] return RagDatasetDocumentPageVO( data=[ RagDatasetDocumentItemVO( id=row["id"], datasetId=row["dataset_id"], name=row.get("original_name") or "", fileType=row.get("file_type") or "", fileSize=row.get("file_size") or 0, chunkCount=row.get("chunk_count") or 0, indexingStatus=row.get("indexing_status") or "waiting", error=row.get("indexing_error") or "", enabled=bool(row.get("enabled")), hitCount=row.get("hit_count") or 0, createdBy=row.get("created_by"), createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0, ) for row in items ], total=int(total or 0), page=Page, limit=Limit, hasMore=has_more, ) async def GetDatasetDocumentDetail( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, DatasetId: int, DocumentId: int, ) -> RagDatasetDocumentItemVO | None: dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId) if not dataset: return None async with GetAsyncSession() as session: row = ( await session.execute( text( """ SELECT id, dataset_id, original_name, file_type, file_size, chunk_count, indexing_status, COALESCE(indexing_error, '') AS indexing_error, enabled, hit_count, created_by, created_at, updated_at FROM rag_document WHERE id = :document_id AND dataset_id = :dataset_id AND deleted_at IS NULL LIMIT 1 """ ), {"document_id": DocumentId, "dataset_id": DatasetId}, ) ).mappings().first() if not row: return None return RagDatasetDocumentItemVO( id=row["id"], datasetId=row["dataset_id"], name=row.get("original_name") or "", fileType=row.get("file_type") or "", fileSize=row.get("file_size") or 0, chunkCount=row.get("chunk_count") or 0, indexingStatus=row.get("indexing_status") or "waiting", error=row.get("indexing_error") or "", enabled=bool(row.get("enabled")), hitCount=row.get("hit_count") or 0, createdBy=row.get("created_by"), createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0, ) async def _get_visible_dataset( self, user_area: str | None, user_role: str | None, tenant_code: str | None, tenant_name: str | None, dataset_id: int, ) -> dict | None: async with GetAsyncSession() as session: await self._ensure_rag_tenant_schema(session) row = ( await session.execute( text( f""" SELECT d.id, d.name, d.description, d.area, COALESCE(NULLIF(BTRIM(d.tenant_code), ''), NULL) AS tenant_code, {self._DATASET_TENANT_NAME_SQL} AS tenant_name, d.is_public, d.is_default, d.status, d.document_count, d.total_chunks, d.chunk_max_size, d.chunk_min_size, d.sort_order, d.retrieval_model, d.collection_name, d.embedding_model, d.created_at, d.updated_at, a.id AS app_id, COALESCE(a.name, '') AS app_name, COALESCE(a.is_default, FALSE) AS app_is_default FROM rag_dataset d {self._APP_LINK_SQL} WHERE d.id = :dataset_id AND d.deleted_at IS NULL AND d.status = 1 LIMIT 1 """ ), {"dataset_id": dataset_id}, ) ).mappings().first() if not row: return None if await self._dataset_visible(dict(row), UserArea=user_area, UserRole=user_role, TenantCode=tenant_code, TenantName=tenant_name): return dict(row) return None async def _to_detail_vo(self, row: dict) -> RagDatasetDetailVO: return RagDatasetDetailVO( id=row["id"], name=row["name"], description=row.get("description") or "", area=row.get("area") or "", tenantCode=str(row.get("tenant_code") or ""), tenantName=str(row.get("tenant_name") or row.get("area") or ""), isPublic=bool(row.get("is_public")), isDefault=bool(row.get("is_default")), status=row.get("status") or 1, documentCount=row.get("document_count") or 0, sortOrder=row.get("sort_order") or 0, totalChunks=row.get("total_chunks") or 0, chunkMaxSize=row.get("chunk_max_size") or 800, chunkMinSize=row.get("chunk_min_size") or 20, retrievalModel=row.get("retrieval_model") or {}, createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0, appId=row.get("app_id"), appName=row.get("app_name") or "", appIsDefault=bool(row.get("app_is_default")), ) async def _to_item_vo(self, row: dict) -> RagDatasetItemVO: return RagDatasetItemVO( id=row["id"], name=row.get("name") or "", description=row.get("description") or "", area=row.get("area") or "", tenantCode=str(row.get("tenant_code") or ""), tenantName=str(row.get("tenant_name") or row.get("area") or ""), isPublic=bool(row.get("is_public")), isDefault=bool(row.get("is_default")), documentCount=row.get("document_count") or 0, totalChunks=row.get("total_chunks") or 0, status=row.get("status") or 1, sortOrder=row.get("sort_order") or 0, createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0, updatedAt=int(row["updated_at"].timestamp()) if row.get("updated_at") else 0, appId=row.get("app_id"), appName=row.get("app_name") or "", appIsDefault=bool(row.get("app_is_default")), ) async def _get_dataset_row(self, dataset_id: int) -> dict | None: async with GetAsyncSession() as session: await self._ensure_rag_tenant_schema(session) row = ( await session.execute( text( f""" SELECT d.id, d.name, d.description, d.area, COALESCE(NULLIF(BTRIM(d.tenant_code), ''), NULL) AS tenant_code, {self._DATASET_TENANT_NAME_SQL} AS tenant_name, d.is_public, d.is_default, d.status, d.document_count, d.total_chunks, d.chunk_max_size, d.chunk_min_size, d.sort_order, d.retrieval_model, d.collection_name, d.embedding_model, d.created_at, d.updated_at, a.id AS app_id, COALESCE(a.name, '') AS app_name, COALESCE(a.is_default, FALSE) AS app_is_default FROM rag_dataset d {self._APP_LINK_SQL} WHERE d.id = :dataset_id AND d.deleted_at IS NULL LIMIT 1 """ ), {"dataset_id": dataset_id}, ) ).mappings().first() return dict(row) if row else None async def _resolve_tenant_context( self, *, UserArea: str | None, TenantCode: str | None, TenantName: str | None, ) -> dict[str, str | None]: resolved = await self.TenantResolver.ResolveUserContext( Area=UserArea, TenantCode=TenantCode, TenantName=TenantName, Source="rag_dataset_user", ) return { "tenant_code": resolved.tenant_code, "tenant_name": resolved.tenant_name, "tenant_type": resolved.tenant_type, "area": resolved.normalized_value or UserArea, } async def _resolve_record_tenant(self, raw_value: str | None): return await self.TenantResolver.Resolve( RawValue=raw_value, Source="rag_dataset_record", ) async def _dataset_visible( self, row: dict, *, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, ) -> bool: if self._role_is_global(UserRole): return True if bool(row.get("is_public")): return True tenant_context = await self._resolve_tenant_context(UserArea=UserArea, TenantCode=TenantCode, TenantName=TenantName) return self._row_matches_tenant_scope( row_tenant_code=row.get("tenant_code"), row_area=row.get("area"), tenant_context=tenant_context, ) async def _build_visible_item_vo( self, row: dict, tenant_context: dict[str, str | None], user_role: str | None, ) -> RagDatasetItemVO | None: if not await self._dataset_visible( row, UserArea=tenant_context.get("area"), UserRole=user_role, TenantCode=tenant_context.get("tenant_code"), TenantName=tenant_context.get("tenant_name"), ): return None return await self._to_item_vo(row) async def _normalize_requested_tenants( self, area_filter: str | None, tenant_code_filter: str | None = None, ) -> list[dict[str, str]]: area_items = [item.strip() for item in str(area_filter or "").split(",") if item.strip()] code_items = [item.strip() for item in str(tenant_code_filter or "").split(",") if item.strip()] normalized: list[dict[str, str]] = [] seen_keys: set[str] = set() for item in area_items: resolution = await self.TenantResolver.Resolve( RawValue=item, Source="rag_dataset_filter", ) normalized_area = resolution.tenant_name or resolution.normalized_value or item tenant_code = resolution.tenant_code or "" dedupe_key = tenant_code or normalized_area if dedupe_key in seen_keys: continue seen_keys.add(dedupe_key) normalized.append( { "tenant_code": tenant_code, "tenant_name": resolution.tenant_name or normalized_area, "normalized_area": normalized_area, } ) for tenant_code in code_items: resolution = await self.TenantResolver.Resolve( RawValue=None, Source="rag_dataset_filter_code", PreferredTenantCode=tenant_code, ) normalized_area = ( resolution.tenant_name or self._fallback_area_from_tenant_code(tenant_code) or resolution.normalized_value or tenant_code ) resolved_code = resolution.tenant_code or tenant_code dedupe_key = resolved_code or normalized_area if dedupe_key in seen_keys: continue seen_keys.add(dedupe_key) normalized.append( { "tenant_code": resolved_code, "tenant_name": resolution.tenant_name or normalized_area, "normalized_area": normalized_area, } ) return normalized async def _resolve_dataset_area_input( self, *, RawArea: str | None, TenantCode: str | None, TenantName: str | None, ) -> tuple[str, str | None, object]: resolution = await self.TenantResolver.Resolve( RawValue=RawArea, Source="rag_dataset_input", PreferredTenantCode=str(TenantCode or "").strip() or None, FallbackTenantName=TenantName, ) normalized_area = resolution.tenant_name or resolution.normalized_value or str(RawArea or "").strip() return normalized_area, resolution.tenant_code, resolution async def _clear_default_flags(self, session, tenant_code: str | None = None) -> None: await self._ensure_rag_tenant_schema(session) normalized_tenant_code = str(tenant_code or "").strip() if normalized_tenant_code: await session.execute( text( """ UPDATE rag_dataset SET is_default = FALSE WHERE deleted_at IS NULL AND tenant_code = :tenant_code """ ), {"tenant_code": normalized_tenant_code}, ) await session.execute( text( """ UPDATE rag_chat_app SET is_default = FALSE WHERE deleted_at IS NULL AND tenant_code = :tenant_code """ ), {"tenant_code": normalized_tenant_code}, ) return await session.execute( text( """ UPDATE rag_dataset SET is_default = FALSE WHERE deleted_at IS NULL AND (tenant_code IS NULL OR BTRIM(tenant_code) = '') """ ) ) await session.execute( text( """ UPDATE rag_chat_app SET is_default = FALSE WHERE deleted_at IS NULL AND (tenant_code IS NULL OR BTRIM(tenant_code) = '') """ ) ) async def _ensure_linked_app( self, session, dataset_id: int, dataset_name: str, dataset_area: str, dataset_tenant_code: str | None, current_user_id: int, is_default: bool, ) -> None: await self._ensure_rag_tenant_schema(session) app_row = ( await session.execute( text( """ SELECT id FROM rag_chat_app WHERE dataset_id = :dataset_id AND deleted_at IS NULL ORDER BY is_default DESC, sort_order ASC, id ASC LIMIT 1 """ ), {"dataset_id": dataset_id}, ) ).mappings().first() app_name = self._build_app_name(dataset_area=dataset_area, dataset_name=dataset_name) if app_row: await session.execute( text( """ UPDATE rag_chat_app SET name = :name, description = :description, area = :area, tenant_code = :tenant_code, is_default = :is_default, status = 1, updated_by = :updated_by, updated_at = NOW() WHERE id = :app_id """ ), { "app_id": int(app_row["id"]), "name": app_name, "description": f"{dataset_area or '默认地区'}知识库问答助手", "area": dataset_area or "", "tenant_code": dataset_tenant_code, "is_default": is_default, "updated_by": current_user_id, }, ) return await session.execute( text( """ INSERT INTO rag_chat_app ( name, description, area, tenant_code, dataset_id, suggested_questions, opening_statement, sort_order, status, is_default, created_by, updated_by ) VALUES ( :name, :description, :area, :tenant_code, :dataset_id, CAST(:suggested_questions AS jsonb), :opening_statement, 0, 1, :is_default, :created_by, :updated_by ) """ ), { "name": app_name, "description": f"{dataset_area or '默认地区'}知识库问答助手", "area": dataset_area or "", "tenant_code": dataset_tenant_code, "dataset_id": dataset_id, "suggested_questions": json.dumps([], ensure_ascii=False), "opening_statement": f"您好,我是{app_name}。", "is_default": is_default, "created_by": current_user_id, "updated_by": current_user_id, }, ) def _build_app_name(self, dataset_area: str, dataset_name: str) -> str: cleaned = (dataset_name or "").strip() if cleaned.endswith("知识库"): cleaned = cleaned[:-3] if cleaned.endswith("助手"): return cleaned if cleaned: return f"{cleaned}助手" return f"{dataset_area or '默认地区'}法务助手" def _slugify_collection_name(self, area: str, name: str) -> str: source = f"{area}_{name}".lower() normalized = re.sub(r"[^a-z0-9]+", "_", source).strip("_") if normalized: return f"legal_kb_{normalized}"[:96] return f"legal_kb_{uuid.uuid4().hex[:12]}" @staticmethod def _fallback_area_from_tenant_code(tenant_code: str | None) -> str | None: normalized = str(tenant_code or "").strip().upper() if normalized == "PUBLIC": return "公共" if normalized == "PROVINCIAL": return "省级" return None def _resolve_managed_tenant_code(self, UserRole: str | None, TenantContext: dict[str, str | None]) -> str | None: if self._role_is_global(UserRole): return None tenant_code = str(TenantContext.get("tenant_code") or "").strip() return tenant_code or None async def _resolve_managed_area(self, UserRole: str | None, UserArea: str | None, TenantContext: dict[str, str | None]) -> str | None: if self._role_is_global(UserRole): return None area = str(TenantContext.get("area") or UserArea or "").strip() if not area: raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户未配置租户/地区,无法管理知识库") return area async def _assert_manage_area_scope( self, UserRole: str | None, UserArea: str | None, TenantContext: dict[str, str | None], DatasetArea: str, DatasetTenantCode: str | None = None, ) -> None: if self._role_is_global(UserRole): return managed_tenant_code = self._resolve_managed_tenant_code(UserRole=UserRole, TenantContext=TenantContext) if managed_tenant_code: if str(DatasetTenantCode or "").strip() != managed_tenant_code: raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户只能管理本租户知识库") return managed_area = await self._resolve_managed_area(UserRole=UserRole, UserArea=UserArea, TenantContext=TenantContext) if DatasetArea != managed_area: raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户只能管理本地区知识库") async def _ensure_rag_tenant_schema(self, session) -> None: if self.__class__._tenant_schema_checked: return async with self.__class__._tenant_schema_lock: if self.__class__._tenant_schema_checked: return columns = ( await session.execute( text( """ SELECT table_name FROM information_schema.columns WHERE table_schema = current_schema() AND table_name IN ('rag_dataset', 'rag_chat_app') AND column_name = 'tenant_code' """ ) ) ).scalars().all() existing = set(columns) if existing == {"rag_dataset", "rag_chat_app"}: self.__class__._tenant_schema_checked = True return await session.execute(text("SET LOCAL lock_timeout = '1000ms'")) if "rag_dataset" not in existing: await session.execute(text("ALTER TABLE rag_dataset ADD COLUMN tenant_code VARCHAR(64) NULL")) if "rag_chat_app" not in existing: await session.execute(text("ALTER TABLE rag_chat_app ADD COLUMN tenant_code VARCHAR(64) NULL")) await session.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_dataset_tenant_code ON rag_dataset(tenant_code) WHERE deleted_at IS NULL")) await session.execute(text("CREATE INDEX IF NOT EXISTS idx_rag_chat_app_tenant_code ON rag_chat_app(tenant_code) WHERE deleted_at IS NULL")) self.__class__._tenant_schema_checked = True def _dataset_tenant_filter_sql( self, tenant_code: str | None, area: str | None, *, prefix: str, alias: str, params: dict, ) -> list[str]: normalized_tenant_code = str(tenant_code or "").strip() normalized_area = str(area or "").strip() if normalized_tenant_code: params[f"{prefix}_tenant_code"] = normalized_tenant_code return [f"{alias}.tenant_code = :{prefix}_tenant_code"] if normalized_area: params[f"{prefix}_area"] = normalized_area return [f"COALESCE({alias}.area, '') = :{prefix}_area"] return ["1 = 0"] @staticmethod def _tenant_context_is_global(tenant_context: dict[str, str | None]) -> bool: tenant_code = str(tenant_context.get("tenant_code") or "").strip().upper() return tenant_code in {"PUBLIC", "PROVINCIAL"} @staticmethod def _role_is_global(user_role: str | None) -> bool: normalized = str(user_role or "").strip() return normalized in {"super_admin", "provincial_admin"} def _row_matches_tenant_scope( self, *, row_tenant_code: str | None, row_area: str | None, tenant_context: dict[str, str | None], ) -> bool: user_tenant_code = str(tenant_context.get("tenant_code") or "").strip() if user_tenant_code: return str(row_tenant_code or "").strip() == user_tenant_code return str(row_area or "").strip() == str(tenant_context.get("area") or "").strip() async def UploadDatasetDocument( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, DatasetId: int, FileName: str, ContentType: str | None, Content: bytes, ProcessConfig: dict | None, ) -> RagDatasetUploadDocumentVO: dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId) if not dataset: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") suffix = os.path.splitext(FileName)[1].lower() if suffix not in {".pdf", ".docx", ".txt", ".md", ".json"}: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD、JSON") object_key = f"rag/{DatasetId}/{datetime.now().strftime('%Y/%m/%d')}/{uuid.uuid4().hex}_{FileName}" content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream" OssClient().EnsureBucket() stored_key = OssClient().UploadBytes(ObjectKey=object_key, Content=Content, ContentType=content_type) async with GetAsyncSession() as session: inserted = ( await session.execute( text( """ INSERT INTO rag_document ( dataset_id, filename, original_name, minio_path, file_type, file_size, chunk_count, indexing_status, enabled, hit_count, created_by ) VALUES ( :dataset_id, :filename, :original_name, :minio_path, :file_type, :file_size, 0, 'indexing', TRUE, 0, :created_by ) RETURNING id, created_at, updated_at """ ), { "dataset_id": DatasetId, "filename": uuid.uuid4().hex + suffix, "original_name": FileName, "minio_path": stored_key, "file_type": suffix.lstrip("."), "file_size": len(Content), "created_by": CurrentUserId, }, ) ).mappings().first() document_id = int(inserted["id"]) asyncio.create_task( self._run_document_indexing_task( dataset=dataset, dataset_id=DatasetId, document_id=document_id, file_name=FileName, content=Content, process_config=ProcessConfig or {}, ) ) return RagDatasetUploadDocumentVO( document={ "id": str(document_id), "name": FileName, "indexing_status": "indexing", "word_count": 0, "hit_count": 0, "enabled": True, }, batch=str(document_id), ) async def GetDatasetDocumentSegments( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, DatasetId: int, DocumentId: int, Page: int, Limit: int, Keyword: str | None, ) -> RagDatasetSegmentPageVO: dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId) if not dataset: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") collection = get_chroma().get_or_create_collection(dataset["collection_name"]) raw = collection.get(where={"document_id": DocumentId}, include=["documents", "metadatas"]) ids = raw.get("ids") or [] docs = raw.get("documents") or [] metas = raw.get("metadatas") or [] items: list[dict] = [] for index, segment_id in enumerate(ids): content = docs[index] if index < len(docs) else "" meta = metas[index] if index < len(metas) and isinstance(metas[index], dict) else {} if Keyword and Keyword.strip() and Keyword.strip() not in content: continue items.append( { "id": str(segment_id), "position": index + 1, "document_id": str(DocumentId), "content": content or "", "word_count": len(content or ""), "hit_count": 0, "enabled": True, "status": "completed", "created_at": 0, } ) offset = max(Page - 1, 0) * Limit page_items = items[offset: offset + Limit] has_more = offset + Limit < len(items) return RagDatasetSegmentPageVO( data=[ RagDatasetSegmentItemVO( id=item["id"], position=item["position"], documentId=item["document_id"], content=item["content"], wordCount=item["word_count"], hitCount=item["hit_count"], enabled=item["enabled"], status=item["status"], createdAt=item["created_at"], ) for item in page_items ], total=len(items), limit=Limit, hasMore=has_more, ) async def GetDatasetDocumentSegmentDetail( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, DatasetId: int, DocumentId: int, SegmentId: str, ) -> RagDatasetSegmentItemVO | None: dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId) if not dataset: return None collection = get_chroma().get_or_create_collection(dataset["collection_name"]) raw = collection.get(ids=[SegmentId], include=["documents", "metadatas"]) ids = raw.get("ids") or [] if not ids: return None content = (raw.get("documents") or [""])[0] or "" metadata = (raw.get("metadatas") or [{}])[0] or {} if str(metadata.get("document_id") or "") != str(DocumentId): return None return RagDatasetSegmentItemVO( id=str(SegmentId), position=int(metadata.get("chunk_index") or 0) + 1, documentId=str(DocumentId), content=content, wordCount=len(content), hitCount=0, enabled=True, status="completed", createdAt=0, ) async def UpdateDatasetDocumentSegment( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, DatasetId: int, DocumentId: int, SegmentId: str, Body: dict, ) -> RagDatasetSegmentItemVO | None: dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId) if not dataset: return None current = await self.GetDatasetDocumentSegmentDetail( CurrentUserId=CurrentUserId, UserArea=UserArea, UserRole=UserRole, TenantCode=TenantCode, TenantName=TenantName, DatasetId=DatasetId, DocumentId=DocumentId, SegmentId=SegmentId, ) if not current: return None segment_body = Body.get("segment") if isinstance(Body.get("segment"), dict) else Body content = str(segment_body.get("content") or current.content).strip() if not content: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "分段内容不能为空") collection = get_chroma().get_or_create_collection(dataset["collection_name"]) raw = collection.get(ids=[SegmentId], include=["metadatas"]) metadata = (raw.get("metadatas") or [{}])[0] or {} embeddings = await self._embed_texts([content], dataset.get("embedding_model") or "") collection.update( ids=[SegmentId], documents=[content], embeddings=embeddings, metadatas=[metadata], ) return RagDatasetSegmentItemVO( id=str(SegmentId), position=int(metadata.get("chunk_index") or 0) + 1, documentId=str(DocumentId), content=content, wordCount=len(content), hitCount=0, enabled=True, status="completed", createdAt=0, ) async def DeleteDatasetDocumentSegment( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, DatasetId: int, DocumentId: int, SegmentId: str, ) -> RagDatasetBatchDeleteResultVO: dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId) if not dataset: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") current = await self.GetDatasetDocumentSegmentDetail( CurrentUserId=CurrentUserId, UserArea=UserArea, UserRole=UserRole, TenantCode=TenantCode, TenantName=TenantName, DatasetId=DatasetId, DocumentId=DocumentId, SegmentId=SegmentId, ) if not current: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "分段不存在") collection = get_chroma().get_or_create_collection(dataset["collection_name"]) collection.delete(ids=[SegmentId]) async with GetAsyncSession() as session: await session.execute( text( """ UPDATE rag_document SET chunk_count = GREATEST(chunk_count - 1, 0), updated_at = NOW() WHERE id = :document_id """ ), {"document_id": DocumentId}, ) await session.execute( text( """ UPDATE rag_dataset SET total_chunks = GREATEST(total_chunks - 1, 0), updated_at = NOW() WHERE id = :dataset_id """ ), {"dataset_id": DatasetId}, ) return RagOperationResultVO(result="success") async def DeleteDatasetDocument( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, DatasetId: int, DocumentId: int, ) -> RagOperationResultVO: result = await self.BatchDeleteDatasetDocuments( CurrentUserId=CurrentUserId, UserArea=UserArea, UserRole=UserRole, TenantCode=TenantCode, TenantName=TenantName, DatasetId=DatasetId, DocumentIds=[DocumentId], ) if result.deletedCount <= 0: reason = result.skipped[0].reason if result.skipped else "文档删除失败" raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, reason) return RagOperationResultVO(result="success") async def BatchDeleteDatasetDocuments( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, DatasetId: int, DocumentIds: list[int], ) -> RagOperationResultVO: dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId) if not dataset: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") normalized_ids = [int(item) for item in DocumentIds if item] if not normalized_ids: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "缺少待删除文档") async with GetAsyncSession() as session: document_rows = ( await session.execute( text( """ SELECT id, dataset_id, minio_path, original_name, indexing_status FROM rag_document WHERE id = ANY(:document_ids) AND dataset_id = :dataset_id AND deleted_at IS NULL """ ), {"document_ids": normalized_ids, "dataset_id": DatasetId}, ) ).mappings().all() if not document_rows: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在") deletable_rows = [] skipped = [] for row in document_rows: if row.get("indexing_status") not in self._DELETABLE_DOCUMENT_STATUSES: skipped.append( { "id": int(row["id"]), "name": str(row.get("original_name") or ""), "reason": "文档仍在处理中,暂不允许删除", } ) continue deletable_rows.append(row) if not deletable_rows: return RagDatasetBatchDeleteResultVO( result="success", requestedCount=len(normalized_ids), deletedCount=0, skippedCount=len(skipped), deletedIds=[], skipped=skipped, ) collection = get_chroma().get_or_create_collection(dataset["collection_name"]) chunk_ids: list[str] = [] for document_row in deletable_rows: raw = collection.get(where={"document_id": int(document_row["id"])}, include=[]) ids = raw.get("ids") or [] if ids: chunk_ids.extend(ids) if chunk_ids: collection.delete(ids=chunk_ids) for document_row in deletable_rows: self._delete_oss_object(document_row.get("minio_path")) existing_ids = [int(row["id"]) for row in deletable_rows] async with GetAsyncSession() as session: await session.execute( text( """ UPDATE rag_document SET deleted_at = NOW(), updated_at = NOW() WHERE id = ANY(:document_ids) """ ), {"document_ids": existing_ids}, ) await session.execute( text( """ UPDATE rag_dataset SET document_count = ( SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL ), total_chunks = COALESCE(( SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL ), 0), updated_at = NOW() WHERE id = :dataset_id """ ), {"dataset_id": DatasetId}, ) return RagDatasetBatchDeleteResultVO( result="success", requestedCount=len(normalized_ids), deletedCount=len(existing_ids), skippedCount=len(skipped), deletedIds=existing_ids, skipped=skipped, ) async def RetrieveDataset( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, DatasetId: int, Query: str, RetrievalModel: dict | None, ) -> RagDatasetRetrieveResponseVO: dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId) if not dataset: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") query_text = (Query or "").strip() if not query_text: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "检索内容不能为空") retrieval_model = RetrievalModel or {} top_k = int(retrieval_model.get("top_k") or 5) top_k = max(1, min(top_k, 20)) score_threshold_enabled = bool(retrieval_model.get("score_threshold_enabled")) score_threshold = float(retrieval_model.get("score_threshold") or 0) if score_threshold_enabled else None query_embedding = await self._embed_texts([query_text], dataset.get("embedding_model") or "") collection = get_chroma().get_or_create_collection(dataset["collection_name"]) raw = collection.query( query_embeddings=query_embedding, n_results=top_k, include=["documents", "metadatas", "distances"], ) ids = (raw.get("ids") or [[]])[0] if raw.get("ids") else [] documents = (raw.get("documents") or [[]])[0] if raw.get("documents") else [] metadatas = (raw.get("metadatas") or [[]])[0] if raw.get("metadatas") else [] distances = (raw.get("distances") or [[]])[0] if raw.get("distances") else [] records: list[RagDatasetRetrieveRecordVO] = [] for index, segment_id in enumerate(ids): content = documents[index] if index < len(documents) else "" metadata = metadatas[index] if index < len(metadatas) and isinstance(metadatas[index], dict) else {} distance = float(distances[index]) if index < len(distances) and distances[index] is not None else 1.0 score = max(0.0, min(1.0, 1.0 / (1.0 + max(0.0, distance)))) if score_threshold_enabled and score_threshold is not None and score < score_threshold: continue document_name = metadata.get("document_name") or metadata.get("source") or "" document_id = str(metadata.get("document_id") or "") chunk_index = int(metadata.get("chunk_index") or index) records.append( RagDatasetRetrieveRecordVO( score=round(score, 6), segment=RagDatasetRetrieveSegmentVO( id=str(segment_id), position=chunk_index + 1, documentId=document_id, content=content or "", answer="", wordCount=len(content or ""), hitCount=0, enabled=True, status="completed", createdAt=0, document=RagDatasetRetrieveDocumentVO( id=document_id, dataSourceType="upload_file", name=document_name, docType=None, ), ), ) ) return RagDatasetRetrieveResponseVO( query={"content": query_text}, records=records, ) async def GetDatasetDocumentIndexingStatus( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, DatasetId: int, DocumentId: int, ) -> dict: document = await self.GetDatasetDocumentDetail( CurrentUserId=CurrentUserId, UserArea=UserArea, UserRole=UserRole, TenantCode=TenantCode, TenantName=TenantName, DatasetId=DatasetId, DocumentId=DocumentId, ) if not document: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在") completed_segments = document.chunkCount if document.indexingStatus == "completed" else 0 total_segments = document.chunkCount if document.chunkCount > 0 else 0 return { "data": [ { "id": str(document.id), "indexing_status": document.indexingStatus, "processing_started_at": document.updatedAt or document.createdAt or None, "parsing_completed_at": document.updatedAt if document.indexingStatus in ("cleaning", "splitting", "indexing", "completed") else None, "cleaning_completed_at": document.updatedAt if document.indexingStatus in ("splitting", "indexing", "completed") else None, "splitting_completed_at": document.updatedAt if document.indexingStatus in ("indexing", "completed") else None, "completed_at": document.updatedAt if document.indexingStatus == "completed" else None, "paused_at": None, "error": document.error or None, "stopped_at": None, "completed_segments": completed_segments, "total_segments": total_segments, } ] } async def UpdateDatasetDocumentByFile( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, DatasetId: int, DocumentId: int, FileName: str, ContentType: str | None, Content: bytes, ProcessConfig: dict | None, ) -> RagDatasetUploadDocumentVO: dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId) if not dataset: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") async with GetAsyncSession() as session: current = ( await session.execute( text( """ SELECT id, minio_path FROM rag_document WHERE id = :document_id AND dataset_id = :dataset_id AND deleted_at IS NULL LIMIT 1 """ ), {"document_id": DocumentId, "dataset_id": DatasetId}, ) ).mappings().first() if not current: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "文档不存在") suffix = os.path.splitext(FileName)[1].lower() if suffix not in {".pdf", ".docx", ".txt", ".md", ".json"}: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 PDF、DOCX、TXT、MD、JSON") content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream" object_key = current.get("minio_path") or f"rag/{DatasetId}/{datetime.now().strftime('%Y/%m/%d')}/{uuid.uuid4().hex}_{FileName}" OssClient().EnsureBucket() stored_key = OssClient().UploadBytes(ObjectKey=object_key, Content=Content, ContentType=content_type) collection = get_chroma().get_or_create_collection(dataset["collection_name"]) raw = collection.get(where={"document_id": DocumentId}, include=[]) ids = raw.get("ids") or [] if ids: collection.delete(ids=ids) async with GetAsyncSession() as session: await session.execute( text( """ UPDATE rag_document SET original_name = :original_name, minio_path = :minio_path, file_type = :file_type, file_size = :file_size, chunk_count = 0, indexing_status = 'indexing', indexing_error = NULL, updated_at = NOW() WHERE id = :document_id """ ), { "document_id": DocumentId, "original_name": FileName, "minio_path": stored_key, "file_type": suffix.lstrip("."), "file_size": len(Content), }, ) asyncio.create_task( self._run_document_indexing_task( dataset=dataset, dataset_id=DatasetId, document_id=DocumentId, file_name=FileName, content=Content, process_config=ProcessConfig or {}, ) ) return RagDatasetUploadDocumentVO( document={ "id": str(DocumentId), "name": FileName, "indexing_status": "indexing", "word_count": 0, "hit_count": 0, "enabled": True, }, batch=str(DocumentId), ) async def BatchUpdateDatasetDocumentStatus( self, CurrentUserId: int, UserArea: str | None, UserRole: str | None, TenantCode: str | None, TenantName: str | None, DatasetId: int, DocumentIds: list[int], Enabled: bool, ) -> RagOperationResultVO: dataset = await self._get_visible_dataset(UserArea, UserRole, TenantCode, TenantName, DatasetId) if not dataset: raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "知识库不存在") ids = [int(doc_id) for doc_id in DocumentIds] if not ids: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "未传入文档ID") async with GetAsyncSession() as session: await session.execute( text( """ UPDATE rag_document SET enabled = :enabled, updated_at = NOW() WHERE dataset_id = :dataset_id AND id = ANY(:document_ids) AND deleted_at IS NULL """ ), {"dataset_id": DatasetId, "document_ids": ids, "enabled": Enabled}, ) return RagOperationResultVO(result="success") def _extract_page_texts(self, *, FileName: str, Content: bytes) -> list[tuple[int, str]]: suffix = os.path.splitext(FileName)[1].lower() with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp_file: temp_file.write(Content) temp_path = temp_file.name try: from fastapi_modules.fastapi_leaudit.services.impl.documentServiceImpl import ( _extract_page_texts_from_docx, _extract_page_texts_from_pdf, ) if suffix == ".pdf": return _extract_page_texts_from_pdf(Path(temp_path)) if suffix == ".docx": return _extract_page_texts_from_docx(Path(temp_path)) if suffix == ".json": return self._extract_page_texts_from_json(Content) text_value = Content.decode("utf-8", errors="ignore").strip() return [(1, text_value)] if text_value else [] finally: try: os.unlink(temp_path) except OSError: pass def _extract_page_texts_from_json(self, content: bytes) -> list[tuple[int, str]]: text_value = content.decode("utf-8", errors="ignore").strip() if not text_value: return [] try: payload = json.loads(text_value) except json.JSONDecodeError: return [(1, text_value)] flattened = self._json_to_text(payload) flattened = flattened.strip() return [(1, flattened)] if flattened else [] def _json_to_text(self, value: object, prefix: str = "") -> str: if value is None: return "" if isinstance(value, str): return f"{prefix}: {value}" if prefix else value if isinstance(value, (int, float, bool)): return f"{prefix}: {value}" if prefix else str(value) if isinstance(value, list): parts: list[str] = [] for index, item in enumerate(value, start=1): item_prefix = f"{prefix}[{index}]" if prefix else f"item_{index}" chunk = self._json_to_text(item, item_prefix) if chunk: parts.append(chunk) return "\n".join(parts) if isinstance(value, dict): parts = [] for key, item in value.items(): item_prefix = f"{prefix}.{key}" if prefix else str(key) chunk = self._json_to_text(item, item_prefix) if chunk: parts.append(chunk) return "\n".join(parts) rendered = str(value).strip() if not rendered: return "" return f"{prefix}: {rendered}" if prefix else rendered def _build_chunks(self, *, file_name: str, page_texts: list[tuple[int, str]], dataset: dict, process_config: dict, document_id: int) -> list[dict]: rules = ((process_config or {}).get("process_rule") or {}).get("rules") or {} segmentation = rules.get("segmentation") or {} pre_rules = rules.get("pre_processing_rules") or [] remove_spaces = any(rule.get("id") == "remove_extra_spaces" and rule.get("enabled") for rule in pre_rules) remove_urls = any(rule.get("id") == "remove_urls_emails" and rule.get("enabled") for rule in pre_rules) separator = segmentation.get("separator") or "\n\n" max_tokens = int(segmentation.get("max_tokens") or dataset.get("chunk_max_size") or 800) chunk_overlap = int(segmentation.get("chunk_overlap") or 50) chunk_overlap = max(0, min(chunk_overlap, max_tokens // 2 if max_tokens > 1 else 0)) chunks: list[dict] = [] for page_no, raw_text in page_texts: text_value = self._preprocess_text(raw_text, remove_spaces=remove_spaces, remove_urls=remove_urls) if not text_value: continue for index, chunk_text in enumerate(self._split_text(text_value, separator=separator, max_chars=max_tokens, overlap=chunk_overlap)): if not chunk_text.strip(): continue chunk_id = f"{document_id}:{page_no}:{index}" chunks.append( { "id": chunk_id, "text": chunk_text, "metadata": { "id": chunk_id, "source": file_name, "document_name": file_name, "document_id": document_id, "page": page_no, "chunk_index": index, }, } ) return chunks async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]: embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or build_openai_embeddings_url(RAG_CONFIG["LLM_BASE_URL"]) embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"] embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4" batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10)) if not embed_url or not embed_key: raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务") embeddings: list[list[float]] = [] async with httpx.AsyncClient(timeout=120.0) as client: for start in range(0, len(texts), batch_size): batch_texts = texts[start:start + batch_size] try: response = await client.post( embed_url, headers={ "Content-Type": "application/json", "Authorization": f"Bearer {embed_key}", }, json={"model": embed_model, "input": batch_texts}, ) response.raise_for_status() except httpx.HTTPStatusError as exc: error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}" raise LeauditException( StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, f"向量化服务调用失败: {error_message[:300]}", ) from exc payload = response.json() rows = payload.get("data") or [] batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")] if len(batch_embeddings) != len(batch_texts): raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常") embeddings.extend(batch_embeddings) if len(embeddings) != len(texts): raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常") return embeddings async def _run_document_indexing_task( self, *, dataset: dict, dataset_id: int, document_id: int, file_name: str, content: bytes, process_config: dict, ) -> None: try: await self._update_document_processing_state(document_id=document_id, status="parsing") page_texts = self._extract_page_texts(FileName=file_name, Content=content) await self._update_document_processing_state(document_id=document_id, status="cleaning") processed = self._build_chunks( file_name=file_name, page_texts=page_texts, dataset=dataset, process_config=process_config, document_id=document_id, ) if not processed: raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文档未提取到可入库文本") await self._update_document_processing_state( document_id=document_id, status="splitting", chunk_count=len(processed), ) embeddings = await self._embed_texts([item["text"] for item in processed], dataset.get("embedding_model") or "") await self._update_document_processing_state( document_id=document_id, status="indexing", chunk_count=len(processed), ) collection = get_chroma().get_or_create_collection(dataset["collection_name"]) collection.add( ids=[item["id"] for item in processed], documents=[item["text"] for item in processed], embeddings=embeddings, metadatas=[item["metadata"] for item in processed], ) async with GetAsyncSession() as session: await session.execute( text( """ UPDATE rag_document SET chunk_count = :chunk_count, indexing_status = 'completed', indexing_error = NULL, indexing_started_at = COALESCE(indexing_started_at, NOW()), indexing_completed_at = NOW(), updated_at = NOW() WHERE id = :document_id """ ), {"chunk_count": len(processed), "document_id": document_id}, ) await self._sync_dataset_counts(dataset_id) except Exception as exc: async with GetAsyncSession() as session: await session.execute( text( """ UPDATE rag_document SET indexing_status = 'error', indexing_error = :error, updated_at = NOW() WHERE id = :document_id """ ), {"document_id": document_id, "error": str(exc)[:2000]}, ) async def _update_document_processing_state(self, *, document_id: int, status: str, chunk_count: int | None = None) -> None: fields = [ "indexing_status = :status", "indexing_error = NULL", "indexing_started_at = COALESCE(indexing_started_at, NOW())", "updated_at = NOW()", ] params: dict[str, object] = { "document_id": document_id, "status": status, } if chunk_count is not None: fields.append("chunk_count = :chunk_count") params["chunk_count"] = chunk_count async with GetAsyncSession() as session: await session.execute( text( f""" UPDATE rag_document SET {", ".join(fields)} WHERE id = :document_id """ ), params, ) async def _sync_dataset_counts(self, dataset_id: int) -> None: async with GetAsyncSession() as session: await session.execute( text( """ UPDATE rag_dataset SET document_count = ( SELECT COUNT(1) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL ), total_chunks = COALESCE(( SELECT SUM(chunk_count) FROM rag_document WHERE dataset_id = :dataset_id AND deleted_at IS NULL ), 0), updated_at = NOW() WHERE id = :dataset_id """ ), {"dataset_id": dataset_id}, ) def _preprocess_text(self, text_value: str, *, remove_spaces: bool, remove_urls: bool) -> str: result = text_value or "" if remove_urls: result = re.sub(r"https?://\\S+|www\\.\\S+|[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\\.[A-Za-z]{2,}", " ", result) if remove_spaces: result = re.sub(r"[ \\t]+", " ", result) result = re.sub(r"\n{3,}", "\n\n", result) return result.strip() def _split_text(self, text_value: str, *, separator: str, max_chars: int, overlap: int) -> list[str]: parts = [part.strip() for part in text_value.split(separator) if part.strip()] if separator else [text_value] chunks: list[str] = [] current = "" for part in parts: candidate = f"{current}{separator if current else ''}{part}" if separator else f"{current}{part}" if len(candidate) <= max_chars: current = candidate continue if current: chunks.append(current) if len(part) <= max_chars: current = part continue start = 0 step = max(max_chars - overlap, 1) while start < len(part): chunks.append(part[start:start + max_chars]) start += step current = "" if current: chunks.append(current) return chunks def _delete_oss_object(self, source: str | None) -> None: if not source: return try: oss = OssClient() ref = oss.ResolveObjectRef(Source=source) if ref.isDirectUrl or not ref.objectKey: return oss._GetMinioClient().remove_object(ref.bucket, ref.objectKey) except Exception: # 对象存储删除失败不阻塞业务主流程,避免历史脏数据导致文档无法删除。 return