Skip to content

Commit bce0f5b

Browse files
authored
Add preliminary support of OpenVINO as Keras 3 backend (#19727)
* [POC][OV] Support OpenVINO as Keras 3 backend Signed-off-by: Kazantsev, Roman <[email protected]> * Mark all unsupported ops from numpy space Signed-off-by: Kazantsev, Roman <[email protected]> * Mark unsupported ops in core, image, and linalg spaces Signed-off-by: Kazantsev, Roman <[email protected]> * Mark unsupported ops in math, nn, random, and rnn spaces Signed-off-by: Kazantsev, Roman <[email protected]> * Fix sorting imports Signed-off-by: Kazantsev, Roman <[email protected]> * Format imports Signed-off-by: Kazantsev, Roman <[email protected]> * Fix sorting imports Signed-off-by: Kazantsev, Roman <[email protected]> * Fix sorting imports Signed-off-by: Kazantsev, Roman <[email protected]> * Fix inference Signed-off-by: Kazantsev, Roman <[email protected]> * Remove openvino specific code in common part Signed-off-by: Kazantsev, Roman <[email protected]> * Fix typo * Clean-up code Signed-off-by: Kazantsev, Roman <[email protected]> * Recover imports Signed-off-by: Kazantsev, Roman <[email protected]> * Sort imports properly Signed-off-by: Kazantsev, Roman <[email protected]> * Format source code Signed-off-by: Kazantsev, Roman <[email protected]> * Format the rest of source code Signed-off-by: Kazantsev, Roman <[email protected]> * Continue format adjustment Signed-off-by: Kazantsev, Roman <[email protected]> * Add OpenVINO dependency Signed-off-by: Kazantsev, Roman <[email protected]> * Fix inference using OV backend Signed-off-by: Kazantsev, Roman <[email protected]> * Support bert_base_en_uncased and mobilenet_v3_small from Keras Hub Signed-off-by: Kazantsev, Roman <[email protected]> * Remove extra openvino specific code from layer.py Signed-off-by: Kazantsev, Roman <[email protected]> * Apply code-style formatting Signed-off-by: Kazantsev, Roman <[email protected]> * Apply code-style formatting Signed-off-by: Kazantsev, Roman <[email protected]> * Fix remained code-style issue Signed-off-by: Kazantsev, Roman <[email protected]> * Run tests for OpenVINO backend in GHA Signed-off-by: Kazantsev, Roman <[email protected]> * Add config file for openvino backend validation Signed-off-by: Kazantsev, Roman <[email protected]> * Add import test for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Fix error in import_test.py Signed-off-by: Kazantsev, Roman <[email protected]> * Add import_test for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Add openvino specific integration tests in GHA Signed-off-by: Kazantsev, Roman <[email protected]> * Exclude coverage for OpenVINO Signed-off-by: Kazantsev, Roman <[email protected]> * remove coverage for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Try layer tests for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Run layer tests for openvino backend selectively Signed-off-by: Kazantsev, Roman <[email protected]> * Mark enabled tests for openvino backend in a different way Signed-off-by: Kazantsev, Roman <[email protected]> * Update .github/workflows/actions.yml * Fix import for BackendVariable Signed-off-by: Kazantsev, Roman <[email protected]> * Fix errors in layer tests for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Add test for Elu via openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Fix sorted imports Signed-off-by: Kazantsev, Roman <[email protected]> * Extend testing for attention Signed-off-by: Kazantsev, Roman <[email protected]> * Update keras/src/layers/attention/attention_test.py * Switch on activation tests for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Switch on attention tests for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Update keras/src/layers/attention/additive_attention_test.py * Update keras/src/layers/attention/grouped_query_attention_test.py * Run conv tests for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Fix convolution in openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Work around constant creation for tuple Signed-off-by: Kazantsev, Roman <[email protected]> * Work around constant creation in reshape Signed-off-by: Kazantsev, Roman <[email protected]> * Run depthwise conv tests for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Fix get_ov_output for other x types Signed-off-by: Kazantsev, Roman <[email protected]> * Fix elu translation Signed-off-by: Kazantsev, Roman <[email protected]> * Fix softmax and log_softmax for None axis Signed-off-by: Kazantsev, Roman <[email protected]> * Run nn tests for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Fix numpy operations for axis to be None Signed-off-by: Kazantsev, Roman <[email protected]> * Run operation_test for openvino_backend Signed-off-by: Kazantsev, Roman <[email protected]> * Switch on math_test for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Switch on image tests for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Switch on linalg test for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Extend OpenVINOKerasTensor with new built-in methods and fix shape op Signed-off-by: Kazantsev, Roman <[email protected]> * Switch on core tests for openvino backend Signed-off-by: Kazantsev, Roman <[email protected]> * Use different way of OpenVINO model creation that supports call method Signed-off-by: Kazantsev, Roman <[email protected]> * Unify integration test for openvino Signed-off-by: Kazantsev, Roman <[email protected]> * Support new operations abs, mod, etc. Signed-off-by: Kazantsev, Roman <[email protected]> * Add support for more operations like squeeze, max Signed-off-by: Kazantsev, Roman <[email protected]> * Try to use excluded test files list Signed-off-by: Kazantsev, Roman <[email protected]> * Apply formatting for normalization_test.py Signed-off-by: Kazantsev, Roman <[email protected]> * Correct GHA yml file Signed-off-by: Kazantsev, Roman <[email protected]> * Test that openvino backend is used Signed-off-by: Kazantsev, Roman <[email protected]> * Revert testing change in excluded test files list Signed-off-by: Kazantsev, Roman <[email protected]> * Include testing group Signed-off-by: Kazantsev, Roman <[email protected]> * Include legacy test group Signed-off-by: Kazantsev, Roman <[email protected]> * Exclude legacy group of tests Signed-off-by: Kazantsev, Roman <[email protected]> * Include initializers tests Signed-off-by: Kazantsev, Roman <[email protected]> * Skip tests for initializers group Signed-off-by: Kazantsev, Roman <[email protected]> * Remove export test group from ignore Signed-off-by: Kazantsev, Roman <[email protected]> * Include dtype_policies test group Signed-off-by: Kazantsev, Roman <[email protected]> * Reduce ignored tests Signed-off-by: Kazantsev, Roman <[email protected]> * Fix ops.cast Signed-off-by: Kazantsev, Roman <[email protected]> * Add decorator for custom_gradient Signed-off-by: Kazantsev, Roman <[email protected]> * Shorten line in custom_gradient Signed-off-by: Kazantsev, Roman <[email protected]> * Ignore dtype_policy_map test Signed-off-by: Kazantsev, Roman <[email protected]> * Include callback tests Signed-off-by: Kazantsev, Roman <[email protected]> * Switch on backend tests Signed-off-by: Kazantsev, Roman <[email protected]> * Exclude failing tests Signed-off-by: Kazantsev, Roman <[email protected]> * Correct paths to excluded tests Signed-off-by: Kazantsev, Roman <[email protected]> * Switch on some layers tests Signed-off-by: Kazantsev, Roman <[email protected]> * Remove pytest.mark.openvino_backend Signed-off-by: Kazantsev, Roman <[email protected]> * Register mark requires_trainable_backend Signed-off-by: Kazantsev, Roman <[email protected]> * Ignore test files in a different way Signed-off-by: Kazantsev, Roman <[email protected]> * Try different way to ignore test files Signed-off-by: Kazantsev, Roman <[email protected]> * Fix GHA yml Signed-off-by: Kazantsev, Roman <[email protected]> * Support tuple axis for logsumexp Signed-off-by: Kazantsev, Roman <[email protected]> * Switch on some ops tests Signed-off-by: Kazantsev, Roman <[email protected]> * Switch on some callbacks tests Signed-off-by: Kazantsev, Roman <[email protected]> * Add openvino export Signed-off-by: Kazantsev, Roman <[email protected]> * Update sklearn tests Signed-off-by: Kazantsev, Roman <[email protected]> * Add a comment to skipp numerical_test Signed-off-by: Kazantsev, Roman <[email protected]> * Add custom requirements file for OpenVINO Signed-off-by: Kazantsev, Roman <[email protected]> * Add reqs of openvino installation for api changes check Signed-off-by: Kazantsev, Roman <[email protected]> * Fix types of Variables and switch on some variables tests Signed-off-by: Kazantsev, Roman <[email protected]> * Fix nightly code check Signed-off-by: Kazantsev, Roman <[email protected]> --------- Signed-off-by: Kazantsev, Roman <[email protected]>
1 parent a0f8922 commit bce0f5b

34 files changed

+2974
-11
lines changed

Diff for: .github/workflows/actions.yml

+15-3
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ jobs:
1616
fail-fast: false
1717
matrix:
1818
python-version: [3.9]
19-
backend: [tensorflow, jax, torch, numpy]
19+
backend: [tensorflow, jax, torch, numpy, openvino]
2020
name: Run tests
2121
runs-on: ubuntu-latest
2222
env:
@@ -47,7 +47,12 @@ jobs:
4747
key: ${{ runner.os }}-pip-${{ hashFiles('pyproject.toml') }}-${{ hashFiles('requirements.txt') }}
4848
- name: Install dependencies
4949
run: |
50-
pip install -r requirements.txt --progress-bar off --upgrade
50+
if [ "${{ matrix.backend }}" == "openvino" ]; then
51+
REQUIREMENTS_FILE="requirements-openvino.txt"
52+
else
53+
REQUIREMENTS_FILE="requirements.txt"
54+
fi
55+
pip install -r $REQUIREMENTS_FILE --progress-bar off --upgrade
5156
pip uninstall -y keras keras-nightly
5257
pip install tf_keras==2.16.0 --progress-bar off --upgrade
5358
pip install -e "." --progress-bar off --upgrade
@@ -86,7 +91,13 @@ jobs:
8691
python integration_tests/torch_custom_fit_test.py
8792
- name: Test with pytest
8893
run: |
89-
pytest keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml
94+
if [ "${{ matrix.backend }}" == "openvino" ]; then
95+
IGNORE_FILE="keras/src/backend/openvino/excluded_tests.txt"
96+
IGNORE_ARGS=$(awk '{print "--ignore=" $0}' "$IGNORE_FILE")
97+
else
98+
IGNORE_ARGS=""
99+
fi
100+
pytest keras --ignore keras/src/applications --cov=keras --cov-config=pyproject.toml $IGNORE_ARGS
90101
coverage xml --omit='keras/src/applications/*,keras/api' -o core-coverage.xml
91102
- name: Codecov keras
92103
uses: codecov/codecov-action@v5
@@ -119,6 +130,7 @@ jobs:
119130
- name: Install dependencies
120131
run: |
121132
pip install -r requirements.txt --progress-bar off --upgrade
133+
pip install -r requirements-openvino.txt --progress-bar off --upgrade
122134
pip uninstall -y keras keras-nightly
123135
pip install -e "." --progress-bar off --upgrade
124136
- name: Lint

Diff for: .github/workflows/config/openvino/keras.json

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
{
2+
"floatx": "float32",
3+
"epsilon": 1e-07,
4+
"backend": "openvino",
5+
"image_data_format": "channels_last"
6+
}

Diff for: .github/workflows/nightly.yml

+1
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,7 @@ jobs:
7979
- name: Install dependencies
8080
run: |
8181
pip install -r requirements.txt --progress-bar off --upgrade
82+
pip install -r requirements-openvino.txt --progress-bar off --upgrade
8283
pip uninstall -y keras keras-nightly
8384
pip install -e "." --progress-bar off --upgrade
8485
- name: Lint

Diff for: conftest.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,13 @@ def pytest_configure(config):
2626

2727
def pytest_collection_modifyitems(config, items):
2828
requires_trainable_backend = pytest.mark.skipif(
29-
backend() == "numpy",
30-
reason="Trainer not implemented for NumPy backend.",
29+
backend() == "numpy" or backend() == "openvino",
30+
reason="Trainer not implemented for NumPy and OpenVINO backend.",
3131
)
3232
for item in items:
3333
if "requires_trainable_backend" in item.keywords:
3434
item.add_marker(requires_trainable_backend)
35+
36+
37+
def skip_if_backend(given_backend, reason):
38+
return pytest.mark.skipif(backend() == given_backend, reason=reason)

Diff for: integration_tests/basic_full_flow.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,8 @@ def call(self, x):
2424
return self.dense3(x)
2525

2626

27-
@pytest.mark.requires_trainable_backend
2827
class BasicFlowTest(testing.TestCase):
28+
@pytest.mark.requires_trainable_backend
2929
def test_basic_fit(self):
3030
model = MyModel(hidden_dim=2, output_dim=1)
3131

@@ -46,3 +46,9 @@ def test_basic_fit(self):
4646
output_after_fit = model(x)
4747

4848
self.assertNotAllClose(output_before_fit, output_after_fit)
49+
50+
def test_basic_fit_no_training(self):
51+
model = MyModel(hidden_dim=2, output_dim=1)
52+
x = np.random.random((128, 4))
53+
model.predict(x)
54+
model(x)

Diff for: integration_tests/import_test.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"--extra-index-url https://download.pytorch.org/whl/cpu ",
1313
),
1414
"jax": ("jax[cpu]", ""),
15+
"openvino": ("openvino", ""),
1516
}
1617

1718

@@ -57,7 +58,9 @@ def manage_venv_installs(whl_path):
5758
"pip uninstall -y "
5859
+ BACKEND_REQ[other_backends[0]][0]
5960
+ " "
60-
+ BACKEND_REQ[other_backends[1]][0],
61+
+ BACKEND_REQ[other_backends[1]][0]
62+
+ " "
63+
+ BACKEND_REQ[other_backends[2]][0],
6164
# Install `.whl` package
6265
"pip install " + whl_path,
6366
]

Diff for: integration_tests/numerical_test.py

+5
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
import keras # isort: skip, keep it on top for torch test
22

3+
import sys
4+
35
import numpy as np
46
import tf_keras
57

@@ -137,6 +139,9 @@ def numerical_test():
137139

138140

139141
if __name__ == "__main__":
142+
if keras.backend.backend() == "openvino":
143+
# this test requires trainable backend
144+
sys.exit(0)
140145
keras.utils.set_random_seed(1337)
141146
tf_keras.utils.set_random_seed(1337)
142147
numerical_test()

Diff for: keras/src/backend/__init__.py

+5
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,11 @@
4949
from keras.src.backend.numpy import * # noqa: F403
5050
from keras.src.backend.numpy.core import Variable as BackendVariable
5151

52+
distribution_lib = None
53+
elif backend() == "openvino":
54+
from keras.src.backend.openvino import * # noqa: F403
55+
from keras.src.backend.openvino.core import Variable as BackendVariable
56+
5257
distribution_lib = None
5358
else:
5459
raise ValueError(f"Unable to import backend : {backend()}")

Diff for: keras/src/backend/common/dtypes_test.py

+6
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ class DtypesTest(test_case.TestCase):
2323
if x not in ALL_DTYPES: # skip duplicates created by remapping
2424
ALL_DTYPES.append(x)
2525
ALL_DTYPES += [None]
26+
elif backend.backend() == "openvino":
27+
ALL_DTYPES = [
28+
x
29+
for x in dtypes.ALLOWED_DTYPES
30+
if x not in ["string", "complex64", "complex128"]
31+
] + [None]
2632
else:
2733
ALL_DTYPES = [x for x in dtypes.ALLOWED_DTYPES if x != "string"] + [
2834
None

Diff for: keras/src/backend/common/variables_test.py

+25
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import pytest
55
from absl.testing import parameterized
66

7+
from conftest import skip_if_backend
78
from keras.src import backend
89
from keras.src import initializers
910
from keras.src import ops
@@ -143,6 +144,9 @@ def test_variable_without_shape_from_callable_initializer(self):
143144
class VariablePropertiesTest(test_case.TestCase):
144145
"""Tests for Variable._deferred_initialize Variable._maybe_autocast"""
145146

147+
@skip_if_backend(
148+
"openvino", "Can not constant fold eltwise node by CPU plugin"
149+
)
146150
def test_deferred_assignment(self):
147151
"""Tests deferred assignment to variables."""
148152
with backend.StatelessScope() as scope:
@@ -246,6 +250,12 @@ def test_standardize_dtype(self, dtype):
246250
f"jax backend does not support {dtype} without x64 enabled"
247251
)
248252

253+
if backend.backend() == "openvino" and dtype in (
254+
"complex64",
255+
"complex128",
256+
):
257+
self.skipTest(f"openvino backend does not support dtype {dtype}")
258+
249259
x = backend.convert_to_tensor(np.zeros(()), dtype)
250260
actual = standardize_dtype(x.dtype)
251261
self.assertEqual(actual, dtype)
@@ -603,12 +613,18 @@ def test__rtruediv__(self):
603613
v2 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))
604614
self.assertAllClose(v1.__rtruediv__(v2), np.array([0.25, 0.4, 0.5]))
605615

616+
@skip_if_backend(
617+
"openvino", "`floor_divide` is not supported with openvino backend"
618+
)
606619
def test__floordiv__(self):
607620
"""Test floordiv operation on a variable."""
608621
v1 = backend.Variable(initializer=np.array([1.0, 2.0, 3.0]))
609622
v2 = backend.Variable(initializer=np.array([-4.0, 5.0, 6.0]))
610623
self.assertAllClose(v1.__floordiv__(v2), np.array([-1.0, 0.0, 0.0]))
611624

625+
@skip_if_backend(
626+
"openvino", "`floor_divide` is not supported with openvino backend"
627+
)
612628
def test__rfloordiv__(self):
613629
"""Test reverse floordiv operation on a variable."""
614630
v1 = backend.Variable(initializer=np.array([-4.0, 5.0, 6.0]))
@@ -734,6 +750,9 @@ def test_variable_rpow(self):
734750
result = v2**v1
735751
self.assertAllClose(result, np.array([4.0, 25.0, 216.0]))
736752

753+
@skip_if_backend(
754+
"openvino", "`round` is not supported with openvino backend"
755+
)
737756
def test_round(self):
738757
v = backend.Variable(initializer=np.array([1.1, 2.2, 3.3]))
739758
self.assertAllClose(round(v), np.array([1.0, 2.0, 3.0]))
@@ -783,6 +802,9 @@ def test_invalid_float(self):
783802
INT_DTYPES = [
784803
x for x in INT_DTYPES if x not in ["uint16", "uint32", "uint64"]
785804
]
805+
elif backend.backend() == "openvino":
806+
# TODO: openvino doesn't support complex
807+
ALL_DTYPES = [x for x in ALL_DTYPES if x not in ["complex128", "complex64"]]
786808
# Remove float8 dtypes for the following tests
787809
ALL_DTYPES = [x for x in ALL_DTYPES if x not in dtypes.FLOAT8_TYPES]
788810
NON_COMPLEX_DTYPES = [x for x in ALL_DTYPES if x and x not in COMPLEX_DTYPES]
@@ -976,6 +998,9 @@ def test_truediv(self, dtypes):
976998
@parameterized.named_parameters(
977999
named_product(dtypes=itertools.combinations(NON_COMPLEX_DTYPES, 2))
9781000
)
1001+
@skip_if_backend(
1002+
"openvino", "`floor_divide` is not supported with openvino backend"
1003+
)
9791004
def test_floordiv(self, dtypes):
9801005
import jax.numpy as jnp
9811006

Diff for: keras/src/backend/openvino/__init__.py

+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
from keras.src.backend.common.name_scope import name_scope
2+
from keras.src.backend.openvino import core
3+
from keras.src.backend.openvino import image
4+
from keras.src.backend.openvino import linalg
5+
from keras.src.backend.openvino import math
6+
from keras.src.backend.openvino import nn
7+
from keras.src.backend.openvino import numpy
8+
from keras.src.backend.openvino import random
9+
from keras.src.backend.openvino.core import IS_THREAD_SAFE
10+
from keras.src.backend.openvino.core import SUPPORTS_SPARSE_TENSORS
11+
from keras.src.backend.openvino.core import Variable
12+
from keras.src.backend.openvino.core import cast
13+
from keras.src.backend.openvino.core import compute_output_spec
14+
from keras.src.backend.openvino.core import cond
15+
from keras.src.backend.openvino.core import convert_to_numpy
16+
from keras.src.backend.openvino.core import convert_to_tensor
17+
from keras.src.backend.openvino.core import is_tensor
18+
from keras.src.backend.openvino.core import random_seed_dtype
19+
from keras.src.backend.openvino.core import shape
20+
from keras.src.backend.openvino.core import vectorized_map
21+
from keras.src.backend.openvino.rnn import cudnn_ok
22+
from keras.src.backend.openvino.rnn import gru
23+
from keras.src.backend.openvino.rnn import lstm
24+
from keras.src.backend.openvino.rnn import rnn

0 commit comments

Comments
 (0)