Skip to content

Commit

Permalink
Rewrite triu and tril in tensorflow backend (#19046)
Browse files Browse the repository at this point in the history
* fix: rewrite triu and tril in tensorflow backend

* Use a new way to handle them.
* Add a new unit test.

* chore: remove docs

* refactor: use a more rapid implementation
  • Loading branch information
dugujiujian1999 authored Jan 10, 2024
1 parent e5dad39 commit dfec3ef
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 12 deletions.
32 changes: 20 additions & 12 deletions keras/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,26 +1568,34 @@ def tri(N, M=None, k=0, dtype=None):

def tril(x, k=0):
x = convert_to_tensor(x)

if k >= 0:
return tf.linalg.band_part(x, -1, k)

# deal with negative k using mask
k = -k - 1
mask = tf.ones_like(x, dtype="bool")
mask = tf.logical_not(tf.linalg.band_part(mask, k, -1))
return tf.where(mask, x, tf.constant(0, x.dtype))
shape = tf.shape(x)
rows, cols = shape[-2], shape[-1]

i, j = tf.meshgrid(tf.range(rows), tf.range(cols), indexing="ij")

mask = i >= j - k

return tf.where(tf.broadcast_to(mask, shape), x, tf.zeros_like(x))


def triu(x, k=0):
x = convert_to_tensor(x)
if k >= 0:
return tf.linalg.band_part(x, k, -1)

# deal with negative k using mask
k = -k
mask = tf.ones_like(x, dtype="bool")
mask = tf.logical_not(tf.linalg.band_part(mask, k, -1))
return tf.where(mask, tf.constant(0, x.dtype), x)
if k <= 0:
return tf.linalg.band_part(x, -k, -1)

shape = tf.shape(x)
rows, cols = shape[-2], shape[-1]

i, j = tf.meshgrid(tf.range(rows), tf.range(cols), indexing="ij")

mask = i <= j - k

return tf.where(tf.broadcast_to(mask, shape), x, tf.zeros_like(x))


def vdot(x1, x2):
Expand Down
10 changes: 10 additions & 0 deletions keras/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3881,6 +3881,11 @@ def test_tril(self):
self.assertAllClose(knp.tril(x, -1), np.tril(x, -1))
self.assertAllClose(knp.Tril(-1)(x), np.tril(x, -1))

x = np.ones([5, 5])
self.assertAllClose(knp.tril(x), np.tril(x))
self.assertAllClose(knp.tril(x, -1), np.tril(x, -1))
self.assertAllClose(knp.Tril(-1)(x), np.tril(x, -1))

def test_tril_in_layer(self):
# https://github.com/keras-team/keras/issues/18890
x = keras.Input((None, 3))
Expand Down Expand Up @@ -3908,6 +3913,11 @@ def test_triu(self):
self.assertAllClose(knp.triu(x, -1), np.triu(x, -1))
self.assertAllClose(knp.Triu(-1)(x), np.triu(x, -1))

x = np.ones([5, 5])
self.assertAllClose(knp.triu(x), np.triu(x))
self.assertAllClose(knp.triu(x, -1), np.triu(x, -1))
self.assertAllClose(knp.Triu(-1)(x), np.triu(x, -1))

def test_triu_in_layer(self):
# https://github.com/keras-team/keras/issues/18890
x = keras.Input((None, 3))
Expand Down

0 comments on commit dfec3ef

Please sign in to comment.