|
137 | 137 | " objects for train and evaluations splits, along with helper functions for preprocessing the dataset.\n", |
138 | 138 | " \"\"\"\n", |
139 | 139 | "\n", |
140 | | - " def __init__(self, dataset_name, tokenizer, sentence1_key, sentence2_key, label_key):\n", |
| 140 | + " def __init__(self, dataset_name, tokenizer, sentence1_key, sentence2_key, label_key, max_length=None):\n", |
141 | 141 | " self.tokenizer = tokenizer\n", |
142 | 142 | " self.dataset_name = dataset_name\n", |
143 | 143 | " self.class_labels = None\n", |
|
150 | 150 | " self.sentence1_key = sentence1_key\n", |
151 | 151 | " self.sentence2_key = sentence2_key\n", |
152 | 152 | " self.label_key = label_key\n", |
| 153 | + "\n", |
| 154 | + " # Max sequence length\n", |
| 155 | + " self.max_length = max_length\n", |
153 | 156 | " \n", |
154 | 157 | " def tokenize_function(self, examples):\n", |
155 | 158 | " # Define the tokenizer args, depending on if the data has 2 sentences or just 1\n", |
156 | 159 | " args = ((examples[self.sentence1_key],) if self.sentence2_key is None \\\n", |
157 | 160 | " else (examples[self.sentence1_key], examples[self.sentence2_key]))\n", |
158 | | - " return self.tokenizer(*args, padding=\"max_length\", truncation=True)\n", |
| 161 | + " return self.tokenizer(*args, padding=\"max_length\", truncation=True, max_length=self.max_length)\n", |
159 | 162 | " \n", |
160 | 163 | " def tokenize_dataset(self, dataset):\n", |
161 | 164 | " # Apply the tokenize function to the dataset\n", |
|
232 | 235 | " \"\"\"\n", |
233 | 236 | " \n", |
234 | 237 | " def __init__(self, tokenizer, dataset_dir, dataset_name, train_size, eval_size, train_split_name,\n", |
235 | | - " eval_split_name, sentence1_key, sentence2_key, label_key):\n", |
| 238 | + " eval_split_name, sentence1_key, sentence2_key, label_key, max_length=None):\n", |
236 | 239 | " \"\"\"\n", |
237 | 240 | " Initialize the HFDSTextClassificationData class for a text classification dataset from Hugging Face.\n", |
238 | 241 | " \n", |
|
249 | 252 | " :param sentence1_key: Name of the sentence1 column\n", |
250 | 253 | " :param sentence2_key: Name of the sentence2 column or `None` if there's only one text column\n", |
251 | 254 | " :param label_key: Name of the label column\n", |
| 255 | + " :param max_length: Optional max sequence length (default None will use the tokenizer's max sequence)\n", |
252 | 256 | " \"\"\"\n", |
253 | 257 | "\n", |
254 | 258 | " # Init base class\n", |
255 | | - " TextClassificationData.__init__(self, dataset_name, tokenizer, sentence1_key, sentence2_key, label_key) \n", |
| 259 | + " TextClassificationData.__init__(self, dataset_name, tokenizer, sentence1_key, sentence2_key, label_key,\n", |
| 260 | + " max_length) \n", |
256 | 261 | " \n", |
257 | 262 | " # Load the dataset from the Hugging Face dataset API\n", |
258 | 263 | " self.dataset = load_dataset(dataset_name, cache_dir=dataset_dir)\n", |
|
279 | 284 | "sentence2_key = None\n", |
280 | 285 | "label_key = \"label\"\n", |
281 | 286 | "\n", |
| 287 | + "# Max sequence length\n", |
| 288 | + "max_length = None\n", |
| 289 | + "\n", |
282 | 290 | "dataset = HFDSTextClassificationData(tokenizer, dataset_dir, dataset_name, train_dataset_size, eval_dataset_size,\n", |
283 | | - " Split.TRAIN, Split.TEST, sentence1_key, sentence2_key, label_key)\n", |
| 291 | + " Split.TRAIN, Split.TEST, sentence1_key, sentence2_key, label_key, max_length)\n", |
284 | 292 | "\n", |
285 | 293 | "# Print a sample of the data\n", |
286 | 294 | "dataset.display_sample(Split.TRAIN, sample_size=5)" |
|
326 | 334 | " \"\"\"\n", |
327 | 335 | " \n", |
328 | 336 | " def __init__(self, tokenizer, dataset_name, dataset_dir, data_files, delimiter, label_names, sentence1_key, sentence2_key,\n", |
329 | | - " label_key, train_percent=0.8, eval_percent=0.2, train_size=None, eval_size=None, map_function=None):\n", |
| 337 | + " label_key, train_percent=0.8, eval_percent=0.2, train_size=None, eval_size=None, map_function=None,\n", |
| 338 | + " max_length=None):\n", |
330 | 339 | " \"\"\"\n", |
331 | 340 | " Intialize the CustomCsvTextClassificationData class for a text classification\n", |
332 | 341 | " dataset. The classes uses the Hugging Face datasets API to load the CSV file,\n", |
|
352 | 361 | " :param eval_size: Size of the eval dataset. Set to `None` to use all the data.\n", |
353 | 362 | " :param map_function: (Optional) Map function to apply to the dataset. For example, if the csv file has string\n", |
354 | 363 | " labels instead of numerical values, map function can do the conversion.\n", |
| 364 | + " :param max_length: Optional max sequence length (default None will use the tokenizer's max sequence) \n", |
355 | 365 | " \"\"\"\n", |
356 | 366 | " # Init base class\n", |
357 | | - " TextClassificationData.__init__(self, dataset_name, tokenizer, sentence1_key, sentence2_key, label_key)\n", |
| 367 | + " TextClassificationData.__init__(self, dataset_name, tokenizer, sentence1_key, sentence2_key, label_key, max_length)\n", |
358 | 368 | " \n", |
359 | 369 | " if (train_percent + eval_percent) > 1:\n", |
360 | 370 | " raise ValueError(\"The combined value of the train percentage and eval percentage \" \\\n", |
|
408 | 418 | "sentence2_key = None\n", |
409 | 419 | "label_key = \"label\"\n", |
410 | 420 | "\n", |
| 421 | + "# Max sequence length\n", |
| 422 | + "max_length = None\n", |
| 423 | + "\n", |
411 | 424 | "# Map function to translate labels in the csv file to numerical values when loading the dataset\n", |
412 | 425 | "def map_spam(example):\n", |
413 | 426 | " example[\"label\"] = int(example[\"label\"] == \"spam\")\n", |
414 | 427 | " return example\n", |
415 | 428 | "\n", |
416 | 429 | "dataset = CustomCsvTextClassificationData(tokenizer, \"smsspamcollection\", dataset_dir, [renamed_csv], delimiter,\n", |
417 | 430 | " label_names, sentence1_key, sentence2_key, label_key, train_size=1000,\n", |
418 | | - " eval_size=1000, map_function=map_spam)\n", |
| 431 | + " eval_size=1000, map_function=map_spam, max_length=max_length)\n", |
419 | 432 | "\n", |
420 | 433 | "# Print a sample of the data\n", |
421 | 434 | "dataset.display_sample(Split.TRAIN, 10)" |
|
0 commit comments