diff --git a/keras/src/layers/layer.py b/keras/src/layers/layer.py index d065fdd2fdf..4f6a31986f9 100644 --- a/keras/src/layers/layer.py +++ b/keras/src/layers/layer.py @@ -872,6 +872,12 @@ def maybe_convert(x): mask = tree.map_structure(backend.get_keras_mask, v) kwargs[expected_mask_arg_name] = mask + # We need to cache the `previous_mask` before `__call__` because the + # mask might be removed during the call, such as `MultiHeadAttention`. + previous_mask = tree.map_structure( + backend.get_keras_mask, call_spec.first_arg + ) + #################### # 7. Call the layer. try: @@ -918,12 +924,9 @@ def maybe_convert(x): if backend.is_tensor(output): self.add_loss(self.activity_regularizer(output)) - # Set masks on outputs, - # provided only the first positional input arg and its mask. + # Set `previous_mask` on outputs if available. It is provided only + # for the first positional input arg and its mask. # TODO: consider extending this to all args and kwargs. - previous_mask = tree.map_structure( - backend.get_keras_mask, call_spec.first_arg - ) if self.supports_masking: self._set_mask_metadata( call_spec.first_arg, outputs, previous_mask diff --git a/keras/src/layers/layer_test.py b/keras/src/layers/layer_test.py index 5fa8bca7354..a74018b007c 100644 --- a/keras/src/layers/layer_test.py +++ b/keras/src/layers/layer_test.py @@ -673,6 +673,23 @@ def call(self, x1, x2, x1_mask=None, x2_mask=None): layer((x1_1, x1_2), x2) layer(x1=(x1_1, x1_2), x2=x2) + class MaskUnsetDuringCallLayer(layers.Layer): + def __init__(self): + super().__init__() + self.supports_masking = True + + def call(self, x, mask=None): + assert mask is not None + backend.set_keras_mask(x, None) # Unset mask + return x + + layer = MaskUnsetDuringCallLayer() + x = backend.numpy.ones((4, 4)) + mask = backend.numpy.ones((4,)) + backend.set_keras_mask(x, mask) + y = layer(x) + self.assertAllClose(y._keras_mask, mask) + def test_stateless_call(self): class TestLayer(layers.Layer): def __init__(self):