Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/deploy/dev' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
JohnJyong committed Dec 28, 2023
2 parents 1c5b11a + 6f8ecdb commit 8426bdd
Show file tree
Hide file tree
Showing 8 changed files with 30 additions and 12 deletions.
2 changes: 1 addition & 1 deletion api/core/model_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def _fetch_credentials_from_bundle(self, provider_model_bundle: ProviderModelBun
:return:
"""
credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM,
model_type=provider_model_bundle.model_type_instance.model_type,
model=model
)

Expand Down
4 changes: 3 additions & 1 deletion api/core/model_runtime/model_providers/__base/ai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,16 @@

from core.model_runtime.entities.defaults import PARAMETER_RULE_TEMPLATE
from core.model_runtime.entities.model_entities import PriceInfo, AIModelEntity, PriceType, PriceConfig, \
DefaultParameterName, FetchFrom
DefaultParameterName, FetchFrom, ModelType
from core.model_runtime.errors.invoke import InvokeError, InvokeAuthorizationError
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer


class AIModel(ABC):
"""
Base class for all models.
"""
model_type: ModelType
model_schemas: list[AIModelEntity] = None
started_at: float = 0

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from core.model_runtime.callbacks.base_callback import Callback
from core.model_runtime.callbacks.logging_callback import LoggingCallback
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageTool, AssistantPromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey, PriceType, ParameterType, ParameterRule
from core.model_runtime.entities.model_entities import ModelPropertyKey, PriceType, ParameterType, ParameterRule, \
ModelType
from core.model_runtime.entities.llm_entities import LLMResult, LLMMode, LLMUsage, \
LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.model_providers.__base.ai_model import AIModel
Expand All @@ -19,6 +20,7 @@ class LargeLanguageModel(AIModel):
"""
Model class for large language model.
"""
model_type: ModelType = ModelType.LLM

def invoke(self, model: str, credentials: dict,
prompt_messages: list[PromptMessage], model_parameters: Optional[dict] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from abc import abstractmethod
from typing import Optional

from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel


class ModerationModel(AIModel):
"""
Model class for moderation model.
"""
model_type: ModelType = ModelType.MODERATION

def invoke(self, model: str, credentials: dict,
text: str, user: Optional[str] = None) \
Expand Down
2 changes: 2 additions & 0 deletions api/core/model_runtime/model_providers/__base/rerank_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from abc import abstractmethod
from typing import Optional

from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.rerank_entities import RerankResult
from core.model_runtime.model_providers.__base.ai_model import AIModel

Expand All @@ -10,6 +11,7 @@ class RerankModel(AIModel):
"""
Base Model class for rerank model.
"""
model_type: ModelType = ModelType.RERANK

def invoke(self, model: str, credentials: dict,
query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
from abc import abstractmethod
from typing import Optional, IO

from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.ai_model import AIModel


class Speech2TextModel(AIModel):
"""
Model class for speech2text model.
"""
model_type: ModelType = ModelType.SPEECH2TEXT

def invoke(self, model: str, credentials: dict,
file: IO[bytes], user: Optional[str] = None) \
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from abc import abstractmethod
from typing import Optional

from core.model_runtime.entities.model_entities import ModelPropertyKey
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from core.model_runtime.entities.text_embedding_entities import TextEmbeddingResult
from core.model_runtime.model_providers.__base.ai_model import AIModel

Expand All @@ -11,6 +11,7 @@ class TextEmbeddingModel(AIModel):
"""
Model class for text embedding model.
"""
model_type: ModelType = ModelType.TEXT_EMBEDDING

def invoke(self, model: str, credentials: dict,
texts: list[str], user: Optional[str] = None) \
Expand Down
23 changes: 15 additions & 8 deletions api/core/model_runtime/model_providers/google/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import google.generativeai as genai
import google.api_core.exceptions as exceptions
import google.generativeai.client as client

from google.generativeai.types import GenerateContentResponse, ContentType
from google.generativeai.types.content_types import to_part
Expand All @@ -13,6 +14,7 @@
from core.model_runtime.errors.invoke import InvokeConnectionError, InvokeServerUnavailableError, InvokeRateLimitError, \
InvokeAuthorizationError, InvokeBadRequestError, InvokeError
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers import google
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel

class GoogleLargeLanguageModel(LargeLanguageModel):
Expand Down Expand Up @@ -79,12 +81,8 @@ def validate_credentials(self, model: str, credentials: dict) -> None:
"""

try:
genai.configure(api_key=credentials['google_api_key'])

models = genai.list_models() # verifies key by listing models
for model in models:
if 'generateContent' in model.supported_generation_methods:
_ = model.name
ping_message = PromptMessage(content="ping", role="system")
self._generate(model, credentials, [ping_message], {"max_tokens_to_sample": 5})

except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
Expand Down Expand Up @@ -112,7 +110,7 @@ def _generate(self, model: str, credentials: dict,
if stop:
config_kwargs["stop_sequences"] = stop

glm_model = genai.GenerativeModel(
google_model = genai.GenerativeModel(
model_name=model
)

Expand All @@ -124,7 +122,15 @@ def _generate(self, model: str, credentials: dict,
else:
history.append(content)

response = glm_model.generate_content(

# Create a new ClientManager with tenant's API key
new_client_manager = client._ClientManager()
new_client_manager.configure(api_key=credentials["google_api_key"])
new_custom_client = new_client_manager.make_client("generative")

google_model._client = new_custom_client

response = google_model.generate_content(
contents=history,
generation_config=genai.types.GenerationConfig(
**config_kwargs
Expand Down Expand Up @@ -153,6 +159,7 @@ def _handle_generate_response(self, model: str, credentials: dict, response: Gen
content=response.text
)


# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
Expand Down

0 comments on commit 8426bdd

Please sign in to comment.