diff --git a/setup.py b/setup.py index 113a9fae..26ff0089 100644 --- a/setup.py +++ b/setup.py @@ -29,6 +29,7 @@ python_requires='>=3.12', packages=[ package_name, + f'{package_name}.dependencies', f'{package_name}.enums', f'{package_name}.exceptions', f'{package_name}.functions', diff --git a/vstools/__init__.py b/vstools/__init__.py index 2ac507f3..4e3f641b 100644 --- a/vstools/__init__.py +++ b/vstools/__init__.py @@ -1,3 +1,4 @@ +from .dependencies import * # noqa: F401, F403 from .enums import * # noqa: F401, F403 from .exceptions import * # noqa: F401, F403 from .functions import * # noqa: F401, F403 diff --git a/vstools/dependencies/__init__.py b/vstools/dependencies/__init__.py new file mode 100644 index 00000000..cc6190ea --- /dev/null +++ b/vstools/dependencies/__init__.py @@ -0,0 +1,2 @@ +from .enums import * # noqa: F401, F403 +from .registry import * # noqa: F401, F403 diff --git a/vstools/dependencies/enums.py b/vstools/dependencies/enums.py new file mode 100644 index 00000000..cf92645f --- /dev/null +++ b/vstools/dependencies/enums.py @@ -0,0 +1,18 @@ +from stgpytools import CustomIntEnum + +__all__ = [ + 'InstallModeEnum' +] + + +class InstallModeEnum(CustomIntEnum): + """Enumeration for different installation modes of dependencies.""" + + AUTO = 0 + """Automatically install missing dependencies without prompting.""" + + PROMPT = 1 + """Prompt the user before installing missing dependencies.""" + + MANUAL = 2 + """Do not install dependencies automatically; user must install manually.""" diff --git a/vstools/dependencies/registry.py b/vstools/dependencies/registry.py new file mode 100644 index 00000000..ce51d7bb --- /dev/null +++ b/vstools/dependencies/registry.py @@ -0,0 +1,208 @@ +import importlib.util +from dataclasses import dataclass, field +from typing import Iterable + +from stgpytools import CustomPermissionError, CustomTypeError + +from .enums import InstallModeEnum + +__all__: list[str] = [ + 'PackageDependencyRegistry', + 'dependency_registry', + + 'PluginInfo', + 'PackageInfo', +] + + +@dataclass +class PluginInfo: + """Information about a plugin.""" + + required_functions: list[str] = field(default_factory=list) + """A list of function names that must be provided by the plugin.""" + + url: str | None = None + """The URL where the plugin can be downloaded or found, if available.""" + + optional: bool = False + """Indicates whether the plugin is optional or required.""" + + +@dataclass +class PackageInfo: + """Information about a package.""" + + required_functions: list[str] = field(default_factory=list) + """A list of function names that must be provided by the package.""" + + version: str | None = None + """The version of the package, if specified.""" + + url: str | None = None + """The URL where the package can be downloaded or found, if available.""" + + optional: bool = False + """Indicates whether the package is optional or required.""" + + +@dataclass +class PackageDependencyRegistry: + """A registry for managing package dependencies and plugins.""" + + plugin_registry: dict[str, dict[str, PluginInfo]] = field(default_factory=dict) + """ + A registry of plugins and their metadata. + + Structure: + { + package_name (str): { + plugin_name (str): PluginInfo, + # Additional plugins... + }, + # Additional packages... + } + """ + + package_registry: dict[str, dict[str, PackageInfo]] = field(default_factory=dict) + """ + A registry of packages and their metadata. + + Structure: + { + package_name (str): PackageInfo, + # Additional packages... + } + """ + + vsrepo_available: bool = field(init=False) + """A flag indicating whether the 'vsrepo' module is available.""" + + install_mode: InstallModeEnum = field(default=InstallModeEnum.PROMPT) + """The installation mode for handling missing dependencies.""" + + def __post_init__(self) -> None: + self.vsrepo_available = importlib.util.find_spec('vsrepo') is not None + # TODO: Ensure vspreview's setting is set here if possible + + def set_install_mode(self, install_mode: InstallModeEnum) -> None: + """ + Set the installation mode for handling missing dependencies. + + This method can only be called from the vspreview package. + This is to prevent abuse from package maintainers. + + :param install_mode: The installation mode to set. + :type install_mode: The installation mode. + + :raises CustomPermissionError: If called from a package other than vspreview. + """ + + from ..utils.package import get_calling_package + + if get_calling_package() != 'vspreview': + raise CustomPermissionError("This method can only be called from the vspreview package.") + + if not isinstance(install_mode, InstallModeEnum): + raise CustomTypeError( + "install_mode must be an instance of InstallModeEnum, " + f"not {type(install_mode).__name__}", self.set_install_mode + ) + + self.install_mode = install_mode + + def add_plugin( + self, + dependency: str, + functions: Iterable[str] | None = None, + parent_package: str | None = None, + *, + url: str | None = None, + optional: bool = False, + ) -> None: + """ + Register a plugin for a package. + + :param dependency: The name of the plugin to depend on. + :param functions: A list of functions used by the plugin. + A check is performed to ensure the installed plugin has these functions. + :param parent_package: The name of the package that depends on this plugin. + :param url: The url to the plugin's download page. Defaults to None. + :param optional: Whether the plugin is optional. Defaults to False. + """ + + if not parent_package: + from ..utils.package import get_calling_package + + parent_package = get_calling_package() + + if not dependency: + return + + if parent_package not in self.plugin_registry: + self.plugin_registry[parent_package] = {} + + plugin_info = self.plugin_registry[parent_package].get(dependency, PluginInfo()) + + if functions: + new_functions = [functions] if isinstance(functions, str) else functions + plugin_info.required_functions.extend(list(set(new_functions) - set(plugin_info.required_functions))) + + if url: + plugin_info.url = url + + plugin_info.optional = optional + + self.plugin_registry[parent_package][dependency] = plugin_info + + def add_package( + self, + dependency: str, + functions: str | list[str] | None = None, + parent_package: str | None = None, + *, + version: str | None = None, + optional: bool = False, + url: str | None = None, + ) -> None: + """ + Register a package dependency. + + :param dependency: The name of the dependency to register. + :param functions: A function or list of functions to check for in the dependency. + Defaults to None. + :param parent_package: The name of the package that depends on this dependency. + :param version: The required version of the dependency. Defaults to None. + :param optional: Whether the dependency is optional. Defaults to False. + :param url: The url to the dependency's download page. Defaults to None. + """ + + if not parent_package: + from ..utils.package import get_calling_package + + parent_package = get_calling_package() + + if not dependency: + return + + if parent_package not in self.package_registry: + self.package_registry[parent_package] = {} + + package_info = self.package_registry[parent_package].get(dependency, PackageInfo()) + + if functions: + new_functions = [functions] if isinstance(functions, str) else functions + package_info.required_functions.extend(list(set(new_functions) - set(package_info.required_functions))) + + if url: + package_info.url = url + + if version: + package_info.version = version + + package_info.optional = optional + + self.package_registry[parent_package][dependency] = package_info + + +dependency_registry = PackageDependencyRegistry() diff --git a/vstools/exceptions/__init__.py b/vstools/exceptions/__init__.py index 698f8a86..ea55b53b 100644 --- a/vstools/exceptions/__init__.py +++ b/vstools/exceptions/__init__.py @@ -1,5 +1,6 @@ from .base import * # noqa: F401, F403 from .color import * # noqa: F401, F403 +from .dependencies import * # noqa: F401, F403 from .enum import * # noqa: F401, F403 from .file import * # noqa: F401, F403 from .generic import * # noqa: F401, F403 diff --git a/vstools/exceptions/dependencies.py b/vstools/exceptions/dependencies.py new file mode 100644 index 00000000..a7b299a8 --- /dev/null +++ b/vstools/exceptions/dependencies.py @@ -0,0 +1,370 @@ +import importlib.util +import warnings +from typing import Any + +from stgpytools import CustomValueError, FuncExceptT + +from ..dependencies.registry import PackageInfo, PluginInfo, dependency_registry +from ..dependencies.enums import InstallModeEnum + +__all__: list[str] = [ + 'DependencyRegistryError', + + 'PluginNotFoundError', + 'PackageNotFoundError', +] + + +prompt_y = ('yes', 'y', '') +prompt_opts = "(yes/no) [yes]" + + +class DependencyRegistryError(CustomValueError): + """Base class for dependency registry errors.""" + + @classmethod + def _get_package(cls, parent_package: str | None = None) -> str: + """ + Get the package name, either from the provided value or by auto-detection. + + :param parent_package: The name of the parent package. If None, auto-detect. + + :return: The package name. + """ + + if parent_package is not None: + return parent_package + + from ..utils.package import get_calling_package + + return get_calling_package(3) + + @staticmethod + def _check_pypi(package: str) -> bool: + """ + Check if the package is available on PyPI and install if necessary based on the install mode. + + :param package: The package to check. + + :return: True if the package was installed or already exists, False otherwise. + """ + + if dependency_registry.install_mode is InstallModeEnum.MANUAL: + return False + + try: + importlib.import_module(package) + return True + except ImportError: + pass + + import subprocess + import sys + + def run_pip(args: list[str]) -> subprocess.CompletedProcess[str]: + return subprocess.run([sys.executable, '-m', 'pip'] + args, capture_output=True, text=True, check=False) + + if importlib.util.find_spec(package) is not None: + return True + + print(f'Package \'{package}\' is not installed.') + + should_install = DependencyRegistryError._prompt_user_for_install(package) + + if not should_install: + print(f'Installation of \'{package}\' cancelled.') + return False + + result = run_pip(['install', package]) + + if result.returncode == 0: + print(f'Successfully installed \'{package}\'') + return True + + print(f'Failed to install \'{package}\'. Error: {result.stderr}') + return False + + @staticmethod + def _check_vsrepo(plugin: str) -> bool: + """ + Check if the plugin is available in VSRepo and install if necessary based on the install mode. + + :param plugin: The plugin to check. + + :return: True if the plugin was installed or already exists, False otherwise. + """ + + if dependency_registry.install_mode is InstallModeEnum.MANUAL: + return False + + if not dependency_registry.vsrepo_available: + return False + + import subprocess + + def run_vsrepo(args: Any) -> Any: + try: + return subprocess.run(['vsrepo'] + args, capture_output=True, text=True, check=True) + except subprocess.CalledProcessError as e: + return e + except FileNotFoundError: + print('Failed to run vsrepo: command not found') + + result = run_vsrepo(['available', plugin]) + + if result.returncode != 0 or 'not found' in result.stdout.lower(): + return False + + print(f'Plugin \'{plugin}\' is available in VSRepo but not installed.') + + should_install = DependencyRegistryError._prompt_user_for_install(plugin) + + if not should_install: + print(f'Installation of \'{plugin}\' cancelled.') + return False + + result = run_vsrepo(['install', plugin]) + + if result and result.returncode == 0: + print(f'Successfully installed \'{plugin}\'') + return True + + print(f'Failed to install \'{plugin}\'') + return False + + @staticmethod + def _prompt_user_for_install(package_or_plugin: str) -> bool: + """ + Prompt the user for installation based on the install mode. + + :param package_or_plugin: The package or plugin name to install. + + :return: True if the user wants to install, False otherwise. + """ + + if dependency_registry.install_mode is InstallModeEnum.AUTO: + return True + + try: + prompt = input(f'Do you want to install \'{package_or_plugin}\'? {prompt_opts}: ') + return prompt.lower().strip() in prompt_y + except EOFError as e: + msg = f'Could not prompt the user to install \'{package_or_plugin}\'! ' + msg += 'Please install this plugin/package manually' + + try: + from vspreview import is_preview + + if is_preview(): + msg += ' or set the dependency install mode in VSPreview to \'auto\'' + except ModuleNotFoundError: + pass + + msg += f'! Error: {e}' + + warnings.warn(msg, UserWarning) + + return False + + @classmethod + def _get_missing_functions(cls, plugin_functions: list[str], plugin: str) -> list[str]: + from vstools import core + + plugin_obj = getattr(core, plugin) + + return [ + f for f in plugin_functions if not hasattr(plugin_obj, f) + and not hasattr(getattr(plugin_obj, f, None), '__call__') + ] + + @classmethod + def _format_message( + cls, base_msg: str, missing: list[str], + version: str | None = None, url: str | None = None, + prompt_update: bool = False + ) -> str: + msg = f'{base_msg}: [{', '.join(missing)}]. ' + + if version: + msg += f'Required version: {version}. ' + + if url: + msg += f'Download URL: {url}' + + if prompt_update: + msg += 'You may need to update!' + + return msg + + +class PluginNotFoundError(DependencyRegistryError): + """Raised when a required plugin is not found in the registry.""" + + def __init__( + self, func: FuncExceptT, parent_package: str, plugin: str, + message: str = 'Plugin \'{plugin}\' not found for package \'{parent_package}\'', + **kwargs: Any + ) -> None: + super().__init__(message, func, parent_package=parent_package, plugin=plugin, **kwargs) + + @classmethod + def check( + cls, func: FuncExceptT, plugins: str | list[str] | None = None, + message: str | None = None, parent_package: str | None = None, **kwargs: Any + ) -> None: + """ + Check if plugin(s) exist in the registry and raise an error if they don't. + + :param func: The function to check. + :param plugins: The plugin or list of plugins to check for. + If None, check all plugins in the caller package's namespace. + :param message: Optional custom error message. + :param parent_package: The package name to check for. If None, check all packages. + :param kwargs: Additional keyword arguments. + + :raises PluginNotFoundError: If any plugin is not found in the registry. + """ + + from vstools import core + + parent_package = cls._get_package(parent_package) + + if plugins is not None: + plugins_to_check = [plugins] if isinstance(plugins, str) else plugins + else: + plugins_to_check = list(dependency_registry.plugin_registry.get(parent_package, {}).keys()) + + missing_plugins = [] + missing_functions = {} + + for plugin in plugins_to_check: + plugin_data = dependency_registry.plugin_registry[parent_package].get(plugin) + + if plugin_data and plugin_data.optional: + if not hasattr(core, plugin): + warnings.warn( + f'Optional plugin \'{plugin}\' for \'{parent_package}\' is not installed.', UserWarning + ) + + continue + + if not hasattr(core, plugin): + if cls._check_vsrepo(plugin): + continue + + missing_plugins.append(plugin) + continue + + if not plugin_data or not plugin_data.required_functions: + continue + + if missing_funcs := cls._get_missing_functions(plugin_data.required_functions, plugin): + missing_functions[plugin] = missing_funcs + + if missing_plugins or missing_functions: + error_messages = [] + + if missing_plugins: + error_messages.append( + f'Plugin(s) not found for package \'{parent_package}\': {', '.join(missing_plugins)}' + ) + + for plugin, funcs in missing_functions.items(): + plugin_info: PluginInfo = dependency_registry.plugin_registry[parent_package][plugin] + + error_messages.append(cls._format_message( + f'Plugin \'{plugin}\' for package \'{parent_package}\' is missing the following function(s)', + missing=funcs, url=plugin_info.url, prompt_update=True + )) + + raise cls( + func, parent_package, + ', '.join(missing_plugins + list(missing_functions.keys())), + message or '\n'.join(error_messages), + **kwargs + ) + + +class PackageNotFoundError(DependencyRegistryError): + """Raised when a required package is not found in the registry.""" + + def __init__( + self, func: FuncExceptT, parent_package: str, + package: str | None = None, message: str | None = None, + **kwargs: Any + ) -> None: + super().__init__( + message or f'Package \'{package or parent_package}\' not found in the registry', + func, parent_package=parent_package, **kwargs + ) + + @classmethod + def check( + cls, func: FuncExceptT, packages: str | list[str] | None = None, + message: str | None = None, parent_package: str | None = None, **kwargs: Any + ) -> None: + """ + Check if package(s) exist in the registry and raise an error if they don't. + + :param func: The function to check. + :param packages: The package or list of packages to check for. + If None, check all packages in the caller package's namespace. + :param message: Optional custom error message. + :param parent_package: The package name to check for. If None, check all packages. + :param kwargs: Additional keyword arguments. + + :raises PackageNotFoundError: If any package is not found in the registry. + """ + + parent_package = cls._get_package(parent_package) + + if packages is not None: + packages_to_check = [packages] if isinstance(packages, str) else packages + else: + packages_to_check = list(dependency_registry.package_registry.get(parent_package, {}).keys()) + + missing_packages = [] + missing_functions = {} + + for pkg in packages_to_check: + package_data: PackageInfo = dependency_registry.package_registry[parent_package][pkg] + + if package_data.optional: + if importlib.util.find_spec(pkg) is None: + warnings.warn( + f'Optional package \'{pkg}\' for \'{parent_package}\' is not installed.', UserWarning + ) + + continue + + if not package_data.required_functions: + continue + + if cls._check_pypi(pkg) or cls._check_vsrepo(pkg): + continue + + if missing_funcs := cls._get_missing_functions(package_data.required_functions, pkg): + missing_functions[pkg] = missing_funcs + + if missing_packages or missing_functions: + error_messages = [] + + if missing_packages: + error_messages.append( + f'Package(s) not found for package \'{parent_package}\': {', '.join(missing_packages)}' + ) + + for pkg, funcs in missing_functions.items(): + pkg_info = dependency_registry.package_registry[parent_package][pkg] + + error_messages.append(cls._format_message( + f'Package \'{pkg}\' for package \'{parent_package}\' is missing the following function(s)', + missing=funcs, url=pkg_info.url, version=pkg_info.version + )) + + raise cls( + func, parent_package, + package=', '.join(missing_packages + list(missing_functions.keys())), + message=message or '\n'.join(error_messages), + **kwargs + ) diff --git a/vstools/functions/check.py b/vstools/functions/check.py index e87a9294..314878e7 100644 --- a/vstools/functions/check.py +++ b/vstools/functions/check.py @@ -8,7 +8,8 @@ from stgpytools import CustomError, F, FuncExceptT from ..exceptions import ( - FormatsRefClipMismatchError, ResolutionsRefClipMismatchError, VariableFormatError, VariableResolutionError + DependencyRegistryError, FormatsRefClipMismatchError, PluginNotFoundError, PackageNotFoundError, + ResolutionsRefClipMismatchError, VariableFormatError, VariableResolutionError, ) from ..types import ConstantFormatVideoNode @@ -19,7 +20,8 @@ 'check_variable_format', 'check_variable_resolution', 'check_variable', - 'check_correct_subsampling' + 'check_correct_subsampling', + 'check_dependencies', ] @@ -192,6 +194,7 @@ def check_correct_subsampling( :raises InvalidSubsamplingError: The clip has invalid subsampling. """ + from ..exceptions import InvalidSubsamplingError if clip.format: @@ -204,3 +207,46 @@ def check_correct_subsampling( 'The {subsampling} subsampling is not supported for this resolution!', reason=dict(width=width, height=height) ) + + +def check_dependencies( + parent_package: str | None = None, func: FuncExceptT | None = None, **kwargs: Any +) -> bool: + """ + Check for both plugin and package dependencies. + + :param parent_package: The name of the parent package. If None, automatically determine. + :param func: Function returned for custom error handling. + This should only be set by VS package developers. + :param kwargs: Additional keyword arguments to pass on to the `check` methods. + + :raises PluginNotFoundError: If a required plugin is not found. + :raises PackageNotFoundError: If a required package is not found. + """ + + func = func or check_dependencies + + if parent_package is None: + from ..utils.package import get_calling_package + + parent_package = get_calling_package(2) + + errors = list[DependencyRegistryError]() + + try: + PluginNotFoundError.check(func, None, None, parent_package, **kwargs) + except PluginNotFoundError as e: + errors.append(e) + + try: + PackageNotFoundError.check(func, None, None, parent_package, **kwargs) + except PackageNotFoundError as e: + errors.append(e) + + if errors: + if len(errors) == 1: + raise errors[0] + + raise DependencyRegistryError(func, '\n'.join(str(e) for e in errors)) + + return True diff --git a/vstools/utils/__init__.py b/vstools/utils/__init__.py index f25b8a94..7f3e70e4 100644 --- a/vstools/utils/__init__.py +++ b/vstools/utils/__init__.py @@ -10,6 +10,7 @@ from .mime import * # noqa: F401, F403 from .misc import * # noqa: F401, F403 from .other import * # noqa: F401, F403 +from .package import * # noqa: F401, F403 from .props import * # noqa: F401, F403 from .ranges import * # noqa: F401, F403 from .scale import * # noqa: F401, F403 diff --git a/vstools/utils/package.py b/vstools/utils/package.py new file mode 100644 index 00000000..79868aa3 --- /dev/null +++ b/vstools/utils/package.py @@ -0,0 +1,35 @@ +import inspect + +from stgpytools import SPath + +__all__: list[str] = [ + 'get_calling_package' +] + + +def get_calling_package(depth: int = 2) -> str: + """ + Get the name of the package from which this function is called. + + If the name is "__main__", use the caller's filename instead. + + :param depth: The number of frames to go back in the call stack. Default is 2. + + :return: The name of the calling package or file. + """ + + stack = inspect.stack() + + if len(stack) <= depth: + return 'unknown' + + frame_info = stack[depth] + module = inspect.getmodule(frame_info.frame) + + if not module: + return 'unknown' + + if module.__name__ == '__main__': + return SPath(frame_info.filename).name + + return module.__name__.split('.')[0]