From 3ddc52b642fbaad4c3e94f765da37dc13c4addac Mon Sep 17 00:00:00 2001 From: Setsugennoao <41454651+Setsugennoao@users.noreply.github.com> Date: Thu, 14 Dec 2023 20:04:07 +0100 Subject: [PATCH] Force concrete subclasses of BaseScaler implement kernel_radius, add API to add abstract classes from other packages --- vskernels/kernels/abstract.py | 52 +++++++++++++++++++++++++++++------ vskernels/util.py | 5 ++++ 2 files changed, 48 insertions(+), 9 deletions(-) diff --git a/vskernels/kernels/abstract.py b/vskernels/kernels/abstract.py index f883527..cee9292 100644 --- a/vskernels/kernels/abstract.py +++ b/vskernels/kernels/abstract.py @@ -7,9 +7,9 @@ from stgpytools import inject_kwargs_params from vstools import ( - CustomIndexError, CustomValueError, FieldBased, FuncExceptT, GenericVSFunction, HoldsVideoFormatT, KwargsT, Matrix, - MatrixT, T, VideoFormatT, check_correct_subsampling, check_variable_resolution, core, depth, expect_bits, - get_subclasses, get_video_format, inject_self, vs, vs_object + CustomIndexError, CustomRuntimeError, CustomValueError, FieldBased, FuncExceptT, GenericVSFunction, + HoldsVideoFormatT, KwargsT, Matrix, MatrixT, T, VideoFormatT, check_correct_subsampling, check_variable_resolution, + core, depth, expect_bits, get_subclasses, get_video_format, inject_self, vs, vs_object ) from ..exceptions import UnknownDescalerError, UnknownKernelError, UnknownResamplerError, UnknownScalerError @@ -22,17 +22,14 @@ 'Kernel', 'KernelT' ] +_finished_loading_abstract = False + def _default_kernel_radius(cls: type[T], self: T) -> int: if hasattr(self, '_static_kernel_radius'): return ceil(self._static_kernel_radius) # type: ignore - try: - return super(cls, self).kernel_radius # type: ignore - except AttributeError: - ... - - raise NotImplementedError + return super(cls, self).kernel_radius # type: ignore @lru_cache @@ -132,6 +129,43 @@ class BaseScaler(vs_object): def __init__(self, **kwargs: Any) -> None: self.kwargs = kwargs + def __init_subclass__(cls) -> None: + if not _finished_loading_abstract: + return + + from .zimg import ZimgComplexKernel + from ..util import abstract_kernels + + if cls in abstract_kernels: + return + + import sys + + module = sys.modules[cls.__module__] + + if hasattr(module, '__abstract__'): + if cls.__name__ in module.__abstract__: + abstract_kernels.append(cls) # type: ignore + return + + if 'kernel_radius' in cls.__dict__.keys(): + return + + mro = [cls, *({*cls.mro()} - {*ZimgComplexKernel.mro()})] + + for sub_cls in mro: + if hasattr(sub_cls, '_static_kernel_radius'): + break + + try: + if hasattr(sub_cls, 'kernel_radius'): + break + except Exception: + ... + else: + if mro: + raise CustomRuntimeError('You must implement kernel_radius when inheriting BaseScaler!', reason=cls) + @classmethod def from_param( cls: type[BaseScalerT], scaler: str | type[BaseScalerT] | BaseScalerT | None = None, /, diff --git a/vskernels/util.py b/vskernels/util.py index 265885a..4fe0b79 100644 --- a/vskernels/util.py +++ b/vskernels/util.py @@ -257,3 +257,8 @@ def resample_to( return Point.resample(clip, out_fmt, matrix) return resampler.resample(clip, out_fmt, matrix) + + +if True: + from .kernels import abstract + abstract._finished_loading_abstract = True