Skip to content

Commit

Permalink
Add inner op (#20532)
Browse files Browse the repository at this point in the history
* add inner op

* Fix tensorflow implementation

* fix

* api

* fix lint

* format
  • Loading branch information
IMvision12 authored Nov 24, 2024
1 parent a503a16 commit 0078f24
Show file tree
Hide file tree
Showing 10 changed files with 126 additions and 0 deletions.
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@
from keras.src.ops.numpy import hstack
from keras.src.ops.numpy import identity
from keras.src.ops.numpy import imag
from keras.src.ops.numpy import inner
from keras.src.ops.numpy import isclose
from keras.src.ops.numpy import isfinite
from keras.src.ops.numpy import isinf
Expand Down
1 change: 1 addition & 0 deletions keras/api/_tf_keras/keras/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from keras.src.ops.numpy import hstack
from keras.src.ops.numpy import identity
from keras.src.ops.numpy import imag
from keras.src.ops.numpy import inner
from keras.src.ops.numpy import isclose
from keras.src.ops.numpy import isfinite
from keras.src.ops.numpy import isinf
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,7 @@
from keras.src.ops.numpy import hstack
from keras.src.ops.numpy import identity
from keras.src.ops.numpy import imag
from keras.src.ops.numpy import inner
from keras.src.ops.numpy import isclose
from keras.src.ops.numpy import isfinite
from keras.src.ops.numpy import isinf
Expand Down
1 change: 1 addition & 0 deletions keras/api/ops/numpy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
from keras.src.ops.numpy import hstack
from keras.src.ops.numpy import identity
from keras.src.ops.numpy import imag
from keras.src.ops.numpy import inner
from keras.src.ops.numpy import isclose
from keras.src.ops.numpy import isfinite
from keras.src.ops.numpy import isinf
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/jax/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1102,6 +1102,12 @@ def vdot(x1, x2):
return jnp.vdot(x1, x2)


def inner(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
return jnp.inner(x1, x2)


def vstack(xs):
return jnp.vstack(xs)

Expand Down
9 changes: 9 additions & 0 deletions keras/src/backend/numpy/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,15 @@ def vdot(x1, x2):
return np.vdot(x1, x2)


def inner(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
dtype = dtypes.result_type(x1.dtype, x2.dtype)
x1 = x1.astype(dtype)
x2 = x2.astype(dtype)
return np.inner(x1, x2)


def vstack(xs):
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
if len(dtype_set) > 1:
Expand Down
18 changes: 18 additions & 0 deletions keras/src/backend/tensorflow/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -2285,6 +2285,24 @@ def vdot(x1, x2):
return tf.cast(dot(x1, x2), result_dtype)


def inner(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
result_dtype = dtypes.result_type(x1.dtype, x2.dtype)
compute_dtype = dtypes.result_type(result_dtype, float)
x1 = tf.cast(x1, compute_dtype)
x2 = tf.cast(x2, compute_dtype)
x = tf.cond(
tf.math.logical_or(
tf.math.equal(tf.rank(x1), 0),
tf.math.equal(tf.rank(x2), 0),
),
lambda: x1 * x2,
lambda: tf.tensordot(x1, x2, axes=[[-1], [-1]]),
)
return tf.cast(x, result_dtype)


def vstack(xs):
dtype_set = set([getattr(x, "dtype", type(x)) for x in xs])
if len(dtype_set) > 1:
Expand Down
14 changes: 14 additions & 0 deletions keras/src/backend/torch/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1487,6 +1487,20 @@ def vdot(x1, x2):
return cast(torch.vdot(x1, x2), result_dtype)


def inner(x1, x2):
x1 = convert_to_tensor(x1)
x2 = convert_to_tensor(x2)
result_dtype = dtypes.result_type(x1.dtype, x2.dtype)
compute_dtype = dtypes.result_type(result_dtype, float)

if get_device() == "cpu" and compute_dtype == "float16":
compute_dtype = "float32"

x1 = cast(x1, compute_dtype)
x2 = cast(x2, compute_dtype)
return cast(torch.inner(x1, x2), result_dtype)


def vstack(xs):
xs = [convert_to_tensor(x) for x in xs]
return torch.vstack(xs)
Expand Down
39 changes: 39 additions & 0 deletions keras/src/ops/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5678,6 +5678,45 @@ def vdot(x1, x2):
return backend.numpy.vdot(x1, x2)


class Inner(Operation):
def call(self, x1, x2):
return backend.numpy.inner(x1, x2)

def compute_output_spec(self, x1, x2):
dtype = dtypes.result_type(
getattr(x1, "dtype", type(x1)),
getattr(x2, "dtype", type(x2)),
)
return KerasTensor([], dtype=dtype)


@keras_export(["keras.ops.inner", "keras.ops.numpy.inner"])
def inner(x1, x2):
"""Return the inner product of two tensors.
Ordinary inner product of vectors for 1-D tensors
(without complex conjugation), in higher dimensions
a sum product over the last axes.
Multidimensional arrays are treated as vectors by flattening
all but their last axes. The resulting dot product is performed
over their last axes.
Args:
x1: First input tensor.
x2: Second input tensor. The last dimension of `x1` and `x2`
must match.
Returns:
Output tensor. The shape of the output is determined by
broadcasting the shapes of `x1` and `x2` after removing
their last axes.
"""
if any_symbolic_tensors((x1, x2)):
return Inner().symbolic_call(x1, x2)
return backend.numpy.inner(x1, x2)


@keras_export(["keras.ops.vectorize", "keras.ops.numpy.vectorize"])
def vectorize(pyfunc, *, excluded=None, signature=None):
"""Turn a function into a vectorized function.
Expand Down
36 changes: 36 additions & 0 deletions keras/src/ops/numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,6 +299,11 @@ def test_vdot(self):
y = KerasTensor((None, 3, 3))
self.assertEqual(knp.vdot(x, y).shape, ())

def test_inner(self):
x = KerasTensor((None,))
y = KerasTensor((3,))
self.assertEqual(knp.inner(x, y).shape, ())

def test_where(self):
condition = KerasTensor((2, None, 1))
x = KerasTensor((None, 1))
Expand Down Expand Up @@ -875,6 +880,11 @@ def test_vdot(self):
y = KerasTensor((2, 3))
self.assertEqual(knp.vdot(x, y).shape, ())

def test_inner(self):
x = KerasTensor((2, 3))
y = KerasTensor((2, 3))
self.assertEqual(knp.inner(x, y).shape, ())

def test_where(self):
condition = KerasTensor((2, 3))
x = KerasTensor((2, 3))
Expand Down Expand Up @@ -2975,6 +2985,12 @@ def test_vdot(self):
self.assertAllClose(knp.vdot(x, y), np.vdot(x, y))
self.assertAllClose(knp.Vdot()(x, y), np.vdot(x, y))

def test_inner(self):
x = np.array([1.0, 2.0, 3.0])
y = np.array([4.0, 5.0, 6.0])
self.assertAllClose(knp.inner(x, y), np.inner(x, y))
self.assertAllClose(knp.Inner()(x, y), np.inner(x, y))

def test_where(self):
x = np.array([1, 2, 3])
y = np.array([4, 5, 6])
Expand Down Expand Up @@ -8249,6 +8265,26 @@ def test_vdot(self, dtypes):
)
self.assertEqual(knp.Vdot().symbolic_call(x1, x2).dtype, expected_dtype)

@parameterized.named_parameters(
named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))
)
def test_inner(self, dtypes):
import jax.numpy as jnp

dtype1, dtype2 = dtypes
x1 = knp.ones((1,), dtype=dtype1)
x2 = knp.ones((1,), dtype=dtype2)
x1_jax = jnp.ones((1,), dtype=dtype1)
x2_jax = jnp.ones((1,), dtype=dtype2)
expected_dtype = standardize_dtype(jnp.inner(x1_jax, x2_jax).dtype)

self.assertEqual(
standardize_dtype(knp.inner(x1, x2).dtype), expected_dtype
)
self.assertEqual(
knp.Inner().symbolic_call(x1, x2).dtype, expected_dtype
)

@parameterized.named_parameters(
named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))
)
Expand Down

0 comments on commit 0078f24

Please sign in to comment.