diff --git a/vstools/dependencies/registry.py b/vstools/dependencies/registry.py index d7ac1cd..f59af2c 100644 --- a/vstools/dependencies/registry.py +++ b/vstools/dependencies/registry.py @@ -83,23 +83,25 @@ def add_plugin( self, dependency: str, functions: Iterable[str] | None = None, - url: str | None = None, parent_package: str | None = None, + *, + url: str | None = None, optional: bool = False, ) -> None: """ Register a plugin for a package. - :param dependency: The name of the plugin to depend on. - :param functions: A list of functions used by the plugin. - A check is performed to ensure the installed plugin has these functions. - :param url: The url to the plugin's download page. Defaults to None. - :param parent_package: The name of the package that depends on this plugin. - :param optional: Whether the plugin is optional. Defaults to False. + :param dependency: The name of the plugin to depend on. + :param functions: A list of functions used by the plugin. + A check is performed to ensure the installed plugin has these functions. + :param parent_package: The name of the package that depends on this plugin. + :param url: The url to the plugin's download page. Defaults to None. + :param optional: Whether the plugin is optional. Defaults to False. """ if not parent_package: from ..utils.package import get_calling_package + parent_package = get_calling_package() if not dependency: @@ -123,21 +125,23 @@ def add_plugin( def add_package( self, dependency: str, + functions: str | list[str] | None = None, parent_package: str | None = None, + *, version: str | None = None, - functions: str | list[str] | None = None, optional: bool = False, url: str | None = None, ) -> None: """ Register a package dependency. - :param dependency: The name of the dependency to register. - :param parent_package: The name of the package that depends on this dependency. - :param version: The required version of the dependency. Defaults to None. - :param functions: A function or list of functions to check for in the dependency. Defaults to None. - :param optional: Whether the dependency is optional. Defaults to False. - :param url: The url to the dependency's download page. Defaults to None. + :param dependency: The name of the dependency to register. + :param functions: A function or list of functions to check for in the dependency. + Defaults to None. + :param parent_package: The name of the package that depends on this dependency. + :param version: The required version of the dependency. Defaults to None. + :param optional: Whether the dependency is optional. Defaults to False. + :param url: The url to the dependency's download page. Defaults to None. """ if not parent_package: diff --git a/vstools/exceptions/dependencies.py b/vstools/exceptions/dependencies.py index 5344874..9ded1ef 100644 --- a/vstools/exceptions/dependencies.py +++ b/vstools/exceptions/dependencies.py @@ -1,11 +1,14 @@ +import importlib.util +import warnings from typing import Any from stgpytools import CustomValueError, FuncExceptT -from ..dependencies.registry import dependency_registry, PackageInfo, PluginInfo +from ..dependencies.registry import PackageInfo, PluginInfo, dependency_registry __all__: list[str] = [ 'DependencyRegistryError', + 'PluginNotFoundError', 'PackageNotFoundError', ] @@ -19,8 +22,9 @@ def _get_package(cls, parent_package: str | None = None) -> str: """ Get the package name, either from the provided value or by auto-detection. - :param parent_package: The package name (optional). - :return: The package name. + :param parent_package: The name of the parent package. If None, auto-detect. + + :return: The package name. """ if parent_package is not None: @@ -47,36 +51,36 @@ def _check_vsrepo(plugin: str) -> bool: def run_vsrepo(args: Any) -> Any: try: - return subprocess.run(["vsrepo"] + args, capture_output=True, text=True, check=True) + return subprocess.run(['vsrepo'] + args, capture_output=True, text=True, check=True) except subprocess.CalledProcessError as e: return e except FileNotFoundError: - print("Failed to run vsrepo: command not found") + print('Failed to run vsrepo: command not found') - if run_vsrepo(["installed", plugin]).returncode == 0: + if run_vsrepo(['installed', plugin]).returncode == 0: return True - result = run_vsrepo(["available", plugin]) + result = run_vsrepo(['available', plugin]) - if result.returncode != 0 or "not found" in result.stdout.lower(): + if result.returncode != 0 or 'not found' in result.stdout.lower(): return False - print(f"Plugin '{plugin}' is available in VSRepo but not installed.") + print(f'Plugin \'{plugin}\' is available in VSRepo but not installed.') - user_input = input(f"Do you want to install \'{plugin}\'? (y/n): ").lower().strip() + user_input = input(f'Do you want to install \'{plugin}\'? (y/n): ').lower().strip() if not user_input or user_input != 'y': - print(f"Installation of \'{plugin}\' cancelled by user.") + print(f'Installation of \'{plugin}\' cancelled by user.') return False - result = run_vsrepo(["install", plugin]) + result = run_vsrepo(['install', plugin]) if result and result.returncode == 0: - print(f"Successfully installed \'{plugin}\'") + print(f'Successfully installed \'{plugin}\'') return True - print(f"Failed to install \'{plugin}\'") + print(f'Failed to install \'{plugin}\'') return False @classmethod @@ -96,16 +100,16 @@ def _format_message( version: str | None = None, url: str | None = None, prompt_update: bool = False ) -> str: - msg = f"{base_msg}: [{', '.join(missing)}]. " + msg = f'{base_msg}: [{', '.join(missing)}]. ' if version: - msg += f"Required version: {version}. " + msg += f'Required version: {version}. ' if url: - msg += f"Download URL: {url}" + msg += f'Download URL: {url}' if prompt_update: - msg += "You may need to update!" + msg += 'You may need to update!' return msg @@ -151,6 +155,14 @@ def check( missing_functions = {} for plugin in plugins_to_check: + plugin_data = dependency_registry.plugin_registry[parent_package].get(plugin) + + if plugin_data and plugin_data.optional: + if not hasattr(core, plugin): + warnings.warn(f"Optional plugin '{plugin}' is not installed.", UserWarning) + + continue + if not hasattr(core, plugin): if cls._check_vsrepo(plugin): continue @@ -158,9 +170,7 @@ def check( missing_plugins.append(plugin) continue - plugin_data: PluginInfo = dependency_registry.plugin_registry[parent_package][plugin] - - if not plugin_data.required_functions: + if not plugin_data or not plugin_data.required_functions: continue if missing_funcs := cls._get_missing_functions(plugin_data.required_functions, plugin): @@ -171,14 +181,14 @@ def check( if missing_plugins: error_messages.append( - f"Plugin(s) not found for package '{parent_package}': {', '.join(missing_plugins)}" + f'Plugin(s) not found for package \'{parent_package}\': {', '.join(missing_plugins)}' ) for plugin, funcs in missing_functions.items(): plugin_info: PluginInfo = dependency_registry.plugin_registry[parent_package][plugin] error_messages.append(cls._format_message( - f"Plugin '{plugin}' for package '{parent_package}' is missing the following function(s)", + f'Plugin \'{plugin}\' for package \'{parent_package}\' is missing the following function(s)', missing=funcs, url=plugin_info.url, prompt_update=True )) @@ -199,7 +209,7 @@ def __init__( **kwargs: Any ) -> None: super().__init__( - message or f"Package '{package or parent_package}' not found in the registry", + message or f'Package \'{package or parent_package}\' not found in the registry', func, parent_package=parent_package, **kwargs ) @@ -222,6 +232,7 @@ def check( """ parent_package = cls._get_package(parent_package) + if packages is not None: packages_to_check = [packages] if isinstance(packages, str) else packages else: @@ -235,12 +246,18 @@ def check( missing_packages.append(pkg) continue - package_info: PackageInfo = dependency_registry.package_registry[parent_package][pkg] + package_data: PackageInfo = dependency_registry.package_registry[parent_package][pkg] + + if package_data.optional: + if importlib.util.find_spec(pkg) is None: + warnings.warn(f"Optional package '{pkg}' is not installed.", UserWarning) + + continue - if not package_info.required_functions: + if not package_data.required_functions: continue - if missing_funcs := cls._get_missing_functions(package_info.required_functions, pkg): + if missing_funcs := cls._get_missing_functions(package_data.required_functions, pkg): missing_functions[pkg] = missing_funcs if missing_packages or missing_functions: @@ -248,14 +265,14 @@ def check( if missing_packages: error_messages.append( - f"Package(s) not found for package '{parent_package}': {', '.join(missing_packages)}" + f'Package(s) not found for package \'{parent_package}\': {', '.join(missing_packages)}' ) for pkg, funcs in missing_functions.items(): pkg_info = dependency_registry.package_registry[parent_package][pkg] error_messages.append(cls._format_message( - f"Package '{pkg}' for package '{parent_package}' is missing the following function(s)", + f'Package \'{pkg}\' for package \'{parent_package}\' is missing the following function(s)', missing=funcs, url=pkg_info.url, version=pkg_info.version )) diff --git a/vstools/functions/check.py b/vstools/functions/check.py index 015d023..314878e 100644 --- a/vstools/functions/check.py +++ b/vstools/functions/check.py @@ -194,6 +194,7 @@ def check_correct_subsampling( :raises InvalidSubsamplingError: The clip has invalid subsampling. """ + from ..exceptions import InvalidSubsamplingError if clip.format: @@ -209,18 +210,18 @@ def check_correct_subsampling( def check_dependencies( - func: FuncExceptT, message: str | None = None, parent_package: str | None = None, **kwargs: Any -) -> None: + parent_package: str | None = None, func: FuncExceptT | None = None, **kwargs: Any +) -> bool: """ Check for both plugin and package dependencies. - :param func: The function to check. - :param message: Custom error message (optional). - :param parent_package: The package name (optional, will be auto-detected if not provided). - :param kwargs: Additional keyword arguments. + :param parent_package: The name of the parent package. If None, automatically determine. + :param func: Function returned for custom error handling. + This should only be set by VS package developers. + :param kwargs: Additional keyword arguments to pass on to the `check` methods. - :raises PluginNotFoundError: If a required plugin is not found. - :raises PackageNotFoundError: If a required package is not found. + :raises PluginNotFoundError: If a required plugin is not found. + :raises PackageNotFoundError: If a required package is not found. """ func = func or check_dependencies @@ -230,15 +231,15 @@ def check_dependencies( parent_package = get_calling_package(2) - errors = [] + errors = list[DependencyRegistryError]() try: - PluginNotFoundError.check(func, None, message, parent_package, **kwargs) + PluginNotFoundError.check(func, None, None, parent_package, **kwargs) except PluginNotFoundError as e: errors.append(e) try: - PackageNotFoundError.check(func, None, message, parent_package, **kwargs) + PackageNotFoundError.check(func, None, None, parent_package, **kwargs) except PackageNotFoundError as e: errors.append(e) @@ -247,3 +248,5 @@ def check_dependencies( raise errors[0] raise DependencyRegistryError(func, '\n'.join(str(e) for e in errors)) + + return True diff --git a/vstools/utils/package.py b/vstools/utils/package.py index 8ec774b..79868aa 100644 --- a/vstools/utils/package.py +++ b/vstools/utils/package.py @@ -1,56 +1,21 @@ import inspect -from typing import ParamSpec, TypeVar from stgpytools import SPath __all__: list[str] = [ - 'get_calling_package_name', - 'get_calling_package' ] -P = ParamSpec('P') -R = TypeVar('R') - - -def get_calling_package_name() -> str: - """ - Get the name of the package to which the calling function belongs. - - :param depth: The depth in the call stack to look for the package name. Default is 1. - - :return: The name of the package containing the calling function. - """ - - frame = inspect.currentframe() - - try: - if frame is None: - return "unknown" - - module = inspect.getmodule(frame) - - if module is None: - return "unknown" - - package = module.__package__ - - if package is None: - return module.__name__.split('.')[0] - - return package.split('.')[0] - finally: - del frame - - def get_calling_package(depth: int = 2) -> str: """ Get the name of the package from which this function is called. + If the name is "__main__", use the caller's filename instead. + :param depth: The number of frames to go back in the call stack. Default is 2. - :return: The name of the calling package. + :return: The name of the calling package or file. """ stack = inspect.stack()