diff --git a/fastapi_modules/fastapi_leaudit/leaudit_bridge/pipeline.py b/fastapi_modules/fastapi_leaudit/leaudit_bridge/pipeline.py index aa82ce7..2fafba0 100644 --- a/fastapi_modules/fastapi_leaudit/leaudit_bridge/pipeline.py +++ b/fastapi_modules/fastapi_leaudit/leaudit_bridge/pipeline.py @@ -24,6 +24,7 @@ from leaudit.ocr.base import BaseOCRClient from leaudit.ocr.models import OcrResult from fastapi_modules.fastapi_leaudit.leaudit_bridge.storage_adapter import StorageAdapter +from fastapi_modules.fastapi_leaudit.rag_engine.retriever import RagRetriever log = logging.getLogger(__name__) @@ -89,10 +90,12 @@ class LauditPipeline: ocr_client: BaseOCRClient, llm_client: BaseLLMClient | None = None, storage_adapter: StorageAdapter | None = None, + rag_retriever: RagRetriever | None = None, ) -> None: self.ocr_client = ocr_client self.llm_client = llm_client self.storage = storage_adapter or StorageAdapter() + self.rag_retriever = rag_retriever or RagRetriever() async def run( self, @@ -219,6 +222,7 @@ class LauditPipeline: visual_manifest=visual_manifest, phase=detected_phase, external_mocks=external_mocks, + retriever=self.rag_retriever, ) timing["evaluation"] = round(time.time() - t0, 2) log.info( diff --git a/fastapi_modules/fastapi_leaudit/rag_engine/retriever.py b/fastapi_modules/fastapi_leaudit/rag_engine/retriever.py new file mode 100644 index 0000000..a63f382 --- /dev/null +++ b/fastapi_modules/fastapi_leaudit/rag_engine/retriever.py @@ -0,0 +1,470 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass, field +from typing import Any, Awaitable, Callable + +import httpx +from sqlalchemy import text + +from fastapi_common.fastapi_common_sqlalchemy.database import GetAsyncSession +from fastapi_common.fastapi_common_web.domain.responses import StatusCodeEnum +from fastapi_common.fastapi_common_web.exception.LeauditException import LeauditException +from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma +from fastapi_modules.fastapi_leaudit.rag_engine.config import RAG_CONFIG, build_openai_embeddings_url + + +EmbedTexts = Callable[[list[str], str], Awaitable[list[list[float]]] | list[list[float]]] + + +@dataclass +class RagRetrieveResult: + rag_context: str = "" + rag_resources: list[dict[str, Any]] = field(default_factory=list) + chunks: list[dict[str, Any]] = field(default_factory=list) + dataset_name: str = "" + + def as_dict(self) -> dict[str, Any]: + return { + "rag_context": self.rag_context, + "rag_resources": self.rag_resources, + } + + +class RagRetriever: + def __init__( + self, + *, + chroma_client: Any | None = None, + embed_texts: EmbedTexts | None = None, + hydrate_documents: bool = True, + ) -> None: + self._chroma_client = chroma_client + self._embed_texts_override = embed_texts + self._hydrate_documents_enabled = hydrate_documents + + async def retrieve( + self, + query: str, + collection_name: str | None = None, + dataset_id: int | None = None, + top_k: int = 5, + source_names: list[str] | None = None, + ) -> RagRetrieveResult: + query_text = (query or "").strip() + if not query_text: + return RagRetrieveResult() + + top_k = max(1, int(top_k or 5)) + dataset = await self._load_dataset(dataset_id) if dataset_id else None + resolved_collection = (collection_name or (dataset or {}).get("collection_name") or "").strip() + if not resolved_collection: + return RagRetrieveResult(dataset_name=str((dataset or {}).get("name") or "")) + + retrieval_model = (dataset or {}).get("retrieval_model") or {} + if dataset and not collection_name: + top_k = max(1, int(retrieval_model.get("top_k") or top_k)) + score_threshold = self._resolve_score_threshold(retrieval_model) + dataset_name = str((dataset or {}).get("name") or "") + embedding_model = str((dataset or {}).get("embedding_model") or "") + + chunks: list[dict[str, Any]] = [] + try: + chunks = await self._vector_retrieve( + query=query_text, + collection_name=resolved_collection, + dataset_name=dataset_name, + embedding_model=embedding_model, + top_k=top_k, + score_threshold=score_threshold, + source_names=source_names, + ) + except Exception: + chunks = [] + + if not chunks: + try: + chunks = await self._keyword_retrieve_context( + dataset_id=dataset_id, + collection_name=resolved_collection, + dataset_name=dataset_name, + query=query_text, + top_k=top_k, + score_threshold=score_threshold, + source_names=source_names, + ) + except Exception: + chunks = [] + + if dataset_id and self._hydrate_documents_enabled: + chunks = await self._hydrate_document_hits(dataset_id, chunks) + + chunks = chunks[:top_k] + return RagRetrieveResult( + rag_context=self._build_context(chunks), + rag_resources=self.build_sources(chunks, dataset_name), + chunks=chunks, + dataset_name=dataset_name, + ) + + async def _load_dataset(self, dataset_id: int | None) -> dict[str, Any] | None: + if not dataset_id: + return None + async with GetAsyncSession() as session: + row = ( + await session.execute( + text( + """ + SELECT id, name, collection_name, retrieval_model, embedding_model + FROM rag_dataset + WHERE id = :dataset_id AND deleted_at IS NULL + LIMIT 1 + """ + ), + {"dataset_id": dataset_id}, + ) + ).mappings().first() + return dict(row) if row else None + + def _resolve_score_threshold(self, retrieval_model: dict[str, Any]) -> float | None: + if not retrieval_model.get("score_threshold_enabled"): + return None + try: + return float(retrieval_model.get("score_threshold")) + except (TypeError, ValueError): + return None + + async def _vector_retrieve( + self, + *, + query: str, + collection_name: str, + dataset_name: str, + embedding_model: str, + top_k: int, + score_threshold: float | None, + source_names: list[str] | None, + ) -> list[dict[str, Any]]: + query_embedding = await self._embed_texts([query], embedding_model) + collection = self._get_chroma().get_or_create_collection(collection_name) + result = collection.query( + query_embeddings=query_embedding, + n_results=max(top_k, 1), + include=["documents", "metadatas", "distances"], + ) + ids = (result.get("ids") or [[]])[0] if result.get("ids") else [] + docs = (result.get("documents") or [[]])[0] + metas = (result.get("metadatas") or [[]])[0] + distances = (result.get("distances") or [[]])[0] + allowed_sources = self._normalize_source_filter(source_names) + chunks: list[dict[str, Any]] = [] + for idx, doc in enumerate(docs): + meta = metas[idx] if idx < len(metas) and isinstance(metas[idx], dict) else {} + document_name = str(meta.get("document_name") or meta.get("source") or "") + if allowed_sources and document_name not in allowed_sources: + continue + dist = float(distances[idx]) if idx < len(distances) and distances[idx] is not None else 1.0 + score = 1.0 / (1.0 + max(dist, 0.0)) + if score_threshold is not None and score < score_threshold: + continue + chunks.append( + { + "id": str(ids[idx] if idx < len(ids) else meta.get("id") or idx), + "text": doc or "", + "source": meta.get("source") or meta.get("document_name") or dataset_name, + "score": score, + "chunk_index": int(meta.get("chunk_index") or idx), + "document_name": document_name, + "document_id": meta.get("document_id"), + "page": meta.get("page"), + } + ) + return chunks + + async def _embed_texts(self, texts: list[str], model_name: str = "") -> list[list[float]]: + if self._embed_texts_override is not None: + result = self._embed_texts_override(texts, model_name) + if hasattr(result, "__await__"): + return await result # type: ignore[misc] + return result + + embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or build_openai_embeddings_url(RAG_CONFIG["LLM_BASE_URL"]) + embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"] + embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4" + batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10)) + if not embed_url or not embed_key: + raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务") + + embeddings: list[list[float]] = [] + async with httpx.AsyncClient(timeout=120.0) as client: + for start in range(0, len(texts), batch_size): + batch_texts = texts[start:start + batch_size] + try: + response = await client.post( + embed_url, + headers={ + "Content-Type": "application/json", + "Authorization": f"Bearer {embed_key}", + }, + json={"model": embed_model, "input": batch_texts}, + ) + response.raise_for_status() + except httpx.HTTPStatusError as exc: + error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}" + raise LeauditException( + StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, + f"向量化服务调用失败: {error_message[:300]}", + ) from exc + + payload = response.json() + rows = payload.get("data") or [] + batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")] + if len(batch_embeddings) != len(batch_texts): + raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常") + embeddings.extend(batch_embeddings) + + if len(embeddings) != len(texts): + raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常") + return embeddings + + async def _keyword_retrieve_context( + self, + *, + dataset_id: int | None, + collection_name: str, + dataset_name: str, + query: str, + top_k: int, + score_threshold: float | None, + source_names: list[str] | None, + ) -> list[dict[str, Any]]: + collection = self._get_chroma().get_or_create_collection(collection_name) + raw = collection.get(include=["documents", "metadatas"]) + ids = raw.get("ids") or [] + docs = raw.get("documents") or [] + metas = raw.get("metadatas") or [] + allowed_sources = self._normalize_source_filter(source_names) + + terms = self._build_keyword_terms(query) + if not terms: + return [] + + scored_chunks: list[dict[str, Any]] = [] + for idx, chunk_id in enumerate(ids): + doc = docs[idx] if idx < len(docs) else "" + meta = metas[idx] if idx < len(metas) and isinstance(metas[idx], dict) else {} + document_name = str(meta.get("document_name") or meta.get("source") or "") + if allowed_sources and document_name not in allowed_sources: + continue + score = self._score_keyword_chunk( + query=query, + terms=terms, + content=doc or "", + document_name=document_name, + ) + if score <= 0: + continue + if score_threshold is not None and score < score_threshold: + continue + scored_chunks.append( + { + "id": str(chunk_id), + "text": doc or "", + "source": meta.get("source") or meta.get("document_name") or dataset_name, + "score": score, + "chunk_index": int(meta.get("chunk_index") or idx), + "document_name": document_name, + "document_id": meta.get("document_id"), + "page": meta.get("page"), + } + ) + + scored_chunks.sort(key=lambda item: (-float(item.get("score") or 0.0), int(item.get("chunk_index") or 0))) + return scored_chunks[: max(top_k * 3, top_k)] + + def _build_keyword_terms(self, query: str) -> list[str]: + normalized = self._normalize_keyword_query(query) + spans = [item.strip() for item in re.findall(r"[\u4e00-\u9fffA-Za-z0-9]+", normalized) if item.strip()] + if not spans: + return [] + + stop_terms = { + "什么", + "请问", + "一下", + "有关", + "关于", + "如何", + "哪些", + "怎么", + "是否", + "规定", + "办法", + "条例", + "法律", + } + terms: list[str] = [] + for span in spans: + if span in stop_terms: + continue + terms.append(span) + if re.fullmatch(r"[\u4e00-\u9fff]+", span): + for size in (2, 3, 4): + if len(span) > size: + for start in range(0, len(span) - size + 1): + token = span[start:start + size] + if token not in stop_terms: + terms.append(token) + + unique_terms: list[str] = [] + seen: set[str] = set() + for term in sorted(terms, key=len, reverse=True): + if term and term not in seen: + unique_terms.append(term) + seen.add(term) + return unique_terms[:20] + + def _normalize_keyword_query(self, query: str) -> str: + normalized = (query or "").strip().lower() + patterns = [ + "是什么", + "什么是", + "有哪些", + "有什么", + "是什么?", + "是什么?", + "请问", + "介绍一下", + "解释一下", + "帮我分析", + "帮我看看", + ] + for pattern in patterns: + normalized = normalized.replace(pattern, " ") + return re.sub(r"\s+", " ", normalized).strip() + + def _score_keyword_chunk(self, *, query: str, terms: list[str], content: str, document_name: str) -> float: + haystack = f"{document_name}\n{content}".lower() + if not haystack: + return 0.0 + + exact_query = self._normalize_keyword_query(query) + if exact_query and exact_query in haystack: + return 0.98 + + matched_weight = 0.0 + total_weight = 0.0 + name_bonus = 0.0 + for term in terms: + weight = float(max(len(term), 1) ** 2) + total_weight += weight + if term.lower() in haystack: + matched_weight += weight + if term.lower() in document_name.lower(): + name_bonus += min(0.15, 0.03 * len(term)) + + if total_weight <= 0: + return 0.0 + score = (matched_weight / total_weight) + name_bonus + return round(min(score, 0.99), 6) + + def _build_context(self, chunks: list[dict[str, Any]]) -> str: + lines: list[str] = [] + for index, chunk in enumerate(chunks, start=1): + document_name = chunk.get("document_name") or chunk.get("source") or "未知来源" + text_value = str(chunk.get("text") or "").strip() + if not text_value: + continue + lines.append(f"[{index}] 来源:{document_name}\n{text_value}") + return "\n\n".join(lines) + + def build_sources(self, context_chunks: list[dict[str, Any]], dataset_name: str = "") -> list[dict[str, Any]]: + return [ + { + "position": index + 1, + "dataset_id": str(chunk.get("dataset_id") or ""), + "dataset_name": dataset_name, + "document_id": str(chunk.get("document_id") or ""), + "document_name": chunk.get("document_name") or chunk.get("source", ""), + "data_source_type": "upload_file", + "segment_id": chunk.get("id", ""), + "retriever_from": "rag", + "score": round(float(chunk.get("score") or 0.0), 4), + "hit_count": chunk.get("hit_count", 0), + "word_count": len(chunk.get("text", "")), + "segment_position": index + 1, + "index_node_hash": "", + "content": chunk.get("text", "")[:500], + "page": None, + } + for index, chunk in enumerate(context_chunks) + ] + + async def _hydrate_document_hits(self, dataset_id: int, chunks: list[dict[str, Any]]) -> list[dict[str, Any]]: + source_names = sorted( + { + str(chunk.get("document_name") or chunk.get("source") or "").strip() + for chunk in chunks + if str(chunk.get("document_name") or chunk.get("source") or "").strip() + } + ) + if not source_names: + return chunks + + async with GetAsyncSession() as session: + rows = ( + await session.execute( + text( + """ + SELECT id, original_name, enabled, hit_count + FROM rag_document + WHERE dataset_id = :dataset_id + AND deleted_at IS NULL + AND original_name = ANY(:source_names) + """ + ), + { + "dataset_id": dataset_id, + "source_names": source_names, + }, + ) + ).mappings().all() + + document_map = {str(row["original_name"]): row for row in rows} + visible_chunks: list[dict[str, Any]] = [] + hit_document_ids: list[int] = [] + for chunk in chunks: + source_name = str(chunk.get("document_name") or chunk.get("source") or "").strip() + document = document_map.get(source_name) + if document and not bool(document.get("enabled")): + continue + if document: + chunk["document_id"] = document["id"] + chunk["dataset_id"] = dataset_id + chunk["document_name"] = document["original_name"] + chunk["hit_count"] = document.get("hit_count") or 0 + hit_document_ids.append(int(document["id"])) + visible_chunks.append(chunk) + + if hit_document_ids: + async with GetAsyncSession() as session: + async with session.begin(): + await session.execute( + text( + """ + UPDATE rag_document + SET hit_count = hit_count + 1, + updated_at = NOW() + WHERE id = ANY(:document_ids) + """ + ), + {"document_ids": sorted(set(hit_document_ids))}, + ) + + return visible_chunks + + def _normalize_source_filter(self, source_names: list[str] | None) -> set[str]: + return {str(name).strip() for name in (source_names or []) if str(name).strip()} + + def _get_chroma(self) -> Any: + return self._chroma_client or get_chroma() diff --git a/fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py b/fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py index bcb8641..dc66a34 100644 --- a/fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py +++ b/fastapi_modules/fastapi_leaudit/services/impl/ragChatServiceImpl.py @@ -33,11 +33,10 @@ from fastapi_modules.fastapi_leaudit.domian.vo.ragChatVo import ( from fastapi_modules.fastapi_leaudit.rag_engine.config import ( RAG_CONFIG, build_openai_chat_completions_url, - build_openai_embeddings_url, ) -from fastapi_modules.fastapi_leaudit.rag_engine.chroma_client import get_chroma from fastapi_modules.fastapi_leaudit.rag_engine.generator import generate_stream from fastapi_modules.fastapi_leaudit.rag_engine.question_chains import generate_followups +from fastapi_modules.fastapi_leaudit.rag_engine.retriever import RagRetriever from fastapi_modules.fastapi_leaudit.services.ragChatService import IRagChatService @@ -54,6 +53,9 @@ class RagChatServiceImpl(IRagChatService): _task_locks: dict[str, asyncio.Lock] = {} _title_tasks: dict[str, asyncio.Task] = {} + def __init__(self, retriever: RagRetriever | None = None) -> None: + self.retriever = retriever or RagRetriever() + async def GetApps(self, CurrentUserId: int, UserArea: str | None, UserRole: str | None) -> RagChatAppListVO: apps = await self._load_apps(UserArea, UserRole, only_default=False) return RagChatAppListVO(data=apps, total=len(apps)) @@ -592,121 +594,11 @@ class RagChatServiceImpl(IRagChatService): raise LeauditException(StatusCodeEnum.HTTP_403_FORBIDDEN, "当前用户无权访问该会话") async def _retrieve_context(self, dataset_id: int | None, query: str) -> tuple[list[dict], str]: - if not dataset_id: - return [], "" - async with GetAsyncSession() as session: - dataset = ( - await session.execute( - text( - """ - SELECT id, name, collection_name, retrieval_model, embedding_model - FROM rag_dataset - WHERE id = :dataset_id AND deleted_at IS NULL - LIMIT 1 - """ - ), - {"dataset_id": dataset_id}, - ) - ).mappings().first() - if not dataset: - return [], "" - retrieval_model = dataset.get("retrieval_model") or {} - top_k = int(retrieval_model.get("top_k") or 5) - score_threshold = None - if retrieval_model.get("score_threshold_enabled"): - try: - score_threshold = float(retrieval_model.get("score_threshold")) - except (TypeError, ValueError): - score_threshold = None - try: - query_embedding = await self._embed_texts([query], dataset.get("embedding_model") or "") - collection = get_chroma().get_or_create_collection(dataset["collection_name"]) - result = collection.query( - query_embeddings=query_embedding, - n_results=max(top_k, 1), - include=["documents", "metadatas", "distances"], - ) - ids = (result.get("ids") or [[]])[0] if result.get("ids") else [] - docs = (result.get("documents") or [[]])[0] - metas = (result.get("metadatas") or [[]])[0] - distances = (result.get("distances") or [[]])[0] - chunks: list[dict] = [] - for idx, doc in enumerate(docs): - meta = metas[idx] if idx < len(metas) else {} - dist = float(distances[idx]) if idx < len(distances) and distances[idx] is not None else 1.0 - score = 1.0 / (1.0 + max(dist, 0.0)) - if score_threshold is not None and score < score_threshold: - continue - chunks.append( - { - "id": str(ids[idx] if idx < len(ids) else meta.get("id") or idx), - "text": doc, - "source": meta.get("source") or meta.get("document_name") or dataset.get("name") or "", - "score": score, - "chunk_index": int(meta.get("chunk_index") or idx), - "document_name": meta.get("document_name") or meta.get("source") or "", - "document_id": meta.get("document_id"), - "page": meta.get("page"), - } - ) - chunks = await self._hydrate_document_hits(dataset_id, chunks) - if chunks: - return chunks[:top_k], dataset.get("name") or "" - except Exception: - pass - - try: - chunks = await self._keyword_retrieve_context( - dataset_id=dataset_id, - collection_name=str(dataset["collection_name"]), - dataset_name=str(dataset.get("name") or ""), - query=query, - top_k=top_k, - score_threshold=score_threshold, - ) - return chunks[:top_k], dataset.get("name") or "" - except Exception: - return [], dataset.get("name") or "" + result = await self.retriever.retrieve(query=query, dataset_id=dataset_id) + return result.chunks, result.dataset_name async def _embed_texts(self, texts: list[str], model_name: str) -> list[list[float]]: - embed_url = (RAG_CONFIG.get("EMBED_URL") or "").strip() or build_openai_embeddings_url(RAG_CONFIG["LLM_BASE_URL"]) - embed_key = (RAG_CONFIG.get("EMBED_KEY") or "").strip() or RAG_CONFIG["LLM_API_KEY"] - embed_model = model_name or (RAG_CONFIG.get("EMBED_MODEL") or "").strip() or "text-embedding-v4" - batch_size = max(1, int(RAG_CONFIG.get("EMBED_BATCH_SIZE") or 10)) - if not embed_url or not embed_key: - raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "未配置可用的向量化服务") - - embeddings: list[list[float]] = [] - async with httpx.AsyncClient(timeout=120.0) as client: - for start in range(0, len(texts), batch_size): - batch_texts = texts[start:start + batch_size] - try: - response = await client.post( - embed_url, - headers={ - "Content-Type": "application/json", - "Authorization": f"Bearer {embed_key}", - }, - json={"model": embed_model, "input": batch_texts}, - ) - response.raise_for_status() - except httpx.HTTPStatusError as exc: - error_message = exc.response.text.strip() or f"{exc.response.status_code} {exc.response.reason_phrase}" - raise LeauditException( - StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, - f"向量化服务调用失败: {error_message[:300]}", - ) from exc - - payload = response.json() - rows = payload.get("data") or [] - batch_embeddings = [row.get("embedding") for row in rows if isinstance(row, dict) and row.get("embedding")] - if len(batch_embeddings) != len(batch_texts): - raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常") - embeddings.extend(batch_embeddings) - - if len(embeddings) != len(texts): - raise LeauditException(StatusCodeEnum.HTTP_500_INTERNAL_SERVER_ERROR, "向量化结果数量异常") - return embeddings + return await self.retriever._embed_texts(texts, model_name) async def _start_message_task( self, @@ -1219,220 +1111,42 @@ class RagChatServiceImpl(IRagChatService): top_k: int, score_threshold: float | None, ) -> list[dict]: - collection = get_chroma().get_or_create_collection(collection_name) - raw = collection.get(include=["documents", "metadatas"]) - ids = raw.get("ids") or [] - docs = raw.get("documents") or [] - metas = raw.get("metadatas") or [] - - terms = self._build_keyword_terms(query) - if not terms: - return [] - - scored_chunks: list[dict] = [] - for idx, chunk_id in enumerate(ids): - doc = docs[idx] if idx < len(docs) else "" - meta = metas[idx] if idx < len(metas) and isinstance(metas[idx], dict) else {} - score = self._score_keyword_chunk( - query=query, - terms=terms, - content=doc or "", - document_name=str(meta.get("document_name") or meta.get("source") or ""), - ) - if score <= 0: - continue - if score_threshold is not None and score < score_threshold: - continue - scored_chunks.append( - { - "id": str(chunk_id), - "text": doc or "", - "source": meta.get("source") or meta.get("document_name") or dataset_name, - "score": score, - "chunk_index": int(meta.get("chunk_index") or idx), - "document_name": meta.get("document_name") or meta.get("source") or "", - "document_id": meta.get("document_id"), - "page": meta.get("page"), - } - ) - - scored_chunks.sort(key=lambda item: (-float(item.get("score") or 0.0), int(item.get("chunk_index") or 0))) - hydrated = await self._hydrate_document_hits(dataset_id, scored_chunks[: max(top_k * 3, top_k)]) - return hydrated[:top_k] + chunks = await self.retriever._keyword_retrieve_context( + dataset_id=dataset_id, + collection_name=collection_name, + dataset_name=dataset_name, + query=query, + top_k=top_k, + score_threshold=score_threshold, + source_names=None, + ) + return chunks[:top_k] def _build_keyword_terms(self, query: str) -> list[str]: - normalized = self._normalize_keyword_query(query) - spans = [item.strip() for item in re.findall(r"[\u4e00-\u9fffA-Za-z0-9]+", normalized) if item.strip()] - if not spans: - return [] - - stop_terms = { - "什么", - "请问", - "一下", - "有关", - "关于", - "如何", - "哪些", - "怎么", - "是否", - "规定", - "办法", - "条例", - "法律", - } - terms: list[str] = [] - for span in spans: - if span in stop_terms: - continue - terms.append(span) - if re.fullmatch(r"[\u4e00-\u9fff]+", span): - for size in (2, 3, 4): - if len(span) > size: - for start in range(0, len(span) - size + 1): - token = span[start:start + size] - if token not in stop_terms: - terms.append(token) - - unique_terms: list[str] = [] - seen: set[str] = set() - for term in sorted(terms, key=len, reverse=True): - if term and term not in seen: - unique_terms.append(term) - seen.add(term) - return unique_terms[:20] + return self.retriever._build_keyword_terms(query) def _normalize_keyword_query(self, query: str) -> str: - normalized = (query or "").strip().lower() - patterns = [ - "是什么", - "什么是", - "有哪些", - "有什么", - "是什么?", - "是什么?", - "请问", - "介绍一下", - "解释一下", - "帮我分析", - "帮我看看", - ] - for pattern in patterns: - normalized = normalized.replace(pattern, " ") - return re.sub(r"\s+", " ", normalized).strip() + return self.retriever._normalize_keyword_query(query) def _score_keyword_chunk(self, *, query: str, terms: list[str], content: str, document_name: str) -> float: - haystack = f"{document_name}\n{content}".lower() - if not haystack: - return 0.0 - - exact_query = self._normalize_keyword_query(query) - if exact_query and exact_query in haystack: - return 0.98 - - matched_weight = 0.0 - total_weight = 0.0 - name_bonus = 0.0 - for term in terms: - weight = float(max(len(term), 1) ** 2) - total_weight += weight - if term.lower() in haystack: - matched_weight += weight - if term.lower() in document_name.lower(): - name_bonus += min(0.15, 0.03 * len(term)) - - if total_weight <= 0: - return 0.0 - score = (matched_weight / total_weight) + name_bonus - return round(min(score, 0.99), 6) + return self.retriever._score_keyword_chunk( + query=query, + terms=terms, + content=content, + document_name=document_name, + ) def _format_sse(self, payload: dict) -> bytes: return f"data: {json.dumps(payload, ensure_ascii=False)}\n\n".encode("utf-8") def _build_sources(self, context_chunks: list[dict], dataset_name: str) -> list[dict]: - return [ - { - "position": index + 1, - "dataset_id": str(chunk.get("dataset_id") or ""), - "dataset_name": dataset_name, - "document_id": str(chunk.get("document_id") or ""), - "document_name": chunk.get("document_name") or chunk.get("source", ""), - "data_source_type": "upload_file", - "segment_id": chunk.get("id", ""), - "retriever_from": "rag", - "score": round(chunk.get("score", 0.0), 4), - "hit_count": chunk.get("hit_count", 0), - "word_count": len(chunk.get("text", "")), - "segment_position": index + 1, - "index_node_hash": "", - "content": chunk.get("text", "")[:500], - "page": None, - } - for index, chunk in enumerate(context_chunks) - ] + build_sources = getattr(self.retriever, "build_sources", None) + if callable(build_sources): + return build_sources(context_chunks, dataset_name) + return RagRetriever(hydrate_documents=False).build_sources(context_chunks, dataset_name) async def _hydrate_document_hits(self, dataset_id: int, chunks: list[dict]) -> list[dict]: - source_names = sorted( - { - str(chunk.get("document_name") or chunk.get("source") or "").strip() - for chunk in chunks - if str(chunk.get("document_name") or chunk.get("source") or "").strip() - } - ) - if not source_names: - return chunks - - async with GetAsyncSession() as session: - rows = ( - await session.execute( - text( - """ - SELECT id, original_name, enabled, hit_count - FROM rag_document - WHERE dataset_id = :dataset_id - AND deleted_at IS NULL - AND original_name = ANY(:source_names) - """ - ), - { - "dataset_id": dataset_id, - "source_names": source_names, - }, - ) - ).mappings().all() - - document_map = {str(row["original_name"]): row for row in rows} - visible_chunks: list[dict] = [] - hit_document_ids: list[int] = [] - for chunk in chunks: - source_name = str(chunk.get("document_name") or chunk.get("source") or "").strip() - document = document_map.get(source_name) - if document and not bool(document.get("enabled")): - continue - if document: - chunk["document_id"] = document["id"] - chunk["dataset_id"] = dataset_id - chunk["document_name"] = document["original_name"] - chunk["hit_count"] = document.get("hit_count") or 0 - hit_document_ids.append(int(document["id"])) - visible_chunks.append(chunk) - - if hit_document_ids: - async with GetAsyncSession() as session: - async with session.begin(): - await session.execute( - text( - """ - UPDATE rag_document - SET hit_count = hit_count + 1, - updated_at = NOW() - WHERE id = ANY(:document_ids) - """ - ), - {"document_ids": sorted(set(hit_document_ids))}, - ) - - return visible_chunks + return await self.retriever._hydrate_document_hits(dataset_id, chunks) def _parse_sse_event(self, chunk: str) -> dict | None: data_lines: list[str] = [] diff --git a/rules/行政处罚/rules.yaml b/rules/行政处罚/rules.yaml index cf5a576..9390e0b 100644 --- a/rules/行政处罚/rules.yaml +++ b/rules/行政处罚/rules.yaml @@ -1392,6 +1392,67 @@ rules: references_laws: - 《中华人民共和国行政处罚法》第五十九条 type: deterministic + - rule_id: JZ-JD-005 + name: 案由及裁量标准适用准确性 + desc: 结合处罚决定书认定依据、处罚依据、罚款项目和罚款金额,检索案由与裁量标准,判断处罚种类和罚款幅度是否适用准确。 + risk: medium + score: 10 + scope: + - 处罚决定书 + rag: + collection: general_legal_kb + top_k: 5 + source_names: + - 广东省烟草专卖行政处罚裁量执行标准-rag.md + - 案由_行政处罚与反走私管理治理办法.md + query_template: | + 认定依据:{{处罚决定书.认定依据}} + 处罚依据:{{处罚决定书.处罚依据}} + 罚款项目:{{处罚决定书.罚款项目}} + 罚款基数:{{处罚决定书.罚款基数}} + 罚款比例:{{处罚决定书.罚款比例}} + 罚款总额:{{处罚决定书.罚款总额}} + 问题:检索对应案由、裁量档次、处罚种类和罚款幅度 + inject_as: rag_context + resources_as: rag_resources + stages: + - id: '1' + check: required + fields: + - 处罚决定书.认定依据 + - 处罚决定书.处罚依据 + - 处罚决定书.罚款项目 + - 处罚决定书.罚款基数 + - 处罚决定书.罚款比例 + - 处罚决定书.罚款总额 + - id: '2' + check: ai + prompt: | + 请结合检索到的法律知识和卷宗处罚决定书字段,判断案由、裁量档次、处罚种类和罚款幅度是否适用准确。 + + 【检索依据】 + {{rag_context}} + + 【处罚决定书字段】 + 认定依据:{{处罚决定书.认定依据}} + 处罚依据:{{处罚决定书.处罚依据}} + 罚款项目:{{处罚决定书.罚款项目}} + 罚款基数:{{处罚决定书.罚款基数}} + 罚款比例:{{处罚决定书.罚款比例}} + 罚款总额:{{处罚决定书.罚款总额}} + + 【判断要求】 + 1. 判断违法事实对应案由是否准确; + 2. 判断处罚依据是否能支撑对应处罚种类; + 3. 判断罚款基数、比例、总额是否落在裁量标准允许幅度内; + 4. 若检索依据不足以确认,应给出 warn,不要编造依据。 + logic: 1 AND 2 + messages: + pass: 案由、裁量档次、处罚种类和罚款幅度适用准确。 + fail: 案由、裁量档次、处罚种类或罚款幅度可能适用不准确,请核对。 + references_laws: + - 《中华人民共和国行政处罚法》第五十九条 + type: ai_rule - group: JZG-SD rules: - rule_id: JZ-SD-001 diff --git a/tests/test_leaudit_rag_bridge.py b/tests/test_leaudit_rag_bridge.py new file mode 100644 index 0000000..0e27abf --- /dev/null +++ b/tests/test_leaudit_rag_bridge.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import tempfile +import sys +import types +import unittest +from pathlib import Path +from unittest.mock import patch + +from leaudit.dsl.schema import Metadata, Rule, RuleAuthoringGroup, RulesFile, Stage +from leaudit.engine.models import EvaluationResult +from leaudit.extraction.bundle import bundle_from_single +from leaudit.extraction.models import ExtractionResult, FieldValue +from leaudit.ocr.models import OcrResult, Page + +from fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline import LauditPipeline + + +class FakeOcrClient: + async def ocr(self, file_path): + return OcrResult( + pages=[Page(page_num=1, text="处罚决定书")], + full_text="处罚决定书", + ) + + +class FakeStorage: + async def update_document_status(self, document_id, status): + return None + + async def save_ocr_result(self, document_id, ocr_result): + return None + + async def save_extraction_result(self, document_id, extraction_bundle): + return None + + async def save_evaluation_results(self, document_id, rules_file, evaluation_result, extraction_bundle): + return None + + +class FakeRetriever: + pass + + +def _rules_file() -> RulesFile: + return RulesFile( + metadata=Metadata( + type_id="test.rag_bridge", + name="RAG bridge test", + version="1.0", + last_updated="2026-05-21", + ), + rules=[ + RuleAuthoringGroup( + group="测试", + rules=[ + Rule( + rule_id="R-RAG", + name="RAG 规则", + risk="medium", + score=1, + stages=[Stage(id="1", check="required", field="处罚决定书.处罚依据")], + logic="1", + messages={"pass": "ok", "fail": "missing"}, + ) + ], + ) + ], + ) + + +class LeauditRagBridgeTest(unittest.IsolatedAsyncioTestCase): + async def test_pipeline_passes_injected_rag_retriever_to_evaluation(self): + captured = {} + retriever = FakeRetriever() + field_value = FieldValue(value="依据", confidence=0.9) + object.__setattr__(field_value, "position", None) + extraction = bundle_from_single( + ExtractionResult( + fields={"处罚决定书.处罚依据": field_value}, + source_text="处罚决定书", + ) + ) + + async def fake_dispatch_extract(*args, **kwargs): + return extraction + + async def fake_determine_phase(*args, **kwargs): + return "executed" + + async def fake_evaluate_extraction(*args, **kwargs): + captured["retriever"] = kwargs.get("retriever") + return EvaluationResult() + + class TestPipeline(LauditPipeline): + async def _extract_and_save_case_number(self, document_id, ocr_result): + return None + + with tempfile.NamedTemporaryFile() as tmp: + pipeline = TestPipeline( + ocr_client=FakeOcrClient(), + storage_adapter=FakeStorage(), + rag_retriever=retriever, + ) + with patch( + "fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.dispatch_extract", + fake_dispatch_extract, + ), patch( + "fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.determine_phase", + fake_determine_phase, + ), patch( + "fastapi_modules.fastapi_leaudit.leaudit_bridge.pipeline.evaluate_extraction", + fake_evaluate_extraction, + ), patch.dict( + sys.modules, + { + "leaudit.extraction.coordinate_resolver": types.SimpleNamespace( + resolve_bundle_positions=lambda *args, **kwargs: None + ) + }, + ): + await pipeline.run(1, Path(tmp.name), _rules_file()) + + self.assertIs(captured["retriever"], retriever) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_rag_retriever.py b/tests/test_rag_retriever.py new file mode 100644 index 0000000..05bc478 --- /dev/null +++ b/tests/test_rag_retriever.py @@ -0,0 +1,102 @@ +from __future__ import annotations + +import unittest + +from fastapi_modules.fastapi_leaudit.rag_engine.retriever import RagRetriever + + +class FakeCollection: + def query(self, **kwargs): + return { + "ids": [["seg-1", "seg-2"]], + "documents": [["烟草处罚裁量标准内容", "其他来源内容"]], + "metadatas": [[ + { + "document_name": "广东省烟草专卖行政处罚裁量执行标准-rag.md", + "chunk_index": 0, + }, + { + "document_name": "其他.md", + "chunk_index": 1, + }, + ]], + "distances": [[0.0, 0.2]], + } + + +class FallbackCollection: + def query(self, **kwargs): + raise RuntimeError("vector unavailable") + + def get(self, **kwargs): + return { + "ids": ["seg-fallback"], + "documents": ["未在当地烟草专卖批发企业进货,对应裁量档次内容"], + "metadatas": [ + { + "document_name": "案由_行政处罚与反走私管理治理办法.md", + "chunk_index": 3, + } + ], + } + + +class FakeChroma: + def __init__(self): + self.collection_name = None + + def get_or_create_collection(self, name): + self.collection_name = name + return FakeCollection() + + +class FallbackChroma: + def get_or_create_collection(self, name): + return FallbackCollection() + + +class RagRetrieverTest(unittest.IsolatedAsyncioTestCase): + async def test_retrieve_from_collection_filters_sources_and_builds_resources(self): + chroma = FakeChroma() + retriever = RagRetriever( + chroma_client=chroma, + embed_texts=lambda texts, model_name="": [[0.1, 0.2] for _ in texts], + hydrate_documents=False, + ) + + result = await retriever.retrieve( + query="罚款幅度", + collection_name="general_legal_kb", + top_k=5, + source_names=["广东省烟草专卖行政处罚裁量执行标准-rag.md"], + ) + + self.assertEqual(chroma.collection_name, "general_legal_kb") + self.assertIn("烟草处罚裁量标准内容", result.rag_context) + self.assertNotIn("其他来源内容", result.rag_context) + self.assertEqual(len(result.rag_resources), 1) + self.assertEqual( + result.rag_resources[0]["document_name"], + "广东省烟草专卖行政处罚裁量执行标准-rag.md", + ) + + async def test_retrieve_uses_keyword_fallback_when_vector_search_fails(self): + retriever = RagRetriever( + chroma_client=FallbackChroma(), + embed_texts=lambda texts, model_name="": [[0.1, 0.2] for _ in texts], + hydrate_documents=False, + ) + + result = await retriever.retrieve( + query="未在当地烟草专卖批发企业进货", + collection_name="general_legal_kb", + source_names=["案由_行政处罚与反走私管理治理办法.md"], + ) + + self.assertIn("对应裁量档次内容", result.rag_context) + self.assertEqual(len(result.chunks), 1) + self.assertEqual(result.rag_resources[0]["segment_id"], "seg-fallback") + + +if __name__ == "__main__": + unittest.main()