Skip to content

Commit

Permalink
feat: enhance gemini models (langgenius#11497)
Browse files Browse the repository at this point in the history
  • Loading branch information
hjlarry authored and 刘江波 committed Dec 20, 2024
1 parent 86e7e12 commit 641958c
Show file tree
Hide file tree
Showing 23 changed files with 138 additions and 113 deletions.
16 changes: 5 additions & 11 deletions api/core/file/file_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,12 @@ def to_prompt_message_content(
else:
data = _to_base64_data_string(f)

return ImagePromptMessageContent(data=data, detail=image_detail_config)
return ImagePromptMessageContent(data=data, detail=image_detail_config, format=f.extension.lstrip("."))
case FileType.AUDIO:
encoded_string = _get_encoded_string(f)
data = _to_base64_data_string(f)
if f.extension is None:
raise ValueError("Missing file extension")
return AudioPromptMessageContent(data=encoded_string, format=f.extension.lstrip("."))
return AudioPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.VIDEO:
if dify_config.MULTIMODAL_SEND_VIDEO_FORMAT == "url":
data = _to_url(f)
Expand All @@ -65,14 +65,8 @@ def to_prompt_message_content(
raise ValueError("Missing file extension")
return VideoPromptMessageContent(data=data, format=f.extension.lstrip("."))
case FileType.DOCUMENT:
data = _get_encoded_string(f)
if f.mime_type is None:
raise ValueError("Missing file mime_type")
return DocumentPromptMessageContent(
encode_format="base64",
mime_type=f.mime_type,
data=data,
)
data = _to_base64_data_string(f)
return DocumentPromptMessageContent(encode_format="base64", data=data, format=f.extension.lstrip("."))
case _:
raise ValueError(f"file type {f.type} is not supported")

Expand Down
3 changes: 2 additions & 1 deletion api/core/model_runtime/entities/message_entities.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,14 @@ class DETAIL(StrEnum):

type: PromptMessageContentType = PromptMessageContentType.IMAGE
detail: DETAIL = DETAIL.LOW
format: str = Field("jpg", description="Image format")


class DocumentPromptMessageContent(PromptMessageContent):
type: PromptMessageContentType = PromptMessageContentType.DOCUMENT
encode_format: Literal["base64"]
mime_type: str
data: str
format: str = Field(..., description="Document format")


class PromptMessage(ABC, BaseModel):
Expand Down
12 changes: 7 additions & 5 deletions api/core/model_runtime/model_providers/anthropic/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,17 +526,19 @@ def _convert_prompt_messages(self, prompt_messages: Sequence[PromptMessage]) ->
}
sub_messages.append(sub_message_dict)
elif isinstance(message_content, DocumentPromptMessageContent):
if message_content.mime_type != "application/pdf":
data_split = message_content.data.split(";base64,")
mime_type = data_split[0].replace("data:", "")
base64_data = data_split[1]
if mime_type != "application/pdf":
raise ValueError(
f"Unsupported document type {message_content.mime_type}, "
"only support application/pdf"
f"Unsupported document type {mime_type}, " "only support application/pdf"
)
sub_message_dict = {
"type": "document",
"source": {
"type": message_content.encode_format,
"media_type": message_content.mime_type,
"data": message_content.data,
"media_type": mime_type,
"data": base64_data,
},
}
sub_messages.append(sub_message_dict)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 1048576
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 2097152
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@ features:
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ features:
- vision
- tool-call
- stream-tool-call
- document
- video
- audio
model_properties:
mode: chat
context_size: 32767
Expand Down
119 changes: 56 additions & 63 deletions api/core/model_runtime/model_providers/google/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,30 @@
import base64
import io
import json
import os
import tempfile
import time
from collections.abc import Generator
from typing import Optional, Union, cast
from typing import Optional, Union

import google.ai.generativelanguage as glm
import google.generativeai as genai
import requests
from google.api_core import exceptions
from google.generativeai.client import _ClientManager
from google.generativeai.types import ContentType, GenerateContentResponse
from google.generativeai.types import ContentType, File, GenerateContentResponse
from google.generativeai.types.content_types import to_part
from PIL import Image

from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.entities.message_entities import (
AssistantPromptMessage,
DocumentPromptMessageContent,
ImagePromptMessageContent,
PromptMessage,
PromptMessageContent,
PromptMessageContentType,
PromptMessageTool,
SystemPromptMessage,
ToolPromptMessage,
UserPromptMessage,
VideoPromptMessageContent,
)
from core.model_runtime.errors.invoke import (
InvokeAuthorizationError,
Expand All @@ -35,21 +36,7 @@
)
from core.model_runtime.errors.validate import CredentialsValidateFailedError
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel

GOOGLE_AVAILABLE_MIMETYPE = [
"application/pdf",
"application/x-javascript",
"text/javascript",
"application/x-python",
"text/x-python",
"text/plain",
"text/html",
"text/css",
"text/md",
"text/csv",
"text/xml",
"text/rtf",
]
from extensions.ext_redis import redis_client


class GoogleLargeLanguageModel(LargeLanguageModel):
Expand Down Expand Up @@ -201,29 +188,17 @@ def _generate(
if stop:
config_kwargs["stop_sequences"] = stop

genai.configure(api_key=credentials["google_api_key"])
google_model = genai.GenerativeModel(model_name=model)

history = []

# hack for gemini-pro-vision, which currently does not support multi-turn chat
if model == "gemini-pro-vision":
last_msg = prompt_messages[-1]
content = self._format_message_to_glm_content(last_msg)
history.append(content)
else:
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
else:
history.append(content)

# Create a new ClientManager with tenant's API key
new_client_manager = _ClientManager()
new_client_manager.configure(api_key=credentials["google_api_key"])
new_custom_client = new_client_manager.make_client("generative")

google_model._client = new_custom_client
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
history[-1]["parts"].extend(content["parts"])
else:
history.append(content)

response = google_model.generate_content(
contents=history,
Expand Down Expand Up @@ -346,7 +321,7 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str:

content = message.content
if isinstance(content, list):
content = "".join(c.data for c in content if c.type != PromptMessageContentType.IMAGE)
content = "".join(c.data for c in content if c.type == PromptMessageContentType.TEXT)

if isinstance(message, UserPromptMessage):
message_text = f"{human_prompt} {content}"
Expand All @@ -359,6 +334,44 @@ def _convert_one_message_to_text(self, message: PromptMessage) -> str:

return message_text

def _upload_file_content_to_google(self, message_content: PromptMessageContent) -> File:
key = f"{message_content.type.value}:{hash(message_content.data)}"
if redis_client.exists(key):
try:
return genai.get_file(redis_client.get(key).decode())
except:
pass
with tempfile.NamedTemporaryFile(delete=False) as temp_file:
if message_content.data.startswith("data:"):
metadata, base64_data = message_content.data.split(",", 1)
file_content = base64.b64decode(base64_data)
mime_type = metadata.split(";", 1)[0].split(":")[1]
temp_file.write(file_content)
else:
# only ImagePromptMessageContent and VideoPromptMessageContent has url
try:
response = requests.get(message_content.data)
response.raise_for_status()
if message_content.type is ImagePromptMessageContent:
prefix = "image/"
elif message_content.type is VideoPromptMessageContent:
prefix = "video/"
mime_type = prefix + message_content.format
temp_file.write(response.content)
except Exception as ex:
raise ValueError(f"Failed to fetch data from url {message_content.data}, {ex}")
temp_file.flush()
try:
file = genai.upload_file(path=temp_file.name, mime_type=mime_type)
while file.state.name == "PROCESSING":
time.sleep(5)
file = genai.get_file(file.name)
# google will delete your upload files in 2 days.
redis_client.setex(key, 47 * 60 * 60, file.name)
return file
finally:
os.unlink(temp_file.name)

def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
"""
Format a single message into glm.Content for Google API
Expand All @@ -374,28 +387,8 @@ def _format_message_to_glm_content(self, message: PromptMessage) -> ContentType:
for c in message.content:
if c.type == PromptMessageContentType.TEXT:
glm_content["parts"].append(to_part(c.data))
elif c.type == PromptMessageContentType.IMAGE:
message_content = cast(ImagePromptMessageContent, c)
if message_content.data.startswith("data:"):
metadata, base64_data = c.data.split(",", 1)
mime_type = metadata.split(";", 1)[0].split(":")[1]
else:
# fetch image data from url
try:
image_content = requests.get(message_content.data).content
with Image.open(io.BytesIO(image_content)) as img:
mime_type = f"image/{img.format.lower()}"
base64_data = base64.b64encode(image_content).decode("utf-8")
except Exception as ex:
raise ValueError(f"Failed to fetch image data from url {message_content.data}, {ex}")
blob = {"inline_data": {"mime_type": mime_type, "data": base64_data}}
glm_content["parts"].append(blob)
elif c.type == PromptMessageContentType.DOCUMENT:
message_content = cast(DocumentPromptMessageContent, c)
if message_content.mime_type not in GOOGLE_AVAILABLE_MIMETYPE:
raise ValueError(f"Unsupported mime type {message_content.mime_type}")
blob = {"inline_data": {"mime_type": message_content.mime_type, "data": message_content.data}}
glm_content["parts"].append(blob)
else:
glm_content["parts"].append(self._upload_file_content_to_google(c))

return glm_content
elif isinstance(message, AssistantPromptMessage):
Expand Down
Loading

0 comments on commit 641958c

Please sign in to comment.