Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: enhance gemini models #11497

Merged
merged 11 commits into from
Dec 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions api/core/file/file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ def to_prompt_message_content(
else:
data = _to_base64_data_string(f)

return ImagePromptMessageContent(data=data, detail=image_detail_config)
return ImagePromptMessageContent(data=data, detail=image_detail_config, format=f.extension.lstrip("."))
laipz8200 marked this conversation as resolved.
Show resolved Hide resolved
case FileType.AUDIO:
encoded_string = _get_encoded_string(f)
data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
return AudioPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.VIDEO:
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
data = _to_url(f)
Expand All @@ -65,14 +65,8 @@ def to_prompt_message_content(
raise ValueError("Missing file extension")
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.DOCUMENT:
data = _get_encoded_string(f)
if f.mime_type is None:
raise ValueError("Missing file mime_type")
return DocumentPromptMessageContent(
encode_format="base64",
mime_type=f.mime_type,
data=data,
)
data = _to_base64_data_string(f)
return DocumentPromptMessageContent(encode_format="base64", data=data, format=f.extension.lstrip("."))
case _:
raise ValueError(f"file type {f.type} is not supported")

Expand Down
3 changes: 2 additions & 1 deletion api/core/model_runtime/entities/message_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,14 @@ class DETAIL(StrEnum):

type: PromptMessageContentType = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW
format: str = Field("jpg", description="Image format")


class DocumentPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
encode_format: Literal["base64"]
mime_type: str
data: str
format: str = Field(..., description="Document format")


class PromptMessage(ABC, BaseModel):
Expand Down
12 changes: 7 additions & 5 deletions api/core/model_runtime/model_providers/anthropic/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,17 +526,19 @@ def _convert_prompt_messages(self, prompt_messages: Sequence[PromptMessage]) ->
}
sub_messages.append(sub_message_dict)
elif isinstance(message_content, DocumentPromptMessageContent):
if message_content.mime_type != "application/pdf":
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
laipz8200 marked this conversation as resolved.
Show resolved Hide resolved
if mime_type != "application/pdf":
raise ValueError(
f"Unsupported document type {message_content.mime_type}, "
"only support application/pdf"
f"Unsupported document type {mime_type}, " "only support application/pdf"
)
sub_message_dict = {
"type": "document",
"source": {
"type": message_content.encode_format,
"media_type": message_content.mime_type,
"data": message_content.data,
"media_type": mime_type,
"data": base64_data,
},
}
sub_messages.append(sub_message_dict)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ features:
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767
Expand Down
119 changes: 56 additions & 63 deletions api/core/model_runtime/model_providers/google/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
import base64
import io
import json
import os
import tempfile
import time
from collections.abc import Generator
from typing import Optional, Union, cast
from typing import Optional, Union

import google.ai.generativelanguage as glm
import google.generativeai as genai
import requests
from google.api_core import exceptions
from google.generativeai.client import _ClientManager
from google.generativeai.types import ContentType, GenerateContentResponse
from google.generativeai.types import ContentType, File, GenerateContentResponse
from google.generativeai.types.content_types import to_part
from PIL import Image

from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
DocumentPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
VideoPromptMessageContent,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
Expand All @@ -35,21 +36,7 @@
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel

GOOGLE_AVAILABLE_MIMETYPE = [
"application/pdf",
"application/x-javascript",
"text/javascript",
"application/x-python",
"text/x-python",
"text/plain",
"text/html",
"text/css",
"text/md",
"text/csv",
"text/xml",
"text/rtf",
]
from extensions.ext_redis import redis_client


class GoogleLargeLanguageModel(LargeLanguageModel):
Expand Down Expand Up @@ -201,29 +188,17 @@ def _generate(
if stop:
config_kwargs["stop_sequences"] = stop

genai.configure(api_key=credentials["google_api_key"])
google_model = genai.GenerativeModel(model_name=model)

history = []

# hack for gemini-pro-vision, which currently does not support multi-turn chat
if model == "gemini-pro-vision":
last_msg = prompt_messages[-1]
content = self._format_message_to_glm_content(last_msg)
history.append(content)
else:
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
else:
history.append(content)

# Create a new ClientManager with tenant's API key
new_client_manager = _ClientManager()
new_client_manager.configure(api_key=credentials["google_api_key"])
new_custom_client = new_client_manager.make_client("generative")

google_model._client = new_custom_client
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
else:
history.append(content)

response = google_model.generate_content(
contents=history,
Expand Down Expand Up @@ -346,7 +321,7 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str:

content = message.content
if isinstance(content, list):
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT)

if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
Expand All @@ -359,6 +334,44 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str:

return message_text

def _upload_file_content_to_google(self, message_content: PromptMessageContent) -> File:
key = f"{message_content.type.value}:{hash(message_content.data)}"
if redis_client.exists(key):
try:
return genai.get_file(redis_client.get(key).decode())
except:
pass
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
if message_content.data.startswith("data:"):
metadata, base64_data = message_content.data.split(",", 1)
file_content = base64.b64decode(base64_data)
mime_type = metadata.split(";", 1)[0].split(":")[1]
temp_file.write(file_content)
else:
# only ImagePromptMessageContent and VideoPromptMessageContent has url
try:
response = requests.get(message_content.data)
response.raise_for_status()
if message_content.type is ImagePromptMessageContent:
prefix = "image/"
elif message_content.type is VideoPromptMessageContent:
prefix = "video/"
mime_type = prefix + message_content.format
temp_file.write(response.content)
except Exception as ex:
raise ValueError(f"Failed to fetch data from url {message_content.data}, {ex}")
temp_file.flush()
try:
file = genai.upload_file(path=temp_file.name, mime_type=mime_type)
while file.state.name == "PROCESSING":
time.sleep(5)
file = genai.get_file(file.name)
# google will delete your upload files in 2 days.
redis_client.setex(key, 47 * 60 * 60, file.name)
return file
finally:
os.unlink(temp_file.name)

def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
"""
Format a single message into glm.Content for Google API
Expand All @@ -374,28 +387,8 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
for c in message.content:
if c.type == PromptMessageContentType.TEXT:
glm_content["parts"].append(to_part(c.data))
elif c.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, c)
if message_content.data.startswith("data:"):
metadata, base64_data = c.data.split(",", 1)
mime_type = metadata.split(";", 1)[0].split(":")[1]
else:
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
glm_content["parts"].append(blob)
elif c.type == PromptMessageContentType.DOCUMENT:
message_content = cast(DocumentPromptMessageContent, c)
if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE:
raise ValueError(f"Unsupported mime type {message_content.mime_type}")
blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}}
glm_content["parts"].append(blob)
else:
glm_content["parts"].append(self._upload_file_content_to_google(c))

return glm_content
elif isinstance(message, AssistantPromptMessage):
Expand Down
Loading
Loading