Skip to content

Commit

Permalink
Clean-up
Browse files Browse the repository at this point in the history
  • Loading branch information
LightArrowsEXE committed Sep 29, 2024
1 parent 8b442fc commit 39d61fb
Show file tree
Hide file tree
Showing 4 changed files with 81 additions and 92 deletions.
32 changes: 18 additions & 14 deletions vstools/dependencies/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,23 +83,25 @@ def add_plugin(
self,
dependency: str,
functions: Iterable[str] | None = None,
url: str | None = None,
parent_package: str | None = None,
*,
url: str | None = None,
optional: bool = False,
) -> None:
"""
Register a plugin for a package.
:param dependency: The name of the plugin to depend on.
:param functions: A list of functions used by the plugin.
A check is performed to ensure the installed plugin has these functions.
:param url: The url to the plugin's download page. Defaults to None.
:param parent_package: The name of the package that depends on this plugin.
:param optional: Whether the plugin is optional. Defaults to False.
:param dependency: The name of the plugin to depend on.
:param functions: A list of functions used by the plugin.
A check is performed to ensure the installed plugin has these functions.
:param parent_package: The name of the package that depends on this plugin.
:param url: The url to the plugin's download page. Defaults to None.
:param optional: Whether the plugin is optional. Defaults to False.
"""

if not parent_package:
from ..utils.package import get_calling_package

parent_package = get_calling_package()

if not dependency:
Expand All @@ -123,21 +125,23 @@ def add_plugin(
def add_package(
self,
dependency: str,
functions: str | list[str] | None = None,
parent_package: str | None = None,
*,
version: str | None = None,
functions: str | list[str] | None = None,
optional: bool = False,
url: str | None = None,
) -> None:
"""
Register a package dependency.
:param dependency: The name of the dependency to register.
:param parent_package: The name of the package that depends on this dependency.
:param version: The required version of the dependency. Defaults to None.
:param functions: A function or list of functions to check for in the dependency. Defaults to None.
:param optional: Whether the dependency is optional. Defaults to False.
:param url: The url to the dependency's download page. Defaults to None.
:param dependency: The name of the dependency to register.
:param functions: A function or list of functions to check for in the dependency.
Defaults to None.
:param parent_package: The name of the package that depends on this dependency.
:param version: The required version of the dependency. Defaults to None.
:param optional: Whether the dependency is optional. Defaults to False.
:param url: The url to the dependency's download page. Defaults to None.
"""

if not parent_package:
Expand Down
75 changes: 46 additions & 29 deletions vstools/exceptions/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import importlib.util
import warnings
from typing import Any

from stgpytools import CustomValueError, FuncExceptT

from ..dependencies.registry import dependency_registry, PackageInfo, PluginInfo
from ..dependencies.registry import PackageInfo, PluginInfo, dependency_registry

__all__: list[str] = [
'DependencyRegistryError',

'PluginNotFoundError',
'PackageNotFoundError',
]
Expand All @@ -19,8 +22,9 @@ def _get_package(cls, parent_package: str | None = None) -> str:
"""
Get the package name, either from the provided value or by auto-detection.
:param parent_package: The package name (optional).
:return: The package name.
:param parent_package: The name of the parent package. If None, auto-detect.
:return: The package name.
"""

if parent_package is not None:
Expand All @@ -47,36 +51,36 @@ def _check_vsrepo(plugin: str) -> bool:

def run_vsrepo(args: Any) -> Any:
try:
return subprocess.run(["vsrepo"] + args, capture_output=True, text=True, check=True)
return subprocess.run(['vsrepo'] + args, capture_output=True, text=True, check=True)
except subprocess.CalledProcessError as e:
return e
except FileNotFoundError:
print("Failed to run vsrepo: command not found")
print('Failed to run vsrepo: command not found')

if run_vsrepo(["installed", plugin]).returncode == 0:
if run_vsrepo(['installed', plugin]).returncode == 0:
return True

result = run_vsrepo(["available", plugin])
result = run_vsrepo(['available', plugin])

if result.returncode != 0 or "not found" in result.stdout.lower():
if result.returncode != 0 or 'not found' in result.stdout.lower():
return False

print(f"Plugin '{plugin}' is available in VSRepo but not installed.")
print(f'Plugin \'{plugin}\' is available in VSRepo but not installed.')

user_input = input(f"Do you want to install \'{plugin}\'? (y/n): ").lower().strip()
user_input = input(f'Do you want to install \'{plugin}\'? (y/n): ').lower().strip()

if not user_input or user_input != 'y':
print(f"Installation of \'{plugin}\' cancelled by user.")
print(f'Installation of \'{plugin}\' cancelled by user.')

return False

result = run_vsrepo(["install", plugin])
result = run_vsrepo(['install', plugin])

if result and result.returncode == 0:
print(f"Successfully installed \'{plugin}\'")
print(f'Successfully installed \'{plugin}\'')
return True

print(f"Failed to install \'{plugin}\'")
print(f'Failed to install \'{plugin}\'')
return False

@classmethod
Expand All @@ -96,16 +100,16 @@ def _format_message(
version: str | None = None, url: str | None = None,
prompt_update: bool = False
) -> str:
msg = f"{base_msg}: [{', '.join(missing)}]. "
msg = f'{base_msg}: [{', '.join(missing)}]. '

if version:
msg += f"Required version: {version}. "
msg += f'Required version: {version}. '

if url:
msg += f"Download URL: {url}"
msg += f'Download URL: {url}'

if prompt_update:
msg += "You may need to update!"
msg += 'You may need to update!'

return msg

Expand Down Expand Up @@ -151,16 +155,22 @@ def check(
missing_functions = {}

for plugin in plugins_to_check:
plugin_data = dependency_registry.plugin_registry[parent_package].get(plugin)

if plugin_data and plugin_data.optional:
if not hasattr(core, plugin):
warnings.warn(f"Optional plugin '{plugin}' is not installed.", UserWarning)

continue

if not hasattr(core, plugin):
if cls._check_vsrepo(plugin):
continue

missing_plugins.append(plugin)
continue

plugin_data: PluginInfo = dependency_registry.plugin_registry[parent_package][plugin]

if not plugin_data.required_functions:
if not plugin_data or not plugin_data.required_functions:
continue

if missing_funcs := cls._get_missing_functions(plugin_data.required_functions, plugin):
Expand All @@ -171,14 +181,14 @@ def check(

if missing_plugins:
error_messages.append(
f"Plugin(s) not found for package '{parent_package}': {', '.join(missing_plugins)}"
f'Plugin(s) not found for package \'{parent_package}\': {', '.join(missing_plugins)}'
)

for plugin, funcs in missing_functions.items():
plugin_info: PluginInfo = dependency_registry.plugin_registry[parent_package][plugin]

error_messages.append(cls._format_message(
f"Plugin '{plugin}' for package '{parent_package}' is missing the following function(s)",
f'Plugin \'{plugin}\' for package \'{parent_package}\' is missing the following function(s)',
missing=funcs, url=plugin_info.url, prompt_update=True
))

Expand All @@ -199,7 +209,7 @@ def __init__(
**kwargs: Any
) -> None:
super().__init__(
message or f"Package '{package or parent_package}' not found in the registry",
message or f'Package \'{package or parent_package}\' not found in the registry',
func, parent_package=parent_package, **kwargs
)

Expand All @@ -222,6 +232,7 @@ def check(
"""

parent_package = cls._get_package(parent_package)

if packages is not None:
packages_to_check = [packages] if isinstance(packages, str) else packages
else:
Expand All @@ -235,27 +246,33 @@ def check(
missing_packages.append(pkg)
continue

package_info: PackageInfo = dependency_registry.package_registry[parent_package][pkg]
package_data: PackageInfo = dependency_registry.package_registry[parent_package][pkg]

if package_data.optional:
if importlib.util.find_spec(pkg) is None:
warnings.warn(f"Optional package '{pkg}' is not installed.", UserWarning)

continue

if not package_info.required_functions:
if not package_data.required_functions:
continue

if missing_funcs := cls._get_missing_functions(package_info.required_functions, pkg):
if missing_funcs := cls._get_missing_functions(package_data.required_functions, pkg):
missing_functions[pkg] = missing_funcs

if missing_packages or missing_functions:
error_messages = []

if missing_packages:
error_messages.append(
f"Package(s) not found for package '{parent_package}': {', '.join(missing_packages)}"
f'Package(s) not found for package \'{parent_package}\': {', '.join(missing_packages)}'
)

for pkg, funcs in missing_functions.items():
pkg_info = dependency_registry.package_registry[parent_package][pkg]

error_messages.append(cls._format_message(
f"Package '{pkg}' for package '{parent_package}' is missing the following function(s)",
f'Package \'{pkg}\' for package \'{parent_package}\' is missing the following function(s)',
missing=funcs, url=pkg_info.url, version=pkg_info.version
))

Expand Down
25 changes: 14 additions & 11 deletions vstools/functions/check.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ def check_correct_subsampling(
:raises InvalidSubsamplingError: The clip has invalid subsampling.
"""

from ..exceptions import InvalidSubsamplingError

if clip.format:
Expand All @@ -209,18 +210,18 @@ def check_correct_subsampling(


def check_dependencies(
func: FuncExceptT, message: str | None = None, parent_package: str | None = None, **kwargs: Any
) -> None:
parent_package: str | None = None, func: FuncExceptT | None = None, **kwargs: Any
) -> bool:
"""
Check for both plugin and package dependencies.
:param func: The function to check.
:param message: Custom error message (optional).
:param parent_package: The package name (optional, will be auto-detected if not provided).
:param kwargs: Additional keyword arguments.
:param parent_package: The name of the parent package. If None, automatically determine.
:param func: Function returned for custom error handling.
This should only be set by VS package developers.
:param kwargs: Additional keyword arguments to pass on to the `check` methods.
:raises PluginNotFoundError: If a required plugin is not found.
:raises PackageNotFoundError: If a required package is not found.
:raises PluginNotFoundError: If a required plugin is not found.
:raises PackageNotFoundError: If a required package is not found.
"""

func = func or check_dependencies
Expand All @@ -230,15 +231,15 @@ def check_dependencies(

parent_package = get_calling_package(2)

errors = []
errors = list[DependencyRegistryError]()

try:
PluginNotFoundError.check(func, None, message, parent_package, **kwargs)
PluginNotFoundError.check(func, None, None, parent_package, **kwargs)
except PluginNotFoundError as e:
errors.append(e)

try:
PackageNotFoundError.check(func, None, message, parent_package, **kwargs)
PackageNotFoundError.check(func, None, None, parent_package, **kwargs)
except PackageNotFoundError as e:
errors.append(e)

Expand All @@ -247,3 +248,5 @@ def check_dependencies(
raise errors[0]

raise DependencyRegistryError(func, '\n'.join(str(e) for e in errors))

return True
41 changes: 3 additions & 38 deletions vstools/utils/package.py
Original file line number Diff line number Diff line change
@@ -1,56 +1,21 @@
import inspect
from typing import ParamSpec, TypeVar

from stgpytools import SPath

__all__: list[str] = [
'get_calling_package_name',

'get_calling_package'
]


P = ParamSpec('P')
R = TypeVar('R')


def get_calling_package_name() -> str:
"""
Get the name of the package to which the calling function belongs.
:param depth: The depth in the call stack to look for the package name. Default is 1.
:return: The name of the package containing the calling function.
"""

frame = inspect.currentframe()

try:
if frame is None:
return "unknown"

module = inspect.getmodule(frame)

if module is None:
return "unknown"

package = module.__package__

if package is None:
return module.__name__.split('.')[0]

return package.split('.')[0]
finally:
del frame


def get_calling_package(depth: int = 2) -> str:
"""
Get the name of the package from which this function is called.
If the name is "__main__", use the caller's filename instead.
:param depth: The number of frames to go back in the call stack. Default is 2.
:return: The name of the calling package.
:return: The name of the calling package or file.
"""

stack = inspect.stack()
Expand Down

0 comments on commit 39d61fb

Please sign in to comment.