diff --git a/.github/workflows/main.yaml b/.github/workflows/main.yaml index 1618bc4..9d3fe1b 100644 --- a/.github/workflows/main.yaml +++ b/.github/workflows/main.yaml @@ -33,3 +33,9 @@ jobs: - name: Test run: | ./tests/scripts/task_build.sh + - name: Install Python Packages + run: | + python3 -m pip install -U pytest numpy torch tensorflow mxnet + - name: Python Tests + run: | + pytest -v diff --git a/.gitignore b/.gitignore index 21c857e..16da57a 100644 --- a/.gitignore +++ b/.gitignore @@ -30,3 +30,6 @@ *~ build bin + +# Python +__pycache__ \ No newline at end of file diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..34be6d6 --- /dev/null +++ b/tests/__init__.py @@ -0,0 +1,39 @@ +import tensorflow as tf +import mxnet as mx + +dtypes = [ + 'uint8', 'uint16', 'uint32', 'uint64', + 'int8', 'int16', 'int32', 'int64', + 'float16', 'float32', 'float64', + 'complex64', 'complex128' +] + +arrays = [ + [1, 2, 3], # 1D array + 1, # ndim = 0 array + [], # empty array + [[1, 2, 3], + [4, 5, 6]], # multi-dimensional array +] + + +class TfTensor: + def __init__(self, tensor): + self.tensor = tensor + + def __dlpack__(self, stream=0): + return tf.experimental.dlpack.to_dlpack(self.tensor) + + def __dlpack_device__(self): + return (1, 0) # we only test CPU tensors for now. + + +class MxArray: + def __init__(self, array): + self.array = array + + def __dlpack__(self, stream=0): + return mx.nd.to_dlpack_for_read(self.array) + + def __dlpack_device__(self): + return (1, 0) diff --git a/tests/test_mxnet.py b/tests/test_mxnet.py new file mode 100644 index 0000000..d34c7f3 --- /dev/null +++ b/tests/test_mxnet.py @@ -0,0 +1,63 @@ +import pytest +import numpy as np +import torch +import tensorflow as tf +import mxnet as mx +from . import dtypes, arrays, TfTensor, MxArray + + +def mxnet_assert_equal(x, y): + x, y = x.asnumpy(), y.asnumpy() + np.testing.assert_array_equal(x, y) + + +@pytest.mark.parametrize('data', arrays) +@pytest.mark.parametrize('dtype', dtypes) +def test_fromself(data, dtype): + try: + x = mx.nd.array(data, dtype=dtype) + except KeyError: + pytest.skip(f"mxnet doesn't support {dtype}.") + y = mx.nd.from_dlpack(MxArray(x).__dlpack__()) + mxnet_assert_equal(x, y) + + +@pytest.mark.parametrize('data', arrays) +@pytest.mark.parametrize('dtype', dtypes) +def test_fromnumpy(data, dtype): + x = np.array(data, dtype=dtype) + try: + expected_y = mx.nd.array(data, dtype=dtype) + except KeyError: + pytest.skip(f"mxnet doesn't support {dtype}.") + y = mx.nd.from_dlpack(x.__dlpack__()) + mxnet_assert_equal(y, expected_y) + + +@pytest.mark.parametrize('data', arrays) +@pytest.mark.parametrize('dtype', dtypes) +def test_fromtorch(data, dtype): + dt = getattr(torch, dtype, None) + if dt is None: + pytest.skip(f"torch doesn't support {dtype}.") + x = torch.tensor(data, dtype=dt) + try: + expected_y = mx.nd.array(data, dtype=dtype) + except KeyError: + pytest.skip(f"mxnet doesn't support {dtype}.") + y = mx.nd.from_dlpack(x.__dlpack__()) + mxnet_assert_equal(y, expected_y) + + +@pytest.mark.parametrize('data', arrays) +@pytest.mark.parametrize('dtype', dtypes) +def test_fromtensorflow(data, dtype): + if 'complex' in dtype: + pytest.xfail("tensorflow currently doesn't support complex dtypes.") + x = tf.constant(data, dtype=dtype) + try: + expected_y = mx.nd.array(data, dtype=dtype) + except KeyError: + pytest.skip(f"mxnet doesn't support {dtype}.") + y = mx.nd.from_dlpack(TfTensor(x).__dlpack__()) + mxnet_assert_equal(y, expected_y) diff --git a/tests/test_numpy.py b/tests/test_numpy.py new file mode 100644 index 0000000..546425d --- /dev/null +++ b/tests/test_numpy.py @@ -0,0 +1,97 @@ +import pytest +import numpy as np +import torch +import tensorflow as tf +import mxnet as mx +from . import dtypes, arrays, TfTensor, MxArray + +@pytest.mark.parametrize('data', arrays) +@pytest.mark.parametrize('dtype', dtypes) +def test_fromself(data, dtype): + x = np.array(data, dtype=dtype) + y = np._from_dlpack(x) + np.testing.assert_array_equal(x, y) + + +@pytest.mark.parametrize('data', arrays) +@pytest.mark.parametrize('dtype', dtypes) +def test_fromtorch(data, dtype): + dt = getattr(torch, dtype, None) + if dt is None: + pytest.skip(f"torch doesn't support {dtype}.") + x = torch.tensor(data, dtype=dt) + y = np._from_dlpack(x) + expected_y = x.numpy() + np.testing.assert_array_equal(y, expected_y) + + +@pytest.mark.parametrize('data', arrays) +@pytest.mark.parametrize('dtype', dtypes) +def test_fromtensorflow(data, dtype): + if 'complex' in dtype: + pytest.xfail("tensorflow currently doesn't support complex dtypes.") + x = tf.constant(data, dtype=dtype) + y = np._from_dlpack(TfTensor(x)) + expected_y = x.numpy() + np.testing.assert_array_equal(y, expected_y) + + +@pytest.mark.parametrize('data', arrays) +@pytest.mark.parametrize('dtype', dtypes) +def test_frommxnet(data, dtype): + try: + x = mx.nd.array(data, dtype=dtype) + except KeyError: + pytest.skip(f"mxnet doesn't support {dtype}.") + y = np._from_dlpack(MxArray(x)) + expected_y = x.asnumpy() + np.testing.assert_array_equal(y, expected_y) + + +def test_byteswapped(): + dt = np.dtype('=i8').newbyteorder() + x = np.arange(5, dtype=dt) + + with pytest.raises(TypeError): + np._from_dlpack(x) + + +def test_invalid_dtype(): + x = np.asarray(np.datetime64('2021-05-27')) + + with pytest.raises(TypeError): + np._from_dlpack(x) + + +def non_contiguous_testcases(x, wrapper=None, to_numpy=lambda x: x): + if wrapper is None: + wrapper = lambda x: x + + y1 = x[0] + np.testing.assert_array_equal(to_numpy(y1), np._from_dlpack(wrapper(y1))) + + y2 = x[:, 0] + np.testing.assert_array_equal(to_numpy(y2), np._from_dlpack(wrapper(y2))) + + y3 = x[1, :] + np.testing.assert_array_equal(to_numpy(y3), np._from_dlpack(wrapper(y3))) + + y4 = x[1:3, 3:5] + np.testing.assert_array_equal(to_numpy(y4), np._from_dlpack(wrapper(y4))) + + +def test_non_contiguous(): + x = np.arange(25).reshape((5, 5)) + non_contiguous_testcases(x) + + # test against torch + x = torch.arange(25).reshape((5, 5)) + non_contiguous_testcases(x, to_numpy=lambda x: x.numpy()) + + # test against tensorflow + x = tf.constant(range(25), shape=(5, 5)) + non_contiguous_testcases(x, TfTensor, lambda x: x.numpy()) + + # test against mxnet + x = mx.nd.arange(25).reshape((5, 5)) + non_contiguous_testcases(x, MxArray, lambda x: x.asnumpy()) diff --git a/tests/test_tensorflow.py b/tests/test_tensorflow.py new file mode 100644 index 0000000..ca1fa40 --- /dev/null +++ b/tests/test_tensorflow.py @@ -0,0 +1,66 @@ +import pytest +import numpy as np +import torch +import tensorflow as tf +from tensorflow.python.ops.numpy_ops import np_config +import mxnet as mx +from . import dtypes, arrays, TfTensor, MxArray + +np_config.enable_numpy_behavior() + + +def tensorflow_assert_equal(x, y): + assert tf.reduce_all(x == y) + assert x.dtype == y.dtype + assert x.device == y.device + assert x.shape == y.shape + + +@pytest.mark.parametrize('data', arrays) +@pytest.mark.parametrize('dtype', dtypes) +def test_fromself(data, dtype): + if 'complex' in dtype: + pytest.xfail("tensorflow currently doesn't support complex dtypes.") + x = tf.constant(data, dtype=dtype) + y = tf.experimental.dlpack.from_dlpack(TfTensor(x).__dlpack__()) + tensorflow_assert_equal(x, y) + + +@pytest.mark.skip(reason="tensorflow crashes when importing NumPy arrays.") +@pytest.mark.parametrize('data', arrays) +@pytest.mark.parametrize('dtype', dtypes) +def test_fromnumpy(data, dtype): + if 'complex' in dtype: + pytest.xfail("tensorflow currently doesn't support complex dtypes.") + x = np.array(data, dtype=dtype) + y = tf.experimental.dlpack.from_dlpack(x.__dlpack__()) + expected_y = tf.constant(data, dtype=dtype) + tensorflow_assert_equal(y, expected_y) + + +@pytest.mark.parametrize('data', arrays) +@pytest.mark.parametrize('dtype', dtypes) +def test_fromtorch(data, dtype): + if 'complex' in dtype: + pytest.xfail("tensorflow currently doesn't support complex dtypes.") + dt = getattr(torch, dtype, None) + if dt is None: + pytest.skip(f"torch doesn't support {dtype}.") + x = torch.tensor(data, dtype=dt) + y = tf.experimental.dlpack.from_dlpack(x.__dlpack__()) + expected_y = tf.constant(data, dtype=dtype) + tensorflow_assert_equal(y, expected_y) + + +@pytest.mark.parametrize('data', arrays) +@pytest.mark.parametrize('dtype', dtypes) +def test_frommxnet(data, dtype): + if 'complex' in dtype: + pytest.xfail("tensorflow currently doesn't support complex dtypes.") + try: + x = mx.nd.array(data, dtype=dtype) + except KeyError: + pytest.skip(f"mxnet doesn't support {dtype}.") + y = tf.experimental.dlpack.from_dlpack(MxArray(x).__dlpack__()) + expected_y = tf.constant(data, dtype=dtype) + tensorflow_assert_equal(y, expected_y) diff --git a/tests/test_torch.py b/tests/test_torch.py new file mode 100644 index 0000000..6e4f23a --- /dev/null +++ b/tests/test_torch.py @@ -0,0 +1,66 @@ +import pytest +import numpy as np +import torch +import tensorflow as tf +import mxnet as mx +from . import dtypes, arrays, TfTensor, MxArray + + +def torch_assert_equal(x, y): + assert torch.all(x == y) + assert x.dtype == y.dtype + assert x.device == y.device + assert x.shape == y.shape + assert x.stride() == y.stride() + + +@pytest.mark.parametrize('data', arrays) +@pytest.mark.parametrize('dtype', dtypes) +def test_fromself(data, dtype): + dt = getattr(torch, dtype, None) + if dt is None: + pytest.skip(f"torch doesn't support {dtype}.") + x = torch.tensor(data, dtype=dt) + y = torch.from_dlpack(x) + torch_assert_equal(x, y) + + +@pytest.mark.parametrize('data', arrays) +@pytest.mark.parametrize('dtype', dtypes) +def test_fromnumpy(data, dtype): + dt = getattr(torch, dtype, None) + if dt is None: + pytest.skip(f"torch doesn't support {dtype}.") + x = np.array(data, dtype=dtype) + y = torch.from_dlpack(x) + expected_y = torch.tensor(data, dtype=dt) + torch_assert_equal(y, expected_y) + + +@pytest.mark.parametrize('data', arrays) +@pytest.mark.parametrize('dtype', dtypes) +def test_fromtensorflow(data, dtype): + if 'complex' in dtype: + pytest.xfail("tensorflow currently doesn't support complex dtypes.") + dt = getattr(torch, dtype, None) + if dt is None: + pytest.skip(f"torch doesn't support {dtype}.") + x = tf.constant(data, dtype=dtype) + y = torch.from_dlpack(TfTensor(x)) + expected_y = torch.tensor(data, dtype=dt) + torch_assert_equal(y, expected_y) + + +@pytest.mark.parametrize('data', arrays) +@pytest.mark.parametrize('dtype', dtypes) +def test_frommxnet(data, dtype): + dt = getattr(torch, dtype, None) + if dt is None: + pytest.skip(f"torch doesn't support {dtype}.") + try: + x = mx.nd.array(data, dtype=dtype) + except KeyError: + pytest.skip(f"mxnet doesn't support {dtype}.") + y = torch.from_dlpack(MxArray(x)) + expected_y = torch.tensor(data, dtype=dt) + torch_assert_equal(y, expected_y)