Skip to content

Commit 074a33e

Browse files
committed
lint
1 parent 078822d commit 074a33e

File tree

3 files changed

+3
-12
lines changed

3 files changed

+3
-12
lines changed

merlin/dataloader/utils/tf/tf_trainer.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,9 @@
11
# External dependencies
22
import argparse
3-
import glob
43
import logging
54
import os
65

76
from merlin.core.compat import cupy, numpy
8-
from merlin.io import Dataset
9-
from nvtabular.ops import TagAsItemID, TagAsUserID, JoinExternal
10-
from merlin.core.dispatch import get_lib
117
from merlin.schema import Tags
128

139
# we can control how much memory to give tensorflow with this environment variable
@@ -76,8 +72,8 @@ def seed_fn():
7672
EMBEDDING_TABLE_SHAPES, MH_EMBEDDING_TABLE_SHAPES = nvt.ops.get_embedding_sizes(proc)
7773
EMBEDDING_TABLE_SHAPES.update(MH_EMBEDDING_TABLE_SHAPES)
7874

79-
train_ds = nvt.Dataset(f'{BASE_DIR}/train', engine='parquet', dtypes={'rating': xp.int8})
80-
train_ds.schema = train_ds.schema.remove_col('genres')
75+
train_ds = nvt.Dataset(f"{BASE_DIR}/train", engine="parquet", dtypes={"rating": xp.int8})
76+
train_ds.schema = train_ds.schema.remove_col("genres")
8177

8278
target_column = train_ds.schema.select_by_tag(Tags.TARGET).column_names[0]
8379

tests/unit/dataloader/test_tf_dataloader.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -536,10 +536,6 @@ def test_multigpu_partitioning(dataset, batch_size, global_rank):
536536

537537

538538
@pytest.mark.multigpu
539-
@pytest.mark.skipif(
540-
os.environ.get("NR_USER") is not None,
541-
reason="not working correctly in ci environment",
542-
)
543539
@pytest.mark.skipif(importlib.util.find_spec("horovod") is None, reason="needs horovod")
544540
@pytest.mark.skipif(
545541
HAS_GPU and cupy and cupy.cuda.runtime.getDeviceCount() <= 1,
@@ -722,7 +718,7 @@ def test_wrong_batch_size_raises_warning():
722718
_ = tf_loader(dataset, batch_size=batch_size)
723719

724720
for power in range(4, 10):
725-
batch_size = 2**power
721+
batch_size = 2 ** power
726722
# warning not raised for power of two
727723
with warnings.catch_warnings():
728724
warnings.simplefilter("error")

tox.ini

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,6 @@ commands =
8282
passenv =
8383
CUDA_VISIBLE_DEVICES
8484
OPAL_PREFIX
85-
NR_USER
8685
sitepackages=true
8786
; Runs in: 1GPU Github Actions runners.
8887
; Runs GPU-based tests.

0 commit comments

Comments
 (0)