diff --git a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py index 2edd13d56d4d87..04789197eec2a1 100644 --- a/api/core/model_runtime/model_providers/sagemaker/llm/llm.py +++ b/api/core/model_runtime/model_providers/sagemaker/llm/llm.py @@ -85,7 +85,6 @@ class SageMakerLargeLanguageModel(LargeLanguageModel): """ sagemaker_client: Any = None - sagemaker_sess: Any = None predictor: Any = None def _handle_chat_generate_response( @@ -213,23 +212,22 @@ def _invoke( :return: full response or stream response chunk generator result """ if not self.sagemaker_client: - access_key = credentials.get("access_key") - secret_key = credentials.get("secret_key") + access_key = credentials.get("aws_access_key_id") + secret_key = credentials.get("aws_secret_access_key") aws_region = credentials.get("aws_region") + boto_session = None if aws_region: if access_key and secret_key: - self.sagemaker_client = boto3.client( - "sagemaker-runtime", - aws_access_key_id=access_key, - aws_secret_access_key=secret_key, - region_name=aws_region, + boto_session = boto3.Session( + aws_access_key_id=access_key, aws_secret_access_key=secret_key, region_name=aws_region ) else: - self.sagemaker_client = boto3.client("sagemaker-runtime", region_name=aws_region) + boto_session = boto3.Session(region_name=aws_region) else: - self.sagemaker_client = boto3.client("sagemaker-runtime") + boto_session = boto3.Session() - sagemaker_session = Session(sagemaker_runtime_client=self.sagemaker_client) + self.sagemaker_client = boto_session.client("sagemaker") + sagemaker_session = Session(boto_session=boto_session, sagemaker_client=self.sagemaker_client) self.predictor = Predictor( endpoint_name=credentials.get("sagemaker_endpoint"), sagemaker_session=sagemaker_session,