feat: 知识库设置页面增加 retrieval_model 检索配置功能

1. 召回测试页面增加 Score 阈值参数配置
2. 知识库设置页面新增检索模型配置:
   - 检索方式 (向量/全文/混合/关键字检索)
   - Reranking 模型 (默认开启,不可关闭)
   - Top K 返回数量
   - Score 阈值 (默认开启,可调节数值)
3. 修复 Dify API 字段名问题 (retrieval_model_dict)
4. 优化数据加载流程,使用详情接口获取完整配置

🤖 Generated with [Claude Code](https://claude.com/claude-code)

Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
2025-12-05 22:07:16 +08:00
parent 5f9ce2fe9f
commit d53742948d
9 changed files with 477 additions and 65 deletions
@@ -1,8 +1,77 @@
import { useState, useEffect, useCallback } from 'react';
import { message } from 'antd';
import type { FormInstance } from 'antd';
import type { Dataset } from '~/api/dify-dataset/type/datasetTypes';
import { updateDatasetName } from '~/api/dify-dataset/api/datasetApi';
import type { Dataset, RetrievalModel } from '~/api/dify-dataset/type/datasetTypes';
import { updateDatasetSettings, fetchDataset } from '~/api/dify-dataset/api/datasetApi';
import { DIFY_CONFIG } from '~/config/api-config';
/**
* 检索方法类型
*/
export type SearchMethod = 'keyword_search' | 'semantic_search' | 'full_text_search' | 'hybrid_search';
/**
* 检索设置表单值
*/
export interface RetrievalSettingsFormValues {
searchMethod: SearchMethod;
topK: number;
scoreThresholdEnabled: boolean;
scoreThreshold: number;
rerankingEnable: boolean;
weights: number; // 混合检索的语义权重 (0-1)
}
/**
* 默认检索设置
*/
const DEFAULT_RETRIEVAL_SETTINGS: RetrievalSettingsFormValues = {
searchMethod: 'semantic_search',
topK: 3,
scoreThresholdEnabled: true, // 默认开启
scoreThreshold: 0.5,
rerankingEnable: true, // 默认开启
weights: 0.7,
};
/**
* 从 Dataset 的 retrieval_model 转换为表单值
*/
function retrievalModelToFormValues(model?: RetrievalModel): RetrievalSettingsFormValues {
if (!model) {
return { ...DEFAULT_RETRIEVAL_SETTINGS };
}
return {
searchMethod: model.search_method || 'semantic_search',
topK: model.top_k ?? 3,
scoreThresholdEnabled: model.score_threshold_enabled ?? false,
scoreThreshold: model.score_threshold ?? 0.5,
rerankingEnable: model.reranking_enable ?? false,
weights: model.weights ?? 0.7,
};
}
/**
* 从表单值转换为 API 请求的 retrieval_model
*/
function formValuesToRetrievalModel(values: RetrievalSettingsFormValues): RetrievalModel {
// 语义检索和混合检索需要 Reranking,强制开启
const needReranking = values.searchMethod === 'semantic_search' || values.searchMethod === 'hybrid_search';
return {
search_method: values.searchMethod,
reranking_enable: needReranking, // 强制开启,不受用户控制
reranking_mode: null,
reranking_model: {
reranking_provider_name: DIFY_CONFIG.rerankingProviderName,
reranking_model_name: DIFY_CONFIG.rerankingModelName,
},
weights: values.searchMethod === 'hybrid_search' ? values.weights : null,
top_k: values.topK,
score_threshold_enabled: true, // 强制开启,不受用户控制
score_threshold: values.scoreThreshold, // 用户可调节数值
};
}
/**
* 知识库设置状态管理 Hook
@@ -15,6 +84,11 @@ export function useDatasetSettings(
const [saving, setSaving] = useState(false);
const [hasChanges, setHasChanges] = useState(false);
// 检索设置状态(注意:Dify API 返回的字段名是 retrieval_model_dict
const [retrievalSettings, setRetrievalSettings] = useState<RetrievalSettingsFormValues>(
() => retrievalModelToFormValues(dataset?.retrieval_model_dict)
);
// 初始化表单数据
useEffect(() => {
if (dataset) {
@@ -22,20 +96,53 @@ export function useDatasetSettings(
name: dataset.name,
description: dataset.description || '',
});
console.log('[DatasetSettings] 初始化检索设置, retrieval_model_dict:', dataset.retrieval_model_dict);
setRetrievalSettings(retrievalModelToFormValues(dataset.retrieval_model_dict));
setHasChanges(false);
}
}, [dataset, form]);
/**
* 更新检索设置
*/
const updateRetrievalSettings = useCallback(<K extends keyof RetrievalSettingsFormValues>(
key: K,
value: RetrievalSettingsFormValues[K]
) => {
setRetrievalSettings(prev => {
const newSettings = { ...prev, [key]: value };
// 检查是否有变化
checkForChanges(newSettings);
return newSettings;
});
}, [dataset]);
/**
* 检查是否有变化
*/
const checkForChanges = useCallback((newRetrievalSettings?: RetrievalSettingsFormValues) => {
const values = form.getFieldsValue();
const currentRetrieval = newRetrievalSettings || retrievalSettings;
const originalRetrieval = retrievalModelToFormValues(dataset?.retrieval_model_dict);
const nameChanged = values.name !== dataset?.name;
const retrievalChanged =
currentRetrieval.searchMethod !== originalRetrieval.searchMethod ||
currentRetrieval.topK !== originalRetrieval.topK ||
currentRetrieval.scoreThresholdEnabled !== originalRetrieval.scoreThresholdEnabled ||
currentRetrieval.scoreThreshold !== originalRetrieval.scoreThreshold ||
currentRetrieval.rerankingEnable !== originalRetrieval.rerankingEnable ||
currentRetrieval.weights !== originalRetrieval.weights;
setHasChanges(nameChanged || retrievalChanged);
}, [form, dataset, retrievalSettings]);
/**
* 处理表单值变化
*/
const handleValuesChange = useCallback(() => {
const values = form.getFieldsValue();
const changed =
values.name !== dataset?.name ||
values.description !== (dataset?.description || '');
setHasChanges(changed);
}, [form, dataset]);
checkForChanges();
}, [checkForChanges]);
/**
* 保存设置
@@ -50,11 +157,18 @@ export function useDatasetSettings(
const values = await form.validateFields();
setSaving(true);
// 目前只支持修改名称
const updatedDataset = await updateDatasetName(dataset.id, values.name);
// 构建完整的更新请求
await updateDatasetSettings(dataset.id, {
name: values.name,
retrieval_model: formValuesToRetrievalModel(retrievalSettings),
});
// PATCH 接口返回的数据可能不完整,重新获取详情
const fullDataset = await fetchDataset(dataset.id);
console.log('[DatasetSettings] 保存后获取完整数据:', fullDataset);
message.success('保存成功');
onDatasetUpdated(updatedDataset);
onDatasetUpdated(fullDataset);
setHasChanges(false);
} catch (err: any) {
console.error('保存设置失败:', err);
@@ -62,7 +176,7 @@ export function useDatasetSettings(
} finally {
setSaving(false);
}
}, [dataset, form, onDatasetUpdated]);
}, [dataset, form, retrievalSettings, onDatasetUpdated]);
/**
* 重置表单
@@ -73,6 +187,7 @@ export function useDatasetSettings(
name: dataset.name,
description: dataset.description || '',
});
setRetrievalSettings(retrievalModelToFormValues(dataset.retrieval_model_dict));
setHasChanges(false);
}
}, [dataset, form]);
@@ -81,11 +196,13 @@ export function useDatasetSettings(
// 状态
saving,
hasChanges,
retrievalSettings,
// 方法
handleValuesChange,
handleSave,
handleReset,
updateRetrievalSettings,
};
}