Skip to content

Commit

Permalink
Merge branch 'feat/model-provider-based-on-runtime' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
zxhlyh committed Dec 28, 2023
2 parents 06268ad + cf8f708 commit ae9bce7
Show file tree
Hide file tree
Showing 12 changed files with 184 additions and 77 deletions.
5 changes: 4 additions & 1 deletion web/app/components/base/prompt-log-modal/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,9 @@ const PromptLogModal: FC<PromptLogModalProps> = ({
setMounted(true)
}, [])

if (!log)
return null

return (
<div
className='fixed top-16 left-2 bottom-2 flex flex-col bg-white border-[0.5px] border-gray-200 rounded-xl shadow-xl z-10'
Expand All @@ -37,7 +40,7 @@ const PromptLogModal: FC<PromptLogModalProps> = ({
<div className='text-base font-semibold text-gray-900'>PROMPT LOG</div>
<div className='flex items-center'>
{
log.length === 1 && (
log?.length === 1 && (
<>
<CopyFeedbackNew className='w-6 h-6' content={log[0].text} />
<div className='mx-2.5 w-[1px] h-[14px] bg-gray-200' />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import type {
DefaultModel,
DefaultModelResponse,
Model,
ModelProvider,
} from './declarations'
import {
ConfigurateMethodEnum,
Expand All @@ -22,6 +23,7 @@ import {
fetchDefaultModal,
fetchModelList,
fetchModelProviderCredentials,
fetchModelProviders,
getPayUrl,
submitFreeQuota,
} from '@/service/common'
Expand Down Expand Up @@ -81,14 +83,23 @@ export const useProviderCrenditialsFormSchemasValue = (
fetchModelProviderCredentials,
)

return configurateMethod === ConfigurateMethodEnum.predefinedModel
? predefinedFormSchemasValue?.credentials
: customFormSchemasValue?.credentials
? {
...customFormSchemasValue?.credentials,
...currentCustomConfigrationModelFixedFields,
}
: undefined
const value = useMemo(() => {
return configurateMethod === ConfigurateMethodEnum.predefinedModel
? predefinedFormSchemasValue?.credentials
: customFormSchemasValue?.credentials
? {
...customFormSchemasValue?.credentials,
...currentCustomConfigrationModelFixedFields,
}
: undefined
}, [
configurateMethod,
currentCustomConfigrationModelFixedFields,
customFormSchemasValue?.credentials,
predefinedFormSchemasValue?.credentials,
])

return value
}

export type ModelTypeIndex = 1 | 2 | 3 | 4
Expand Down Expand Up @@ -237,3 +248,27 @@ export const useFreeQuota = (onSuccess: () => void) => {

return handleClick
}

export const useModelProviders = () => {
const { data: providersData, mutate, isLoading } = useSWR('/workspaces/current/model-providers', fetchModelProviders)

return {
data: providersData?.data || [],
mutate,
isLoading,
}
}

export const useUpdateModelProvidersAndModelList = () => {
const { mutate } = useSWRConfig()
const updateModelList = useUpdateModelList()

const updateModelProvidersAndModelList = useCallback((provider: ModelProvider) => {
mutate('/workspaces/current/model-providers')
provider?.supported_model_types.forEach((modelType) => {
updateModelList(modelType)
})
}, [mutate, updateModelList])

return updateModelProvidersAndModelList
}
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import { useMemo, useState } from 'react'
import useSWR from 'swr'
import { useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import SystemModelSelector from './system-model-selector'
import ProviderAddedCard from './provider-added-card'
import ProviderAddedCard, { UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST } from './provider-added-card'
import ProviderCard from './provider-card'
import ModelModal from './model-modal'
import type {
ConfigurateMethodEnum,
CustomConfigrationModelFixedFields,
Expand All @@ -13,27 +11,24 @@ import type {
import { CustomConfigurationStatusEnum } from './declarations'
import {
useDefaultModel,
useUpdateModelList,
useUpdateModelProvidersAndModelList,
} from './hooks'
import { fetchModelProviders } from '@/service/common'
import { AlertTriangle } from '@/app/components/base/icons/src/vender/solid/alertsAndFeedback'
import Loading from '@/app/components/base/loading'
import { useProviderContext } from '@/context/provider-context'
import { useModalContext } from '@/context/modal-context'
import { useEventEmitterContextContext } from '@/context/event-emitter'

const ModelProviderPage = () => {
const { t } = useTranslation()
const updateModelList = useUpdateModelList()
const { eventEmitter } = useEventEmitterContextContext()
const updateModelProvidersAndModelList = useUpdateModelProvidersAndModelList()
const { data: textGenerationDefaultModel } = useDefaultModel(1)
const { data: embeddingsDefaultModel } = useDefaultModel(2)
const { data: rerankDefaultModel } = useDefaultModel(3)
const { data: speech2textDefaultModel } = useDefaultModel(4)
const [currentProvider, setCurrentProvider] = useState<ModelProvider | null>(null)
const [currentConfigurateMethod, setCurrentConfigurateMethod] = useState<ConfigurateMethodEnum | null>(null)
const [currentCustomConfigrationModelFixedFields, setCurrentCustomConfigrationModelFixedFields] = useState<CustomConfigrationModelFixedFields | undefined>(undefined)
const { data: providersData, mutate: mutateProviders, isLoading } = useSWR('/workspaces/current/model-providers', fetchModelProviders)
const { modelProviders: providers } = useProviderContext()
const { setShowModelModal } = useModalContext()
const defaultModelNotConfigured = !textGenerationDefaultModel && !embeddingsDefaultModel && !speech2textDefaultModel && !rerankDefaultModel
const providers = useMemo(() => {
return providersData ? providersData.data : []
}, [providersData])
const [configedProviders, notConfigedProviders] = useMemo(() => {
const configedProviders: ModelProvider[] = []
const notConfigedProviders: ModelProvider[] = []
Expand All @@ -53,21 +48,23 @@ const ModelProviderPage = () => {
configurateMethod: ConfigurateMethodEnum,
customConfigrationModelFixedFields?: CustomConfigrationModelFixedFields,
) => {
setCurrentProvider(provider)
setCurrentConfigurateMethod(configurateMethod)
setCurrentCustomConfigrationModelFixedFields(customConfigrationModelFixedFields)
}

const handleCancelModelModal = () => {
setCurrentProvider(null)
setCurrentConfigurateMethod(null)
setCurrentCustomConfigrationModelFixedFields(undefined)
}
setShowModelModal({
payload: {
currentProvider: provider,
currentConfigurateMethod: configurateMethod,
currentCustomConfigrationModelFixedFields: customConfigrationModelFixedFields,
},
onSaveCallback: () => {
updateModelProvidersAndModelList(provider)

const handleSaveCrendentials = () => {
mutateProviders()
currentProvider?.supported_model_types.forEach((modelType) => {
updateModelList(modelType)
if (customConfigrationModelFixedFields && provider.custom_configuration.status === CustomConfigurationStatusEnum.active) {
console.log('1')
eventEmitter?.emit({
type: UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST,
payload: provider.provider,
} as any)
}
},
})
}

Expand All @@ -89,16 +86,8 @@ const ModelProviderPage = () => {
embeddingsDefaultModel={embeddingsDefaultModel}
rerankDefaultModel={rerankDefaultModel}
speech2textDefaultModel={speech2textDefaultModel}
onUpdate={() => mutateProviders()}
/>
</div>
{
isLoading && (
<div className='mt-[240px]'>
<Loading />
</div>
)
}
{
!!configedProviders?.length && (
<div className='pb-3'>
Expand Down Expand Up @@ -135,17 +124,6 @@ const ModelProviderPage = () => {
</>
)
}
{
!!currentProvider && !!currentConfigurateMethod && (
<ModelModal
provider={currentProvider}
configurateMethod={currentConfigurateMethod}
currentCustomConfigrationModelFixedFields={currentCustomConfigrationModelFixedFields}
onCancel={handleCancelModelModal}
onSave={handleSaveCrendentials}
/>
)
}
</div>
)
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ const Form: FC<FormProps> = ({
shouldClearVariable[clearVariable] = undefined
})
}
console.log(key, val, shouldClearVariable)
onChange({ ...value, [key]: val, ...shouldClearVariable })
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
import type { FC } from 'react'
import { useCallback, useEffect, useMemo, useState } from 'react'
import {
memo,
useCallback,
useEffect,
useMemo,
useState,
} from 'react'
import { useTranslation } from 'react-i18next'
import type {
CredentialFormSchema,
Expand Down Expand Up @@ -341,4 +347,4 @@ const ModelModal: FC<ModelModalProps> = ({
)
}

export default ModelModal
export default memo(ModelModal)
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ const ModelSelector: FC<ModelSelectorProps> = ({
<div className='relative'>
<PortalToFollowElemTrigger
onClick={handleToggle}
className='block cursor-not-allowed'
className='block'
>
{
currentModel && currentProvider && (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ const ModelTrigger: FC<ModelTriggerProps> = ({
return (
<div
className={`
group flex items-center px-2 h-8 rounded-lg bg-gray-100 hover:bg-gray-200 cursopr-pointer
group flex items-center px-2 h-8 rounded-lg bg-gray-100 hover:bg-gray-200 cursor-pointer
${className}
${open && '!bg-gray-200'}
`}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,19 @@ import type {
Model,
ModelItem,
} from '../declarations'
import { useLanguage } from '../hooks'
import {
useLanguage,
useUpdateModelProvidersAndModelList,
} from '../hooks'
import ModelIcon from '../model-icon'
import ModelName from '../model-name'
import { ModelStatusEnum } from '../declarations'
import {
ConfigurateMethodEnum,
ModelStatusEnum,
} from '../declarations'
import { Check } from '@/app/components/base/icons/src/vender/line/general'
import { useModalContext } from '@/context/modal-context'
import { useProviderContext } from '@/context/provider-context'

type PopupItemProps = {
defaultModel?: DefaultModel
Expand All @@ -23,12 +31,27 @@ const PopupItem: FC<PopupItemProps> = ({
}) => {
const { t } = useTranslation()
const language = useLanguage()
const { setShowModelModal } = useModalContext()
const { modelProviders } = useProviderContext()
const updateModelProvidersAndModelList = useUpdateModelProvidersAndModelList()
const currentProvider = modelProviders.find(provider => provider.provider === model.provider)!
const handleSelect = (provider: string, modelItem: ModelItem) => {
if (modelItem.status !== ModelStatusEnum.active)
return

onSelect(provider, modelItem)
}
const handleOpenModelModal = () => {
setShowModelModal({
payload: {
currentProvider,
currentConfigurateMethod: ConfigurateMethodEnum.predefinedModel,
},
onSaveCallback: () => {
updateModelProvidersAndModelList(currentProvider)
},
})
}

return (
<div className='mb-1'>
Expand Down Expand Up @@ -69,7 +92,10 @@ const PopupItem: FC<PopupItemProps> = ({
}
{
modelItem.status === ModelStatusEnum.noConfigure && (
<div className='hidden group-hover:block text-xs font-medium text-primary-600 cursor-pointer'>
<div
className='hidden group-hover:block text-xs font-medium text-primary-600 cursor-pointer'
onClick={handleOpenModelModal}
>
{t('common.operation.add').toLocaleUpperCase()}
</div>
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,9 @@ import AddModelButton from './add-model-button'
import { ChevronDownDouble } from '@/app/components/base/icons/src/vender/line/arrows'
import { Loading02 } from '@/app/components/base/icons/src/vender/line/general'
import { fetchModelProviderModelList } from '@/service/common'
import { useEventEmitterContextContext } from '@/context/event-emitter'

export const UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST = 'UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST'
type ProviderAddedCardProps = {
provider: ModelProvider
onOpenModal: (configurateMethod: ConfigurateMethodEnum, currentCustomConfigrationModelFixedFields?: CustomConfigrationModelFixedFields) => void
Expand All @@ -30,6 +32,7 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
onOpenModal,
}) => {
const { t } = useTranslation()
const { eventEmitter } = useEventEmitterContextContext()
const [fetched, setFetched] = useState(false)
const [loading, setLoading] = useState(false)
const [collapsed, setCollapsed] = useState(true)
Expand All @@ -39,15 +42,13 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
const hasModelList = fetched && !!modelList.length
const showQuota = systemConfig.enabled || ['minimax', 'spark', 'zhipuai', 'anthropic'].includes(provider.provider)

const handleOpenModelList = async () => {
if (fetched) {
setCollapsed(false)
const getModelList = async (providerName: string) => {
console.log('3')
if (loading)
return
}

try {
setLoading(true)
const modelsData = await fetchModelProviderModelList(`/workspaces/current/model-providers/${provider.provider}/models`)
const modelsData = await fetchModelProviderModelList(`/workspaces/current/model-providers/${providerName}/models`)
setModelList(modelsData.data)
setCollapsed(false)
setFetched(true)
Expand All @@ -56,6 +57,20 @@ const ProviderAddedCard: FC<ProviderAddedCardProps> = ({
setLoading(false)
}
}
const handleOpenModelList = () => {
if (fetched) {
setCollapsed(false)
return
}

getModelList(provider.provider)
}

eventEmitter?.useSubscription((v: any) => {
console.log('2')
if (v?.type === UPDATE_MODEL_PROVIDER_CUSTOM_MODEL_LIST)
getModelList(v.payload as string)
})

return (
<div
Expand Down
Loading

0 comments on commit ae9bce7

Please sign in to comment.