Files
leaudit-platform-frontend/app/hooks/dify-dataset-manager/dataset-settings.ts
T
TanWenyan d53742948d feat: 知识库设置页面增加 retrieval_model 检索配置功能
1. 召回测试页面增加 Score 阈值参数配置
2. 知识库设置页面新增检索模型配置:
   - 检索方式 (向量/全文/混合/关键字检索)
   - Reranking 模型 (默认开启,不可关闭)
   - Top K 返回数量
   - Score 阈值 (默认开启,可调节数值)
3. 修复 Dify API 字段名问题 (retrieval_model_dict)
4. 优化数据加载流程,使用详情接口获取完整配置

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
2025-12-05 22:07:16 +08:00

210 lines
7.0 KiB
TypeScript
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import { useState, useEffect, useCallback } from 'react';
import { message } from 'antd';
import type { FormInstance } from 'antd';
import type { Dataset, RetrievalModel } from '~/api/dify-dataset/type/datasetTypes';
import { updateDatasetSettings, fetchDataset } from '~/api/dify-dataset/api/datasetApi';
import { DIFY_CONFIG } from '~/config/api-config';
/**
* 检索方法类型
*/
export type SearchMethod = 'keyword_search' | 'semantic_search' | 'full_text_search' | 'hybrid_search';
/**
* 检索设置表单值
*/
export interface RetrievalSettingsFormValues {
searchMethod: SearchMethod;
topK: number;
scoreThresholdEnabled: boolean;
scoreThreshold: number;
rerankingEnable: boolean;
weights: number; // 混合检索的语义权重 (0-1)
}
/**
* 默认检索设置
*/
const DEFAULT_RETRIEVAL_SETTINGS: RetrievalSettingsFormValues = {
searchMethod: 'semantic_search',
topK: 3,
scoreThresholdEnabled: true, // 默认开启
scoreThreshold: 0.5,
rerankingEnable: true, // 默认开启
weights: 0.7,
};
/**
* 从 Dataset 的 retrieval_model 转换为表单值
*/
function retrievalModelToFormValues(model?: RetrievalModel): RetrievalSettingsFormValues {
if (!model) {
return { ...DEFAULT_RETRIEVAL_SETTINGS };
}
return {
searchMethod: model.search_method || 'semantic_search',
topK: model.top_k ?? 3,
scoreThresholdEnabled: model.score_threshold_enabled ?? false,
scoreThreshold: model.score_threshold ?? 0.5,
rerankingEnable: model.reranking_enable ?? false,
weights: model.weights ?? 0.7,
};
}
/**
* 从表单值转换为 API 请求的 retrieval_model
*/
function formValuesToRetrievalModel(values: RetrievalSettingsFormValues): RetrievalModel {
// 语义检索和混合检索需要 Reranking,强制开启
const needReranking = values.searchMethod === 'semantic_search' || values.searchMethod === 'hybrid_search';
return {
search_method: values.searchMethod,
reranking_enable: needReranking, // 强制开启,不受用户控制
reranking_mode: null,
reranking_model: {
reranking_provider_name: DIFY_CONFIG.rerankingProviderName,
reranking_model_name: DIFY_CONFIG.rerankingModelName,
},
weights: values.searchMethod === 'hybrid_search' ? values.weights : null,
top_k: values.topK,
score_threshold_enabled: true, // 强制开启,不受用户控制
score_threshold: values.scoreThreshold, // 用户可调节数值
};
}
/**
* 知识库设置状态管理 Hook
*/
export function useDatasetSettings(
dataset: Dataset | null,
form: FormInstance,
onDatasetUpdated: (dataset: Dataset) => void
) {
const [saving, setSaving] = useState(false);
const [hasChanges, setHasChanges] = useState(false);
// 检索设置状态(注意:Dify API 返回的字段名是 retrieval_model_dict
const [retrievalSettings, setRetrievalSettings] = useState<RetrievalSettingsFormValues>(
() => retrievalModelToFormValues(dataset?.retrieval_model_dict)
);
// 初始化表单数据
useEffect(() => {
if (dataset) {
form.setFieldsValue({
name: dataset.name,
description: dataset.description || '',
});
console.log('[DatasetSettings] 初始化检索设置, retrieval_model_dict:', dataset.retrieval_model_dict);
setRetrievalSettings(retrievalModelToFormValues(dataset.retrieval_model_dict));
setHasChanges(false);
}
}, [dataset, form]);
/**
* 更新检索设置
*/
const updateRetrievalSettings = useCallback(<K extends keyof RetrievalSettingsFormValues>(
key: K,
value: RetrievalSettingsFormValues[K]
) => {
setRetrievalSettings(prev => {
const newSettings = { ...prev, [key]: value };
// 检查是否有变化
checkForChanges(newSettings);
return newSettings;
});
}, [dataset]);
/**
* 检查是否有变化
*/
const checkForChanges = useCallback((newRetrievalSettings?: RetrievalSettingsFormValues) => {
const values = form.getFieldsValue();
const currentRetrieval = newRetrievalSettings || retrievalSettings;
const originalRetrieval = retrievalModelToFormValues(dataset?.retrieval_model_dict);
const nameChanged = values.name !== dataset?.name;
const retrievalChanged =
currentRetrieval.searchMethod !== originalRetrieval.searchMethod ||
currentRetrieval.topK !== originalRetrieval.topK ||
currentRetrieval.scoreThresholdEnabled !== originalRetrieval.scoreThresholdEnabled ||
currentRetrieval.scoreThreshold !== originalRetrieval.scoreThreshold ||
currentRetrieval.rerankingEnable !== originalRetrieval.rerankingEnable ||
currentRetrieval.weights !== originalRetrieval.weights;
setHasChanges(nameChanged || retrievalChanged);
}, [form, dataset, retrievalSettings]);
/**
* 处理表单值变化
*/
const handleValuesChange = useCallback(() => {
checkForChanges();
}, [checkForChanges]);
/**
* 保存设置
*/
const handleSave = useCallback(async () => {
if (!dataset) {
message.error('知识库不存在');
return;
}
try {
const values = await form.validateFields();
setSaving(true);
// 构建完整的更新请求
await updateDatasetSettings(dataset.id, {
name: values.name,
retrieval_model: formValuesToRetrievalModel(retrievalSettings),
});
// PATCH 接口返回的数据可能不完整,重新获取详情
const fullDataset = await fetchDataset(dataset.id);
console.log('[DatasetSettings] 保存后获取完整数据:', fullDataset);
message.success('保存成功');
onDatasetUpdated(fullDataset);
setHasChanges(false);
} catch (err: any) {
console.error('保存设置失败:', err);
message.error(err.message || '保存失败');
} finally {
setSaving(false);
}
}, [dataset, form, retrievalSettings, onDatasetUpdated]);
/**
* 重置表单
*/
const handleReset = useCallback(() => {
if (dataset) {
form.setFieldsValue({
name: dataset.name,
description: dataset.description || '',
});
setRetrievalSettings(retrievalModelToFormValues(dataset.retrieval_model_dict));
setHasChanges(false);
}
}, [dataset, form]);
return {
// 状态
saving,
hasChanges,
retrievalSettings,
// 方法
handleValuesChange,
handleSave,
handleReset,
updateRetrievalSettings,
};
}
export type UseDatasetSettingsReturn = ReturnType<typeof useDatasetSettings>;