Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
80e38da
WIP: median support for OpenVINO back-end
alungriffith Jun 10, 2025
98f982e
Finished median(), linted
alungriffith Jun 13, 2025
9fd5282
Added comments
alungriffith Jun 13, 2025
93797be
Fixed k_value not scalar
alungriffith Jun 13, 2025
97e3852
Fixed squeeze missing axis
alungriffith Jun 13, 2025
4922cee
Fixed missing output and dtype issues
alungriffith Jun 13, 2025
3f56a13
Fixed final median shape issue
alungriffith Jun 13, 2025
77f7859
Fix tuple convert to OpenVINO
alungriffith Jun 13, 2025
fbb727a
Fix missing gather axis arguements
alungriffith Jun 13, 2025
e66b943
Fix missing range dtype
alungriffith Jun 13, 2025
45ed5bf
Fix range start is scalar
alungriffith Jun 13, 2025
073a48a
NumpyDtypeTest::test_median reinserted
alungriffith Jun 13, 2025
2a926e9
Revert to state of 'Fix missing range dtype'
alungriffith Jun 13, 2025
916faed
2nd fix for range start with scalar
alungriffith Jun 13, 2025
69268d9
Fix missing rank scalar
alungriffith Jun 13, 2025
d92b60a
fixed and passed local testing. Submit median for PR
alungriffith Jun 18, 2025
77dc326
test
alungriffith Jul 13, 2025
4e74f2e
fixed x_flatten_rank calculation
alungriffith Jul 13, 2025
8786cfc
added comments following review
alungriffith Jul 13, 2025
04e72cf
comments added and variables renamed to improve clarity. med_1 calcul…
alungriffith Jul 14, 2025
0eeaa5e
added comment to median function
alungriffith Jul 14, 2025
8b0976d
Merge branch 'master' into median_support_ov
alungriffith Sep 6, 2025
2da2263
Changed median comments to docstring.
alungriffith Sep 6, 2025
282dadb
Solved the code formatting error.
alungriffith Sep 6, 2025
25df359
Addressed code review comments on openvino numpy median function.
alungriffith Sep 11, 2025
da1bf1e
Merge remote-tracking branch 'upstream/master' into median_support_ov
alungriffith Sep 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions keras/src/backend/openvino/excluded_concrete_tests.txt
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ NumpyDtypeTest::test_logspace
NumpyDtypeTest::test_matmul_
NumpyDtypeTest::test_max
NumpyDtypeTest::test_mean
NumpyDtypeTest::test_median
NumpyDtypeTest::test_minimum_python_types
NumpyDtypeTest::test_multiply
NumpyDtypeTest::test_power
Expand Down Expand Up @@ -99,7 +98,6 @@ NumpyOneInputOpsCorrectnessTest::test_isneginf
NumpyOneInputOpsCorrectnessTest::test_isposinf
NumpyOneInputOpsCorrectnessTest::test_max
NumpyOneInputOpsCorrectnessTest::test_mean
NumpyOneInputOpsCorrectnessTest::test_median
NumpyOneInputOpsCorrectnessTest::test_pad_float16_constant_2
NumpyOneInputOpsCorrectnessTest::test_pad_float32_constant_2
NumpyOneInputOpsCorrectnessTest::test_pad_float64_constant_2
Expand Down
166 changes: 165 additions & 1 deletion keras/src/backend/openvino/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,171 @@ def maximum(x1, x2):


def median(x, axis=None, keepdims=False):
raise NotImplementedError("`median` is not supported with openvino backend")
"""
The median algorithm follows numpy's method;
if axis is None, flatten all dimensions of the array and find the
median value.
if axis is single int or list/tuple of multiple values, re-order x array
to move those axis dims to the right, flatten the multiple axis dims
then calculate median values along the flattened axis.
"""
if np.isscalar(x):
x = get_ov_output(x)
return OpenVINOKerasTensor(x)

x = get_ov_output(x)
x_type = x.get_element_type()
if x_type == Type.boolean or x_type.is_integral():
x_type = OPENVINO_DTYPES[config.floatx()]
x = ov_opset.convert(x, x_type).output(0)

x_shape_original = ov_opset.shape_of(x, Type.i32).output(0)
x_rank_original = ov_opset.shape_of(x_shape_original, Type.i32).output(0)
x_rank_original_scalar = ov_opset.squeeze(
x_rank_original, ov_opset.constant(0, Type.i32).output(0)
).output(0)

if axis is None:
flatten_shape = ov_opset.constant([-1], Type.i32).output(0)
x = ov_opset.reshape(x, flatten_shape, False).output(0)
flattened = True
else:
# move axis dims to the rightmost positions.
flattened = False
if isinstance(axis, int):
axis = [axis]
if isinstance(axis, (tuple, list)):
axis = list(axis)
ov_axis = ov_opset.constant(axis, Type.i32).output(0)
# normalise any negative axes to their positive equivalents by gathering
# the indices from axis range.
axis_as_range = ov_opset.range(
ov_opset.constant(0, Type.i32).output(0),
x_rank_original_scalar,
ov_opset.constant(1, Type.i32).output(0),
Type.i32,
).output(0)
flatten_axes = ov_opset.gather(
axis_as_range, ov_axis, ov_opset.constant([0], Type.i32)
).output(0)

# right (flatten) axis dims are defined,
# now define the left (remaining) axis dims.

# to find remaining axes, use not_equal comparison between flatten_axes
# and axis_as_range.
# reshape axis_as_range to suit not_equal broadcasting rules for
# comparison.
axis_comparison_shape = ov_opset.concat(
[
ov_opset.shape_of(flatten_axes, Type.i32).output(0),
ov_opset.shape_of(axis_as_range, Type.i32).output(0),
],
0,
).output(0)
reshaped_axis_range = ov_opset.broadcast(
axis_as_range, axis_comparison_shape
).output(0)
axis_compare = ov_opset.not_equal(
reshaped_axis_range,
ov_opset.unsqueeze(
flatten_axes, ov_opset.constant(1, Type.i32).output(0)
).output(0),
).output(0)
axis_compare = ov_opset.reduce_logical_and(
axis_compare, ov_opset.constant(0, Type.i32).output(0)
).output(0)
nz = ov_opset.non_zero(axis_compare, Type.i32).output(0)
nz = ov_opset.squeeze(
nz, ov_opset.constant(0, Type.i32).output(0)
).output(0)
remaining_axes = ov_opset.gather(
axis_as_range, nz, ov_opset.constant(0, Type.i32).output(0)
).output(0)
# concat to place flatten axes on the right and remaining axes on the
# left.
reordered_axes = ov_opset.concat(
[remaining_axes, flatten_axes], 0
).output(0)
x_transposed = ov_opset.transpose(x, reordered_axes).output(0)

# flatten the axis dims if more than 1 axis in input.
if len(axis) > 1:
x_flatten_rank = ov_opset.subtract(
x_rank_original,
ov_opset.constant([len(axis) - 1], Type.i32).output(0),
).output(0)
# create flatten shape of 0's (keep axes)
# and -1 at the end (flattened axis)
x_flatten_shape = ov_opset.broadcast(
ov_opset.constant([0], Type.i32).output(0), x_flatten_rank
).output(0)
x_flatten_shape = ov_opset.scatter_elements_update(
x_flatten_shape,
ov_opset.constant([-1], Type.i32).output(0),
ov_opset.constant([-1], Type.i32).output(0),
0,
"sum",
).output(0)

x_transposed = ov_opset.reshape(
x_transposed, x_flatten_shape, True
).output(0)

x = x_transposed

k_value = ov_opset.gather(
ov_opset.shape_of(x, Type.i32).output(0),
ov_opset.constant(-1, Type.i32).output(0),
ov_opset.constant(0, Type.i32).output(0),
).output(0)

x_sorted = ov_opset.topk(
x, k_value, -1, "min", "value", stable=True
).output(0)

half_index = ov_opset.floor(
ov_opset.divide(k_value, ov_opset.constant(2, Type.i32)).output(0)
).output(0)

# for odd length dimension, select the middle value as median.
# for even length dimension, calculate the mean between the 2 middle values.
x_mod = ov_opset.mod(k_value, ov_opset.constant(2, Type.i32)).output(0)
is_even = ov_opset.equal(x_mod, ov_opset.constant(0, Type.i32)).output(0)

med_0 = ov_opset.gather(
x_sorted, half_index, ov_opset.constant(-1, Type.i32).output(0)
).output(0)
med_1 = ov_opset.gather(
x_sorted,
ov_opset.subtract(half_index, ov_opset.constant(1, Type.i32)).output(0),
ov_opset.constant(-1, Type.i32).output(0),
).output(0)

median_odd = med_0
median_even = ov_opset.divide(
ov_opset.add(med_1, med_0).output(0),
ov_opset.constant(2, x_type),
).output(0)

median_eval = ov_opset.select(is_even, median_even, median_odd).output(0)

if keepdims:
# reshape median_eval to original rank of x.
if flattened:
# create a tensor of ones for reshape, the original rank of x.
median_shape = ov_opset.divide(
x_shape_original, x_shape_original, "none"
).output(0)
median_eval = ov_opset.reshape(
median_eval, median_shape, False
).output(0)
else:
median_eval = ov_opset.unsqueeze(median_eval, flatten_axes).output(
0
)

return OpenVINOKerasTensor(median_eval)


def meshgrid(*x, indexing="xy"):
Expand Down