Files
leaudit-platform-frontend/app/hooks/dify-dataset-manager/retrieve-test.ts
T

91 lines
3.0 KiB
TypeScript

import { useState, useCallback } from 'react';
import { message } from 'antd';
import type { RetrieveRecord, RetrievalModel } from '~/api/dify-dataset/type';
import { retrieveDataset } from '~/api/dify-dataset/api/segmentApi';
import { DIFY_CONFIG } from '~/config/api-config';
import type { SearchMethod } from '~/types/dify-dataset-manager/retrieve-test';
/**
* 构建完整的 retrieval_model 参数(匹配 Dify API 规范)
* 根据检索方式启用 Reranking(语义搜索和混合搜索需要启用)
*/
function buildRetrievalModel(searchMethod: SearchMethod, topK: number): RetrievalModel {
// 语义搜索和混合搜索需要启用 Reranking
const needReranking = searchMethod === 'semantic_search' || searchMethod === 'hybrid_search';
return {
search_method: searchMethod,
reranking_enable: needReranking,
reranking_mode: needReranking ? null : null,
reranking_model: {
reranking_provider_name: DIFY_CONFIG.rerankingProviderName,
reranking_model_name: DIFY_CONFIG.rerankingModelName,
},
weights: null,
top_k: topK,
score_threshold_enabled: false,
score_threshold: null,
};
}
/**
* 召回测试状态管理 Hook
*/
export function useRetrieveTest(datasetId: string) {
const [searchQuery, setSearchQuery] = useState('');
const [retrieveResults, setRetrieveResults] = useState<RetrieveRecord[]>([]);
const [retrieving, setRetrieving] = useState(false);
// 默认使用语义搜索
const [searchMethod, setSearchMethod] = useState<SearchMethod>('semantic_search');
const [topK, setTopK] = useState<number>(5);
/**
* 执行检索
*/
const handleRetrieve = useCallback(async () => {
if (!searchQuery.trim()) {
message.warning('请输入检索关键词');
return;
}
if (!datasetId) {
message.warning('知识库ID不存在');
return;
}
setRetrieving(true);
try {
const retrievalModel = buildRetrievalModel(searchMethod, topK);
console.log('[Hook] 检索参数:', { datasetId, query: searchQuery, retrievalModel });
const response = await retrieveDataset(datasetId, searchQuery, retrievalModel);
setRetrieveResults(response.records || []);
if (response.records?.length === 0) {
message.info('未找到匹配的结果');
}
} catch (err: any) {
console.error('检索失败:', err);
message.error(err.message || '检索失败');
} finally {
setRetrieving(false);
}
}, [datasetId, searchQuery, searchMethod, topK]);
return {
// 状态
searchQuery,
setSearchQuery,
retrieveResults,
retrieving,
searchMethod,
setSearchMethod,
topK,
setTopK,
// 方法
handleRetrieve,
};
}
export type UseRetrieveTestReturn = ReturnType<typeof useRetrieveTest>;