diff --git a/tests/test_balance.py b/tests/test_balance.py index 0b795b0..825d0f0 100644 --- a/tests/test_balance.py +++ b/tests/test_balance.py @@ -141,8 +141,8 @@ def test_layerwise_sandbox(device): assert layer.training assert all(p.device.type == device for p in layer.parameters()) - assert all(not l.training for l in model) - assert all(p.device.type == 'cpu' for p in model.parameters()) + assert all(not layer.training for layer in model) + assert all(param.device.type == 'cpu' for param in model.parameters()) @pytest.mark.parametrize('device', devices) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index b058e76..23935f5 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -13,19 +13,19 @@ def test_clock_cycles(): assert list(clock_cycles(3, 1)) == [[(0, 0)], [(1, 0)], [(2, 0)]] assert list(clock_cycles(3, 3)) == [ # noqa - [(0, 0)], - [(1, 0), (0, 1)], - [(2, 0), (1, 1), (0, 2)], - [(2, 1), (1, 2)], - [(2, 2)], + [(0, 0)], # noqa + [(1, 0), (0, 1)], # noqa + [(2, 0), (1, 1), (0, 2)], # noqa + [(2, 1), (1, 2)], # noqa + [(2, 2)], # noqa ] assert list(clock_cycles(4, 2)) == [ # noqa - [(0, 0)], - [(1, 0), (0, 1)], - [(2, 0), (1, 1)], - [(3, 0), (2, 1)], - [(3, 1)], + [(0, 0)], # noqa + [(1, 0), (0, 1)], # noqa + [(2, 0), (1, 1)], # noqa + [(3, 0), (2, 1)], # noqa + [(3, 1)], # noqa ]