Skip to content

Commit

Permalink
use one method to get boto client for aws bedrock
Browse files Browse the repository at this point in the history
  • Loading branch information
warren830 committed Dec 9, 2024
1 parent 32f8439 commit 023e604
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 22 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import boto3
from botocore.config import Config


def get_bedrock_client(service_name, credentials=None):
client_config = Config(region_name=credentials["aws_region"])
aws_access_key_id = credentials["aws_access_key_id"],
aws_secret_access_key = credentials["aws_secret_access_key"]
if aws_access_key_id and aws_secret_access_key:
# 使用 AKSK 方式
client = boto3.client(
service_name=service_name,
config=client_config,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key
)
else:
# 使用 IAM 角色方式
client = boto3.client(service_name=service_name, config=client_config)

return client
10 changes: 3 additions & 7 deletions api/core/model_runtime/model_providers/bedrock/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel

from api.core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client

logger = logging.getLogger(__name__)
ANTHROPIC_BLOCK_MODE_PROMPT = """You should always follow the instructions and output a valid {{block}} object.
The structure of the {{block}} object you can found in the instructions, use {"answer": "$your_answer"} as the default structure
Expand Down Expand Up @@ -173,13 +175,7 @@ def _generate_with_converse(
:param stream: is stream response
:return: full response or stream response chunk generator result
"""
bedrock_client = boto3.client(
service_name="bedrock-runtime",
aws_access_key_id=credentials.get("aws_access_key_id"),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
region_name=credentials["aws_region"],
)

bedrock_client = get_bedrock_client("bedrock-runtime", credentials)
system, prompt_message_dicts = self._convert_converse_prompt_messages(prompt_messages)
inference_config, additional_model_fields = self._convert_converse_api_model_parameters(model_parameters, stop)

Expand Down
10 changes: 3 additions & 7 deletions api/core/model_runtime/model_providers/bedrock/rerank/rerank.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.rerank_model import RerankModel

from api.core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client


class BedrockRerankModel(RerankModel):
"""
Expand Down Expand Up @@ -48,13 +50,7 @@ def _invoke(
return RerankResult(model=model, docs=docs)

# initialize client
client_config = Config(region_name=credentials["aws_region"])
bedrock_runtime = boto3.client(
service_name="bedrock-agent-runtime",
config=client_config,
aws_access_key_id=credentials.get("aws_access_key_id", ""),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
)
bedrock_runtime = get_bedrock_client("bedrock-agent-runtime", credentials)
queries = [{"type": "TEXT", "textQuery": {"text": query}}]
text_sources = []
for text in docs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
)
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel

from api.core.model_runtime.model_providers.bedrock.get_bedrock_client import get_bedrock_client

logger = logging.getLogger(__name__)


Expand All @@ -48,14 +50,7 @@ def _invoke(
:param input_type: input type
:return: embeddings result
"""
client_config = Config(region_name=credentials["aws_region"])

bedrock_runtime = boto3.client(
service_name="bedrock-runtime",
config=client_config,
aws_access_key_id=credentials.get("aws_access_key_id"),
aws_secret_access_key=credentials.get("aws_secret_access_key"),
)
bedrock_runtime = get_bedrock_client("bedrock-runtime", credentials)

embeddings = []
token_usage = 0
Expand Down

0 comments on commit 023e604

Please sign in to comment.