Skip to content

Commit

Permalink
Fix masking when _keras_mask is unset during call (#20594)
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 authored Dec 5, 2024
1 parent 44e0723 commit e53e61d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 5 deletions.
13 changes: 8 additions & 5 deletions keras/src/layers/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,12 @@ def maybe_convert(x):
mask = tree.map_structure(backend.get_keras_mask, v)
kwargs[expected_mask_arg_name] = mask

# We need to cache the `previous_mask` before `__call__` because the
# mask might be removed during the call, such as `MultiHeadAttention`.
previous_mask = tree.map_structure(
backend.get_keras_mask, call_spec.first_arg
)

####################
# 7. Call the layer.
try:
Expand Down Expand Up @@ -918,12 +924,9 @@ def maybe_convert(x):
if backend.is_tensor(output):
self.add_loss(self.activity_regularizer(output))

# Set masks on outputs,
# provided only the first positional input arg and its mask.
# Set `previous_mask` on outputs if available. It is provided only
# for the first positional input arg and its mask.
# TODO: consider extending this to all args and kwargs.
previous_mask = tree.map_structure(
backend.get_keras_mask, call_spec.first_arg
)
if self.supports_masking:
self._set_mask_metadata(
call_spec.first_arg, outputs, previous_mask
Expand Down
17 changes: 17 additions & 0 deletions keras/src/layers/layer_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,23 @@ def call(self, x1, x2, x1_mask=None, x2_mask=None):
layer((x1_1, x1_2), x2)
layer(x1=(x1_1, x1_2), x2=x2)

class MaskUnsetDuringCallLayer(layers.Layer):
def __init__(self):
super().__init__()
self.supports_masking = True

def call(self, x, mask=None):
assert mask is not None
backend.set_keras_mask(x, None) # Unset mask
return x

layer = MaskUnsetDuringCallLayer()
x = backend.numpy.ones((4, 4))
mask = backend.numpy.ones((4,))
backend.set_keras_mask(x, mask)
y = layer(x)
self.assertAllClose(y._keras_mask, mask)

def test_stateless_call(self):
class TestLayer(layers.Layer):
def __init__(self):
Expand Down

0 comments on commit e53e61d

Please sign in to comment.