195 lines
6.5 KiB
TypeScript
195 lines
6.5 KiB
TypeScript
import { message } from 'antd';
|
||
import { useCallback, useEffect, useState } from 'react';
|
||
import { fetchDataset, updateDatasetSettings } from '~/api/dify-dataset/api/datasetApi';
|
||
import type { Dataset, RetrievalModel } from '~/api/dify-dataset/type/datasetTypes';
|
||
import { DIFY_CONFIG } from '~/config/api-config';
|
||
|
||
/**
|
||
* 检索方法类型
|
||
*/
|
||
export type SearchMethod = 'keyword_search' | 'semantic_search' | 'full_text_search' | 'hybrid_search';
|
||
|
||
/**
|
||
* 检索设置表单值
|
||
*/
|
||
export interface RetrievalSettingsFormValues {
|
||
searchMethod: SearchMethod;
|
||
topK: number;
|
||
scoreThresholdEnabled: boolean;
|
||
scoreThreshold: number;
|
||
rerankingEnable: boolean;
|
||
weights: number; // 混合检索的语义权重 (0-1)
|
||
}
|
||
|
||
/**
|
||
* 默认检索设置
|
||
*/
|
||
const DEFAULT_RETRIEVAL_SETTINGS: RetrievalSettingsFormValues = {
|
||
searchMethod: 'semantic_search',
|
||
topK: 3,
|
||
scoreThresholdEnabled: true, // 默认开启
|
||
scoreThreshold: 0.5,
|
||
rerankingEnable: true, // 默认开启
|
||
weights: 0.7,
|
||
};
|
||
|
||
/**
|
||
* 从 Dataset 的 retrieval_model 转换为表单值
|
||
*/
|
||
function retrievalModelToFormValues(model?: RetrievalModel): RetrievalSettingsFormValues {
|
||
if (!model) {
|
||
return { ...DEFAULT_RETRIEVAL_SETTINGS };
|
||
}
|
||
return {
|
||
searchMethod: model.search_method || 'semantic_search',
|
||
topK: model.top_k ?? 3,
|
||
scoreThresholdEnabled: model.score_threshold_enabled ?? false,
|
||
scoreThreshold: model.score_threshold ?? 0.5,
|
||
rerankingEnable: model.reranking_enable ?? false,
|
||
weights: model.weights ?? 0.7,
|
||
};
|
||
}
|
||
|
||
/**
|
||
* 从表单值转换为 API 请求的 retrieval_model
|
||
*/
|
||
function formValuesToRetrievalModel(values: RetrievalSettingsFormValues): RetrievalModel {
|
||
// 语义检索和混合检索需要 Reranking,强制开启
|
||
const needReranking = values.searchMethod === 'semantic_search' || values.searchMethod === 'hybrid_search';
|
||
|
||
return {
|
||
search_method: values.searchMethod,
|
||
reranking_enable: needReranking, // 强制开启,不受用户控制
|
||
reranking_mode: null,
|
||
reranking_model: {
|
||
reranking_provider_name: DIFY_CONFIG.rerankingProviderName,
|
||
reranking_model_name: DIFY_CONFIG.rerankingModelName,
|
||
},
|
||
weights: values.searchMethod === 'hybrid_search' ? values.weights : null,
|
||
top_k: values.topK,
|
||
score_threshold_enabled: true, // 强制开启,不受用户控制
|
||
score_threshold: values.scoreThreshold, // 用户可调节数值
|
||
};
|
||
}
|
||
|
||
/**
|
||
* 知识库设置状态管理 Hook
|
||
* 注:Dify API 不支持修改知识库名称和描述,只支持修改检索设置
|
||
*/
|
||
export function useDatasetSettings(
|
||
dataset: Dataset | null,
|
||
onDatasetUpdated: (dataset: Dataset) => void
|
||
) {
|
||
const [saving, setSaving] = useState(false);
|
||
const [hasChanges, setHasChanges] = useState(false);
|
||
|
||
// 检索设置状态(注意:Dify API 返回的字段名是 retrieval_model_dict)
|
||
const [retrievalSettings, setRetrievalSettings] = useState<RetrievalSettingsFormValues>(
|
||
() => retrievalModelToFormValues(dataset?.retrieval_model_dict)
|
||
);
|
||
|
||
// 原始检索设置,用于对比变化
|
||
const [originalSettings, setOriginalSettings] = useState<RetrievalSettingsFormValues>(
|
||
() => retrievalModelToFormValues(dataset?.retrieval_model_dict)
|
||
);
|
||
|
||
// 初始化检索设置
|
||
useEffect(() => {
|
||
if (dataset) {
|
||
console.log('[DatasetSettings] 初始化检索设置, retrieval_model_dict:', dataset.retrieval_model_dict);
|
||
const settings = retrievalModelToFormValues(dataset.retrieval_model_dict);
|
||
setRetrievalSettings(settings);
|
||
setOriginalSettings(settings);
|
||
setHasChanges(false);
|
||
}
|
||
}, [dataset]);
|
||
|
||
/**
|
||
* 检查检索设置是否有变化
|
||
*/
|
||
const checkRetrievalChanges = useCallback((newSettings: RetrievalSettingsFormValues) => {
|
||
const hasChanged =
|
||
newSettings.searchMethod !== originalSettings.searchMethod ||
|
||
newSettings.topK !== originalSettings.topK ||
|
||
newSettings.scoreThresholdEnabled !== originalSettings.scoreThresholdEnabled ||
|
||
newSettings.scoreThreshold !== originalSettings.scoreThreshold ||
|
||
newSettings.rerankingEnable !== originalSettings.rerankingEnable ||
|
||
newSettings.weights !== originalSettings.weights;
|
||
|
||
setHasChanges(hasChanged);
|
||
}, [originalSettings]);
|
||
|
||
/**
|
||
* 更新检索设置
|
||
*/
|
||
const updateRetrievalSettings = useCallback(<K extends keyof RetrievalSettingsFormValues>(
|
||
key: K,
|
||
value: RetrievalSettingsFormValues[K]
|
||
) => {
|
||
setRetrievalSettings(prev => {
|
||
const newSettings = { ...prev, [key]: value };
|
||
// 检查是否有变化
|
||
checkRetrievalChanges(newSettings);
|
||
return newSettings;
|
||
});
|
||
}, [checkRetrievalChanges]);
|
||
|
||
/**
|
||
* 保存设置
|
||
* 注:仅保存检索设置,Dify API 不支持修改名称和描述
|
||
*/
|
||
const handleSave = useCallback(async () => {
|
||
if (!dataset) {
|
||
message.error('知识库不存在');
|
||
return;
|
||
}
|
||
|
||
try {
|
||
setSaving(true);
|
||
|
||
// 仅更新检索设置
|
||
await updateDatasetSettings(dataset.id, {
|
||
retrieval_model: formValuesToRetrievalModel(retrievalSettings),
|
||
});
|
||
|
||
// PATCH 接口返回的数据可能不完整,重新获取详情
|
||
const fullDataset = await fetchDataset(dataset.id);
|
||
console.log('[DatasetSettings] 保存后获取完整数据:', fullDataset);
|
||
|
||
message.success('检索设置保存成功');
|
||
onDatasetUpdated(fullDataset);
|
||
setOriginalSettings(retrievalSettings);
|
||
setHasChanges(false);
|
||
} catch (err: any) {
|
||
console.error('保存设置失败:', err);
|
||
message.error(err.message || '保存失败');
|
||
} finally {
|
||
setSaving(false);
|
||
}
|
||
}, [dataset, retrievalSettings, onDatasetUpdated]);
|
||
|
||
/**
|
||
* 重置检索设置
|
||
*/
|
||
const handleReset = useCallback(() => {
|
||
if (dataset) {
|
||
setRetrievalSettings(originalSettings);
|
||
setHasChanges(false);
|
||
}
|
||
}, [dataset, originalSettings]);
|
||
|
||
return {
|
||
// 状态
|
||
saving,
|
||
hasChanges,
|
||
retrievalSettings,
|
||
|
||
// 方法
|
||
handleSave,
|
||
handleReset,
|
||
updateRetrievalSettings,
|
||
};
|
||
}
|
||
|
||
export type UseDatasetSettingsReturn = ReturnType<typeof useDatasetSettings>;
|