From 3bdaf2dcaecb3c2b61dcc69202df7440cae3fc03 Mon Sep 17 00:00:00 2001 From: Yi Date: Thu, 9 Jan 2025 16:48:02 +0800 Subject: [PATCH] chore: avoid unnecessary loading for the model selector in agent node --- .../agent-model-trigger.tsx | 87 ++++++++++++------- 1 file changed, 58 insertions(+), 29 deletions(-) diff --git a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/agent-model-trigger.tsx b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/agent-model-trigger.tsx index 52b73924cc..4ee772e692 100644 --- a/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/agent-model-trigger.tsx +++ b/web/app/components/header/account-setting/model-provider-page/model-parameter-modal/agent-model-trigger.tsx @@ -67,42 +67,71 @@ const AgentModelTrigger: FC = ({ } }, [modelProviders, providerName]) const [pluginInfo, setPluginInfo] = useState(null) - const [isPluginChecked, setIsPluginChecked] = useState(false) + const [isPluginChecked, setIsPluginChecked] = useState(!!modelProvider) const [installed, setInstalled] = useState(false) const [inModelList, setInModelList] = useState(false) const invalidateInstalledPluginList = useInvalidateInstalledPluginList() const handleOpenModal = useModelModalHandler() + const checkPluginInfo = useMemo(async () => { + if (!providerName || !modelId) + return null + + const parts = providerName.split('/') + try { + const pluginInfo = await fetchPluginInfoFromMarketPlace({ + org: parts[0], + name: parts[1], + }) + if (pluginInfo.data.plugin.category === PluginType.model) + return pluginInfo.data.plugin + } + catch (error) { + // pass + } + return null + }, [providerName, modelId]) + + const checkModelList = useMemo(async () => { + if (!modelId || !currentProvider) + return false + + try { + const modelsData = await fetchModelProviderModelList( + `/workspaces/current/model-providers/${currentProvider?.provider}/models`, + ) + return !!modelsData.data.find(item => item.model === modelId) + } + catch (error) { + // pass + } + return false + }, [modelId, currentProvider]) + useEffect(() => { - (async () => { - if (modelId && currentProvider) { - try { - const modelsData = await fetchModelProviderModelList(`/workspaces/current/model-providers/${currentProvider?.provider}/models`) - if (modelId && modelsData.data.find(item => item.model === modelId)) - setInModelList(true) - } - catch (error) { - // pass + let isSubscribed = true + + const initializeChecks = async () => { + if (!isPluginChecked) { + const [pluginResult, modelListResult] = await Promise.all([ + checkPluginInfo, + checkModelList, + ]) + + if (isSubscribed) { + if (pluginResult) + setPluginInfo(pluginResult) + setInModelList(modelListResult) + setIsPluginChecked(true) } } - if (providerName) { - const parts = providerName.split('/') - const org = parts[0] - const name = parts[1] - try { - const pluginInfo = await fetchPluginInfoFromMarketPlace({ org, name }) - if (pluginInfo.data.plugin.category === PluginType.model) - setPluginInfo(pluginInfo.data.plugin) - } - catch (error) { - // pass - } - setIsPluginChecked(true) - } - else { - setIsPluginChecked(true) - } - })() - }, [providerName, modelId, currentProvider]) + } + + initializeChecks() + + return () => { + isSubscribed = false + } + }, [checkPluginInfo, checkModelList, isPluginChecked, modelId, currentProvider]) if (modelId && !isPluginChecked) return