Skip to content

Commit

Permalink
Feat/open ai compatible functioncall (#2783)
Browse files Browse the repository at this point in the history
Co-authored-by: jyong <[email protected]>
  • Loading branch information
JohnJyong and JohnJyong authored Mar 11, 2024
1 parent f8951d7 commit e54c9cd
Show file tree
Hide file tree
Showing 8 changed files with 105 additions and 40 deletions.
2 changes: 1 addition & 1 deletion api/core/model_runtime/model_providers/cohere/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,7 +472,7 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
else:
raise ValueError(f"Got unknown type {message}")

if message.name is not None:
if message.name:
message_dict["user_name"] = message.name

return message_dict
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ parameter_rules:
min: 1
max: 8000
- name: safe_prompt
defulat: false
default: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ parameter_rules:
min: 1
max: 8000
- name: safe_prompt
defulat: false
default: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ parameter_rules:
min: 1
max: 8000
- name: safe_prompt
defulat: false
default: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ parameter_rules:
min: 1
max: 2048
- name: safe_prompt
defulat: false
default: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ parameter_rules:
min: 1
max: 8000
- name: safe_prompt
defulat: false
default: false
type: boolean
help:
en_US: Whether to inject a safety prompt before all conversations.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
AIModelEntity,
DefaultParameterName,
FetchFrom,
ModelFeature,
ModelPropertyKey,
ModelType,
ParameterRule,
Expand Down Expand Up @@ -166,11 +167,23 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode
"""
generate custom model entities from credentials
"""
support_function_call = False
features = []
function_calling_type = credentials.get('function_calling_type', 'no_call')
if function_calling_type == 'function_call':
features = [ModelFeature.TOOL_CALL]
support_function_call = True
endpoint_url = credentials["endpoint_url"]
# if not endpoint_url.endswith('/'):
# endpoint_url += '/'
# if 'https://api.openai.com/v1/' == endpoint_url:
# features = [ModelFeature.STREAM_TOOL_CALL]
entity = AIModelEntity(
model=model,
label=I18nObject(en_US=model),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
features=features if support_function_call else [],
model_properties={
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size', "4096")),
ModelPropertyKey.MODE: credentials.get('mode'),
Expand All @@ -194,14 +207,6 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode
max=1,
precision=2
),
ParameterRule(
name="top_k",
label=I18nObject(en_US="Top K"),
type=ParameterType.INT,
default=int(credentials.get('top_k', 1)),
min=1,
max=100
),
ParameterRule(
name=DefaultParameterName.FREQUENCY_PENALTY.value,
label=I18nObject(en_US="Frequency Penalty"),
Expand Down Expand Up @@ -232,7 +237,7 @@ def get_customizable_model_schema(self, model: str, credentials: dict) -> AIMode
output=Decimal(credentials.get('output_price', 0)),
unit=Decimal(credentials.get('unit', 0)),
currency=credentials.get('currency', "USD")
)
),
)

if credentials['mode'] == 'chat':
Expand Down Expand Up @@ -292,14 +297,22 @@ def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptM
raise ValueError("Unsupported completion type for model configuration.")

# annotate tools with names, descriptions, etc.
function_calling_type = credentials.get('function_calling_type', 'no_call')
formatted_tools = []
if tools:
data["tool_choice"] = "auto"
if function_calling_type == 'function_call':
data['functions'] = [{
"name": tool.name,
"description": tool.description,
"parameters": tool.parameters
} for tool in tools]
elif function_calling_type == 'tool_call':
data["tool_choice"] = "auto"

for tool in tools:
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
for tool in tools:
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))

data["tools"] = formatted_tools
data["tools"] = formatted_tools

if stop:
data["stop"] = stop
Expand Down Expand Up @@ -367,9 +380,9 @@ def create_final_llm_result_chunk(index: int, message: AssistantPromptMessage, f

for chunk in response.iter_lines(decode_unicode=True, delimiter=delimiter):
if chunk:
#ignore sse comments
# ignore sse comments
if chunk.startswith(':'):
continue
continue
decoded_chunk = chunk.strip().lstrip('data: ').lstrip()
chunk_json = None
try:
Expand Down Expand Up @@ -452,18 +465,24 @@ def _handle_generate_response(self, model: str, credentials: dict, response: req

response_content = ''
tool_calls = None

function_calling_type = credentials.get('function_calling_type', 'no_call')
if completion_type is LLMMode.CHAT:
response_content = output.get('message', {})['content']
tool_calls = output.get('message', {}).get('tool_calls')
if function_calling_type == 'tool_call':
tool_calls = output.get('message', {}).get('tool_calls')
elif function_calling_type == 'function_call':
tool_calls = output.get('message', {}).get('function_call')

elif completion_type is LLMMode.COMPLETION:
response_content = output['text']

assistant_message = AssistantPromptMessage(content=response_content, tool_calls=[])

if tool_calls:
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
if function_calling_type == 'tool_call':
assistant_message.tool_calls = self._extract_response_tool_calls(tool_calls)
elif function_calling_type == 'function_call':
assistant_message.tool_calls = [self._extract_response_function_call(tool_calls)]

usage = response_json.get("usage")
if usage:
Expand Down Expand Up @@ -522,33 +541,34 @@ def _convert_prompt_message_to_dict(self, message: PromptMessage) -> dict:
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
in
message.tool_calls]
# function_call = message.tool_calls[0]
# message_dict["function_call"] = {
# "name": function_call.function.name,
# "arguments": function_call.function.arguments,
# }
# message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
# in
# message.tool_calls]

function_call = message.tool_calls[0]
message_dict["function_call"] = {
"name": function_call.function.name,
"arguments": function_call.function.arguments,
}
elif isinstance(message, SystemPromptMessage):
message = cast(SystemPromptMessage, message)
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, ToolPromptMessage):
message = cast(ToolPromptMessage, message)
message_dict = {
"role": "tool",
"content": message.content,
"tool_call_id": message.tool_call_id
}
# message_dict = {
# "role": "function",
# "role": "tool",
# "content": message.content,
# "name": message.tool_call_id
# "tool_call_id": message.tool_call_id
# }
message_dict = {
"role": "function",
"content": message.content,
"name": message.tool_call_id
}
else:
raise ValueError(f"Got unknown type {message}")

if message.name is not None:
if message.name:
message_dict["name"] = message.name

return message_dict
Expand Down Expand Up @@ -693,3 +713,26 @@ def _extract_response_tool_calls(self,
tool_calls.append(tool_call)

return tool_calls

def _extract_response_function_call(self, response_function_call) \
-> AssistantPromptMessage.ToolCall:
"""
Extract function call from response
:param response_function_call: response function call
:return: tool call
"""
tool_call = None
if response_function_call:
function = AssistantPromptMessage.ToolCall.ToolCallFunction(
name=response_function_call['name'],
arguments=response_function_call['arguments']
)

tool_call = AssistantPromptMessage.ToolCall(
id=response_function_call['name'],
type="function",
function=function
)

return tool_call
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,28 @@ model_credential_schema:
value: llm
default: '4096'
type: text-input
- variable: function_calling_type
show_on:
- variable: __model_type
value: llm
label:
en_US: Function calling
type: select
required: false
default: no_call
options:
- value: function_call
label:
en_US: Support
zh_Hans: 支持
# - value: tool_call
# label:
# en_US: Tool Call
# zh_Hans: Tool Call
- value: no_call
label:
en_US: Not Support
zh_Hans: 不支持
- variable: stream_mode_delimiter
label:
zh_Hans: 流模式返回结果的分隔符
Expand Down

0 comments on commit e54c9cd

Please sign in to comment.