import { useState, useEffect, useCallback } from 'react'; import { message } from 'antd'; import type { FormInstance } from 'antd'; import type { Dataset, RetrievalModel } from '~/api/dify-dataset/type/datasetTypes'; import { updateDatasetSettings, fetchDataset } from '~/api/dify-dataset/api/datasetApi'; import { DIFY_CONFIG } from '~/config/api-config'; /** * 检索方法类型 */ export type SearchMethod = 'keyword_search' | 'semantic_search' | 'full_text_search' | 'hybrid_search'; /** * 检索设置表单值 */ export interface RetrievalSettingsFormValues { searchMethod: SearchMethod; topK: number; scoreThresholdEnabled: boolean; scoreThreshold: number; rerankingEnable: boolean; weights: number; // 混合检索的语义权重 (0-1) } /** * 默认检索设置 */ const DEFAULT_RETRIEVAL_SETTINGS: RetrievalSettingsFormValues = { searchMethod: 'semantic_search', topK: 3, scoreThresholdEnabled: true, // 默认开启 scoreThreshold: 0.5, rerankingEnable: true, // 默认开启 weights: 0.7, }; /** * 从 Dataset 的 retrieval_model 转换为表单值 */ function retrievalModelToFormValues(model?: RetrievalModel): RetrievalSettingsFormValues { if (!model) { return { ...DEFAULT_RETRIEVAL_SETTINGS }; } return { searchMethod: model.search_method || 'semantic_search', topK: model.top_k ?? 3, scoreThresholdEnabled: model.score_threshold_enabled ?? false, scoreThreshold: model.score_threshold ?? 0.5, rerankingEnable: model.reranking_enable ?? false, weights: model.weights ?? 0.7, }; } /** * 从表单值转换为 API 请求的 retrieval_model */ function formValuesToRetrievalModel(values: RetrievalSettingsFormValues): RetrievalModel { // 语义检索和混合检索需要 Reranking,强制开启 const needReranking = values.searchMethod === 'semantic_search' || values.searchMethod === 'hybrid_search'; return { search_method: values.searchMethod, reranking_enable: needReranking, // 强制开启,不受用户控制 reranking_mode: null, reranking_model: { reranking_provider_name: DIFY_CONFIG.rerankingProviderName, reranking_model_name: DIFY_CONFIG.rerankingModelName, }, weights: values.searchMethod === 'hybrid_search' ? values.weights : null, top_k: values.topK, score_threshold_enabled: true, // 强制开启,不受用户控制 score_threshold: values.scoreThreshold, // 用户可调节数值 }; } /** * 知识库设置状态管理 Hook */ export function useDatasetSettings( dataset: Dataset | null, form: FormInstance, onDatasetUpdated: (dataset: Dataset) => void ) { const [saving, setSaving] = useState(false); const [hasChanges, setHasChanges] = useState(false); // 检索设置状态(注意:Dify API 返回的字段名是 retrieval_model_dict) const [retrievalSettings, setRetrievalSettings] = useState( () => retrievalModelToFormValues(dataset?.retrieval_model_dict) ); // 初始化表单数据 useEffect(() => { if (dataset) { form.setFieldsValue({ name: dataset.name, description: dataset.description || '', }); console.log('[DatasetSettings] 初始化检索设置, retrieval_model_dict:', dataset.retrieval_model_dict); setRetrievalSettings(retrievalModelToFormValues(dataset.retrieval_model_dict)); setHasChanges(false); } }, [dataset, form]); /** * 更新检索设置 */ const updateRetrievalSettings = useCallback(( key: K, value: RetrievalSettingsFormValues[K] ) => { setRetrievalSettings(prev => { const newSettings = { ...prev, [key]: value }; // 检查是否有变化 checkForChanges(newSettings); return newSettings; }); }, [dataset]); /** * 检查是否有变化 */ const checkForChanges = useCallback((newRetrievalSettings?: RetrievalSettingsFormValues) => { const values = form.getFieldsValue(); const currentRetrieval = newRetrievalSettings || retrievalSettings; const originalRetrieval = retrievalModelToFormValues(dataset?.retrieval_model_dict); const nameChanged = values.name !== dataset?.name; const retrievalChanged = currentRetrieval.searchMethod !== originalRetrieval.searchMethod || currentRetrieval.topK !== originalRetrieval.topK || currentRetrieval.scoreThresholdEnabled !== originalRetrieval.scoreThresholdEnabled || currentRetrieval.scoreThreshold !== originalRetrieval.scoreThreshold || currentRetrieval.rerankingEnable !== originalRetrieval.rerankingEnable || currentRetrieval.weights !== originalRetrieval.weights; setHasChanges(nameChanged || retrievalChanged); }, [form, dataset, retrievalSettings]); /** * 处理表单值变化 */ const handleValuesChange = useCallback(() => { checkForChanges(); }, [checkForChanges]); /** * 保存设置 */ const handleSave = useCallback(async () => { if (!dataset) { message.error('知识库不存在'); return; } try { const values = await form.validateFields(); setSaving(true); // 构建完整的更新请求 await updateDatasetSettings(dataset.id, { name: values.name, retrieval_model: formValuesToRetrievalModel(retrievalSettings), }); // PATCH 接口返回的数据可能不完整,重新获取详情 const fullDataset = await fetchDataset(dataset.id); console.log('[DatasetSettings] 保存后获取完整数据:', fullDataset); message.success('保存成功'); onDatasetUpdated(fullDataset); setHasChanges(false); } catch (err: any) { console.error('保存设置失败:', err); message.error(err.message || '保存失败'); } finally { setSaving(false); } }, [dataset, form, retrievalSettings, onDatasetUpdated]); /** * 重置表单 */ const handleReset = useCallback(() => { if (dataset) { form.setFieldsValue({ name: dataset.name, description: dataset.description || '', }); setRetrievalSettings(retrievalModelToFormValues(dataset.retrieval_model_dict)); setHasChanges(false); } }, [dataset, form]); return { // 状态 saving, hasChanges, retrievalSettings, // 方法 handleValuesChange, handleSave, handleReset, updateRetrievalSettings, }; } export type UseDatasetSettingsReturn = ReturnType;