Merge pull request 'fix: restore rag chat retrieval sources and follow-up metadata' (#3) from fix/chat-rag-retrieval-and-history into main
Merge pull request #3 from fix/chat-rag-retrieval-and-history fix: restore rag chat retrieval sources and follow-up metadata
This commit was merged in pull request #3.
This commit is contained in:
@@ -34,6 +34,7 @@ class RagMessageItemVO(BaseModel):
|
|||||||
answer: str = Field(...)
|
answer: str = Field(...)
|
||||||
feedback: dict | None = Field(None)
|
feedback: dict | None = Field(None)
|
||||||
retrieverResources: list[dict] | None = Field(None)
|
retrieverResources: list[dict] | None = Field(None)
|
||||||
|
suggestedQuestions: list[str] = Field(default_factory=list)
|
||||||
createdAt: int = Field(0)
|
createdAt: int = Field(0)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import json
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator
|
||||||
|
|
||||||
|
import httpx
|
||||||
from sqlalchemy import text
|
from sqlalchemy import text
|
||||||
|
|
||||||
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession
|
||||||
@@ -25,6 +26,7 @@ from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import (
|
|||||||
RagMessagePageVO,
|
RagMessagePageVO,
|
||||||
RagOperationResultVO,
|
RagOperationResultVO,
|
||||||
)
|
)
|
||||||
|
from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG
|
||||||
from fastapi_modules.fastapi_leaudit.rag_engine.generator import generate_stream
|
from fastapi_modules.fastapi_leaudit.rag_engine.generator import generate_stream
|
||||||
from fastapi_modules.fastapi_leaudit.rag_engine.question_chains import generate_followups
|
from fastapi_modules.fastapi_leaudit.rag_engine.question_chains import generate_followups
|
||||||
from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService
|
from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService
|
||||||
@@ -194,7 +196,7 @@ class RagChatServiceImpl(IRagChatService):
|
|||||||
await session.execute(
|
await session.execute(
|
||||||
text(
|
text(
|
||||||
"""
|
"""
|
||||||
SELECT message_id, role, content, sources, feedback, created_at
|
SELECT message_id, role, content, sources, metadata, feedback, created_at
|
||||||
FROM rag_message
|
FROM rag_message
|
||||||
WHERE conversation_id = :conversation_id
|
WHERE conversation_id = :conversation_id
|
||||||
ORDER BY created_at ASC
|
ORDER BY created_at ASC
|
||||||
@@ -216,6 +218,11 @@ class RagChatServiceImpl(IRagChatService):
|
|||||||
row = items[idx]
|
row = items[idx]
|
||||||
if row["role"] == "user":
|
if row["role"] == "user":
|
||||||
answer = items[idx + 1] if idx + 1 < len(items) and items[idx + 1]["role"] == "assistant" else None
|
answer = items[idx + 1] if idx + 1 < len(items) and items[idx + 1]["role"] == "assistant" else None
|
||||||
|
answer_sources = self._parse_json_field(answer.get("sources")) if answer else []
|
||||||
|
answer_metadata = self._parse_json_field(answer.get("metadata")) if answer else {}
|
||||||
|
suggested_questions = answer_metadata.get("suggested_questions") if isinstance(answer_metadata, dict) else []
|
||||||
|
if not isinstance(suggested_questions, list):
|
||||||
|
suggested_questions = []
|
||||||
data.append(
|
data.append(
|
||||||
RagMessageItemVO(
|
RagMessageItemVO(
|
||||||
id=(answer["message_id"] if answer else row["message_id"]),
|
id=(answer["message_id"] if answer else row["message_id"]),
|
||||||
@@ -223,7 +230,8 @@ class RagChatServiceImpl(IRagChatService):
|
|||||||
query=row["content"],
|
query=row["content"],
|
||||||
answer=answer["content"] if answer else "",
|
answer=answer["content"] if answer else "",
|
||||||
feedback=({"rating": answer["feedback"]} if answer and answer.get("feedback") else None),
|
feedback=({"rating": answer["feedback"]} if answer and answer.get("feedback") else None),
|
||||||
retrieverResources=(answer.get("sources") if answer else None),
|
retrieverResources=answer_sources or None,
|
||||||
|
suggestedQuestions=[str(item) for item in suggested_questions],
|
||||||
createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0,
|
createdAt=int(row["created_at"].timestamp()) if row.get("created_at") else 0,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -392,6 +400,18 @@ class RagChatServiceImpl(IRagChatService):
|
|||||||
area = row.get("area") or ""
|
area = row.get("area") or ""
|
||||||
return area in ("", "省级", user_area or "") or bool(row.get("dataset_public"))
|
return area in ("", "省级", user_area or "") or bool(row.get("dataset_public"))
|
||||||
|
|
||||||
|
def _parse_json_field(self, value):
|
||||||
|
if value is None:
|
||||||
|
return {}
|
||||||
|
if isinstance(value, (dict, list)):
|
||||||
|
return value
|
||||||
|
if isinstance(value, str):
|
||||||
|
try:
|
||||||
|
return json.loads(value)
|
||||||
|
except Exception:
|
||||||
|
return {}
|
||||||
|
return {}
|
||||||
|
|
||||||
async def _ensure_conversation(self, user_id: int, conversation_id: str | None, app_id: int | None) -> str:
|
async def _ensure_conversation(self, user_id: int, conversation_id: str | None, app_id: int | None) -> str:
|
||||||
if conversation_id and conversation_id != "-1":
|
if conversation_id and conversation_id != "-1":
|
||||||
async with GetAsyncSession() as session:
|
async with GetAsyncSession() as session:
|
||||||
@@ -450,7 +470,7 @@ class RagChatServiceImpl(IRagChatService):
|
|||||||
await session.execute(
|
await session.execute(
|
||||||
text(
|
text(
|
||||||
"""
|
"""
|
||||||
SELECT id, name, collection_name, retrieval_model
|
SELECT id, name, collection_name, retrieval_model, embedding_model
|
||||||
FROM rag_dataset
|
FROM rag_dataset
|
||||||
WHERE id = :dataset_id AND deleted_at IS NULL
|
WHERE id = :dataset_id AND deleted_at IS NULL
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
@@ -475,7 +495,12 @@ class RagChatServiceImpl(IRagChatService):
|
|||||||
return [], dataset.get("name") or ""
|
return [], dataset.get("name") or ""
|
||||||
try:
|
try:
|
||||||
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
collection = get_chroma().get_or_create_collection(dataset["collection_name"])
|
||||||
result = collection.query(query_texts=[query], n_results=max(top_k, 1))
|
query_embedding = await self._embed_texts([query], dataset.get("embedding_model") or "")
|
||||||
|
result = collection.query(
|
||||||
|
query_embeddings=query_embedding,
|
||||||
|
n_results=max(top_k, 1),
|
||||||
|
include=["documents", "metadatas", "distances"],
|
||||||
|
)
|
||||||
docs = (result.get("documents") or [[]])[0]
|
docs = (result.get("documents") or [[]])[0]
|
||||||
metas = (result.get("metadatas") or [[]])[0]
|
metas = (result.get("metadatas") or [[]])[0]
|
||||||
distances = (result.get("distances") or [[]])[0]
|
distances = (result.get("distances") or [[]])[0]
|
||||||
@@ -483,7 +508,8 @@ class RagChatServiceImpl(IRagChatService):
|
|||||||
for idx, doc in enumerate(docs):
|
for idx, doc in enumerate(docs):
|
||||||
meta = metas[idx] if idx < len(metas) else {}
|
meta = metas[idx] if idx < len(metas) else {}
|
||||||
dist = distances[idx] if idx < len(distances) else 0.0
|
dist = distances[idx] if idx < len(distances) else 0.0
|
||||||
score = 1 - float(dist or 0.0)
|
distance = max(0.0, float(dist or 0.0))
|
||||||
|
score = 1.0 / (1.0 + distance)
|
||||||
if score_threshold is not None and score < score_threshold:
|
if score_threshold is not None and score < score_threshold:
|
||||||
continue
|
continue
|
||||||
chunks.append(
|
chunks.append(
|
||||||
@@ -501,6 +527,46 @@ class RagChatServiceImpl(IRagChatService):
|
|||||||
except Exception:
|
except Exception:
|
||||||
return [], dataset.get("name") or ""
|
return [], dataset.get("name") or ""
|
||||||
|
|
||||||
|
async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]:
|
||||||
|
embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or f"{RAG_CONFIG['LLM_BASE_URL'].rstrip('/')}/embeddings"
|
||||||
|
embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"]
|
||||||
|
embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4"
|
||||||
|
batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10))
|
||||||
|
if not embed_url or not embed_key:
|
||||||
|
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务")
|
||||||
|
|
||||||
|
embeddings: list[list[float]] = []
|
||||||
|
async with httpx.AsyncClient(timeout=120.0) as client:
|
||||||
|
for start in range(0, len(texts), batch_size):
|
||||||
|
batch_texts = texts[start:start + batch_size]
|
||||||
|
try:
|
||||||
|
response = await client.post(
|
||||||
|
embed_url,
|
||||||
|
headers={
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
"Authorization": f"Bearer {embed_key}",
|
||||||
|
},
|
||||||
|
json={"model": embed_model, "input": batch_texts},
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
except httpx.HTTPStatusError as exc:
|
||||||
|
error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}"
|
||||||
|
raise LeauditException(
|
||||||
|
StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||||
|
f"向量化服务调用失败: {error_message[:300]}",
|
||||||
|
) from exc
|
||||||
|
|
||||||
|
payload = response.json()
|
||||||
|
rows = payload.get("data") or []
|
||||||
|
batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")]
|
||||||
|
if len(batch_embeddings) != len(batch_texts):
|
||||||
|
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
|
||||||
|
embeddings.extend(batch_embeddings)
|
||||||
|
|
||||||
|
if len(embeddings) != len(texts):
|
||||||
|
raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常")
|
||||||
|
return embeddings
|
||||||
|
|
||||||
def _build_sources(self, context_chunks: list[dict], dataset_name: str) -> list[dict]:
|
def _build_sources(self, context_chunks: list[dict], dataset_name: str) -> list[dict]:
|
||||||
return [
|
return [
|
||||||
{
|
{
|
||||||
|
|||||||
@@ -1186,7 +1186,7 @@ class RagDatasetServiceImpl(IRagDatasetService):
|
|||||||
content = documents[index] if index < len(documents) else ""
|
content = documents[index] if index < len(documents) else ""
|
||||||
metadata = metadatas[index] if index < len(metadatas) and isinstance(metadatas[index], dict) else {}
|
metadata = metadatas[index] if index < len(metadatas) and isinstance(metadatas[index], dict) else {}
|
||||||
distance = float(distances[index]) if index < len(distances) and distances[index] is not None else 1.0
|
distance = float(distances[index]) if index < len(distances) and distances[index] is not None else 1.0
|
||||||
score = max(0.0, min(1.0, 1.0 - distance))
|
score = max(0.0, min(1.0, 1.0 / (1.0 + max(0.0, distance))))
|
||||||
if score_threshold_enabled and score_threshold is not None and score < score_threshold:
|
if score_threshold_enabled and score_threshold is not None and score < score_threshold:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user