Skip to content

Commit

Permalink
Add preliminary support of OpenVINO as Keras 3 backend (#19727)
Browse files Browse the repository at this point in the history
* [POC][OV] Support OpenVINO as Keras 3 backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Mark all unsupported ops from numpy space

Signed-off-by: Kazantsev, Roman <[email protected]>

* Mark unsupported ops in core, image, and linalg spaces

Signed-off-by: Kazantsev, Roman <[email protected]>

* Mark unsupported ops in math, nn, random, and rnn spaces

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix sorting imports

Signed-off-by: Kazantsev, Roman <[email protected]>

* Format imports

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix sorting imports

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix sorting imports

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix inference

Signed-off-by: Kazantsev, Roman <[email protected]>

* Remove openvino specific code in common part

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix typo

* Clean-up code

Signed-off-by: Kazantsev, Roman <[email protected]>

* Recover imports

Signed-off-by: Kazantsev, Roman <[email protected]>

* Sort imports properly

Signed-off-by: Kazantsev, Roman <[email protected]>

* Format source code

Signed-off-by: Kazantsev, Roman <[email protected]>

* Format the rest of source code

Signed-off-by: Kazantsev, Roman <[email protected]>

* Continue format adjustment

Signed-off-by: Kazantsev, Roman <[email protected]>

* Add OpenVINO dependency

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix inference using OV backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Support bert_base_en_uncased and mobilenet_v3_small from Keras Hub

Signed-off-by: Kazantsev, Roman <[email protected]>

* Remove extra openvino specific code from layer.py

Signed-off-by: Kazantsev, Roman <[email protected]>

* Apply code-style formatting

Signed-off-by: Kazantsev, Roman <[email protected]>

* Apply code-style formatting

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix remained code-style issue

Signed-off-by: Kazantsev, Roman <[email protected]>

* Run tests for OpenVINO backend in GHA

Signed-off-by: Kazantsev, Roman <[email protected]>

* Add config file for openvino backend validation

Signed-off-by: Kazantsev, Roman <[email protected]>

* Add import test for openvino backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix error in import_test.py

Signed-off-by: Kazantsev, Roman <[email protected]>

* Add import_test for openvino backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Add openvino specific integration tests in GHA

Signed-off-by: Kazantsev, Roman <[email protected]>

* Exclude coverage for OpenVINO

Signed-off-by: Kazantsev, Roman <[email protected]>

* remove coverage for openvino backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Try layer tests for openvino backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Run layer tests for openvino backend selectively

Signed-off-by: Kazantsev, Roman <[email protected]>

* Mark enabled tests for openvino backend in a different way

Signed-off-by: Kazantsev, Roman <[email protected]>

* Update .github/workflows/actions.yml

* Fix import for BackendVariable

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix errors in layer tests for openvino backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Add test for Elu via openvino backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix sorted imports

Signed-off-by: Kazantsev, Roman <[email protected]>

* Extend testing for attention

Signed-off-by: Kazantsev, Roman <[email protected]>

* Update keras/src/layers/attention/attention_test.py

* Switch on activation tests for openvino backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Switch on attention tests for openvino backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Update keras/src/layers/attention/additive_attention_test.py

* Update keras/src/layers/attention/grouped_query_attention_test.py

* Run conv tests for openvino backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix convolution in openvino backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Work around constant creation for tuple

Signed-off-by: Kazantsev, Roman <[email protected]>

* Work around constant creation in reshape

Signed-off-by: Kazantsev, Roman <[email protected]>

* Run depthwise conv tests for openvino backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix get_ov_output for other x types

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix elu translation

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix softmax and log_softmax for None axis

Signed-off-by: Kazantsev, Roman <[email protected]>

* Run nn tests for openvino backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix numpy operations for axis to be None

Signed-off-by: Kazantsev, Roman <[email protected]>

* Run operation_test for openvino_backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Switch on math_test for openvino backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Switch on image tests for openvino backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Switch on linalg test for openvino backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Extend OpenVINOKerasTensor with new built-in methods and fix shape op

Signed-off-by: Kazantsev, Roman <[email protected]>

* Switch on core tests for openvino backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Use different way of OpenVINO model creation that supports call method

Signed-off-by: Kazantsev, Roman <[email protected]>

* Unify integration test for openvino

Signed-off-by: Kazantsev, Roman <[email protected]>

* Support new operations abs, mod, etc.

Signed-off-by: Kazantsev, Roman <[email protected]>

* Add support for more operations like squeeze, max

Signed-off-by: Kazantsev, Roman <[email protected]>

* Try to use excluded test files list

Signed-off-by: Kazantsev, Roman <[email protected]>

* Apply formatting for normalization_test.py

Signed-off-by: Kazantsev, Roman <[email protected]>

* Correct GHA yml file

Signed-off-by: Kazantsev, Roman <[email protected]>

* Test that openvino backend is used

Signed-off-by: Kazantsev, Roman <[email protected]>

* Revert testing change in excluded test files list

Signed-off-by: Kazantsev, Roman <[email protected]>

* Include testing group

Signed-off-by: Kazantsev, Roman <[email protected]>

* Include legacy test group

Signed-off-by: Kazantsev, Roman <[email protected]>

* Exclude legacy group of tests

Signed-off-by: Kazantsev, Roman <[email protected]>

* Include initializers tests

Signed-off-by: Kazantsev, Roman <[email protected]>

* Skip tests for initializers group

Signed-off-by: Kazantsev, Roman <[email protected]>

* Remove export test group from ignore

Signed-off-by: Kazantsev, Roman <[email protected]>

* Include dtype_policies test group

Signed-off-by: Kazantsev, Roman <[email protected]>

* Reduce ignored tests

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix ops.cast

Signed-off-by: Kazantsev, Roman <[email protected]>

* Add decorator for custom_gradient

Signed-off-by: Kazantsev, Roman <[email protected]>

* Shorten line in custom_gradient

Signed-off-by: Kazantsev, Roman <[email protected]>

* Ignore dtype_policy_map test

Signed-off-by: Kazantsev, Roman <[email protected]>

* Include callback tests

Signed-off-by: Kazantsev, Roman <[email protected]>

* Switch on backend tests

Signed-off-by: Kazantsev, Roman <[email protected]>

* Exclude failing tests

Signed-off-by: Kazantsev, Roman <[email protected]>

* Correct paths to excluded tests

Signed-off-by: Kazantsev, Roman <[email protected]>

* Switch on some layers tests

Signed-off-by: Kazantsev, Roman <[email protected]>

* Remove pytest.mark.openvino_backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Register mark requires_trainable_backend

Signed-off-by: Kazantsev, Roman <[email protected]>

* Ignore test files in a different way

Signed-off-by: Kazantsev, Roman <[email protected]>

* Try different way to ignore test files

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix GHA yml

Signed-off-by: Kazantsev, Roman <[email protected]>

* Support tuple axis for logsumexp

Signed-off-by: Kazantsev, Roman <[email protected]>

* Switch on some ops tests

Signed-off-by: Kazantsev, Roman <[email protected]>

* Switch on some callbacks tests

Signed-off-by: Kazantsev, Roman <[email protected]>

* Add openvino export

Signed-off-by: Kazantsev, Roman <[email protected]>

* Update sklearn tests

Signed-off-by: Kazantsev, Roman <[email protected]>

* Add a comment to skipp numerical_test

Signed-off-by: Kazantsev, Roman <[email protected]>

* Add custom requirements file for OpenVINO

Signed-off-by: Kazantsev, Roman <[email protected]>

* Add reqs of openvino installation for api changes check

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix types of Variables and switch on some variables tests

Signed-off-by: Kazantsev, Roman <[email protected]>

* Fix nightly code check

Signed-off-by: Kazantsev, Roman <[email protected]>

---------

Signed-off-by: Kazantsev, Roman <[email protected]>
  • Loading branch information
rkazants authored Dec 18, 2024
1 parent a0f8922 commit bce0f5b
Show file tree
Hide file tree
Showing 34 changed files with 2,974 additions and 11 deletions.
18 changes: 15 additions & 3 deletions .github/workflows/actions.yml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ jobs:
fail-fast: false
matrix:
python-version: [3.9]
backend: [tensorflow, jax, torch, numpy]
backend: [tensorflow, jax, torch, numpy, openvino]
name: Run tests
runs-on: ubuntu-latest
env:
Expand Down Expand Up @@ -47,7 +47,12 @@ jobs:
key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }}
- name: Install dependencies
run: |
pip install -r requirements.txt --progress-bar off --upgrade
if [ "${{ matrix.backend }}" == "openvino" ]; then
REQUIREMENTS_FILE="requirements-openvino.txt"
else
REQUIREMENTS_FILE="requirements.txt"
fi
pip install -r $REQUIREMENTS_FILE --progress-bar off --upgrade
pip uninstall -y keras keras-nightly
pip install tf_keras==2.16.0 --progress-bar off --upgrade
pip install -e "." --progress-bar off --upgrade
Expand Down Expand Up @@ -86,7 +91,13 @@ jobs:
python integration_tests/torch_custom_fit_test.py
- name: Test with pytest
run: |
pytest keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml
if [ "${{ matrix.backend }}" == "openvino" ]; then
IGNORE_FILE="keras/src/backend/openvino/excluded_tests.txt"
IGNORE_ARGS=$(awk '{print "--ignore=" $0}' "$IGNORE_FILE")
else
IGNORE_ARGS=""
fi
pytest keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml $IGNORE_ARGS
coverage xml --omit='keras/src/applications/*,keras/api' -o core-coverage.xml
- name: Codecov keras
uses: codecov/codecov-action@v5
Expand Down Expand Up @@ -119,6 +130,7 @@ jobs:
- name: Install dependencies
run: |
pip install -r requirements.txt --progress-bar off --upgrade
pip install -r requirements-openvino.txt --progress-bar off --upgrade
pip uninstall -y keras keras-nightly
pip install -e "." --progress-bar off --upgrade
- name: Lint
Expand Down
6 changes: 6 additions & 0 deletions .github/workflows/config/openvino/keras.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
{
"floatx": "float32",
"epsilon": 1e-07,
"backend": "openvino",
"image_data_format": "channels_last"
}
1 change: 1 addition & 0 deletions .github/workflows/nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ jobs:
- name: Install dependencies
run: |
pip install -r requirements.txt --progress-bar off --upgrade
pip install -r requirements-openvino.txt --progress-bar off --upgrade
pip uninstall -y keras keras-nightly
pip install -e "." --progress-bar off --upgrade
- name: Lint
Expand Down
8 changes: 6 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,13 @@ def pytest_configure(config):

def pytest_collection_modifyitems(config, items):
requires_trainable_backend = pytest.mark.skipif(
backend() == "numpy",
reason="Trainer not implemented for NumPy backend.",
backend() == "numpy" or backend() == "openvino",
reason="Trainer not implemented for NumPy and OpenVINO backend.",
)
for item in items:
if "requires_trainable_backend" in item.keywords:
item.add_marker(requires_trainable_backend)


def skip_if_backend(given_backend, reason):
return pytest.mark.skipif(backend() == given_backend, reason=reason)
8 changes: 7 additions & 1 deletion integration_tests/basic_full_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ def call(self, x):
return self.dense3(x)


@pytest.mark.requires_trainable_backend
class BasicFlowTest(testing.TestCase):
@pytest.mark.requires_trainable_backend
def test_basic_fit(self):
model = MyModel(hidden_dim=2, output_dim=1)

Expand All @@ -46,3 +46,9 @@ def test_basic_fit(self):
output_after_fit = model(x)

self.assertNotAllClose(output_before_fit, output_after_fit)

def test_basic_fit_no_training(self):
model = MyModel(hidden_dim=2, output_dim=1)
x = np.random.random((128, 4))
model.predict(x)
model(x)
5 changes: 4 additions & 1 deletion integration_tests/import_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
"--extra-index-url https://download.pytorch.org/whl/cpu ",
),
"jax": ("jax[cpu]", ""),
"openvino": ("openvino", ""),
}


Expand Down Expand Up @@ -57,7 +58,9 @@ def manage_venv_installs(whl_path):
"pip uninstall -y "
+ BACKEND_REQ[other_backends[0]][0]
+ " "
+ BACKEND_REQ[other_backends[1]][0],
+ BACKEND_REQ[other_backends[1]][0]
+ " "
+ BACKEND_REQ[other_backends[2]][0],
# Install `.whl` package
"pip install " + whl_path,
]
Expand Down
5 changes: 5 additions & 0 deletions integration_tests/numerical_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import keras # isort: skip, keep it on top for torch test

import sys

import numpy as np
import tf_keras

Expand Down Expand Up @@ -137,6 +139,9 @@ def numerical_test():


if __name__ == "__main__":
if keras.backend.backend() == "openvino":
# this test requires trainable backend
sys.exit(0)
keras.utils.set_random_seed(1337)
tf_keras.utils.set_random_seed(1337)
numerical_test()
5 changes: 5 additions & 0 deletions keras/src/backend/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@
from keras.src.backend.numpy import * # noqa: F403
from keras.src.backend.numpy.core import Variable as BackendVariable

distribution_lib = None
elif backend() == "openvino":
from keras.src.backend.openvino import * # noqa: F403
from keras.src.backend.openvino.core import Variable as BackendVariable

distribution_lib = None
else:
raise ValueError(f"Unable to import backend : {backend()}")
Expand Down
6 changes: 6 additions & 0 deletions keras/src/backend/common/dtypes_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,12 @@ class DtypesTest(test_case.TestCase):
if x not in ALL_DTYPES: # skip duplicates created by remapping
ALL_DTYPES.append(x)
ALL_DTYPES += [None]
elif backend.backend() == "openvino":
ALL_DTYPES = [
x
for x in dtypes.ALLOWED_DTYPES
if x not in ["string", "complex64", "complex128"]
] + [None]
else:
ALL_DTYPES = [x for x in dtypes.ALLOWED_DTYPES if x != "string"] + [
None
Expand Down
25 changes: 25 additions & 0 deletions keras/src/backend/common/variables_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import pytest
from absl.testing import parameterized

from conftest import skip_if_backend
from keras.src import backend
from keras.src import initializers
from keras.src import ops
Expand Down Expand Up @@ -143,6 +144,9 @@ def test_variable_without_shape_from_callable_initializer(self):
class VariablePropertiesTest(test_case.TestCase):
"""Tests for Variable._deferred_initialize Variable._maybe_autocast"""

@skip_if_backend(
"openvino", "Can not constant fold eltwise node by CPU plugin"
)
def test_deferred_assignment(self):
"""Tests deferred assignment to variables."""
with backend.StatelessScope() as scope:
Expand Down Expand Up @@ -246,6 +250,12 @@ def test_standardize_dtype(self, dtype):
f"jax backend does not support {dtype} without x64 enabled"
)

if backend.backend() == "openvino" and dtype in (
"complex64",
"complex128",
):
self.skipTest(f"openvino backend does not support dtype {dtype}")

x = backend.convert_to_tensor(np.zeros(()), dtype)
actual = standardize_dtype(x.dtype)
self.assertEqual(actual, dtype)
Expand Down Expand Up @@ -603,12 +613,18 @@ def test__rtruediv__(self):
v2 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))
self.assertAllClose(v1.__rtruediv__(v2), np.array([0.25, 0.4, 0.5]))

@skip_if_backend(
"openvino", "`floor_divide` is not supported with openvino backend"
)
def test__floordiv__(self):
"""Test floordiv operation on a variable."""
v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))
v2 = backend.Variable(initializer=np.array([-4.0, 5.0, 6.0]))
self.assertAllClose(v1.__floordiv__(v2), np.array([-1.0, 0.0, 0.0]))

@skip_if_backend(
"openvino", "`floor_divide` is not supported with openvino backend"
)
def test__rfloordiv__(self):
"""Test reverse floordiv operation on a variable."""
v1 = backend.Variable(initializer=np.array([-4.0, 5.0, 6.0]))
Expand Down Expand Up @@ -734,6 +750,9 @@ def test_variable_rpow(self):
result = v2**v1
self.assertAllClose(result, np.array([4.0, 25.0, 216.0]))

@skip_if_backend(
"openvino", "`round` is not supported with openvino backend"
)
def test_round(self):
v = backend.Variable(initializer=np.array([1.1, 2.2, 3.3]))
self.assertAllClose(round(v), np.array([1.0, 2.0, 3.0]))
Expand Down Expand Up @@ -783,6 +802,9 @@ def test_invalid_float(self):
INT_DTYPES = [
x for x in INT_DTYPES if x not in ["uint16", "uint32", "uint64"]
]
elif backend.backend() == "openvino":
# TODO: openvino doesn't support complex
ALL_DTYPES = [x for x in ALL_DTYPES if x not in ["complex128", "complex64"]]
# Remove float8 dtypes for the following tests
ALL_DTYPES = [x for x in ALL_DTYPES if x not in dtypes.FLOAT8_TYPES]
NON_COMPLEX_DTYPES = [x for x in ALL_DTYPES if x and x not in COMPLEX_DTYPES]
Expand Down Expand Up @@ -976,6 +998,9 @@ def test_truediv(self, dtypes):
@parameterized.named_parameters(
named_product(dtypes=itertools.combinations(NON_COMPLEX_DTYPES, 2))
)
@skip_if_backend(
"openvino", "`floor_divide` is not supported with openvino backend"
)
def test_floordiv(self, dtypes):
import jax.numpy as jnp

Expand Down
24 changes: 24 additions & 0 deletions keras/src/backend/openvino/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from keras.src.backend.common.name_scope import name_scope
from keras.src.backend.openvino import core
from keras.src.backend.openvino import image
from keras.src.backend.openvino import linalg
from keras.src.backend.openvino import math
from keras.src.backend.openvino import nn
from keras.src.backend.openvino import numpy
from keras.src.backend.openvino import random
from keras.src.backend.openvino.core import IS_THREAD_SAFE
from keras.src.backend.openvino.core import SUPPORTS_SPARSE_TENSORS
from keras.src.backend.openvino.core import Variable
from keras.src.backend.openvino.core import cast
from keras.src.backend.openvino.core import compute_output_spec
from keras.src.backend.openvino.core import cond
from keras.src.backend.openvino.core import convert_to_numpy
from keras.src.backend.openvino.core import convert_to_tensor
from keras.src.backend.openvino.core import is_tensor
from keras.src.backend.openvino.core import random_seed_dtype
from keras.src.backend.openvino.core import shape
from keras.src.backend.openvino.core import vectorized_map
from keras.src.backend.openvino.rnn import cudnn_ok
from keras.src.backend.openvino.rnn import gru
from keras.src.backend.openvino.rnn import lstm
from keras.src.backend.openvino.rnn import rnn
Loading

0 comments on commit bce0f5b

Please sign in to comment.