import { useState, useCallback } from 'react'; import { message } from 'antd'; import type { RetrieveRecord, RetrievalModel } from '~/api/dify-dataset/type'; import { retrieveDataset } from '~/api/dify-dataset/api/segmentApi'; import { DIFY_CONFIG } from '~/config/api-config'; import type { SearchMethod } from '~/types/dify-dataset-manager/retrieve-test'; /** * 构建完整的 retrieval_model 参数(匹配 Dify API 规范) * 根据检索方式启用 Reranking(语义搜索和混合搜索需要启用) */ function buildRetrievalModel( searchMethod: SearchMethod, topK: number, scoreThresholdEnabled: boolean, scoreThreshold: number ): RetrievalModel { // 语义搜索和混合搜索需要启用 Reranking const needReranking = searchMethod === 'semantic_search' || searchMethod === 'hybrid_search'; return { search_method: searchMethod, reranking_enable: needReranking, reranking_mode: needReranking ? null : null, reranking_model: { reranking_provider_name: DIFY_CONFIG.rerankingProviderName, reranking_model_name: DIFY_CONFIG.rerankingModelName, }, weights: null, top_k: topK, score_threshold_enabled: scoreThresholdEnabled, score_threshold: scoreThresholdEnabled ? scoreThreshold : null, }; } /** * 召回测试状态管理 Hook */ export function useRetrieveTest(datasetId: string) { const [searchQuery, setSearchQuery] = useState(''); const [retrieveResults, setRetrieveResults] = useState([]); const [retrieving, setRetrieving] = useState(false); // 默认使用语义搜索 const [searchMethod, setSearchMethod] = useState('semantic_search'); const [topK, setTopK] = useState(5); // Score 阈值相关状态 const [scoreThresholdEnabled, setScoreThresholdEnabled] = useState(false); const [scoreThreshold, setScoreThreshold] = useState(0.5); /** * 执行检索 */ const handleRetrieve = useCallback(async () => { if (!searchQuery.trim()) { message.warning('请输入检索关键词'); return; } if (!datasetId) { message.warning('知识库ID不存在'); return; } setRetrieving(true); try { const retrievalModel = buildRetrievalModel(searchMethod, topK, scoreThresholdEnabled, scoreThreshold); console.log('[Hook] 检索参数:', { datasetId, query: searchQuery, retrievalModel }); const response = await retrieveDataset(datasetId, searchQuery, retrievalModel); setRetrieveResults(response.records || []); if (response.records?.length === 0) { message.info('未找到匹配的结果'); } } catch (err: any) { console.error('检索失败:', err); message.error(err.message || '检索失败'); } finally { setRetrieving(false); } }, [datasetId, searchQuery, searchMethod, topK, scoreThresholdEnabled, scoreThreshold]); return { // 状态 searchQuery, setSearchQuery, retrieveResults, retrieving, searchMethod, setSearchMethod, topK, setTopK, scoreThresholdEnabled, setScoreThresholdEnabled, scoreThreshold, setScoreThreshold, // 方法 handleRetrieve, }; } export type UseRetrieveTestReturn = ReturnType;