From 7c485f8bb80d170b5756575c150e7cdb41613145 Mon Sep 17 00:00:00 2001 From: ybalbert001 <120714773+ybalbert001@users.noreply.github.com> Date: Tue, 24 Sep 2024 10:33:30 +0800 Subject: [PATCH] fix llm integration problem: It doesn't work on docker env (#8701) Co-authored-by: Yuanbo Li --- .../model_providers/sagemaker/llm/llm.py | 20 +++++++++---------- 1 file changed, 9 insertions(+), 11 deletions(-) diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py index 2edd13d56d4d87..04789197eec2a1 100644 --- a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -85,7 +85,6 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): """ sagemaker_client: Any = None - sagemaker_sess: Any = None predictor: Any = None def _handle_chat_generate_response( @@ -213,23 +212,22 @@ def _invoke( :return: full response or stream response chunk generator result """ if not self.sagemaker_client: - access_key = credentials.get("access_key") - secret_key = credentials.get("secret_key") + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") aws_region = credentials.get("aws_region") + boto_session = None if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client( - "sagemaker-runtime", - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - region_name=aws_region, + boto_session = boto3.Session( + aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=aws_region ) else: - self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + boto_session = boto3.Session(region_name=aws_region) else: - self.sagemaker_client = boto3.client("sagemaker-runtime") + boto_session = boto3.Session() - sagemaker_session = Session(sagemaker_runtime_client=self.sagemaker_client) + self.sagemaker_client = boto_session.client("sagemaker") + sagemaker_session = Session(boto_session=boto_session, sagemaker_client=self.sagemaker_client) self.predictor = Predictor( endpoint_name=credentials.get("sagemaker_endpoint"), sagemaker_session=sagemaker_session,