Skip to content

Commit

Permalink
Update grain postprocess
Browse files Browse the repository at this point in the history
  • Loading branch information
Setsugennoao committed Feb 5, 2024
1 parent 01f9f0d commit 6f78cd7
Showing 1 changed file with 17 additions and 4 deletions.
21 changes: 17 additions & 4 deletions vsdeband/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from functools import reduce
from typing import TYPE_CHECKING, Any, Callable, Iterable, TypeAlias, cast
from typing import TYPE_CHECKING, Any, Callable, Iterable, Protocol, TypeAlias, cast

from vsdenoise import Prefilter
from vsexprtools import complexpr_available, norm_expr
Expand All @@ -12,7 +12,7 @@
from vsrgtools import BlurMatrix
from vstools import (
CustomIndexError, CustomOverflowError, CustomValueError, InvalidColorFamilyError, KwargsT, Matrix, MatrixT,
VSFunctionNoArgs, check_variable, core, depth, fallback, get_neutral_values, get_peak_value, get_y, inject_self,
check_variable, core, depth, fallback, get_neutral_values, get_peak_value, get_sample_type, get_y, inject_self,
join, mod_x, normalize_seq, plane, scale_8bit, split, to_arr, vs
)

Expand Down Expand Up @@ -42,6 +42,16 @@ class _gpp:
Resolver: TypeAlias = Callable[[vs.VideoNode], Any]


class ResolverOneClipArgs(Protocol):
def __call__(self, grained: vs.VideoNode) -> vs.VideoNode:
...


class ResolverTwoClipsArgs(Protocol):
def __call__(self, grained: vs.VideoNode, clip: vs.VideoNode) -> vs.VideoNode:
...


@dataclass
class GrainPP(_gpp):
value: str
Expand All @@ -53,7 +63,7 @@ def Bump(cls, strength: float = 0.1) -> GrainPP:


FadeLimits = tuple[int | Iterable[int] | None, int | Iterable[int] | None]
GrainPostProcessT = VSFunctionNoArgs | str | GrainPP | GrainPP.Resolver
GrainPostProcessT = ResolverOneClipArgs | ResolverTwoClipsArgs | str | GrainPP | GrainPP.Resolver
GrainPostProcessesT = GrainPostProcessT | list[GrainPostProcessT]


Expand Down Expand Up @@ -260,7 +270,10 @@ def _try_grain(src: vs.VideoNode, stre: tuple[float, float] = strength, **args:
if self.postprocess:
for postprocess in cast(list[GrainPostProcessT], to_arr(self.postprocess)):
if callable(postprocess):
postprocess = postprocess(grained)
try:
postprocess = postprocess(grained, clip)
except TypeError:
postprocess = postprocess(grained)

if isinstance(postprocess, vs.VideoNode):
grained = postprocess
Expand Down

0 comments on commit 6f78cd7

Please sign in to comment.