Files
leaudit-platform-frontend/app/hooks/dify-dataset-manager/dataset-settings.ts
T

195 lines
6.5 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import { message } from 'antd';
import { useCallback, useEffect, useState } from 'react';
import { fetchDataset, updateDatasetSettings } from '~/api/dify-dataset/api/datasetApi';
import type { Dataset, RetrievalModel } from '~/api/dify-dataset/type/datasetTypes';
import { DIFY_CONFIG } from '~/config/api-config';
/**
* 检索方法类型
*/
export type SearchMethod = 'keyword_search' | 'semantic_search' | 'full_text_search' | 'hybrid_search';
/**
* 检索设置表单值
*/
export interface RetrievalSettingsFormValues {
searchMethod: SearchMethod;
topK: number;
scoreThresholdEnabled: boolean;
scoreThreshold: number;
rerankingEnable: boolean;
weights: number; // 混合检索的语义权重 (0-1)
}
/**
* 默认检索设置
*/
const DEFAULT_RETRIEVAL_SETTINGS: RetrievalSettingsFormValues = {
searchMethod: 'semantic_search',
topK: 3,
scoreThresholdEnabled: true, // 默认开启
scoreThreshold: 0.5,
rerankingEnable: true, // 默认开启
weights: 0.7,
};
/**
* 从 Dataset 的 retrieval_model 转换为表单值
*/
function retrievalModelToFormValues(model?: RetrievalModel): RetrievalSettingsFormValues {
if (!model) {
return { ...DEFAULT_RETRIEVAL_SETTINGS };
}
return {
searchMethod: model.search_method || 'semantic_search',
topK: model.top_k ?? 3,
scoreThresholdEnabled: model.score_threshold_enabled ?? false,
scoreThreshold: model.score_threshold ?? 0.5,
rerankingEnable: model.reranking_enable ?? false,
weights: model.weights ?? 0.7,
};
}
/**
* 从表单值转换为 API 请求的 retrieval_model
*/
function formValuesToRetrievalModel(values: RetrievalSettingsFormValues): RetrievalModel {
// 语义检索和混合检索需要 Reranking,强制开启
const needReranking = values.searchMethod === 'semantic_search' || values.searchMethod === 'hybrid_search';
return {
search_method: values.searchMethod,
reranking_enable: needReranking, // 强制开启,不受用户控制
reranking_mode: null,
reranking_model: {
reranking_provider_name: DIFY_CONFIG.rerankingProviderName,
reranking_model_name: DIFY_CONFIG.rerankingModelName,
},
weights: values.searchMethod === 'hybrid_search' ? values.weights : null,
top_k: values.topK,
score_threshold_enabled: true, // 强制开启,不受用户控制
score_threshold: values.scoreThreshold, // 用户可调节数值
};
}
/**
* 知识库设置状态管理 Hook
* 注:Dify API 不支持修改知识库名称和描述,只支持修改检索设置
*/
export function useDatasetSettings(
dataset: Dataset | null,
onDatasetUpdated: (dataset: Dataset) => void
) {
const [saving, setSaving] = useState(false);
const [hasChanges, setHasChanges] = useState(false);
// 检索设置状态(注意:Dify API 返回的字段名是 retrieval_model_dict
const [retrievalSettings, setRetrievalSettings] = useState<RetrievalSettingsFormValues>(
() => retrievalModelToFormValues(dataset?.retrieval_model_dict)
);
// 原始检索设置,用于对比变化
const [originalSettings, setOriginalSettings] = useState<RetrievalSettingsFormValues>(
() => retrievalModelToFormValues(dataset?.retrieval_model_dict)
);
// 初始化检索设置
useEffect(() => {
if (dataset) {
console.log('[DatasetSettings] 初始化检索设置, retrieval_model_dict:', dataset.retrieval_model_dict);
const settings = retrievalModelToFormValues(dataset.retrieval_model_dict);
setRetrievalSettings(settings);
setOriginalSettings(settings);
setHasChanges(false);
}
}, [dataset]);
/**
* 检查检索设置是否有变化
*/
const checkRetrievalChanges = useCallback((newSettings: RetrievalSettingsFormValues) => {
const hasChanged =
newSettings.searchMethod !== originalSettings.searchMethod ||
newSettings.topK !== originalSettings.topK ||
newSettings.scoreThresholdEnabled !== originalSettings.scoreThresholdEnabled ||
newSettings.scoreThreshold !== originalSettings.scoreThreshold ||
newSettings.rerankingEnable !== originalSettings.rerankingEnable ||
newSettings.weights !== originalSettings.weights;
setHasChanges(hasChanged);
}, [originalSettings]);
/**
* 更新检索设置
*/
const updateRetrievalSettings = useCallback(<K extends keyof RetrievalSettingsFormValues>(
key: K,
value: RetrievalSettingsFormValues[K]
) => {
setRetrievalSettings(prev => {
const newSettings = { ...prev, [key]: value };
// 检查是否有变化
checkRetrievalChanges(newSettings);
return newSettings;
});
}, [checkRetrievalChanges]);
/**
* 保存设置
* 注:仅保存检索设置,Dify API 不支持修改名称和描述
*/
const handleSave = useCallback(async () => {
if (!dataset) {
message.error('知识库不存在');
return;
}
try {
setSaving(true);
// 仅更新检索设置
await updateDatasetSettings(dataset.id, {
retrieval_model: formValuesToRetrievalModel(retrievalSettings),
});
// PATCH 接口返回的数据可能不完整,重新获取详情
const fullDataset = await fetchDataset(dataset.id);
console.log('[DatasetSettings] 保存后获取完整数据:', fullDataset);
message.success('检索设置保存成功');
onDatasetUpdated(fullDataset);
setOriginalSettings(retrievalSettings);
setHasChanges(false);
} catch (err: any) {
console.error('保存设置失败:', err);
message.error(err.message || '保存失败');
} finally {
setSaving(false);
}
}, [dataset, retrievalSettings, onDatasetUpdated]);
/**
* 重置检索设置
*/
const handleReset = useCallback(() => {
if (dataset) {
setRetrievalSettings(originalSettings);
setHasChanges(false);
}
}, [dataset, originalSettings]);
return {
// 状态
saving,
hasChanges,
retrievalSettings,
// 方法
handleSave,
handleReset,
updateRetrievalSettings,
};
}
export type UseDatasetSettingsReturn = ReturnType<typeof useDatasetSettings>;