diff --git a/vsdeband/noise.py b/vsdeband/noise.py index c084b62..8a17f10 100644 --- a/vsdeband/noise.py +++ b/vsdeband/noise.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass, field from functools import reduce -from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeAlias, cast +from typing import TYPE_CHECKING, Any, Callable, Iterable, Protocol, TypeAlias, cast from vsdenoise import Prefilter from vsexprtools import complexpr_available, norm_expr @@ -12,7 +12,7 @@ from vsrgtools import BlurMatrix from vstools import ( CustomIndexError, CustomOverflowError, CustomValueError, InvalidColorFamilyError, KwargsT, Matrix, MatrixT, - VSFunctionNoArgs, check_variable, core, depth, fallback, get_neutral_values, get_peak_value, get_y, inject_self, + check_variable, core, depth, fallback, get_neutral_values, get_peak_value, get_sample_type, get_y, inject_self, join, mod_x, normalize_seq, plane, scale_8bit, split, to_arr, vs ) @@ -42,6 +42,16 @@ class _gpp: Resolver: TypeAlias = Callable[[vs.VideoNode], Any] +class ResolverOneClipArgs(Protocol): + def __call__(self, grained: vs.VideoNode) -> vs.VideoNode: + ... + + +class ResolverTwoClipsArgs(Protocol): + def __call__(self, grained: vs.VideoNode, clip: vs.VideoNode) -> vs.VideoNode: + ... + + @dataclass class GrainPP(_gpp): value: str @@ -53,7 +63,7 @@ def Bump(cls, strength: float = 0.1) -> GrainPP: FadeLimits = tuple[int | Iterable[int] | None, int | Iterable[int] | None] -GrainPostProcessT = VSFunctionNoArgs | str | GrainPP | GrainPP.Resolver +GrainPostProcessT = ResolverOneClipArgs | ResolverTwoClipsArgs | str | GrainPP | GrainPP.Resolver GrainPostProcessesT = GrainPostProcessT | list[GrainPostProcessT] @@ -260,7 +270,10 @@ def _try_grain(src: vs.VideoNode, stre: tuple[float, float] = strength, **args: if self.postprocess: for postprocess in cast(list[GrainPostProcessT], to_arr(self.postprocess)): if callable(postprocess): - postprocess = postprocess(grained) + try: + postprocess = postprocess(grained, clip) + except TypeError: + postprocess = postprocess(grained) if isinstance(postprocess, vs.VideoNode): grained = postprocess