822 lines
34 KiB
Python
822 lines
34 KiB
Python
"""Temporary RAG attachments for chat conversations."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import csv
|
|
import hashlib
|
|
import json
|
|
import mimetypes
|
|
import os
|
|
import re
|
|
import tempfile
|
|
import uuid
|
|
from datetime import datetime, timedelta, timezone
|
|
from pathlib import Path
|
|
from typing import Any, Awaitable, Callable
|
|
|
|
from sqlalchemy import text
|
|
|
|
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
|
from fastapi_common.fastapi_common_storage.oss_client import OssClient
|
|
from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum
|
|
from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException
|
|
from fastapi_modules.fastapi_leaudit.domian.vo.ragChatAttachmentVo import (
|
|
RagChatAttachmentDeleteVO,
|
|
RagChatAttachmentVO,
|
|
)
|
|
from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma
|
|
from fastapi_modules.fastapi_leaudit.rag_engine.retriever import EmbedTexts, RagRetriever
|
|
from fastapi_modules.fastapi_leaudit.services.impl.ragChatServiceImpl import RagChatServiceImpl
|
|
from fastapi_modules.fastapi_leaudit.services.impl.ragDatasetServiceImpl import RagDatasetServiceImpl
|
|
from fastapi_modules.fastapi_leaudit.services.ragChatAttachmentService import IRagChatAttachmentService
|
|
|
|
|
|
DEFAULT_ATTACHMENT_TTL_DAYS = 7
|
|
SUPPORTED_TEXT_SUFFIXES = {".txt", ".md", ".json", ".csv"}
|
|
SUPPORTED_DOCUMENT_SUFFIXES = {".docx", ".pdf", ".xlsx"}
|
|
SUPPORTED_IMAGE_SUFFIXES = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".tif", ".tiff"}
|
|
SUPPORTED_SUFFIXES = SUPPORTED_TEXT_SUFFIXES | SUPPORTED_DOCUMENT_SUFFIXES | SUPPORTED_IMAGE_SUFFIXES
|
|
|
|
|
|
class RagChatAttachmentServiceImpl(IRagChatAttachmentService):
|
|
"""Manage temporary, conversation-scoped RAG attachment indexes."""
|
|
|
|
_attachment_schema_checked = False
|
|
_attachment_schema_lock = asyncio.Lock()
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
chroma_client: Any | None = None,
|
|
embed_texts: EmbedTexts | None = None,
|
|
chat_service: RagChatServiceImpl | None = None,
|
|
ocr_client_factory: Callable[[], Any] | None = None,
|
|
) -> None:
|
|
self._chroma_client = chroma_client
|
|
self.retriever = RagRetriever(
|
|
chroma_client=chroma_client,
|
|
embed_texts=embed_texts,
|
|
hydrate_documents=False,
|
|
)
|
|
self.chat_service = chat_service or RagChatServiceImpl()
|
|
self.dataset_helpers = RagDatasetServiceImpl()
|
|
self._ocr_client_factory = ocr_client_factory
|
|
|
|
def _default_expires_at(self, now: datetime | None = None) -> datetime:
|
|
base = self._ensure_aware_datetime(now or datetime.now(timezone.utc))
|
|
return base + timedelta(days=DEFAULT_ATTACHMENT_TTL_DAYS)
|
|
|
|
def BuildCollectionName(self, *, TenantCode: str | None, UserId: int, ConversationId: str, AttachmentId: str) -> str:
|
|
tenant = self._slugify(TenantCode or "default")
|
|
user = self._slugify(str(UserId))
|
|
conversation_hash = hashlib.sha1(ConversationId.encode("utf-8")).hexdigest()[:16]
|
|
attachment = self._slugify(AttachmentId)
|
|
collection_name = f"chat_attachment_{tenant}_{user}_{conversation_hash}_{attachment}"
|
|
return collection_name[:120]
|
|
|
|
def BuildChunks(
|
|
self,
|
|
*,
|
|
AttachmentId: str,
|
|
TenantCode: str | None,
|
|
UserId: int,
|
|
ConversationId: str,
|
|
FileName: str,
|
|
PageTexts: list[tuple[int, str]],
|
|
ChunkMaxSize: int = 800,
|
|
ChunkOverlap: int = 80,
|
|
) -> list[dict[str, Any]]:
|
|
chunk_max_size = max(50, int(ChunkMaxSize or 800))
|
|
chunk_overlap = max(0, min(int(ChunkOverlap or 0), chunk_max_size // 2))
|
|
chunks: list[dict[str, Any]] = []
|
|
for page_no, raw_text in PageTexts:
|
|
text_value = self._preprocess_text(raw_text)
|
|
if not text_value:
|
|
continue
|
|
split_chunks = self._split_text(text_value, separator="\n\n", max_chars=chunk_max_size, overlap=chunk_overlap)
|
|
for index, chunk_text in enumerate(split_chunks):
|
|
if not chunk_text.strip():
|
|
continue
|
|
chunk_id = f"{AttachmentId}:{page_no}:{index}"
|
|
chunks.append(
|
|
{
|
|
"id": chunk_id,
|
|
"text": chunk_text,
|
|
"metadata": {
|
|
"id": chunk_id,
|
|
"source": FileName,
|
|
"document_name": FileName,
|
|
"attachment_id": AttachmentId,
|
|
"tenant_code": str(TenantCode or ""),
|
|
"user_id": str(UserId),
|
|
"conversation_id": ConversationId,
|
|
"source_scope": "chat_attachment",
|
|
"page": int(page_no),
|
|
"chunk_index": index,
|
|
},
|
|
}
|
|
)
|
|
return chunks
|
|
|
|
async def CreateAttachment(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
TenantCode: str | None,
|
|
TenantName: str | None,
|
|
ConversationId: str | None,
|
|
AppId: int | None,
|
|
FileName: str,
|
|
ContentType: str | None,
|
|
Content: bytes,
|
|
) -> RagChatAttachmentVO:
|
|
if not FileName:
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "上传文件名不能为空")
|
|
if not Content:
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "上传文件内容不能为空")
|
|
|
|
suffix = Path(FileName).suffix.lower()
|
|
if suffix not in SUPPORTED_SUFFIXES:
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前仅支持 TXT、MD、JSON、CSV、DOCX、PDF、XLSX 和图片文件")
|
|
|
|
app = await self.chat_service._resolve_app(AppId, UserArea, UserRole, TenantCode, TenantName)
|
|
if not app:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "未配置可用聊天应用")
|
|
|
|
conversation_id = await self.chat_service._ensure_conversation(
|
|
user_id=CurrentUserId,
|
|
conversation_id=ConversationId,
|
|
app_id=app["id"],
|
|
user_area=UserArea,
|
|
user_role=UserRole,
|
|
tenant_code=TenantCode,
|
|
tenant_name=TenantName,
|
|
)
|
|
|
|
attachment_id = str(uuid.uuid4())
|
|
collection_name = self.BuildCollectionName(
|
|
TenantCode=TenantCode,
|
|
UserId=CurrentUserId,
|
|
ConversationId=conversation_id,
|
|
AttachmentId=attachment_id,
|
|
)
|
|
content_type = ContentType or mimetypes.guess_type(FileName)[0] or "application/octet-stream"
|
|
expires_at = self._default_expires_at()
|
|
object_key = f"rag-chat-attachments/{TenantCode or 'default'}/{CurrentUserId}/{conversation_id}/{attachment_id}_{FileName}"
|
|
OssClient().EnsureBucket()
|
|
stored_key = OssClient().UploadBytes(ObjectKey=object_key, Content=Content, ContentType=content_type)
|
|
|
|
async with GetAsyncSession() as session:
|
|
await self._ensure_attachment_schema(session)
|
|
row = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
INSERT INTO rag_chat_attachment (
|
|
attachment_id, conversation_id, user_id, tenant_code, area,
|
|
filename, original_name, content_type, file_size, minio_path,
|
|
collection_name, indexing_status, chunk_count, expires_at
|
|
) VALUES (
|
|
:attachment_id, :conversation_id, :user_id, :tenant_code, :area,
|
|
:filename, :original_name, :content_type, :file_size, :minio_path,
|
|
:collection_name, 'waiting', 0, :expires_at
|
|
)
|
|
RETURNING *
|
|
"""
|
|
),
|
|
{
|
|
"attachment_id": attachment_id,
|
|
"conversation_id": conversation_id,
|
|
"user_id": CurrentUserId,
|
|
"tenant_code": TenantCode,
|
|
"area": UserArea,
|
|
"filename": f"{attachment_id}{suffix}",
|
|
"original_name": FileName,
|
|
"content_type": content_type,
|
|
"file_size": len(Content),
|
|
"minio_path": stored_key,
|
|
"collection_name": collection_name,
|
|
"expires_at": expires_at,
|
|
},
|
|
)
|
|
).mappings().first()
|
|
|
|
asyncio.create_task(
|
|
self._run_attachment_indexing_task(
|
|
attachment_id=attachment_id,
|
|
tenant_code=TenantCode,
|
|
user_id=CurrentUserId,
|
|
conversation_id=conversation_id,
|
|
file_name=FileName,
|
|
content=Content,
|
|
collection_name=collection_name,
|
|
)
|
|
)
|
|
return self._to_vo(dict(row))
|
|
|
|
async def GetAttachment(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
TenantCode: str | None,
|
|
TenantName: str | None,
|
|
ConversationId: str,
|
|
AttachmentId: str,
|
|
) -> RagChatAttachmentVO:
|
|
self._validate_conversation_id(ConversationId)
|
|
record = await self._get_attachment_record(AttachmentId)
|
|
self._assert_attachment_scope(
|
|
record,
|
|
CurrentUserId=CurrentUserId,
|
|
TenantCode=TenantCode,
|
|
UserArea=UserArea,
|
|
ConversationId=ConversationId,
|
|
RequireCompleted=False,
|
|
Now=datetime.now(timezone.utc),
|
|
)
|
|
return self._to_vo(record)
|
|
|
|
async def DeleteAttachment(
|
|
self,
|
|
CurrentUserId: int,
|
|
UserArea: str | None,
|
|
UserRole: str | None,
|
|
TenantCode: str | None,
|
|
TenantName: str | None,
|
|
ConversationId: str,
|
|
AttachmentId: str,
|
|
) -> RagChatAttachmentDeleteVO:
|
|
self._validate_conversation_id(ConversationId)
|
|
record = await self._get_attachment_record(AttachmentId)
|
|
self._assert_attachment_scope(
|
|
record,
|
|
CurrentUserId=CurrentUserId,
|
|
TenantCode=TenantCode,
|
|
UserArea=UserArea,
|
|
ConversationId=ConversationId,
|
|
RequireCompleted=False,
|
|
Now=datetime.now(timezone.utc),
|
|
)
|
|
async with GetAsyncSession() as session:
|
|
await self._ensure_attachment_schema(session)
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
UPDATE rag_chat_attachment
|
|
SET deleted_at = NOW(), updated_at = NOW()
|
|
WHERE attachment_id = :attachment_id
|
|
"""
|
|
),
|
|
{"attachment_id": AttachmentId},
|
|
)
|
|
self._delete_collection(str(record.get("collection_name") or ""))
|
|
self._delete_oss_object(str(record.get("minio_path") or ""))
|
|
return RagChatAttachmentDeleteVO(result="success")
|
|
|
|
async def ValidateAttachmentForChat(
|
|
self,
|
|
CurrentUserId: int,
|
|
TenantCode: str | None,
|
|
ConversationId: str,
|
|
AttachmentId: str,
|
|
UserArea: str | None = None,
|
|
) -> dict:
|
|
record = await self._get_attachment_record(AttachmentId)
|
|
self._assert_attachment_scope(
|
|
record,
|
|
CurrentUserId=CurrentUserId,
|
|
TenantCode=TenantCode,
|
|
UserArea=UserArea,
|
|
ConversationId=ConversationId,
|
|
RequireCompleted=True,
|
|
Now=datetime.now(timezone.utc),
|
|
)
|
|
return record
|
|
|
|
async def RetrieveAttachmentContext(
|
|
self,
|
|
CurrentUserId: int,
|
|
TenantCode: str | None,
|
|
ConversationId: str,
|
|
AttachmentId: str,
|
|
Query: str,
|
|
TopK: int = 5,
|
|
UserArea: str | None = None,
|
|
) -> tuple[list[dict], str]:
|
|
record = await self.ValidateAttachmentForChat(
|
|
CurrentUserId=CurrentUserId,
|
|
TenantCode=TenantCode,
|
|
UserArea=UserArea,
|
|
ConversationId=ConversationId,
|
|
AttachmentId=AttachmentId,
|
|
)
|
|
result = await self.retriever.retrieve(
|
|
query=Query,
|
|
collection_name=str(record["collection_name"]),
|
|
top_k=TopK,
|
|
)
|
|
chunks: list[dict] = []
|
|
for chunk in result.chunks:
|
|
chunks.append(
|
|
{
|
|
**chunk,
|
|
"dataset_name": "临时上传文件",
|
|
"attachment_id": AttachmentId,
|
|
"conversation_id": ConversationId,
|
|
"source_scope": "chat_attachment",
|
|
"data_source_type": "chat_attachment",
|
|
"document_name": chunk.get("document_name") or record.get("original_name") or chunk.get("source"),
|
|
"source": chunk.get("source") or record.get("original_name") or "上传文件",
|
|
}
|
|
)
|
|
return chunks, str(record.get("original_name") or "上传文件")
|
|
|
|
async def ResolveActiveAttachmentIdForConversation(
|
|
self,
|
|
CurrentUserId: int,
|
|
TenantCode: str | None,
|
|
ConversationId: str,
|
|
UserArea: str | None = None,
|
|
) -> str | None:
|
|
attachment_ids = await self.ResolveActiveAttachmentIdsForConversation(
|
|
CurrentUserId=CurrentUserId,
|
|
TenantCode=TenantCode,
|
|
UserArea=UserArea,
|
|
ConversationId=ConversationId,
|
|
)
|
|
return attachment_ids[0] if attachment_ids else None
|
|
|
|
async def ResolveActiveAttachmentIdsForConversation(
|
|
self,
|
|
CurrentUserId: int,
|
|
TenantCode: str | None,
|
|
ConversationId: str,
|
|
UserArea: str | None = None,
|
|
) -> list[str]:
|
|
self._validate_conversation_id(ConversationId)
|
|
async with GetAsyncSession() as session:
|
|
await self._ensure_attachment_schema(session)
|
|
rows = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
SELECT attachment_id, tenant_code, area
|
|
FROM rag_chat_attachment
|
|
WHERE user_id = :user_id
|
|
AND conversation_id = :conversation_id
|
|
AND indexing_status = 'completed'
|
|
AND expires_at > NOW()
|
|
AND deleted_at IS NULL
|
|
AND (
|
|
NULLIF(:tenant_code, '') IS NULL
|
|
OR NULLIF(tenant_code, '') IS NULL
|
|
OR tenant_code = :tenant_code
|
|
)
|
|
AND (
|
|
NULLIF(:tenant_code, '') IS NOT NULL
|
|
OR NULLIF(:user_area, '') IS NULL
|
|
OR NULLIF(area, '') IS NULL
|
|
OR area = :user_area
|
|
)
|
|
ORDER BY indexing_completed_at DESC NULLS LAST, created_at DESC
|
|
"""
|
|
),
|
|
{
|
|
"user_id": CurrentUserId,
|
|
"conversation_id": ConversationId,
|
|
"tenant_code": TenantCode,
|
|
"user_area": UserArea,
|
|
},
|
|
)
|
|
).mappings().all()
|
|
return [str(row["attachment_id"]) for row in rows]
|
|
|
|
async def CleanupExpiredAttachments(self, *, Limit: int = 100) -> int:
|
|
async with GetAsyncSession() as session:
|
|
await self._ensure_attachment_schema(session)
|
|
rows = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
SELECT attachment_id, collection_name, minio_path
|
|
FROM rag_chat_attachment
|
|
WHERE deleted_at IS NULL
|
|
AND expires_at <= NOW()
|
|
ORDER BY expires_at ASC
|
|
LIMIT :limit
|
|
"""
|
|
),
|
|
{"limit": max(1, int(Limit or 100))},
|
|
)
|
|
).mappings().all()
|
|
attachment_ids = [str(row["attachment_id"]) for row in rows]
|
|
if attachment_ids:
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
UPDATE rag_chat_attachment
|
|
SET deleted_at = NOW(), updated_at = NOW()
|
|
WHERE attachment_id = ANY(:attachment_ids)
|
|
"""
|
|
),
|
|
{"attachment_ids": attachment_ids},
|
|
)
|
|
for row in rows:
|
|
self._delete_collection(str(row.get("collection_name") or ""))
|
|
self._delete_oss_object(str(row.get("minio_path") or ""))
|
|
return len(rows)
|
|
|
|
async def _get_attachment_record(self, attachment_id: str) -> dict:
|
|
if not str(attachment_id or "").strip():
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "附件ID不能为空")
|
|
async with GetAsyncSession() as session:
|
|
await self._ensure_attachment_schema(session)
|
|
row = (
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
SELECT *
|
|
FROM rag_chat_attachment
|
|
WHERE attachment_id = :attachment_id
|
|
AND deleted_at IS NULL
|
|
LIMIT 1
|
|
"""
|
|
),
|
|
{"attachment_id": attachment_id},
|
|
)
|
|
).mappings().first()
|
|
if not row:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "临时附件不存在")
|
|
return dict(row)
|
|
|
|
def _assert_attachment_scope(
|
|
self,
|
|
record: dict,
|
|
*,
|
|
CurrentUserId: int,
|
|
TenantCode: str | None,
|
|
ConversationId: str,
|
|
RequireCompleted: bool,
|
|
Now: datetime | None = None,
|
|
UserArea: str | None = None,
|
|
) -> None:
|
|
if record.get("deleted_at") is not None:
|
|
raise LeauditException(StatusCodeEnum.HTTP_404_NOT_FOUND, "临时附件不存在")
|
|
if int(record.get("user_id") or 0) != int(CurrentUserId):
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权访问该临时附件")
|
|
if str(record.get("conversation_id") or "") != str(ConversationId or ""):
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "临时附件不属于当前会话")
|
|
expected_tenant = str(TenantCode or "").strip()
|
|
record_tenant = str(record.get("tenant_code") or "").strip()
|
|
if expected_tenant and record_tenant and expected_tenant != record_tenant:
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "临时附件不属于当前租户")
|
|
expected_area = str(UserArea or "").strip()
|
|
record_area = str(record.get("area") or "").strip()
|
|
if not expected_tenant and expected_area and record_area and expected_area != record_area:
|
|
raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "临时附件不属于当前地区")
|
|
|
|
expires_at = self._ensure_aware_datetime(record.get("expires_at"))
|
|
now = self._ensure_aware_datetime(Now or datetime.now(timezone.utc))
|
|
if expires_at <= now:
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件已过期,请重新上传")
|
|
|
|
if RequireCompleted and str(record.get("indexing_status") or "") != "completed":
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件尚未完成解析,请稍后再试")
|
|
|
|
async def _run_attachment_indexing_task(
|
|
self,
|
|
*,
|
|
attachment_id: str,
|
|
tenant_code: str | None,
|
|
user_id: int,
|
|
conversation_id: str,
|
|
file_name: str,
|
|
content: bytes,
|
|
collection_name: str,
|
|
) -> None:
|
|
try:
|
|
await self._update_attachment_state(attachment_id=attachment_id, status="parsing")
|
|
page_texts = await self._extract_page_texts(FileName=file_name, Content=content)
|
|
if not page_texts:
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文件未提取到可检索文本")
|
|
|
|
await self._update_attachment_state(attachment_id=attachment_id, status="splitting")
|
|
chunks = self.BuildChunks(
|
|
AttachmentId=attachment_id,
|
|
TenantCode=tenant_code,
|
|
UserId=user_id,
|
|
ConversationId=conversation_id,
|
|
FileName=file_name,
|
|
PageTexts=page_texts,
|
|
)
|
|
if not chunks:
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "文件未切分出可检索文本")
|
|
|
|
await self._update_attachment_state(attachment_id=attachment_id, status="indexing", chunk_count=len(chunks))
|
|
embeddings = await self.retriever._embed_texts([item["text"] for item in chunks], "")
|
|
collection = self._get_chroma().get_or_create_collection(collection_name)
|
|
collection.add(
|
|
ids=[item["id"] for item in chunks],
|
|
documents=[item["text"] for item in chunks],
|
|
embeddings=embeddings,
|
|
metadatas=[item["metadata"] for item in chunks],
|
|
)
|
|
|
|
await self._update_attachment_state(
|
|
attachment_id=attachment_id,
|
|
status="completed",
|
|
chunk_count=len(chunks),
|
|
completed=True,
|
|
)
|
|
except Exception as exc:
|
|
await self._update_attachment_state(
|
|
attachment_id=attachment_id,
|
|
status="error",
|
|
error=str(exc)[:2000],
|
|
)
|
|
|
|
async def _extract_page_texts(self, *, FileName: str, Content: bytes) -> list[tuple[int, str]]:
|
|
suffix = Path(FileName).suffix.lower()
|
|
if suffix in {".txt", ".md"}:
|
|
text_value = Content.decode("utf-8", errors="ignore").strip()
|
|
return [(1, text_value)] if text_value else []
|
|
if suffix == ".json":
|
|
return self.dataset_helpers._extract_page_texts_from_json(Content)
|
|
if suffix == ".csv":
|
|
return [(1, self._csv_to_text(Content))] if self._csv_to_text(Content).strip() else []
|
|
if suffix == ".xlsx":
|
|
return self._extract_page_texts_from_xlsx(Content)
|
|
if suffix in SUPPORTED_IMAGE_SUFFIXES:
|
|
return await self._extract_page_texts_from_image(FileName, Content)
|
|
return self.dataset_helpers._extract_page_texts(FileName=FileName, Content=Content)
|
|
|
|
def _csv_to_text(self, content: bytes) -> str:
|
|
text_value = content.decode("utf-8-sig", errors="ignore")
|
|
rows: list[str] = []
|
|
for row_index, row in enumerate(csv.reader(text_value.splitlines()), start=1):
|
|
values = [str(cell).strip() for cell in row if str(cell).strip()]
|
|
if values:
|
|
rows.append(f"row_{row_index}: " + " | ".join(values))
|
|
return "\n".join(rows)
|
|
|
|
def _extract_page_texts_from_xlsx(self, content: bytes) -> list[tuple[int, str]]:
|
|
try:
|
|
from openpyxl import load_workbook
|
|
except Exception as exc: # pragma: no cover - depends on optional env package
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "当前环境未安装 openpyxl,暂无法解析 Excel 文件") from exc
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".xlsx", delete=False) as temp_file:
|
|
temp_file.write(content)
|
|
temp_path = temp_file.name
|
|
try:
|
|
workbook = load_workbook(temp_path, data_only=True, read_only=True)
|
|
page_texts: list[tuple[int, str]] = []
|
|
for sheet_index, sheet in enumerate(workbook.worksheets, start=1):
|
|
lines: list[str] = [f"Sheet: {sheet.title}"]
|
|
for row in sheet.iter_rows(values_only=True):
|
|
values = [str(cell).strip() for cell in row if cell is not None and str(cell).strip()]
|
|
if values:
|
|
lines.append(" | ".join(values))
|
|
rendered = "\n".join(lines).strip()
|
|
if rendered:
|
|
page_texts.append((sheet_index, rendered))
|
|
return page_texts
|
|
finally:
|
|
try:
|
|
os.unlink(temp_path)
|
|
except OSError:
|
|
pass
|
|
|
|
async def _extract_page_texts_from_image(self, file_name: str, content: bytes) -> list[tuple[int, str]]:
|
|
suffix = Path(file_name).suffix.lower() or ".png"
|
|
with tempfile.NamedTemporaryFile(suffix=suffix, delete=False) as temp_file:
|
|
temp_file.write(content)
|
|
temp_path = Path(temp_file.name)
|
|
try:
|
|
ocr_client = self._create_ocr_client()
|
|
result = await ocr_client.ocr(temp_path)
|
|
text_value = self._ocr_result_to_text(result)
|
|
return [(1, text_value)] if text_value else []
|
|
finally:
|
|
try:
|
|
temp_path.unlink()
|
|
except OSError:
|
|
pass
|
|
|
|
def _ocr_result_to_text(self, value: Any) -> str:
|
|
if value is None:
|
|
return ""
|
|
if isinstance(value, str):
|
|
return value.strip()
|
|
if isinstance(value, dict):
|
|
candidates: list[str] = []
|
|
for key in ("full_text", "text", "content", "markdown"):
|
|
if isinstance(value.get(key), str) and value.get(key).strip():
|
|
candidates.append(value[key].strip())
|
|
pages = value.get("pages")
|
|
if isinstance(pages, list):
|
|
for page in pages:
|
|
page_text = self._ocr_result_to_text(page)
|
|
if page_text:
|
|
candidates.append(page_text)
|
|
lines = value.get("lines") or value.get("blocks")
|
|
if isinstance(lines, list):
|
|
for line in lines:
|
|
line_text = self._ocr_result_to_text(line)
|
|
if line_text:
|
|
candidates.append(line_text)
|
|
return "\n".join(dict.fromkeys(candidates)).strip()
|
|
if isinstance(value, list):
|
|
return "\n".join(item for item in (self._ocr_result_to_text(item) for item in value) if item).strip()
|
|
|
|
text_attrs = []
|
|
for attr in ("full_text", "text", "content", "markdown"):
|
|
raw = getattr(value, attr, None)
|
|
if isinstance(raw, str) and raw.strip():
|
|
text_attrs.append(raw.strip())
|
|
pages = getattr(value, "pages", None)
|
|
if isinstance(pages, list):
|
|
for page in pages:
|
|
page_text = self._ocr_result_to_text(page)
|
|
if page_text:
|
|
text_attrs.append(page_text)
|
|
return "\n".join(dict.fromkeys(text_attrs)).strip()
|
|
|
|
def _create_ocr_client(self):
|
|
if self._ocr_client_factory is not None:
|
|
return self._ocr_client_factory()
|
|
from fastapi_modules.fastapi_leaudit.leaudit_bridge.client_factory import create_ocr_client
|
|
|
|
return create_ocr_client()
|
|
|
|
async def _update_attachment_state(
|
|
self,
|
|
*,
|
|
attachment_id: str,
|
|
status: str,
|
|
chunk_count: int | None = None,
|
|
error: str | None = None,
|
|
completed: bool = False,
|
|
) -> None:
|
|
fields = [
|
|
"indexing_status = :status",
|
|
"indexing_error = :error",
|
|
"updated_at = NOW()",
|
|
"indexing_started_at = COALESCE(indexing_started_at, NOW())",
|
|
]
|
|
if chunk_count is not None:
|
|
fields.append("chunk_count = :chunk_count")
|
|
if completed:
|
|
fields.append("indexing_completed_at = NOW()")
|
|
params: dict[str, Any] = {
|
|
"attachment_id": attachment_id,
|
|
"status": status,
|
|
"error": error,
|
|
"chunk_count": chunk_count,
|
|
}
|
|
async with GetAsyncSession() as session:
|
|
await self._ensure_attachment_schema(session)
|
|
await session.execute(
|
|
text(
|
|
f"""
|
|
UPDATE rag_chat_attachment
|
|
SET {", ".join(fields)}
|
|
WHERE attachment_id = :attachment_id
|
|
"""
|
|
),
|
|
params,
|
|
)
|
|
|
|
async def _ensure_attachment_schema(self, session) -> None:
|
|
if self.__class__._attachment_schema_checked:
|
|
return
|
|
async with self.__class__._attachment_schema_lock:
|
|
if self.__class__._attachment_schema_checked:
|
|
return
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
CREATE TABLE IF NOT EXISTS rag_chat_attachment (
|
|
id BIGSERIAL PRIMARY KEY,
|
|
attachment_id VARCHAR(64) NOT NULL UNIQUE,
|
|
conversation_id VARCHAR(64) NOT NULL,
|
|
user_id BIGINT NOT NULL,
|
|
tenant_code VARCHAR(64) NULL,
|
|
area VARCHAR(255) NULL,
|
|
filename VARCHAR(512) NOT NULL,
|
|
original_name VARCHAR(512) NOT NULL,
|
|
content_type VARCHAR(255) NULL,
|
|
file_size BIGINT NOT NULL DEFAULT 0,
|
|
minio_path TEXT NULL,
|
|
collection_name VARCHAR(160) NOT NULL,
|
|
indexing_status VARCHAR(32) NOT NULL DEFAULT 'waiting',
|
|
indexing_error TEXT NULL,
|
|
chunk_count INTEGER NOT NULL DEFAULT 0,
|
|
indexing_started_at TIMESTAMPTZ NULL,
|
|
indexing_completed_at TIMESTAMPTZ NULL,
|
|
expires_at TIMESTAMPTZ NOT NULL,
|
|
created_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
updated_at TIMESTAMPTZ NOT NULL DEFAULT NOW(),
|
|
deleted_at TIMESTAMPTZ NULL
|
|
)
|
|
"""
|
|
)
|
|
)
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
CREATE INDEX IF NOT EXISTS idx_rag_chat_attachment_scope
|
|
ON rag_chat_attachment(tenant_code, user_id, conversation_id, attachment_id)
|
|
WHERE deleted_at IS NULL
|
|
"""
|
|
)
|
|
)
|
|
await session.execute(
|
|
text(
|
|
"""
|
|
CREATE INDEX IF NOT EXISTS idx_rag_chat_attachment_expires
|
|
ON rag_chat_attachment(expires_at)
|
|
WHERE deleted_at IS NULL
|
|
"""
|
|
)
|
|
)
|
|
self.__class__._attachment_schema_checked = True
|
|
|
|
def _get_chroma(self) -> Any:
|
|
return self._chroma_client or get_chroma()
|
|
|
|
def _delete_collection(self, collection_name: str) -> None:
|
|
if not collection_name:
|
|
return
|
|
try:
|
|
self._get_chroma().delete_collection(collection_name)
|
|
except Exception:
|
|
return
|
|
|
|
def _delete_oss_object(self, source: str | None) -> None:
|
|
if not source:
|
|
return
|
|
try:
|
|
oss = OssClient()
|
|
ref = oss.ResolveObjectRef(Source=source)
|
|
if ref.isDirectUrl or not ref.objectKey:
|
|
return
|
|
oss._GetMinioClient().remove_object(ref.bucket, ref.objectKey)
|
|
except Exception:
|
|
return
|
|
|
|
def _to_vo(self, row: dict) -> RagChatAttachmentVO:
|
|
expires_at = row.get("expires_at")
|
|
created_at = row.get("created_at")
|
|
return RagChatAttachmentVO(
|
|
attachmentId=str(row.get("attachment_id") or ""),
|
|
conversationId=str(row.get("conversation_id") or ""),
|
|
fileName=str(row.get("original_name") or row.get("filename") or ""),
|
|
contentType=str(row.get("content_type") or ""),
|
|
fileSize=int(row.get("file_size") or 0),
|
|
indexingStatus=str(row.get("indexing_status") or "waiting"),
|
|
indexingError=row.get("indexing_error"),
|
|
chunkCount=int(row.get("chunk_count") or 0),
|
|
collectionName=str(row.get("collection_name") or ""),
|
|
expiresAt=self._timestamp(expires_at),
|
|
createdAt=self._timestamp(created_at),
|
|
)
|
|
|
|
def _preprocess_text(self, text_value: str) -> str:
|
|
result = text_value or ""
|
|
result = re.sub(r"[ \t]+", " ", result)
|
|
result = re.sub(r"\n{3,}", "\n\n", result)
|
|
return result.strip()
|
|
|
|
def _split_text(self, text_value: str, *, separator: str, max_chars: int, overlap: int) -> list[str]:
|
|
return self.dataset_helpers._split_text(text_value, separator=separator, max_chars=max_chars, overlap=overlap)
|
|
|
|
def _slugify(self, value: str) -> str:
|
|
normalized = re.sub(r"[^A-Za-z0-9_]+", "_", str(value or "").strip())
|
|
normalized = re.sub(r"_+", "_", normalized).strip("_").lower()
|
|
return normalized or "default"
|
|
|
|
def _validate_conversation_id(self, conversation_id: str | None) -> None:
|
|
if not str(conversation_id or "").strip() or str(conversation_id or "").strip() == "-1":
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "会话ID不能为空")
|
|
|
|
def _ensure_aware_datetime(self, value: Any) -> datetime:
|
|
if isinstance(value, datetime):
|
|
if value.tzinfo is None:
|
|
return value.replace(tzinfo=timezone.utc)
|
|
return value
|
|
if isinstance(value, str):
|
|
parsed = datetime.fromisoformat(value.replace("Z", "+00:00"))
|
|
if parsed.tzinfo is None:
|
|
return parsed.replace(tzinfo=timezone.utc)
|
|
return parsed
|
|
raise LeauditException(StatusCodeEnum.HTTP_400_BAD_REQUEST, "临时附件时间字段异常")
|
|
|
|
def _timestamp(self, value: Any) -> int:
|
|
if not value:
|
|
return 0
|
|
return int(self._ensure_aware_datetime(value).timestamp())
|