Skip to content

Commit f209471

Browse files
authored
Fix for CLI Text Classification Benchmark / Quantize (#404)
* added an env variable to fix benchmark / quantize * added env var to hf model
1 parent d38778c commit f209471

File tree

5 files changed

+10
-6
lines changed

5 files changed

+10
-6
lines changed

tlt/datasets/text_classification/tf_custom_text_classification_dataset.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ class TFCustomTextClassificationDataset(TextClassificationDataset, TFDataset):
7575
7676
"""
7777

78-
def __init__(self, dataset_dir, dataset_name, csv_file_name, class_names, label_map_func=None,
78+
def __init__(self, dataset_dir, dataset_name, csv_file_name, class_names=[], label_map_func=None,
7979
defaults=[tf.string, tf.string], delimiter=",", header=False, select_cols=None, exclude_cols=None,
8080
shuffle_files=True, seed=None, **kwargs):
8181
"""
@@ -85,11 +85,6 @@ def __init__(self, dataset_dir, dataset_name, csv_file_name, class_names, label_
8585
if not os.path.exists(dataset_file):
8686
raise FileNotFoundError("The dataset file ({}) does not exist".format(dataset_file))
8787

88-
if not isinstance(class_names, list):
89-
raise TypeError("The class_names is expected to be a list, but found a {}", type(class_names))
90-
if len(class_names) == 0:
91-
raise ValueError("The class_names list cannot be empty.")
92-
9388
if label_map_func and not callable(label_map_func):
9489
raise TypeError("The label_map_func is expected to be a function, but found a {}", type(label_map_func))
9590

tlt/models/hf_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -176,6 +176,8 @@ def benchmark(self, dataset, saved_model_dir=None, warmup=10, iteration=100, cor
176176
FileNotFoundError: if a model.pt is not found in the saved_model_dir or if the inc_config_path file
177177
is not found
178178
"""
179+
os.environ["NC_ENV_CONF"] = "True"
180+
179181
# Verify dataset is of the right type
180182
if not isinstance(dataset, self._inc_compatible_dataset):
181183
raise NotImplementedError('Quantization has only been implemented for TLT datasets, and type '

tlt/models/pytorch_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -275,6 +275,8 @@ def benchmark(self, dataset, saved_model_dir=None, warmup=10, iteration=100, cor
275275
FileNotFoundError: if a model.pt is not found in the saved_model_dir or if the inc_config_path file
276276
is not found
277277
"""
278+
os.environ["NC_ENV_CONF"] = "True"
279+
278280
# Verify dataset is of the right type
279281
if not isinstance(dataset, self._inc_compatible_dataset):
280282
raise NotImplementedError('Quantization has only been implemented for TLT datasets, and type '

tlt/models/tf_model.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -409,6 +409,8 @@ def benchmark(self, dataset, saved_model_dir=None, warmup=10, iteration=100, cor
409409
FileNotFoundError: if a saved_model.pb is not found in the saved_model_dir or if the inc_config_path file
410410
is not found
411411
"""
412+
os.environ["NC_ENV_CONF"] = "True"
413+
412414
# If provided, the saved model directory should exist and contain a saved_model.pb file
413415
if saved_model_dir is not None:
414416
if not os.path.isdir(saved_model_dir):

tlt/tools/cli/commands/train.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,9 @@ def train(framework, model_name, use_case, output_dir, dataset_dir, dataset_file
276276
if not class_names:
277277
raise ValueError("Loading a text classification dataset requires --class-names to specify a list "
278278
"of the class labels for the dataset.")
279+
elif len(class_names) == 0:
280+
raise ValueError("Loading a text classification dataset requires --class-names to specify a list "
281+
"of the class labels of which the len > 0")
279282
dataset = dataset_factory.load_dataset(dataset_dir, model.use_case, model.framework, dataset_name,
280283
class_names=class_names, csv_file_name=dataset_file,
281284
delimiter=delimiter)

0 commit comments

Comments
 (0)