dify/web/app/components/plugins/plugin-detail-panel/model-selector/index.tsx

251 lines
8.9 KiB
TypeScript
Raw Normal View History

2024-12-20 16:36:22 +08:00
import type {
FC,
ReactNode,
} from 'react'
import { useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import type {
DefaultModel,
FormValue,
} from '@/app/components/header/account-setting/model-provider-page/declarations'
import { ModelStatusEnum, ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
import {
useModelList,
} from '@/app/components/header/account-setting/model-provider-page/hooks'
import AgentModelTrigger from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/agent-model-trigger'
2024-12-20 16:36:22 +08:00
import Trigger from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger'
import type { TriggerProps } from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal/trigger'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
2024-12-24 11:00:07 +08:00
import LLMParamsPanel from './llm-params-panel'
2024-12-24 14:15:18 +08:00
import TTSParamsPanel from './tts-params-panel'
2024-12-20 16:36:22 +08:00
import { useProviderContext } from '@/context/provider-context'
2024-12-24 11:00:07 +08:00
import cn from '@/utils/classnames'
2024-12-20 16:36:22 +08:00
export type ModelParameterModalProps = {
popupClassName?: string
portalToFollowElemContentClassName?: string
isAdvancedMode: boolean
2024-12-24 14:15:18 +08:00
value: any
setModel: (model: any) => void
2024-12-20 16:36:22 +08:00
renderTrigger?: (v: TriggerProps) => ReactNode
readonly?: boolean
isInWorkflow?: boolean
isAgentStrategy?: boolean
2024-12-20 16:36:22 +08:00
scope?: string
}
const ModelParameterModal: FC<ModelParameterModalProps> = ({
popupClassName,
portalToFollowElemContentClassName,
isAdvancedMode,
2024-12-24 14:15:18 +08:00
value,
2024-12-20 16:36:22 +08:00
setModel,
renderTrigger,
readonly,
isInWorkflow,
isAgentStrategy,
2024-12-24 14:15:18 +08:00
scope = ModelTypeEnum.textGeneration,
2024-12-20 16:36:22 +08:00
}) => {
const { t } = useTranslation()
const { isAPIKeySet } = useProviderContext()
const [open, setOpen] = useState(false)
const scopeArray = scope.split('&')
2024-12-27 14:20:28 +08:00
const scopeFeatures = useMemo(() => {
if (scopeArray.includes('all'))
return []
return scopeArray.filter(item => ![
ModelTypeEnum.textGeneration,
ModelTypeEnum.textEmbedding,
ModelTypeEnum.rerank,
ModelTypeEnum.moderation,
ModelTypeEnum.speech2text,
ModelTypeEnum.tts,
].includes(item as ModelTypeEnum))
}, [scopeArray])
2024-12-24 11:00:07 +08:00
2024-12-20 16:36:22 +08:00
const { data: textGenerationList } = useModelList(ModelTypeEnum.textGeneration)
const { data: textEmbeddingList } = useModelList(ModelTypeEnum.textEmbedding)
const { data: rerankList } = useModelList(ModelTypeEnum.rerank)
const { data: moderationList } = useModelList(ModelTypeEnum.moderation)
const { data: sttList } = useModelList(ModelTypeEnum.speech2text)
const { data: ttsList } = useModelList(ModelTypeEnum.tts)
const scopedModelList = useMemo(() => {
const resultList: any[] = []
if (scopeArray.includes('all')) {
return [
...textGenerationList,
...textEmbeddingList,
...rerankList,
...sttList,
...ttsList,
...moderationList,
]
}
2024-12-24 14:15:18 +08:00
if (scopeArray.includes(ModelTypeEnum.textGeneration))
2024-12-20 16:36:22 +08:00
return textGenerationList
2024-12-24 14:15:18 +08:00
if (scopeArray.includes(ModelTypeEnum.textEmbedding))
2024-12-20 16:36:22 +08:00
return textEmbeddingList
2024-12-24 14:15:18 +08:00
if (scopeArray.includes(ModelTypeEnum.rerank))
2024-12-20 16:36:22 +08:00
return rerankList
2024-12-24 14:15:18 +08:00
if (scopeArray.includes(ModelTypeEnum.moderation))
2024-12-20 16:36:22 +08:00
return moderationList
2024-12-24 14:15:18 +08:00
if (scopeArray.includes(ModelTypeEnum.speech2text))
2024-12-20 16:36:22 +08:00
return sttList
2024-12-24 14:15:18 +08:00
if (scopeArray.includes(ModelTypeEnum.tts))
2024-12-20 16:36:22 +08:00
return ttsList
return resultList
}, [scopeArray, textGenerationList, textEmbeddingList, rerankList, sttList, ttsList, moderationList])
const { currentProvider, currentModel } = useMemo(() => {
2024-12-24 14:15:18 +08:00
const currentProvider = scopedModelList.find(item => item.provider === value?.provider)
const currentModel = currentProvider?.models.find((model: { model: string }) => model.model === value?.model)
2024-12-20 16:36:22 +08:00
return {
currentProvider,
currentModel,
}
2024-12-24 14:15:18 +08:00
}, [scopedModelList, value?.provider, value?.model])
2024-12-20 16:36:22 +08:00
const hasDeprecated = useMemo(() => {
return !currentProvider || !currentModel
}, [currentModel, currentProvider])
const modelDisabled = useMemo(() => {
return currentModel?.status !== ModelStatusEnum.active
}, [currentModel?.status])
const disabled = useMemo(() => {
return !isAPIKeySet || hasDeprecated || modelDisabled
}, [hasDeprecated, isAPIKeySet, modelDisabled])
const handleChangeModel = ({ provider, model }: DefaultModel) => {
const targetProvider = scopedModelList.find(modelItem => modelItem.provider === provider)
const targetModelItem = targetProvider?.models.find((modelItem: { model: string }) => modelItem.model === model)
2024-12-24 14:15:18 +08:00
const model_type = targetModelItem?.model_type as string
2024-12-20 16:36:22 +08:00
setModel({
provider,
2024-12-24 14:15:18 +08:00
model,
model_type,
...(model_type === ModelTypeEnum.textGeneration ? {
mode: targetModelItem?.model_properties.mode as string,
} : {}),
})
}
const handleLLMParamsChange = (newParams: FormValue) => {
const newValue = {
...(value?.completionParams || {}),
completion_params: newParams,
}
setModel({
...value,
...newValue,
})
}
const handleTTSParamsChange = (language: string, voice: string) => {
setModel({
...value,
language,
voice,
2024-12-20 16:36:22 +08:00
})
}
return (
<PortalToFollowElem
open={open}
onOpenChange={setOpen}
placement={isInWorkflow ? 'left' : 'bottom-end'}
offset={4}
>
<div className='relative'>
<PortalToFollowElemTrigger
onClick={() => {
if (readonly)
return
setOpen(v => !v)
}}
className='block'
>
{
renderTrigger
? renderTrigger({
open,
disabled,
modelDisabled,
hasDeprecated,
currentProvider,
currentModel,
2024-12-24 14:15:18 +08:00
providerName: value?.provider,
modelId: value?.model,
2024-12-20 16:36:22 +08:00
})
: (isAgentStrategy
? <AgentModelTrigger
disabled={disabled}
hasDeprecated={hasDeprecated}
currentProvider={currentProvider}
currentModel={currentModel}
providerName={value?.provider}
modelId={value?.model}
2024-12-31 11:46:36 +08:00
scope={scope}
/>
: <Trigger
2024-12-20 16:36:22 +08:00
disabled={disabled}
isInWorkflow={isInWorkflow}
modelDisabled={modelDisabled}
hasDeprecated={hasDeprecated}
currentProvider={currentProvider}
currentModel={currentModel}
2024-12-24 14:15:18 +08:00
providerName={value?.provider}
modelId={value?.model}
2024-12-20 16:36:22 +08:00
/>
)
}
</PortalToFollowElemTrigger>
2024-12-31 12:38:07 +08:00
<PortalToFollowElemContent className={cn('z-50', portalToFollowElemContentClassName)}>
2024-12-20 16:36:22 +08:00
<div className={cn(popupClassName, 'w-[389px] rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg')}>
<div className={cn('max-h-[420px] p-4 pt-3 overflow-y-auto')}>
<div className='relative'>
<div className={cn('mb-1 h-6 flex items-center text-text-secondary system-sm-semibold')}>
{t('common.modelProvider.model').toLocaleUpperCase()}
</div>
<ModelSelector
2024-12-24 14:15:18 +08:00
defaultModel={(value?.provider || value?.model) ? { provider: value?.provider, model: value?.model } : undefined}
2024-12-20 16:36:22 +08:00
modelList={scopedModelList}
scopeFeatures={scopeFeatures}
2024-12-20 16:36:22 +08:00
onSelect={handleChangeModel}
/>
</div>
{(currentModel?.model_type === ModelTypeEnum.textGeneration || currentModel?.model_type === ModelTypeEnum.tts) && (
2024-12-24 11:00:07 +08:00
<div className='my-3 h-[1px] bg-divider-subtle' />
)}
{currentModel?.model_type === ModelTypeEnum.textGeneration && (
<LLMParamsPanel
2024-12-24 14:15:18 +08:00
provider={value?.provider}
modelId={value?.model}
completionParams={value?.completion_params || {}}
onCompletionParamsChange={handleLLMParamsChange}
2024-12-24 11:00:07 +08:00
isAdvancedMode={isAdvancedMode}
/>
)}
2024-12-24 14:15:18 +08:00
{currentModel?.model_type === ModelTypeEnum.tts && (
<TTSParamsPanel
currentModel={currentModel}
language={value?.language}
voice={value?.voice}
onChange={handleTTSParamsChange}
/>
)}
2024-12-20 16:36:22 +08:00
</div>
</div>
</PortalToFollowElemContent>
</div>
</PortalToFollowElem>
)
}
export default ModelParameterModal