Skip to content

Commit

Permalink
Addition of Sparsemax activation (#20558)
Browse files Browse the repository at this point in the history
* add: sprsemax ops

* add: sparsemax api references to inits

* add: sparsemax tests

* edit: changes after test

* edit: test case

* rename: function in numpy

* add: pointers to rest inits

* edit: docstrings

* change: x to logits in docstring
  • Loading branch information
old-school-kid authored Nov 28, 2024
1 parent 75522e4 commit a3a368d
Show file tree
Hide file tree
Showing 16 changed files with 217 additions and 1 deletion.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,5 @@ dist/**
examples/**/*.jpg
.python-version
.coverage
*coverage.xml
*coverage.xml
.ruff_cache
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from keras.src.activations.activations import softplus
from keras.src.activations.activations import softsign
from keras.src.activations.activations import sparse_plus
from keras.src.activations.activations import sparsemax
from keras.src.activations.activations import squareplus
from keras.src.activations.activations import tanh
from keras.src.activations.activations import tanh_shrink
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
from keras.src.ops.nn import softsign
from keras.src.ops.nn import sparse_categorical_crossentropy
from keras.src.ops.nn import sparse_plus
from keras.src.ops.nn import sparsemax
from keras.src.ops.nn import squareplus
from keras.src.ops.nn import tanh_shrink
from keras.src.ops.numpy import abs
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,6 @@
from keras.src.ops.nn import softsign
from keras.src.ops.nn import sparse_categorical_crossentropy
from keras.src.ops.nn import sparse_plus
from keras.src.ops.nn import sparsemax
from keras.src.ops.nn import squareplus
from keras.src.ops.nn import tanh_shrink
1 change: 1 addition & 0 deletions keras/api/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
from keras.src.activations.activations import softplus
from keras.src.activations.activations import softsign
from keras.src.activations.activations import sparse_plus
from keras.src.activations.activations import sparsemax
from keras.src.activations.activations import squareplus
from keras.src.activations.activations import tanh
from keras.src.activations.activations import tanh_shrink
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@
from keras.src.ops.nn import softsign
from keras.src.ops.nn import sparse_categorical_crossentropy
from keras.src.ops.nn import sparse_plus
from keras.src.ops.nn import sparsemax
from keras.src.ops.nn import squareplus
from keras.src.ops.nn import tanh_shrink
from keras.src.ops.numpy import abs
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,5 +45,6 @@
from keras.src.ops.nn import softsign
from keras.src.ops.nn import sparse_categorical_crossentropy
from keras.src.ops.nn import sparse_plus
from keras.src.ops.nn import sparsemax
from keras.src.ops.nn import squareplus
from keras.src.ops.nn import tanh_shrink
2 changes: 2 additions & 0 deletions keras/src/activations/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from keras.src.activations.activations import softplus
from keras.src.activations.activations import softsign
from keras.src.activations.activations import sparse_plus
from keras.src.activations.activations import sparsemax
from keras.src.activations.activations import squareplus
from keras.src.activations.activations import tanh
from keras.src.activations.activations import tanh_shrink
Expand Down Expand Up @@ -59,6 +60,7 @@
mish,
log_softmax,
log_sigmoid,
sparsemax,
}

ALL_OBJECTS_DICT = {fn.__name__: fn for fn in ALL_OBJECTS}
Expand Down
25 changes: 25 additions & 0 deletions keras/src/activations/activations.py
Original file line number Diff line number Diff line change
Expand Up @@ -617,3 +617,28 @@ def log_softmax(x, axis=-1):
axis: Integer, axis along which the softmax is applied.
"""
return ops.log_softmax(x, axis=axis)


@keras_export(["keras.activations.sparsemax"])
def sparsemax(x, axis=-1):
"""Sparsemax activation function.
For each batch `i`, and class `j`,
sparsemax activation function is defined as:
`sparsemax(x)[i, j] = max(x[i, j] - τ(x[i, :]), 0).`
Args:
x: Input tensor.
axis: `int`, axis along which the sparsemax operation is applied.
Returns:
A tensor, output of sparsemax transformation. Has the same type and
shape as `x`.
Reference:
- [Martins et.al., 2016](https://arxiv.org/abs/1602.02068)
"""
x = backend.convert_to_tensor(x)
return ops.sparsemax(x, axis)
49 changes: 49 additions & 0 deletions keras/src/activations/activations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,55 @@ def test_linear(self):
x_int32 = np.random.randint(-10, 10, (10, 5)).astype(np.int32)
self.assertAllClose(x_int32, activations.linear(x_int32))

def test_sparsemax(self):
# result check with 1d
x_1d = np.linspace(1, 12, num=12)
expected_result = np.zeros_like(x_1d)
expected_result[-1] = 1.0
self.assertAllClose(expected_result, activations.sparsemax(x_1d))

# result check with 2d
x_2d = np.linspace(1, 12, num=12).reshape(-1, 2)
expected_result = np.zeros_like(x_2d)
expected_result[:, -1] = 1.0
self.assertAllClose(expected_result, activations.sparsemax(x_2d))

# result check with 3d
x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3)
expected_result = np.zeros_like(x_3d)
expected_result[:, :, -1] = 1.0
self.assertAllClose(expected_result, activations.sparsemax(x_3d))

# result check with axis=-2 with 2d input
x_2d = np.linspace(1, 12, num=12).reshape(-1, 2)
expected_result = np.zeros_like(x_2d)
expected_result[-1, :] = 1.0
self.assertAllClose(
expected_result, activations.sparsemax(x_2d, axis=-2)
)

# result check with axis=-2 with 3d input
x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3)
expected_result = np.ones_like(x_3d)
self.assertAllClose(
expected_result, activations.sparsemax(x_3d, axis=-2)
)

# result check with axis=-3 with 3d input
x_3d = np.linspace(1, 12, num=12).reshape(-1, 1, 3)
expected_result = np.zeros_like(x_3d)
expected_result[-1, :, :] = 1.0
self.assertAllClose(
expected_result, activations.sparsemax(x_3d, axis=-3)
)

# result check with axis=-3 with 4d input
x_4d = np.linspace(1, 12, num=12).reshape(-1, 1, 1, 2)
expected_result = np.ones_like(x_4d)
self.assertAllClose(
expected_result, activations.sparsemax(x_4d, axis=-3)
)

def test_get_method(self):
obj = activations.get("relu")
self.assertEqual(obj, activations.relu)
Expand Down
18 changes: 18 additions & 0 deletions keras/src/backend/jax/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,24 @@ def log_softmax(x, axis=-1):
return jnn.log_softmax(x, axis=axis)


def sparsemax(logits, axis=-1):
# Sort logits along the specified axis in descending order
logits = convert_to_tensor(logits)
logits_sorted = -1.0 * jnp.sort(logits * -1.0, axis=axis)
logits_cumsum = jnp.cumsum(logits_sorted, axis=axis) # find cumulative sum
r = jnp.arange(1, logits.shape[axis] + 1) # Determine the sparsity
r_shape = [1] * logits.ndim
r_shape[axis] = -1 # Broadcast to match the target axis
r = r.reshape(r_shape)
support = logits_sorted - (logits_cumsum - 1) / r > 0
# Find the threshold
k = jnp.sum(support, axis=axis, keepdims=True)
logits_cumsum_safe = jnp.where(support, logits_cumsum, 0.0)
tau = (jnp.sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k
output = jnp.maximum(logits - tau, 0.0)
return output


def _convert_to_spatial_operand(
x,
num_spatial_dims,
Expand Down
18 changes: 18 additions & 0 deletions keras/src/backend/numpy/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,24 @@ def log_softmax(x, axis=None):
return x - max_x - logsumexp


def sparsemax(logits, axis=-1):
# Sort logits along the specified axis in descending order
logits = convert_to_tensor(logits)
logits_sorted = -1.0 * np.sort(-1.0 * logits, axis=axis)
logits_cumsum = np.cumsum(logits_sorted, axis=axis)
r = np.arange(1, logits.shape[axis] + 1)
r_shape = [1] * logits.ndim
r_shape[axis] = -1 # Broadcast to match the target axis
r = r.reshape(r_shape)
support = logits_sorted - (logits_cumsum - 1) / r > 0
# Find the threshold
k = np.sum(support, axis=axis, keepdims=True)
logits_cumsum_safe = np.where(support, logits_cumsum, 0.0)
tau = (np.sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k
output = np.maximum(logits - tau, 0.0)
return output


def _convert_to_spatial_operand(
x,
num_spatial_dims,
Expand Down
18 changes: 18 additions & 0 deletions keras/src/backend/tensorflow/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,24 @@ def log_softmax(x, axis=-1):
return tf.nn.log_softmax(x, axis=axis)


def sparsemax(logits, axis=-1):
# Sort logits along the specified axis in descending order
logits = convert_to_tensor(logits)
logits_sorted = tf.sort(logits, direction="DESCENDING", axis=axis)
logits_cumsum = tf.cumsum(logits_sorted, axis=axis)
r = tf.range(1, tf.shape(logits)[axis] + 1, dtype=logits.dtype)
r_shape = [1] * len(logits.shape)
r_shape[axis] = -1 # Broadcast to match the target axis
r = tf.reshape(r, r_shape) # Reshape for broadcasting
support = logits_sorted - (logits_cumsum - 1) / r > 0
# Find the threshold
logits_cumsum_safe = tf.where(support, logits_cumsum, 0.0)
k = tf.reduce_sum(tf.cast(support, logits.dtype), axis=axis, keepdims=True)
tau = (tf.reduce_sum(logits_cumsum_safe, axis=axis, keepdims=True) - 1) / k
output = tf.maximum(logits - tau, 0.0)
return output


def _transpose_spatial_inputs(inputs):
num_spatial_dims = len(inputs.shape) - 2
# Tensorflow pooling does not support `channels_first` format, so
Expand Down
22 changes: 22 additions & 0 deletions keras/src/backend/torch/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,28 @@ def log_softmax(x, axis=-1):
return cast(output, dtype)


def sparsemax(logits, axis=-1):
# Sort logits along the specified axis in descending order
logits = convert_to_tensor(logits)
logits_sorted, _ = torch.sort(logits, dim=axis, descending=True)
logits_cumsum = torch.cumsum(logits_sorted, dim=axis)
r = torch.arange(
1, logits.size(axis) + 1, device=logits.device, dtype=logits.dtype
)
r_shape = [1] * logits.ndim
r_shape[axis] = -1 # Broadcast to match the target axis
r = r.view(r_shape)
support = logits_sorted - (logits_cumsum - 1) / r > 0
# Find the threshold
k = torch.sum(support, dim=axis, keepdim=True)
logits_cumsum_safe = torch.where(
support, logits_cumsum, torch.tensor(0.0, device=logits.device)
)
tau = (torch.sum(logits_cumsum_safe, dim=axis, keepdim=True) - 1) / k
output = torch.clamp(logits - tau, min=0.0)
return output


def _compute_padding_length(
input_length, kernel_length, stride, dilation_rate=1
):
Expand Down
42 changes: 42 additions & 0 deletions keras/src/ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,48 @@ def log_softmax(x, axis=-1):
return backend.nn.log_softmax(x, axis=axis)


class Sparsemax(Operation):
def __init__(self, axis=-1):
super().__init__()
self.axis = axis

def call(self, x):
return backend.nn.sparsemax(x, axis=self.axis)

def compute_output_spec(self, x):
return KerasTensor(x.shape, dtype=x.dtype)


@keras_export(["keras.ops.sparsemax", "keras.ops.nn.sparsemax"])
def sparsemax(x, axis=-1):
"""Sparsemax activation function.
For each batch `i`, and class `j`,
sparsemax activation function is defined as:
`sparsemax(x)[i, j] = max(x[i, j] - τ(x[i, :]), 0).`
Args:
x: Input tensor.
axis: `int`, axis along which the sparsemax operation is applied.
Returns:
A tensor, output of sparsemax transformation. Has the same type and
shape as `x`.
Example:
>>> x = np.array([-1., 0., 1.])
>>> x_sparsemax = keras.ops.sparsemax(x)
>>> print(x_sparsemax)
array([0., 0., 1.], shape=(3,), dtype=float64)
"""
if any_symbolic_tensors((x,)):
return Sparsemax(axis).symbolic_call(x)
return backend.nn.sparsemax(x, axis=axis)


class MaxPool(Operation):
def __init__(
self,
Expand Down
15 changes: 15 additions & 0 deletions keras/src/ops/nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,6 +200,10 @@ def test_log_softmax(self):
self.assertEqual(knn.log_softmax(x, axis=1).shape, (None, 2, 3))
self.assertEqual(knn.log_softmax(x, axis=-1).shape, (None, 2, 3))

def test_sparsemax(self):
x = KerasTensor([None, 2, 3])
self.assertEqual(knn.sparsemax(x).shape, (None, 2, 3))

def test_max_pool(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
Expand Down Expand Up @@ -861,6 +865,10 @@ def test_log_softmax(self):
self.assertEqual(knn.log_softmax(x, axis=1).shape, (1, 2, 3))
self.assertEqual(knn.log_softmax(x, axis=-1).shape, (1, 2, 3))

def test_sparsemax(self):
x = KerasTensor([1, 2, 3])
self.assertEqual(knn.sparsemax(x).shape, (1, 2, 3))

def test_max_pool(self):
data_format = backend.config.image_data_format()
if data_format == "channels_last":
Expand Down Expand Up @@ -1487,6 +1495,13 @@ def test_log_softmax_correctness_with_axis_tuple(self):
)
self.assertAllClose(normalized_sum_by_axis, 1.0)

def test_sparsemax(self):
x = np.array([-0.5, 0, 1, 2, 3], dtype=np.float32)
self.assertAllClose(
knn.sparsemax(x),
[0.0, 0.0, 0.0, 0.0, 1.0],
)

def test_max_pool(self):
data_format = backend.config.image_data_format()
# Test 1D max pooling.
Expand Down

0 comments on commit a3a368d

Please sign in to comment.