feat: 知识库设置页面增加 retrieval_model 检索配置功能
1. 召回测试页面增加 Score 阈值参数配置 2. 知识库设置页面新增检索模型配置: - 检索方式 (向量/全文/混合/关键字检索) - Reranking 模型 (默认开启,不可关闭) - Top K 返回数量 - Score 阈值 (默认开启,可调节数值) 3. 修复 Dify API 字段名问题 (retrieval_model_dict) 4. 优化数据加载流程,使用详情接口获取完整配置 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
@@ -1,8 +1,77 @@
|
||||
import { useState, useEffect, useCallback } from 'react';
|
||||
import { message } from 'antd';
|
||||
import type { FormInstance } from 'antd';
|
||||
import type { Dataset } from '~/api/dify-dataset/type/datasetTypes';
|
||||
import { updateDatasetName } from '~/api/dify-dataset/api/datasetApi';
|
||||
import type { Dataset, RetrievalModel } from '~/api/dify-dataset/type/datasetTypes';
|
||||
import { updateDatasetSettings, fetchDataset } from '~/api/dify-dataset/api/datasetApi';
|
||||
import { DIFY_CONFIG } from '~/config/api-config';
|
||||
|
||||
/**
|
||||
* 检索方法类型
|
||||
*/
|
||||
export type SearchMethod = 'keyword_search' | 'semantic_search' | 'full_text_search' | 'hybrid_search';
|
||||
|
||||
/**
|
||||
* 检索设置表单值
|
||||
*/
|
||||
export interface RetrievalSettingsFormValues {
|
||||
searchMethod: SearchMethod;
|
||||
topK: number;
|
||||
scoreThresholdEnabled: boolean;
|
||||
scoreThreshold: number;
|
||||
rerankingEnable: boolean;
|
||||
weights: number; // 混合检索的语义权重 (0-1)
|
||||
}
|
||||
|
||||
/**
|
||||
* 默认检索设置
|
||||
*/
|
||||
const DEFAULT_RETRIEVAL_SETTINGS: RetrievalSettingsFormValues = {
|
||||
searchMethod: 'semantic_search',
|
||||
topK: 3,
|
||||
scoreThresholdEnabled: true, // 默认开启
|
||||
scoreThreshold: 0.5,
|
||||
rerankingEnable: true, // 默认开启
|
||||
weights: 0.7,
|
||||
};
|
||||
|
||||
/**
|
||||
* 从 Dataset 的 retrieval_model 转换为表单值
|
||||
*/
|
||||
function retrievalModelToFormValues(model?: RetrievalModel): RetrievalSettingsFormValues {
|
||||
if (!model) {
|
||||
return { ...DEFAULT_RETRIEVAL_SETTINGS };
|
||||
}
|
||||
return {
|
||||
searchMethod: model.search_method || 'semantic_search',
|
||||
topK: model.top_k ?? 3,
|
||||
scoreThresholdEnabled: model.score_threshold_enabled ?? false,
|
||||
scoreThreshold: model.score_threshold ?? 0.5,
|
||||
rerankingEnable: model.reranking_enable ?? false,
|
||||
weights: model.weights ?? 0.7,
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* 从表单值转换为 API 请求的 retrieval_model
|
||||
*/
|
||||
function formValuesToRetrievalModel(values: RetrievalSettingsFormValues): RetrievalModel {
|
||||
// 语义检索和混合检索需要 Reranking,强制开启
|
||||
const needReranking = values.searchMethod === 'semantic_search' || values.searchMethod === 'hybrid_search';
|
||||
|
||||
return {
|
||||
search_method: values.searchMethod,
|
||||
reranking_enable: needReranking, // 强制开启,不受用户控制
|
||||
reranking_mode: null,
|
||||
reranking_model: {
|
||||
reranking_provider_name: DIFY_CONFIG.rerankingProviderName,
|
||||
reranking_model_name: DIFY_CONFIG.rerankingModelName,
|
||||
},
|
||||
weights: values.searchMethod === 'hybrid_search' ? values.weights : null,
|
||||
top_k: values.topK,
|
||||
score_threshold_enabled: true, // 强制开启,不受用户控制
|
||||
score_threshold: values.scoreThreshold, // 用户可调节数值
|
||||
};
|
||||
}
|
||||
|
||||
/**
|
||||
* 知识库设置状态管理 Hook
|
||||
@@ -15,6 +84,11 @@ export function useDatasetSettings(
|
||||
const [saving, setSaving] = useState(false);
|
||||
const [hasChanges, setHasChanges] = useState(false);
|
||||
|
||||
// 检索设置状态(注意:Dify API 返回的字段名是 retrieval_model_dict)
|
||||
const [retrievalSettings, setRetrievalSettings] = useState<RetrievalSettingsFormValues>(
|
||||
() => retrievalModelToFormValues(dataset?.retrieval_model_dict)
|
||||
);
|
||||
|
||||
// 初始化表单数据
|
||||
useEffect(() => {
|
||||
if (dataset) {
|
||||
@@ -22,20 +96,53 @@ export function useDatasetSettings(
|
||||
name: dataset.name,
|
||||
description: dataset.description || '',
|
||||
});
|
||||
console.log('[DatasetSettings] 初始化检索设置, retrieval_model_dict:', dataset.retrieval_model_dict);
|
||||
setRetrievalSettings(retrievalModelToFormValues(dataset.retrieval_model_dict));
|
||||
setHasChanges(false);
|
||||
}
|
||||
}, [dataset, form]);
|
||||
|
||||
/**
|
||||
* 更新检索设置
|
||||
*/
|
||||
const updateRetrievalSettings = useCallback(<K extends keyof RetrievalSettingsFormValues>(
|
||||
key: K,
|
||||
value: RetrievalSettingsFormValues[K]
|
||||
) => {
|
||||
setRetrievalSettings(prev => {
|
||||
const newSettings = { ...prev, [key]: value };
|
||||
// 检查是否有变化
|
||||
checkForChanges(newSettings);
|
||||
return newSettings;
|
||||
});
|
||||
}, [dataset]);
|
||||
|
||||
/**
|
||||
* 检查是否有变化
|
||||
*/
|
||||
const checkForChanges = useCallback((newRetrievalSettings?: RetrievalSettingsFormValues) => {
|
||||
const values = form.getFieldsValue();
|
||||
const currentRetrieval = newRetrievalSettings || retrievalSettings;
|
||||
const originalRetrieval = retrievalModelToFormValues(dataset?.retrieval_model_dict);
|
||||
|
||||
const nameChanged = values.name !== dataset?.name;
|
||||
const retrievalChanged =
|
||||
currentRetrieval.searchMethod !== originalRetrieval.searchMethod ||
|
||||
currentRetrieval.topK !== originalRetrieval.topK ||
|
||||
currentRetrieval.scoreThresholdEnabled !== originalRetrieval.scoreThresholdEnabled ||
|
||||
currentRetrieval.scoreThreshold !== originalRetrieval.scoreThreshold ||
|
||||
currentRetrieval.rerankingEnable !== originalRetrieval.rerankingEnable ||
|
||||
currentRetrieval.weights !== originalRetrieval.weights;
|
||||
|
||||
setHasChanges(nameChanged || retrievalChanged);
|
||||
}, [form, dataset, retrievalSettings]);
|
||||
|
||||
/**
|
||||
* 处理表单值变化
|
||||
*/
|
||||
const handleValuesChange = useCallback(() => {
|
||||
const values = form.getFieldsValue();
|
||||
const changed =
|
||||
values.name !== dataset?.name ||
|
||||
values.description !== (dataset?.description || '');
|
||||
setHasChanges(changed);
|
||||
}, [form, dataset]);
|
||||
checkForChanges();
|
||||
}, [checkForChanges]);
|
||||
|
||||
/**
|
||||
* 保存设置
|
||||
@@ -50,11 +157,18 @@ export function useDatasetSettings(
|
||||
const values = await form.validateFields();
|
||||
setSaving(true);
|
||||
|
||||
// 目前只支持修改名称
|
||||
const updatedDataset = await updateDatasetName(dataset.id, values.name);
|
||||
// 构建完整的更新请求
|
||||
await updateDatasetSettings(dataset.id, {
|
||||
name: values.name,
|
||||
retrieval_model: formValuesToRetrievalModel(retrievalSettings),
|
||||
});
|
||||
|
||||
// PATCH 接口返回的数据可能不完整,重新获取详情
|
||||
const fullDataset = await fetchDataset(dataset.id);
|
||||
console.log('[DatasetSettings] 保存后获取完整数据:', fullDataset);
|
||||
|
||||
message.success('保存成功');
|
||||
onDatasetUpdated(updatedDataset);
|
||||
onDatasetUpdated(fullDataset);
|
||||
setHasChanges(false);
|
||||
} catch (err: any) {
|
||||
console.error('保存设置失败:', err);
|
||||
@@ -62,7 +176,7 @@ export function useDatasetSettings(
|
||||
} finally {
|
||||
setSaving(false);
|
||||
}
|
||||
}, [dataset, form, onDatasetUpdated]);
|
||||
}, [dataset, form, retrievalSettings, onDatasetUpdated]);
|
||||
|
||||
/**
|
||||
* 重置表单
|
||||
@@ -73,6 +187,7 @@ export function useDatasetSettings(
|
||||
name: dataset.name,
|
||||
description: dataset.description || '',
|
||||
});
|
||||
setRetrievalSettings(retrievalModelToFormValues(dataset.retrieval_model_dict));
|
||||
setHasChanges(false);
|
||||
}
|
||||
}, [dataset, form]);
|
||||
@@ -81,11 +196,13 @@ export function useDatasetSettings(
|
||||
// 状态
|
||||
saving,
|
||||
hasChanges,
|
||||
|
||||
retrievalSettings,
|
||||
|
||||
// 方法
|
||||
handleValuesChange,
|
||||
handleSave,
|
||||
handleReset,
|
||||
updateRetrievalSettings,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import { useState, useEffect, useCallback } from 'react';
|
||||
import { message } from 'antd';
|
||||
import type { Dataset } from '~/api/dify-dataset/type/datasetTypes';
|
||||
import type { Document } from '~/api/dify-dataset/type/documentTypes';
|
||||
import { fetchDatasets } from '~/api/dify-dataset/api/datasetApi';
|
||||
import { fetchDatasets, fetchDataset } from '~/api/dify-dataset/api/datasetApi';
|
||||
import { fetchDocuments } from '~/api/dify-dataset/api/documentApi';
|
||||
import type { MenuTab } from '~/types/dify-dataset-manager/layout';
|
||||
import { DEFAULT_DOCUMENT_PAGE_SIZE } from '~/types/dify-dataset-manager/index';
|
||||
@@ -58,20 +58,26 @@ export function useDatasetManager() {
|
||||
}, [documentPageSize]);
|
||||
|
||||
/**
|
||||
* 加载知识库(获取第一个知识库)
|
||||
* 加载知识库(获取第一个知识库,再获取详情以包含 retrieval_model)
|
||||
*/
|
||||
const loadDataset = useCallback(async () => {
|
||||
setLoadingDataset(true);
|
||||
try {
|
||||
console.log('[DatasetManager] 加载知识库...');
|
||||
// 先获取列表,找到第一个知识库的 ID
|
||||
const response = await fetchDatasets(1, 1);
|
||||
console.log('[DatasetManager] 知识库响应:', response);
|
||||
console.log('[DatasetManager] 知识库列表响应:', response);
|
||||
|
||||
if (response && response.data && response.data.length > 0) {
|
||||
const firstDataset = response.data[0];
|
||||
setDataset(firstDataset);
|
||||
const firstDatasetId = response.data[0].id;
|
||||
|
||||
// 再获取详情,包含完整的 retrieval_model 等字段
|
||||
const fullDataset = await fetchDataset(firstDatasetId);
|
||||
console.log('[DatasetManager] 知识库详情响应:', fullDataset);
|
||||
|
||||
setDataset(fullDataset);
|
||||
// 立即加载文档
|
||||
await loadDocuments(firstDataset.id, 1);
|
||||
await loadDocuments(firstDatasetId, 1);
|
||||
} else {
|
||||
setError('未找到知识库,请先在Dify中创建知识库');
|
||||
}
|
||||
|
||||
@@ -9,7 +9,12 @@ import type { SearchMethod } from '~/types/dify-dataset-manager/retrieve-test';
|
||||
* 构建完整的 retrieval_model 参数(匹配 Dify API 规范)
|
||||
* 根据检索方式启用 Reranking(语义搜索和混合搜索需要启用)
|
||||
*/
|
||||
function buildRetrievalModel(searchMethod: SearchMethod, topK: number): RetrievalModel {
|
||||
function buildRetrievalModel(
|
||||
searchMethod: SearchMethod,
|
||||
topK: number,
|
||||
scoreThresholdEnabled: boolean,
|
||||
scoreThreshold: number
|
||||
): RetrievalModel {
|
||||
// 语义搜索和混合搜索需要启用 Reranking
|
||||
const needReranking = searchMethod === 'semantic_search' || searchMethod === 'hybrid_search';
|
||||
|
||||
@@ -23,8 +28,8 @@ function buildRetrievalModel(searchMethod: SearchMethod, topK: number): Retrieva
|
||||
},
|
||||
weights: null,
|
||||
top_k: topK,
|
||||
score_threshold_enabled: false,
|
||||
score_threshold: null,
|
||||
score_threshold_enabled: scoreThresholdEnabled,
|
||||
score_threshold: scoreThresholdEnabled ? scoreThreshold : null,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -38,6 +43,9 @@ export function useRetrieveTest(datasetId: string) {
|
||||
// 默认使用语义搜索
|
||||
const [searchMethod, setSearchMethod] = useState<SearchMethod>('semantic_search');
|
||||
const [topK, setTopK] = useState<number>(5);
|
||||
// Score 阈值相关状态
|
||||
const [scoreThresholdEnabled, setScoreThresholdEnabled] = useState(false);
|
||||
const [scoreThreshold, setScoreThreshold] = useState<number>(0.5);
|
||||
|
||||
/**
|
||||
* 执行检索
|
||||
@@ -55,7 +63,7 @@ export function useRetrieveTest(datasetId: string) {
|
||||
|
||||
setRetrieving(true);
|
||||
try {
|
||||
const retrievalModel = buildRetrievalModel(searchMethod, topK);
|
||||
const retrievalModel = buildRetrievalModel(searchMethod, topK, scoreThresholdEnabled, scoreThreshold);
|
||||
console.log('[Hook] 检索参数:', { datasetId, query: searchQuery, retrievalModel });
|
||||
|
||||
const response = await retrieveDataset(datasetId, searchQuery, retrievalModel);
|
||||
@@ -69,7 +77,7 @@ export function useRetrieveTest(datasetId: string) {
|
||||
} finally {
|
||||
setRetrieving(false);
|
||||
}
|
||||
}, [datasetId, searchQuery, searchMethod, topK]);
|
||||
}, [datasetId, searchQuery, searchMethod, topK, scoreThresholdEnabled, scoreThreshold]);
|
||||
|
||||
return {
|
||||
// 状态
|
||||
@@ -81,6 +89,10 @@ export function useRetrieveTest(datasetId: string) {
|
||||
setSearchMethod,
|
||||
topK,
|
||||
setTopK,
|
||||
scoreThresholdEnabled,
|
||||
setScoreThresholdEnabled,
|
||||
scoreThreshold,
|
||||
setScoreThreshold,
|
||||
|
||||
// 方法
|
||||
handleRetrieve,
|
||||
|
||||
Reference in New Issue
Block a user