Skip to content

Commit

Permalink
Merge branch 'feat/model-runtime' into deploy/dev
Browse files Browse the repository at this point in the history
  • Loading branch information
guchenhe committed Dec 28, 2023
2 parents 709fbdb + 15d6806 commit effd62c
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 6 deletions.
19 changes: 13 additions & 6 deletions api/core/model_runtime/model_providers/google/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,12 +115,19 @@ def _generate(self, model: str, credentials: dict,
)

history = []
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
else:
history.append(content)

# hack for gemini-pro-vision, which currently does not support multi-turn chat
if model == "gemini-pro-vision":
last_msg = prompt_messages[-1]
content = self._format_message_to_glm_content(last_msg)
history.append(content)
else:
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
else:
history.append(content)


# Create a new ClientManager with tenant's API key
Expand Down
52 changes: 52 additions & 0 deletions api/tests/integration_tests/model_runtime/google/test_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,58 @@ def test_invoke_chat_model_with_vision():
assert len(result.message.content) > 0


def test_invoke_chat_model_with_vision_multi_pics():
model = GoogleLargeLanguageModel()

result = model.invoke(
model='gemini-pro-vision',
credentials={
'google_api_key': os.environ.get('GOOGLE_API_KEY')
},
prompt_messages=[
SystemPromptMessage(
content='You are a helpful AI assistant.'
),
UserPromptMessage(
content=[
TextPromptMessageContent(
data="what do you see?"
),
ImagePromptMessageContent(
data=''
)
]
),
AssistantPromptMessage(
content="I see a blue letter 'D' with a gradient from light blue to dark blue."
),
UserPromptMessage(
content=[
TextPromptMessageContent(
data="what about now?"
),
ImagePromptMessageContent(
data=''
)
]
)
],
model_parameters={
'temperature': 0.3,
'top_p': 0.2,
'top_k': 3,
'max_tokens': 100
},
stream=False,
user="abc-123"
)

print(f"resultz: {result.message.content}")
assert isinstance(result, LLMResult)
assert len(result.message.content) > 0



def test_get_num_tokens():
model = GoogleLargeLanguageModel()

Expand Down

0 comments on commit effd62c

Please sign in to comment.