Skip to content

Commit

Permalink
Merge branch 'main' into feat/workflow
Browse files Browse the repository at this point in the history
  • Loading branch information
zxhlyh committed Mar 15, 2024
2 parents e3f1e14 + cef1686 commit 0240658
Show file tree
Hide file tree
Showing 7 changed files with 103 additions and 16 deletions.
3 changes: 3 additions & 0 deletions api/controllers/console/app/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ def post(self, app_id):
masked_parameter_map = {}
tool_map = {}
for tool in agent_mode.get('tools') or []:
if not isinstance(tool, dict) or len(tool.keys()) <= 3:
continue

agent_tool_entity = AgentToolEntity(**tool)
# get tool
try:
Expand Down
14 changes: 10 additions & 4 deletions api/core/model_runtime/model_providers/localai/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections.abc import Generator
from typing import cast
from urllib.parse import urljoin

from httpx import Timeout
from openai import (
Expand All @@ -19,6 +18,7 @@
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_message import FunctionCall
from openai.types.completion import Completion
from yarl import URL

from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
Expand Down Expand Up @@ -181,7 +181,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
UserPromptMessage(content='ping')
], model_parameters={
'max_tokens': 10,
}, stop=[])
}, stop=[], stream=False)
except Exception as ex:
raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}')

Expand Down Expand Up @@ -227,14 +227,20 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode
)
]

model_properties = {
ModelPropertyKey.MODE: completion_model,
} if completion_model else {}

model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048'))

entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
model_properties={ ModelPropertyKey.MODE: completion_model } if completion_model else {},
model_properties=model_properties,
parameter_rules=rules
)

Expand Down Expand Up @@ -319,7 +325,7 @@ def _to_client_kwargs(self, credentials: dict) -> dict:
client_kwargs = {
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"api_key": "1",
"base_url": urljoin(credentials['server_url'], 'v1'),
"base_url": str(URL(credentials['server_url']) / 'v1'),
}

return client_kwargs
Expand Down
9 changes: 9 additions & 0 deletions api/core/model_runtime/model_providers/localai/localai.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,12 @@ model_credential_schema:
placeholder:
zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080
en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080
- variable: context_size
label:
zh_Hans: 上下文大小
en_US: Context size
placeholder:
zh_Hans: 输入上下文大小
en_US: Enter context size
required: false
type: text-input
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import time
from json import JSONDecodeError, dumps
from os.path import join
from typing import Optional

from requests import post
from yarl import URL

from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
Expand Down Expand Up @@ -57,7 +58,7 @@ def _invoke(self, model: str, credentials: dict,
}

try:
response = post(join(url, 'embeddings'), headers=headers, data=dumps(data), timeout=10)
response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10)
except Exception as e:
raise InvokeConnectionError(str(e))

Expand Down Expand Up @@ -113,6 +114,27 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int
# use GPT2Tokenizer to get num tokens
num_tokens += self._get_num_tokens_by_gpt2(text)
return num_tokens

def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
Get customizable model schema
:param model: model name
:param credentials: model credentials
:return: model schema
"""
return AIModelEntity(
model=model,
label=I18nObject(zh_Hans=model, en_US=model),
model_type=ModelType.TEXT_EMBEDDING,
features=[],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')),
ModelPropertyKey.MAX_CHUNKS: 1,
},
parameter_rules=[]
)

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Expand Down
33 changes: 33 additions & 0 deletions api/core/tools/provider/builtin/chart/chart.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,44 @@
import matplotlib.pyplot as plt
from fontTools.ttLib import TTFont
from matplotlib.font_manager import findSystemFonts

from core.tools.errors import ToolProviderCredentialValidationError
from core.tools.provider.builtin.chart.tools.line import LinearChartTool
from core.tools.provider.builtin_tool_provider import BuiltinToolProviderController

# use a business theme
plt.style.use('seaborn-v0_8-darkgrid')
plt.rcParams['axes.unicode_minus'] = False

def init_fonts():
fonts = findSystemFonts()

popular_unicode_fonts = [
'Arial Unicode MS', 'DejaVu Sans', 'DejaVu Sans Mono', 'DejaVu Serif', 'FreeMono', 'FreeSans', 'FreeSerif',
'Liberation Mono', 'Liberation Sans', 'Liberation Serif', 'Noto Mono', 'Noto Sans', 'Noto Serif', 'Open Sans',
'Roboto', 'Source Code Pro', 'Source Sans Pro', 'Source Serif Pro', 'Ubuntu', 'Ubuntu Mono'
]

supported_fonts = []

for font_path in fonts:
try:
font = TTFont(font_path)
# get family name
family_name = font['name'].getName(1, 3, 1).toUnicode()
if family_name in popular_unicode_fonts:
supported_fonts.append(family_name)
except:
pass

plt.rcParams['font.family'] = 'sans-serif'
# sort by order of popular_unicode_fonts
for font in popular_unicode_fonts:
if font in supported_fonts:
plt.rcParams['font.sans-serif'] = font
break

init_fonts()

class ChartProvider(BuiltinToolProviderController):
def _validate_credentials(self, credentials: dict) -> None:
Expand Down
4 changes: 2 additions & 2 deletions web/app/activate/activateForm.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ const ActivateForm = () => {
</label>
<div className="relative mt-1 rounded-md shadow-sm">
<SimpleSelect
defaultValue={defaultLanguage()}
items={languages}
defaultValue={LanguagesSupported[0]}
items={languages.filter(item => item.supported)}
onSelect={(item) => {
setLanguage(item.value as string)
}}
Expand Down
28 changes: 21 additions & 7 deletions web/app/components/base/chat/chat/index.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import type {
} from 'react'
import {
memo,
useCallback,
useEffect,
useRef,
} from 'react'
Expand Down Expand Up @@ -76,19 +77,20 @@ const Chat: FC<ChatProps> = ({
const chatContainerInnerRef = useRef<HTMLDivElement>(null)
const chatFooterRef = useRef<HTMLDivElement>(null)
const chatFooterInnerRef = useRef<HTMLDivElement>(null)
const userScrolledRef = useRef(false)

const handleScrolltoBottom = () => {
if (chatContainerRef.current)
const handleScrolltoBottom = useCallback(() => {
if (chatContainerRef.current && !userScrolledRef.current)
chatContainerRef.current.scrollTop = chatContainerRef.current.scrollHeight
}
}, [])

const handleWindowResize = () => {
const handleWindowResize = useCallback(() => {
if (chatContainerRef.current && chatFooterRef.current)
chatFooterRef.current.style.width = `${chatContainerRef.current.clientWidth}px`

if (chatContainerInnerRef.current && chatFooterInnerRef.current)
chatFooterInnerRef.current.style.width = `${chatContainerInnerRef.current.clientWidth}px`
}
}, [])

useThrottleEffect(() => {
handleScrolltoBottom()
Expand All @@ -98,7 +100,7 @@ const Chat: FC<ChatProps> = ({
useEffect(() => {
window.addEventListener('resize', debounce(handleWindowResize))
return () => window.removeEventListener('resize', handleWindowResize)
}, [])
}, [handleWindowResize])

useEffect(() => {
if (chatFooterRef.current && chatContainerRef.current) {
Expand All @@ -117,7 +119,19 @@ const Chat: FC<ChatProps> = ({
resizeObserver.disconnect()
}
}
}, [chatFooterRef, chatContainerRef])
}, [handleScrolltoBottom])

useEffect(() => {
const chatContainer = chatContainerRef.current
if (chatContainer) {
const setUserScrolled = () => {
if (chatContainer)
userScrolledRef.current = chatContainer.scrollHeight - chatContainer.scrollTop >= chatContainer.clientHeight + 300
}
chatContainer.addEventListener('scroll', setUserScrolled)
return () => chatContainer.removeEventListener('scroll', setUserScrolled)
}
}, [])

const hasTryToAsk = config?.suggested_questions_after_answer?.enabled && !!suggestedQuestions?.length && onSend

Expand Down

0 comments on commit 0240658

Please sign in to comment.