2626from nemo_curator .utils .script_utils import ArgumentHelper
2727
2828
29- def load_dataset (input_data_dir ) :
29+ def load_dataset (input_data_dir : str ) -> DocumentDataset :
3030 files = list (get_all_files_paths_under (input_data_dir , keep_extensions = "jsonl" ))
3131 raw_data = read_data (files , file_type = "jsonl" , backend = "pandas" , add_filename = True )
32- dataset = DocumentDataset (raw_data )
32+ return DocumentDataset (raw_data )
3333
34- return dataset
3534
36-
37- def create_samples (data_path , label , num_samples ):
35+ def create_samples (data_path : str , label : str , num_samples : int ) -> list [str ]:
3836 raw_dataset = load_dataset (data_path )
3937 label_quality = nc .Modify (FastTextLabelModifier (label ))
4038
4139 labeled_dataset = label_quality (raw_dataset )
42- labeled_samples = labeled_dataset .df .sample (
43- frac = num_samples / len (labeled_dataset .df )
44- )
40+ labeled_samples = labeled_dataset .df .sample (frac = num_samples / len (labeled_dataset .df ))
4541
4642 return labeled_samples ["text" ].compute ().values .tolist ()
4743
4844
49- def main (args ) :
45+ def main (args : argparse . Namespace ) -> None :
5046 # Params
5147 low_quality_data_path = "/path/to/low_quality"
5248 high_quality_data_path = "/path/to/high_quality"
@@ -55,13 +51,9 @@ def main(args):
5551 filtered_output = "/path/to/output"
5652
5753 # Prepare samples for the classifier
58- client = get_client (** ArgumentHelper .parse_client_args (args ))
59- low_quality_samples = create_samples (
60- low_quality_data_path , "__label__lq" , num_low_quality_samples
61- )
62- high_quality_samples = create_samples (
63- high_quality_data_path , "__label__hq" , num_high_quality_samples
64- )
54+ client = get_client (** ArgumentHelper .parse_client_args (args )) # noqa: F841
55+ low_quality_samples = create_samples (low_quality_data_path , "__label__lq" , num_low_quality_samples )
56+ high_quality_samples = create_samples (high_quality_data_path , "__label__hq" , num_high_quality_samples )
6557
6658 train_samples = low_quality_samples + high_quality_samples
6759 random .shuffle (train_samples )
@@ -96,12 +88,10 @@ def main(args):
9688
9789
9890def attach_args (
99- parser = argparse .ArgumentParser (
100- formatter_class = argparse .ArgumentDefaultsHelpFormatter
101- ),
102- ):
91+ parser : argparse .ArgumentParser ,
92+ ) -> argparse .ArgumentParser :
10393 return ArgumentHelper (parser ).add_distributed_args ()
10494
10595
10696if __name__ == "__main__" :
107- main (attach_args ().parse_args ())
97+ main (attach_args (argparse . ArgumentParser ( formatter_class = argparse . ArgumentDefaultsHelpFormatter ) ).parse_args ())
0 commit comments