Skip to content

Commit

Permalink
generalize position helper for parsing _position.yaml and sorting obj…
Browse files Browse the repository at this point in the history
…ects by name (#2803)
  • Loading branch information
bowenliang123 authored Mar 13, 2024
1 parent 849dc05 commit 8b15b74
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 46 deletions.
14 changes: 8 additions & 6 deletions api/core/extension/extensible.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@
import json
import logging
import os
from collections import OrderedDict
from typing import Any, Optional

from pydantic import BaseModel

from core.utils.position_helper import sort_to_dict_by_position_map


class ExtensionModule(enum.Enum):
MODERATION = 'moderation'
Expand Down Expand Up @@ -36,7 +37,8 @@ def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None:

@classmethod
def scan_extensions(cls):
extensions = {}
extensions: list[ModuleExtension] = []
position_map = {}

# get the path of the current class
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py')
Expand All @@ -63,6 +65,7 @@ def scan_extensions(cls):
if os.path.exists(builtin_file_path):
with open(builtin_file_path, encoding='utf-8') as f:
position = int(f.read().strip())
position_map[extension_name] = position

if (extension_name + '.py') not in file_names:
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
Expand Down Expand Up @@ -96,16 +99,15 @@ def scan_extensions(cls):
with open(json_path, encoding='utf-8') as f:
json_data = json.load(f)

extensions[extension_name] = ModuleExtension(
extensions.append(ModuleExtension(
extension_class=extension_class,
name=extension_name,
label=json_data.get('label'),
form_schema=json_data.get('form_schema'),
builtin=builtin,
position=position
)
))

sorted_items = sorted(extensions.items(), key=lambda x: (x[1].position is None, x[1].position))
sorted_extensions = OrderedDict(sorted_items)
sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name)

return sorted_extensions
14 changes: 3 additions & 11 deletions api/core/model_runtime/model_providers/__base/ai_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
)
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.utils.position_helper import get_position_map, sort_by_position_map


class AIModel(ABC):
Expand Down Expand Up @@ -148,15 +149,7 @@ def predefined_models(self) -> list[AIModelEntity]:
]

# get _position.yaml file path
position_file_path = os.path.join(provider_model_type_path, '_position.yaml')

# read _position.yaml file
position_map = {}
if os.path.exists(position_file_path):
with open(position_file_path, encoding='utf-8') as f:
positions = yaml.safe_load(f)
# convert list to dict with key as model provider name, value as index
position_map = {position: index for index, position in enumerate(positions)}
position_map = get_position_map(provider_model_type_path)

# traverse all model_schema_yaml_paths
for model_schema_yaml_path in model_schema_yaml_paths:
Expand Down Expand Up @@ -206,8 +199,7 @@ def predefined_models(self) -> list[AIModelEntity]:
model_schemas.append(model_schema)

# resort model schemas by position
if position_map:
model_schemas.sort(key=lambda x: position_map.get(x.model, 999))
model_schemas = sort_by_position_map(position_map, model_schemas, lambda x: x.model)

# cache model schemas
self.model_schemas = model_schemas
Expand Down
22 changes: 6 additions & 16 deletions api/core/model_runtime/model_providers/model_provider_factory.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
import importlib
import logging
import os
from collections import OrderedDict
from typing import Optional

import yaml
from pydantic import BaseModel

from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.entities.provider_entities import ProviderConfig, ProviderEntity, SimpleProviderEntity
from core.model_runtime.model_providers.__base.model_provider import ModelProvider
from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator
from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator
from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -200,7 +199,6 @@ def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]:
if self.model_provider_extensions:
return self.model_provider_extensions

model_providers = {}

# get the path of current classes
current_path = os.path.abspath(__file__)
Expand All @@ -215,17 +213,10 @@ def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]:
]

# get _position.yaml file path
position_file_path = os.path.join(model_providers_path, '_position.yaml')

# read _position.yaml file
position_map = {}
if os.path.exists(position_file_path):
with open(position_file_path, encoding='utf-8') as f:
positions = yaml.safe_load(f)
# convert list to dict with key as model provider name, value as index
position_map = {position: index for index, position in enumerate(positions)}
position_map = get_position_map(model_providers_path)

# traverse all model_provider_dir_paths
model_providers: list[ModelProviderExtension] = []
for model_provider_dir_path in model_provider_dir_paths:
# get model_provider dir name
model_provider_name = os.path.basename(model_provider_dir_path)
Expand Down Expand Up @@ -256,14 +247,13 @@ def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]:
logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.")
continue

model_providers[model_provider_name] = ModelProviderExtension(
model_providers.append(ModelProviderExtension(
name=model_provider_name,
provider_instance=model_provider_class(),
position=position_map.get(model_provider_name)
)
))

sorted_items = sorted(model_providers.items(), key=lambda x: (x[1].position is None, x[1].position))
sorted_extensions = OrderedDict(sorted_items)
sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name)

self.model_provider_extensions = sorted_extensions

Expand Down
21 changes: 8 additions & 13 deletions api/core/tools/provider/builtin/_positions.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import os.path

from yaml import FullLoader, load

from core.tools.entities.user_entities import UserToolProvider
from core.utils.position_helper import get_position_map, sort_by_position_map


class BuiltinToolProviderSort:
Expand All @@ -11,18 +10,14 @@ class BuiltinToolProviderSort:
@classmethod
def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]:
if not cls._position:
tmp_position = {}
file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml')
with open(file_path) as f:
for pos, val in enumerate(load(f, Loader=FullLoader)):
tmp_position[val] = pos
cls._position = tmp_position
cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..'))

def sort_compare(provider: UserToolProvider) -> int:
def name_func(provider: UserToolProvider) -> str:
if provider.type == UserToolProvider.ProviderType.MODEL:
return cls._position.get(f'model.{provider.name}', 10000)
return cls._position.get(provider.name, 10000)

sorted_providers = sorted(providers, key=sort_compare)
return f'model.{provider.name}'
else:
return provider.name

sorted_providers = sort_by_position_map(cls._position, providers, name_func)

return sorted_providers
70 changes: 70 additions & 0 deletions api/core/utils/position_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
import logging
import os
from collections import OrderedDict
from collections.abc import Callable
from typing import Any, AnyStr

import yaml


def get_position_map(
folder_path: AnyStr,
file_name: str = '_position.yaml',
) -> dict[str, int]:
"""
Get the mapping from name to index from a YAML file
:param folder_path:
:param file_name: the YAML file name, default to '_position.yaml'
:return: a dict with name as key and index as value
"""
try:
position_file_name = os.path.join(folder_path, file_name)
if not os.path.exists(position_file_name):
return {}

with open(position_file_name, encoding='utf-8') as f:
positions = yaml.safe_load(f)
position_map = {}
for index, name in enumerate(positions):
if name and isinstance(name, str):
position_map[name.strip()] = index
return position_map
except:
logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.')
return {}


def sort_by_position_map(
position_map: dict[str, int],
data: list[Any],
name_func: Callable[[Any], str],
) -> list[Any]:
"""
Sort the objects by the position map.
If the name of the object is not in the position map, it will be put at the end.
:param position_map: the map holding positions in the form of {name: index}
:param name_func: the function to get the name of the object
:param data: the data to be sorted
:return: the sorted objects
"""
if not position_map or not data:
return data

return sorted(data, key=lambda x: position_map.get(name_func(x), float('inf')))


def sort_to_dict_by_position_map(
position_map: dict[str, int],
data: list[Any],
name_func: Callable[[Any], str],
) -> OrderedDict[str, Any]:
"""
Sort the objects into a ordered dict by the position map.
If the name of the object is not in the position map, it will be put at the end.
:param position_map: the map holding positions in the form of {name: index}
:param name_func: the function to get the name of the object
:param data: the data to be sorted
:return: an OrderedDict with the sorted pairs of name and object
"""
sorted_items = sort_by_position_map(position_map, data, name_func)
return OrderedDict([(name_func(item), item) for item in sorted_items])

0 comments on commit 8b15b74

Please sign in to comment.