Skip to content

Commit

Permalink
Fix issues with randomgrayscale layer
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Dec 14, 2024
1 parent d96f910 commit 4c05e0c
Showing 1 changed file with 3 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class RandomGrayscale(BaseImagePreprocessingLayer):
will have the same value.
"""

def __init__(self, factor=0.5, data_format=None, **kwargs):
def __init__(self, factor=0.5, data_format=None, seed=None, **kwargs):
super().__init__(**kwargs)
if factor < 0 or factor > 1:
raise ValueError(
Expand All @@ -54,7 +54,8 @@ def __init__(self, factor=0.5, data_format=None, **kwargs):
)
self.factor = factor
self.data_format = backend.standardize_data_format(data_format)
self.generator = self.backend.random.SeedGenerator()
self.seed = seed
self.generator = self.backend.random.SeedGenerator(seed)

def get_random_transformation(self, images, training=True, seed=None):
if seed is None:
Expand Down

0 comments on commit 4c05e0c

Please sign in to comment.