From 4d7cfd0de510f98da4753a6ad24293ca1647226c Mon Sep 17 00:00:00 2001 From: Kazuki Takamatsu Date: Sun, 8 Dec 2024 09:44:49 +0900 Subject: [PATCH] Fix model provider of vertex ai (#11437) --- .../model_runtime/model_providers/vertex_ai/llm/llm.py | 10 ++++++---- .../vertex_ai/text_embedding/text_embedding.py | 10 ++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py index 1469de605525ef..934195cc3d6fa2 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py +++ b/api/core/model_runtime/model_providers/vertex_ai/llm/llm.py @@ -104,13 +104,14 @@ def _generate_anthropic( """ # use Anthropic official SDK references # - https://github.com/anthropics/anthropic-sdk-python - service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) + service_account_key = credentials.get("vertex_service_account_key", "") project_id = credentials["vertex_project_id"] SCOPES = ["https://www.googleapis.com/auth/cloud-platform"] token = "" # get access token from service account credential - if service_account_info: + if service_account_key: + service_account_info = json.loads(base64.b64decode(service_account_key)) credentials = service_account.Credentials.from_service_account_info(service_account_info, scopes=SCOPES) request = google.auth.transport.requests.Request() credentials.refresh(request) @@ -478,10 +479,11 @@ def _generate( if stop: config_kwargs["stop_sequences"] = stop - service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) + service_account_key = credentials.get("vertex_service_account_key", "") project_id = credentials["vertex_project_id"] location = credentials["vertex_location"] - if service_account_info: + if service_account_key: + service_account_info = json.loads(base64.b64decode(service_account_key)) service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) aiplatform.init(credentials=service_accountSA, project=project_id, location=location) else: diff --git a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py index 9cd0c78d99df24..eb54941e086752 100644 --- a/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py +++ b/api/core/model_runtime/model_providers/vertex_ai/text_embedding/text_embedding.py @@ -48,10 +48,11 @@ def _invoke( :param input_type: input type :return: embeddings result """ - service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) + service_account_key = credentials.get("vertex_service_account_key", "") project_id = credentials["vertex_project_id"] location = credentials["vertex_location"] - if service_account_info: + if service_account_key: + service_account_info = json.loads(base64.b64decode(service_account_key)) service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) aiplatform.init(credentials=service_accountSA, project=project_id, location=location) else: @@ -100,10 +101,11 @@ def validate_credentials(self, model: str, credentials: dict) -> None: :return: """ try: - service_account_info = json.loads(base64.b64decode(credentials["vertex_service_account_key"])) + service_account_key = credentials.get("vertex_service_account_key", "") project_id = credentials["vertex_project_id"] location = credentials["vertex_location"] - if service_account_info: + if service_account_key: + service_account_info = json.loads(base64.b64decode(service_account_key)) service_accountSA = service_account.Credentials.from_service_account_info(service_account_info) aiplatform.init(credentials=service_accountSA, project=project_id, location=location) else: