Skip to content

Commit

Permalink
Fix/localai (#2840)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yeuoly authored Mar 15, 2024
1 parent af98954 commit 742be06
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 7 deletions.
14 changes: 10 additions & 4 deletions api/core/model_runtime/model_providers/localai/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from collections.abc import Generator
from typing import cast
from urllib.parse import urljoin

from httpx import Timeout
from openai import (
Expand All @@ -19,6 +18,7 @@
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion_message import FunctionCall
from openai.types.completion import Completion
from yarl import URL

from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
Expand Down Expand Up @@ -181,7 +181,7 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
UserPromptMessage(content='ping')
], model_parameters={
'max_tokens': 10,
}, stop=[])
}, stop=[], stream=False)
except Exception as ex:
raise CredentialsValidateFailedError(f'Invalid credentials {str(ex)}')

Expand Down Expand Up @@ -227,14 +227,20 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode
)
]

model_properties = {
ModelPropertyKey.MODE: completion_model,
} if completion_model else {}

model_properties[ModelPropertyKey.CONTEXT_SIZE] = int(credentials.get('context_size', '2048'))

entity = AIModelEntity(
model=model,
label=I18nObject(
en_US=model
),
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_type=ModelType.LLM,
model_properties={ ModelPropertyKey.MODE: completion_model } if completion_model else {},
model_properties=model_properties,
parameter_rules=rules
)

Expand Down Expand Up @@ -319,7 +325,7 @@ def _to_client_kwargs(self, credentials: dict) -> dict:
client_kwargs = {
"timeout": Timeout(315.0, read=300.0, write=10.0, connect=5.0),
"api_key": "1",
"base_url": urljoin(credentials['server_url'], 'v1'),
"base_url": str(URL(credentials['server_url']) / 'v1'),
}

return client_kwargs
Expand Down
9 changes: 9 additions & 0 deletions api/core/model_runtime/model_providers/localai/localai.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,12 @@ model_credential_schema:
placeholder:
zh_Hans: 在此输入LocalAI的服务器地址,如 http://192.168.1.100:8080
en_US: Enter the url of your LocalAI, e.g. http://192.168.1.100:8080
- variable: context_size
label:
zh_Hans: 上下文大小
en_US: Context size
placeholder:
zh_Hans: 输入上下文大小
en_US: Enter context size
required: false
type: text-input
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import time
from json import JSONDecodeError, dumps
from os.path import join
from typing import Optional

from requests import post
from yarl import URL

from core.model_runtime.entities.model_entities import PriceType
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.entities.model_entities import AIModelEntity, FetchFrom, ModelPropertyKey, ModelType, PriceType
from core.model_runtime.entities.text_embedding_entities import EmbeddingUsage, TextEmbeddingResult
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
Expand Down Expand Up @@ -57,7 +58,7 @@ def _invoke(self, model: str, credentials: dict,
}

try:
response = post(join(url, 'embeddings'), headers=headers, data=dumps(data), timeout=10)
response = post(str(URL(url) / 'embeddings'), headers=headers, data=dumps(data), timeout=10)
except Exception as e:
raise InvokeConnectionError(str(e))

Expand Down Expand Up @@ -113,6 +114,27 @@ def get_num_tokens(self, model: str, credentials: dict, texts: list[str]) -> int
# use GPT2Tokenizer to get num tokens
num_tokens += self._get_num_tokens_by_gpt2(text)
return num_tokens

def _get_customizable_model_schema(self, model: str, credentials: dict) -> AIModelEntity | None:
"""
Get customizable model schema
:param model: model name
:param credentials: model credentials
:return: model schema
"""
return AIModelEntity(
model=model,
label=I18nObject(zh_Hans=model, en_US=model),
model_type=ModelType.TEXT_EMBEDDING,
features=[],
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', '512')),
ModelPropertyKey.MAX_CHUNKS: 1,
},
parameter_rules=[]
)

def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Expand Down

0 comments on commit 742be06

Please sign in to comment.