Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import functools
import unittest

from absl.testing import absltest
import einops
Expand Down Expand Up @@ -214,7 +215,9 @@ def setUp(self):
)

def test_sparse_dense_matmul_one_chip_unsharded(self):
devices = jax.devices()[:1]
devices = jax.devices()
if len(devices) != 1:
raise unittest.SkipTest("Unsupported topology.")
mesh = jax.sharding.Mesh(devices, "x")
feature_specs = {
"feature_spec_a": self.feature_spec_a,
Expand Down Expand Up @@ -495,7 +498,9 @@ def test_sparse_dense_matmul_one_chip_unsharded(self):
)

def test_tpu_sparse_dense_matmul_grad_sharded_two_tables(self):
devices = jax.devices()[:2]
devices = jax.devices()
if len(devices) != 2:
raise unittest.SkipTest("Unsupported topology.")
num_sc_per_device = utils.num_sparsecores_per_device(devices[0])
num_devices = len(devices)
mesh = jax.sharding.Mesh(devices, "x")
Expand Down Expand Up @@ -775,7 +780,9 @@ def test_tpu_sparse_dense_matmul_grad_sharded_two_tables(self):
)

def test_tpu_sparse_dense_matmul_grad_sharded_two_tables_stacked(self):
devices = jax.devices()[:2]
devices = jax.devices()
if len(devices) != 2:
raise unittest.SkipTest("Unsupported topology.")
num_sc_per_device = utils.num_sparsecores_per_device(devices[0])
num_devices = len(devices)
mesh = jax.sharding.Mesh(devices, "x")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
import collections
import functools
import unittest

from absl.testing import absltest
from absl.testing import parameterized
Expand Down Expand Up @@ -105,6 +106,11 @@ class ErrorHandlingTest(absltest.TestCase):

# Tests that even if static buffer size is too small, the matmul can proceed.
def test_static_buffer_size_was_too_small(self):
global_devices = jax.devices()
if len(global_devices) != 1:
self.skipTest("Unsupported topology.")
first_device = global_devices[0]

long_feature = np.arange(800, dtype=np.int32).reshape(8, -1)
long_weights = np.ones(long_feature.shape, dtype=np.float32)

Expand All @@ -126,8 +132,6 @@ def test_static_buffer_size_was_too_small(self):
output_shape=[8, 8],
name="feature",
)
global_devices = jax.devices()
first_device = global_devices[0]
num_sc_per_device = utils.num_sparsecores_per_device(first_device)
mesh = jax.sharding.Mesh([first_device], "x")
feature_specs = {
Expand Down Expand Up @@ -359,7 +363,9 @@ def setUp(self):

@parameterized.parameters(False, True)
def test_sparse_dense_matmul_two_chips_sharded(self, using_pmap):
devices = jax.devices()[:2]
devices = jax.devices()
if len(devices) != 2:
raise unittest.SkipTest("Unsupported topology.")
num_sc_per_device = utils.num_sparsecores_per_device(devices[0])
mesh = jax.sharding.Mesh(devices, "x")
feature_specs = {
Expand Down Expand Up @@ -484,7 +490,9 @@ def test_sparse_dense_matmul_two_chips_sharded(self, using_pmap):

@parameterized.parameters(False, True)
def test_sparse_dense_matmul_two_chips_sharded_stacked(self, using_pmap):
devices = jax.devices()[:2]
devices = jax.devices()
if len(devices) != 2:
raise unittest.SkipTest("Unsupported topology.")
num_sc_per_device = utils.num_sparsecores_per_device(devices[0])
mesh = jax.sharding.Mesh(devices, "x")
feature_specs = {
Expand Down Expand Up @@ -633,8 +641,9 @@ def test_sparse_dense_matmul_two_chips_sharded_stacked(self, using_pmap):

@parameterized.parameters(False, True)
def test_sparse_dense_matmul_single_chip(self, using_pmap):
global_devices = jax.devices()
devices = [global_devices[0]]
devices = jax.devices()
if len(devices) != 1:
raise unittest.SkipTest("Unsupported topology.")
num_sc_per_device = utils.num_sparsecores_per_device(devices[0])
mesh = jax.sharding.Mesh(devices, "x")
feature_specs = {
Expand Down Expand Up @@ -741,7 +750,9 @@ def test_sparse_dense_matmul_single_chip(self, using_pmap):

@parameterized.parameters(False, True)
def test_sparse_dense_matmul_two_tables(self, using_pmap):
devices = jax.devices()[:2]
devices = jax.devices()
if len(devices) != 2:
raise unittest.SkipTest("Unsupported topology.")
num_sc_per_device = utils.num_sparsecores_per_device(devices[0])
mesh = jax.sharding.Mesh(devices, "x")
feature_specs = {
Expand Down Expand Up @@ -1168,6 +1179,8 @@ def test_sparse_dense_matmul_two_tables(self, using_pmap):
@parameterized.parameters(False, True)
def test_sparse_dense_matmul_four_chips_complex_stacked(self, using_pmap):
devices = jax.devices()
if len(devices) != 4:
self.skipTest("Unsupported topology.")
num_sc_per_device = utils.num_sparsecores_per_device(devices[0])
mesh = jax.sharding.Mesh(devices, "x")
country_table = embedding_spec.TableSpec(
Expand Down