-
Notifications
You must be signed in to change notification settings - Fork 4.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: LLM routing integration via NotDiamond (#4543)
* WIP * Integrate Notdiamond via API call * Fix frontend lock * Update * Update ND integration to use HandleInput for LLM models * Move NotDiamond out into bundles and add fallback call * Update src/backend/base/langflow/components/notdiamond/notdiamond.py Co-authored-by: Gabriel Luiz Freitas Almeida <[email protected]> * Update src/backend/base/langflow/components/notdiamond/notdiamond.py Co-authored-by: Gabriel Luiz Freitas Almeida <[email protected]> * Update src/frontend/src/icons/NotDiamond/NotDiamondIcon.jsx Co-authored-by: Gabriel Luiz Freitas Almeida <[email protected]> --------- Co-authored-by: Gabriel Luiz Freitas Almeida <[email protected]>
- Loading branch information
1 parent
39c34fc
commit ef2199b
Showing
7 changed files
with
572 additions
and
316 deletions.
There are no files selected for viewing
Empty file.
218 changes: 218 additions & 0 deletions
218
src/backend/base/langflow/components/notdiamond/notdiamond.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,218 @@ | ||
import warnings | ||
|
||
import requests | ||
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage | ||
from pydantic.v1 import SecretStr | ||
|
||
from langflow.base.models.chat_result import get_chat_result | ||
from langflow.base.models.model_utils import get_model_name | ||
from langflow.custom.custom_component.component import Component | ||
from langflow.io import ( | ||
BoolInput, | ||
DropdownInput, | ||
HandleInput, | ||
MessageInput, | ||
MessageTextInput, | ||
Output, | ||
SecretStrInput, | ||
StrInput, | ||
) | ||
from langflow.schema.message import Message | ||
|
||
ND_MODEL_MAPPING = { | ||
"gpt-4o": {"provider": "openai", "model": "gpt-4o"}, | ||
"gpt-4o-mini": {"provider": "openai", "model": "gpt-4o-mini"}, | ||
"gpt-4-turbo": {"provider": "openai", "model": "gpt-4-turbo-2024-04-09"}, | ||
"claude-3-5-haiku-20241022": {"provider": "anthropic", "model": "claude-3-5-haiku-20241022"}, | ||
"claude-3-5-sonnet-20241022": {"provider": "anthropic", "model": "claude-3-5-sonnet-20241022"}, | ||
"anthropic.claude-3-5-sonnet-20241022-v2:0": {"provider": "anthropic", "model": "claude-3-5-sonnet-20241022"}, | ||
"anthropic.claude-3-5-haiku-20241022-v1:0": {"provider": "anthropic", "model": "claude-3-5-haiku-20241022"}, | ||
"gemini-1.5-pro": {"provider": "google", "model": "gemini-1.5-pro-latest"}, | ||
"gemini-1.5-flash": {"provider": "google", "model": "gemini-1.5-flash-latest"}, | ||
"llama-3.1-sonar-large-128k-online": {"provider": "perplexity", "model": "llama-3.1-sonar-large-128k-online"}, | ||
"meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo": { | ||
"provider": "togetherai", | ||
"model": "Meta-Llama-3.1-70B-Instruct-Turbo", | ||
}, | ||
"meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo": { | ||
"provider": "togetherai", | ||
"model": "Meta-Llama-3.1-405B-Instruct-Turbo", | ||
}, | ||
"mistral-large-latest": {"provider": "mistral", "model": "mistral-large-2407"}, | ||
} | ||
|
||
|
||
class NotDiamondComponent(Component): | ||
display_name = "Not Diamond Router" | ||
description = "Call the right model at the right time with the world's most powerful AI model router." | ||
documentation: str = "https://docs.notdiamond.ai/" | ||
icon = "NotDiamond" | ||
name = "NotDiamond" | ||
|
||
def __init__(self, *args, **kwargs): | ||
super().__init__(*args, **kwargs) | ||
|
||
inputs = [ | ||
MessageInput(name="input_value", display_name="Input"), | ||
MessageTextInput( | ||
name="system_message", | ||
display_name="System Message", | ||
info="System message to pass to the model.", | ||
advanced=False, | ||
), | ||
HandleInput( | ||
name="models", | ||
display_name="Language Models", | ||
input_types=["LanguageModel"], | ||
required=True, | ||
is_list=True, | ||
info="Link the models you want to route between.", | ||
), | ||
SecretStrInput( | ||
name="api_key", | ||
display_name="Not Diamond API Key", | ||
info="The Not Diamond API Key to use for routing.", | ||
advanced=False, | ||
value="NOTDIAMOND_API_KEY", | ||
), | ||
StrInput( | ||
name="preference_id", | ||
display_name="Preference ID", | ||
info="The ID of the router preference that was configured via the Dashboard.", | ||
advanced=False, | ||
), | ||
DropdownInput( | ||
name="tradeoff", | ||
display_name="Tradeoff", | ||
info="The tradeoff between cost and latency for the router to determine the best LLM for a given query.", | ||
advanced=False, | ||
options=["quality", "cost", "latency"], | ||
value="quality", | ||
), | ||
BoolInput( | ||
name="hash_content", | ||
display_name="Hash Content", | ||
info="Whether to hash the content before being sent to the NotDiamond API.", | ||
advanced=False, | ||
value=False, | ||
), | ||
] | ||
|
||
outputs = [Output(display_name="Output", name="output", method="model_select")] | ||
|
||
def model_select(self) -> Message: | ||
api_key = SecretStr(self.api_key).get_secret_value() if self.api_key else None | ||
input_value = self.input_value | ||
system_message = self.system_message | ||
messages = self._format_input(input_value, system_message) | ||
|
||
selected_models = [] | ||
mapped_selected_models = [] | ||
for model in self.models: | ||
model_name = get_model_name(model) | ||
|
||
if model_name in ND_MODEL_MAPPING: | ||
selected_models.append(model) | ||
mapped_selected_models.append(ND_MODEL_MAPPING[model_name]) | ||
|
||
payload = { | ||
"messages": messages, | ||
"llm_providers": mapped_selected_models, | ||
"hash_content": self.hash_content, | ||
} | ||
|
||
if self.tradeoff != "quality": | ||
payload["tradeoff"] = self.tradeoff | ||
|
||
if self.preference_id and self.preference_id != "": | ||
payload["preference_id"] = self.preference_id | ||
|
||
header = { | ||
"Authorization": f"Bearer {api_key}", | ||
"accept": "application/json", | ||
"content-type": "application/json", | ||
} | ||
|
||
response = requests.post( | ||
"https://api.notdiamond.ai/v2/modelRouter/modelSelect", | ||
json=payload, | ||
headers=header, | ||
timeout=10, | ||
) | ||
|
||
result = response.json() | ||
|
||
if "providers" not in result: | ||
# No provider returned by NotDiamond API, likely failed. Fallback to first model. | ||
return self._call_get_chat_result(self.models[0], input_value, system_message) | ||
|
||
providers = result["providers"] | ||
|
||
if len(providers) == 0: | ||
# No provider returned by NotDiamond API, likely failed. Fallback to first model. | ||
return self._call_get_chat_result(self.models[0], input_value, system_message) | ||
|
||
nd_result = providers[0] | ||
|
||
chosen_model = None | ||
for nd_model, selected_model in zip(mapped_selected_models, selected_models, strict=False): | ||
if nd_model["provider"] == nd_result["provider"] and nd_model["model"] == nd_result["model"]: | ||
chosen_model = selected_model | ||
break | ||
|
||
if chosen_model is None: | ||
# No provider returned by NotDiamond API, likely failed. Fallback to first model. | ||
return self._call_get_chat_result(self.models[0], input_value, system_message) | ||
|
||
return self._call_get_chat_result(chosen_model, input_value, system_message) | ||
|
||
def _call_get_chat_result(self, chosen_model, input_value, system_message): | ||
result = get_chat_result( | ||
runnable=chosen_model, | ||
input_value=input_value, | ||
system_message=system_message, | ||
) | ||
return result | ||
|
||
def _format_input( | ||
self, | ||
input_value: str | Message, | ||
system_message: str | None = None, | ||
): | ||
messages: list[BaseMessage] = [] | ||
if not input_value and not system_message: | ||
msg = "The message you want to send to the router is empty." | ||
raise ValueError(msg) | ||
system_message_added = False | ||
if input_value: | ||
if isinstance(input_value, Message): | ||
with warnings.catch_warnings(): | ||
warnings.simplefilter("ignore") | ||
if "prompt" in input_value: | ||
prompt = input_value.load_lc_prompt() | ||
if system_message: | ||
prompt.messages = [ | ||
SystemMessage(content=system_message), | ||
*prompt.messages, # type: ignore[has-type] | ||
] | ||
system_message_added = True | ||
messages.extend(prompt.messages) | ||
else: | ||
messages.append(input_value.to_lc_message()) | ||
else: | ||
messages.append(HumanMessage(content=input_value)) | ||
|
||
if system_message and not system_message_added: | ||
messages.insert(0, SystemMessage(content=system_message)) | ||
|
||
# Convert Langchain messages to OpenAI format | ||
openai_messages = [] | ||
for msg in messages: | ||
if isinstance(msg, HumanMessage): | ||
openai_messages.append({"role": "user", "content": msg.content}) | ||
elif isinstance(msg, AIMessage): | ||
openai_messages.append({"role": "assistant", "content": msg.content}) | ||
elif isinstance(msg, SystemMessage): | ||
openai_messages.append({"role": "system", "content": msg.content}) | ||
|
||
return openai_messages |
Oops, something went wrong.