import { message } from 'antd'; import { useCallback, useEffect, useState } from 'react'; import { fetchDataset, updateDatasetSettings } from '~/api/dify-dataset/api/datasetApi'; import type { Dataset, RetrievalModel } from '~/api/dify-dataset/type/datasetTypes'; import { DIFY_CONFIG } from '~/config/api-config'; /** * 检索方法类型 */ export type SearchMethod = 'keyword_search' | 'semantic_search' | 'full_text_search' | 'hybrid_search'; /** * 检索设置表单值 */ export interface RetrievalSettingsFormValues { searchMethod: SearchMethod; topK: number; scoreThresholdEnabled: boolean; scoreThreshold: number; rerankingEnable: boolean; weights: number; // 混合检索的语义权重 (0-1) } /** * 默认检索设置 */ const DEFAULT_RETRIEVAL_SETTINGS: RetrievalSettingsFormValues = { searchMethod: 'semantic_search', topK: 3, scoreThresholdEnabled: true, // 默认开启 scoreThreshold: 0.5, rerankingEnable: true, // 默认开启 weights: 0.7, }; /** * 从 Dataset 的 retrieval_model 转换为表单值 */ function retrievalModelToFormValues(model?: RetrievalModel): RetrievalSettingsFormValues { if (!model) { return { ...DEFAULT_RETRIEVAL_SETTINGS }; } return { searchMethod: model.search_method || 'semantic_search', topK: model.top_k ?? 3, scoreThresholdEnabled: model.score_threshold_enabled ?? false, scoreThreshold: model.score_threshold ?? 0.5, rerankingEnable: model.reranking_enable ?? false, weights: model.weights ?? 0.7, }; } /** * 从表单值转换为 API 请求的 retrieval_model */ function formValuesToRetrievalModel(values: RetrievalSettingsFormValues): RetrievalModel { // 语义检索和混合检索需要 Reranking,强制开启 const needReranking = values.searchMethod === 'semantic_search' || values.searchMethod === 'hybrid_search'; return { search_method: values.searchMethod, reranking_enable: needReranking, // 强制开启,不受用户控制 reranking_mode: null, reranking_model: { reranking_provider_name: DIFY_CONFIG.rerankingProviderName, reranking_model_name: DIFY_CONFIG.rerankingModelName, }, weights: values.searchMethod === 'hybrid_search' ? values.weights : null, top_k: values.topK, score_threshold_enabled: true, // 强制开启,不受用户控制 score_threshold: values.scoreThreshold, // 用户可调节数值 }; } /** * 知识库设置状态管理 Hook * 注:Dify API 不支持修改知识库名称和描述,只支持修改检索设置 */ export function useDatasetSettings( dataset: Dataset | null, onDatasetUpdated: (dataset: Dataset) => void ) { const [saving, setSaving] = useState(false); const [hasChanges, setHasChanges] = useState(false); // 检索设置状态(注意:Dify API 返回的字段名是 retrieval_model_dict) const [retrievalSettings, setRetrievalSettings] = useState( () => retrievalModelToFormValues(dataset?.retrieval_model_dict) ); // 原始检索设置,用于对比变化 const [originalSettings, setOriginalSettings] = useState( () => retrievalModelToFormValues(dataset?.retrieval_model_dict) ); // 初始化检索设置 useEffect(() => { if (dataset) { console.log('[DatasetSettings] 初始化检索设置, retrieval_model_dict:', dataset.retrieval_model_dict); const settings = retrievalModelToFormValues(dataset.retrieval_model_dict); setRetrievalSettings(settings); setOriginalSettings(settings); setHasChanges(false); } }, [dataset]); /** * 检查检索设置是否有变化 */ const checkRetrievalChanges = useCallback((newSettings: RetrievalSettingsFormValues) => { const hasChanged = newSettings.searchMethod !== originalSettings.searchMethod || newSettings.topK !== originalSettings.topK || newSettings.scoreThresholdEnabled !== originalSettings.scoreThresholdEnabled || newSettings.scoreThreshold !== originalSettings.scoreThreshold || newSettings.rerankingEnable !== originalSettings.rerankingEnable || newSettings.weights !== originalSettings.weights; setHasChanges(hasChanged); }, [originalSettings]); /** * 更新检索设置 */ const updateRetrievalSettings = useCallback(( key: K, value: RetrievalSettingsFormValues[K] ) => { setRetrievalSettings(prev => { const newSettings = { ...prev, [key]: value }; // 检查是否有变化 checkRetrievalChanges(newSettings); return newSettings; }); }, [checkRetrievalChanges]); /** * 保存设置 * 注:仅保存检索设置,Dify API 不支持修改名称和描述 */ const handleSave = useCallback(async () => { if (!dataset) { message.error('知识库不存在'); return; } try { setSaving(true); // 仅更新检索设置 await updateDatasetSettings(dataset.id, { retrieval_model: formValuesToRetrievalModel(retrievalSettings), }); // PATCH 接口返回的数据可能不完整,重新获取详情 const fullDataset = await fetchDataset(dataset.id); console.log('[DatasetSettings] 保存后获取完整数据:', fullDataset); message.success('检索设置保存成功'); onDatasetUpdated(fullDataset); setOriginalSettings(retrievalSettings); setHasChanges(false); } catch (err: any) { console.error('保存设置失败:', err); message.error(err.message || '保存失败'); } finally { setSaving(false); } }, [dataset, retrievalSettings, onDatasetUpdated]); /** * 重置检索设置 */ const handleReset = useCallback(() => { if (dataset) { setRetrievalSettings(originalSettings); setHasChanges(false); } }, [dataset, originalSettings]); return { // 状态 saving, hasChanges, retrievalSettings, // 方法 handleSave, handleReset, updateRetrievalSettings, }; } export type UseDatasetSettingsReturn = ReturnType;