From 8b15b742ade2713d9623db8ad7cd01e80881421c Mon Sep 17 00:00:00 2001 From: Bowen Liang Date: Wed, 13 Mar 2024 20:29:38 +0800 Subject: [PATCH] generalize position helper for parsing _position.yaml and sorting objects by name (#2803) --- api/core/extension/extensible.py | 14 ++-- .../model_providers/__base/ai_model.py | 14 +--- .../model_providers/model_provider_factory.py | 22 ++---- api/core/tools/provider/builtin/_positions.py | 21 +++--- api/core/utils/position_helper.py | 70 +++++++++++++++++++ 5 files changed, 95 insertions(+), 46 deletions(-) create mode 100644 api/core/utils/position_helper.py diff --git a/api/core/extension/extensible.py b/api/core/extension/extensible.py index c19aaefe9efafe..1809dcd8df0bb9 100644 --- a/api/core/extension/extensible.py +++ b/api/core/extension/extensible.py @@ -3,11 +3,12 @@ import json import logging import os -from collections import OrderedDict from typing import Any, Optional from pydantic import BaseModel +from core.utils.position_helper import sort_to_dict_by_position_map + class ExtensionModule(enum.Enum): MODERATION = 'moderation' @@ -36,7 +37,8 @@ def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None: @classmethod def scan_extensions(cls): - extensions = {} + extensions: list[ModuleExtension] = [] + position_map = {} # get the path of the current class current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py') @@ -63,6 +65,7 @@ def scan_extensions(cls): if os.path.exists(builtin_file_path): with open(builtin_file_path, encoding='utf-8') as f: position = int(f.read().strip()) + position_map[extension_name] = position if (extension_name + '.py') not in file_names: logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.") @@ -96,16 +99,15 @@ def scan_extensions(cls): with open(json_path, encoding='utf-8') as f: json_data = json.load(f) - extensions[extension_name] = ModuleExtension( + extensions.append(ModuleExtension( extension_class=extension_class, name=extension_name, label=json_data.get('label'), form_schema=json_data.get('form_schema'), builtin=builtin, position=position - ) + )) - sorted_items = sorted(extensions.items(), key=lambda x: (x[1].position is None, x[1].position)) - sorted_extensions = OrderedDict(sorted_items) + sorted_extensions = sort_to_dict_by_position_map(position_map, extensions, lambda x: x.name) return sorted_extensions diff --git a/api/core/model_runtime/model_providers/__base/ai_model.py b/api/core/model_runtime/model_providers/__base/ai_model.py index 026e6eca218f38..34a737549381de 100644 --- a/api/core/model_runtime/model_providers/__base/ai_model.py +++ b/api/core/model_runtime/model_providers/__base/ai_model.py @@ -18,6 +18,7 @@ ) from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer +from core.utils.position_helper import get_position_map, sort_by_position_map class AIModel(ABC): @@ -148,15 +149,7 @@ def predefined_models(self) -> list[AIModelEntity]: ] # get _position.yaml file path - position_file_path = os.path.join(provider_model_type_path, '_position.yaml') - - # read _position.yaml file - position_map = {} - if os.path.exists(position_file_path): - with open(position_file_path, encoding='utf-8') as f: - positions = yaml.safe_load(f) - # convert list to dict with key as model provider name, value as index - position_map = {position: index for index, position in enumerate(positions)} + position_map = get_position_map(provider_model_type_path) # traverse all model_schema_yaml_paths for model_schema_yaml_path in model_schema_yaml_paths: @@ -206,8 +199,7 @@ def predefined_models(self) -> list[AIModelEntity]: model_schemas.append(model_schema) # resort model schemas by position - if position_map: - model_schemas.sort(key=lambda x: position_map.get(x.model, 999)) + model_schemas = sort_by_position_map(position_map, model_schemas, lambda x: x.model) # cache model schemas self.model_schemas = model_schemas diff --git a/api/core/model_runtime/model_providers/model_provider_factory.py b/api/core/model_runtime/model_providers/model_provider_factory.py index 185ff627116bb2..ee0385c6d08080 100644 --- a/api/core/model_runtime/model_providers/model_provider_factory.py +++ b/api/core/model_runtime/model_providers/model_provider_factory.py @@ -1,10 +1,8 @@ import importlib import logging import os -from collections import OrderedDict from typing import Optional -import yaml from pydantic import BaseModel from core.model_runtime.entities.model_entities import ModelType @@ -12,6 +10,7 @@ from core.model_runtime.model_providers.__base.model_provider import ModelProvider from core.model_runtime.schema_validators.model_credential_schema_validator import ModelCredentialSchemaValidator from core.model_runtime.schema_validators.provider_credential_schema_validator import ProviderCredentialSchemaValidator +from core.utils.position_helper import get_position_map, sort_to_dict_by_position_map logger = logging.getLogger(__name__) @@ -200,7 +199,6 @@ def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]: if self.model_provider_extensions: return self.model_provider_extensions - model_providers = {} # get the path of current classes current_path = os.path.abspath(__file__) @@ -215,17 +213,10 @@ def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]: ] # get _position.yaml file path - position_file_path = os.path.join(model_providers_path, '_position.yaml') - - # read _position.yaml file - position_map = {} - if os.path.exists(position_file_path): - with open(position_file_path, encoding='utf-8') as f: - positions = yaml.safe_load(f) - # convert list to dict with key as model provider name, value as index - position_map = {position: index for index, position in enumerate(positions)} + position_map = get_position_map(model_providers_path) # traverse all model_provider_dir_paths + model_providers: list[ModelProviderExtension] = [] for model_provider_dir_path in model_provider_dir_paths: # get model_provider dir name model_provider_name = os.path.basename(model_provider_dir_path) @@ -256,14 +247,13 @@ def _get_model_provider_map(self) -> dict[str, ModelProviderExtension]: logger.warning(f"Missing {model_provider_name}.yaml file in {model_provider_dir_path}, Skip.") continue - model_providers[model_provider_name] = ModelProviderExtension( + model_providers.append(ModelProviderExtension( name=model_provider_name, provider_instance=model_provider_class(), position=position_map.get(model_provider_name) - ) + )) - sorted_items = sorted(model_providers.items(), key=lambda x: (x[1].position is None, x[1].position)) - sorted_extensions = OrderedDict(sorted_items) + sorted_extensions = sort_to_dict_by_position_map(position_map, model_providers, lambda x: x.name) self.model_provider_extensions = sorted_extensions diff --git a/api/core/tools/provider/builtin/_positions.py b/api/core/tools/provider/builtin/_positions.py index fa2c5d27ef823e..2bf70bd35643d0 100644 --- a/api/core/tools/provider/builtin/_positions.py +++ b/api/core/tools/provider/builtin/_positions.py @@ -1,8 +1,7 @@ import os.path -from yaml import FullLoader, load - from core.tools.entities.user_entities import UserToolProvider +from core.utils.position_helper import get_position_map, sort_by_position_map class BuiltinToolProviderSort: @@ -11,18 +10,14 @@ class BuiltinToolProviderSort: @classmethod def sort(cls, providers: list[UserToolProvider]) -> list[UserToolProvider]: if not cls._position: - tmp_position = {} - file_path = os.path.join(os.path.dirname(__file__), '..', '_position.yaml') - with open(file_path) as f: - for pos, val in enumerate(load(f, Loader=FullLoader)): - tmp_position[val] = pos - cls._position = tmp_position + cls._position = get_position_map(os.path.join(os.path.dirname(__file__), '..')) - def sort_compare(provider: UserToolProvider) -> int: + def name_func(provider: UserToolProvider) -> str: if provider.type == UserToolProvider.ProviderType.MODEL: - return cls._position.get(f'model.{provider.name}', 10000) - return cls._position.get(provider.name, 10000) - - sorted_providers = sorted(providers, key=sort_compare) + return f'model.{provider.name}' + else: + return provider.name + + sorted_providers = sort_by_position_map(cls._position, providers, name_func) return sorted_providers \ No newline at end of file diff --git a/api/core/utils/position_helper.py b/api/core/utils/position_helper.py new file mode 100644 index 00000000000000..e038390e096caa --- /dev/null +++ b/api/core/utils/position_helper.py @@ -0,0 +1,70 @@ +import logging +import os +from collections import OrderedDict +from collections.abc import Callable +from typing import Any, AnyStr + +import yaml + + +def get_position_map( + folder_path: AnyStr, + file_name: str = '_position.yaml', +) -> dict[str, int]: + """ + Get the mapping from name to index from a YAML file + :param folder_path: + :param file_name: the YAML file name, default to '_position.yaml' + :return: a dict with name as key and index as value + """ + try: + position_file_name = os.path.join(folder_path, file_name) + if not os.path.exists(position_file_name): + return {} + + with open(position_file_name, encoding='utf-8') as f: + positions = yaml.safe_load(f) + position_map = {} + for index, name in enumerate(positions): + if name and isinstance(name, str): + position_map[name.strip()] = index + return position_map + except: + logging.warning(f'Failed to load the YAML position file {folder_path}/{file_name}.') + return {} + + +def sort_by_position_map( + position_map: dict[str, int], + data: list[Any], + name_func: Callable[[Any], str], +) -> list[Any]: + """ + Sort the objects by the position map. + If the name of the object is not in the position map, it will be put at the end. + :param position_map: the map holding positions in the form of {name: index} + :param name_func: the function to get the name of the object + :param data: the data to be sorted + :return: the sorted objects + """ + if not position_map or not data: + return data + + return sorted(data, key=lambda x: position_map.get(name_func(x), float('inf'))) + + +def sort_to_dict_by_position_map( + position_map: dict[str, int], + data: list[Any], + name_func: Callable[[Any], str], +) -> OrderedDict[str, Any]: + """ + Sort the objects into a ordered dict by the position map. + If the name of the object is not in the position map, it will be put at the end. + :param position_map: the map holding positions in the form of {name: index} + :param name_func: the function to get the name of the object + :param data: the data to be sorted + :return: an OrderedDict with the sorted pairs of name and object + """ + sorted_items = sort_by_position_map(position_map, data, name_func) + return OrderedDict([(name_func(item), item) for item in sorted_items])