|
14 | 14 | from typing_extensions import override
|
15 | 15 |
|
16 | 16 | from fairseq2.data import (
|
| 17 | + CollateOptionsOverride, |
17 | 18 | Collater,
|
18 | 19 | DataPipeline,
|
19 | 20 | DataPipelineBuilder,
|
@@ -77,28 +78,23 @@ class ParallelTextDataset(ABC):
|
77 | 78 | def create_reader(
|
78 | 79 | self,
|
79 | 80 | split: str,
|
80 |
| - tokenizer: TextTokenizer, |
| 81 | + source_tokenizer: TextTokenizer, |
| 82 | + target_tokenizer: TextTokenizer, |
81 | 83 | gang: Gang,
|
82 | 84 | min_seq_len: int,
|
83 | 85 | max_seq_len: int,
|
84 | 86 | options: ParallelTextReadOptions | None = None,
|
85 | 87 | ) -> DataReader[Seq2SeqBatch]:
|
86 | 88 | """Create a dataset reader.
|
87 | 89 |
|
88 |
| - :param split: |
89 |
| - The split to read. |
90 |
| - :param tokenizer: |
91 |
| - The tokenizer to encode text. |
92 |
| - :param gang: |
93 |
| - The gang over which to shard the dataset. |
94 |
| - :param min_seq_len: |
95 |
| - The minimum sequence length of each example. Examples shorter than |
96 |
| - this value will be dropped. |
97 |
| - :param max_seq_len: |
98 |
| - The maximum sequence length of each example. Examples longer than |
99 |
| - this value will be dropped. |
100 |
| - :param options: |
101 |
| - The read options. |
| 90 | + :param split: The split to read. |
| 91 | + :param source_tokenizer: The tokenizer to encode source text. |
| 92 | + :param gang: The gang over which to shard the dataset. |
| 93 | + :param min_seq_len: The minimum sequence length of each example. |
| 94 | + Examples shorter than this value will be dropped. |
| 95 | + :param max_seq_len: The maximum sequence length of each example. |
| 96 | + Examples longer than this value will be dropped. |
| 97 | + :param options: The read options. |
102 | 98 | """
|
103 | 99 |
|
104 | 100 | @abstractmethod
|
@@ -171,7 +167,7 @@ def from_path(cls, path: Path, name: str) -> GenericParallelTextDataset:
|
171 | 167 | manifest_file = path.joinpath(split).joinpath("MANIFEST")
|
172 | 168 |
|
173 | 169 | try:
|
174 |
| - with manifest_file.open() as fp: |
| 170 | + with manifest_file.open(encoding="utf-8") as fp: |
175 | 171 | content = list(fp)
|
176 | 172 | except OSError as ex:
|
177 | 173 | raise DatasetLoadError(
|
@@ -252,7 +248,8 @@ def value_error() -> ValueError:
|
252 | 248 | def create_reader(
|
253 | 249 | self,
|
254 | 250 | split: str,
|
255 |
| - tokenizer: TextTokenizer, |
| 251 | + source_tokenizer: TextTokenizer, |
| 252 | + target_tokenizer: TextTokenizer, |
256 | 253 | gang: Gang,
|
257 | 254 | min_seq_len: int,
|
258 | 255 | max_seq_len: int,
|
@@ -289,11 +286,11 @@ def create_reader(
|
289 | 286 | if direction.origin:
|
290 | 287 | source_mode = f"{source_mode}_{direction.origin}"
|
291 | 288 |
|
292 |
| - source_encoder = tokenizer.create_encoder( |
| 289 | + source_encoder = source_tokenizer.create_encoder( |
293 | 290 | task="translation", lang=direction.source_lang, mode=source_mode
|
294 | 291 | )
|
295 | 292 |
|
296 |
| - target_encoder = tokenizer.create_encoder( |
| 293 | + target_encoder = target_tokenizer.create_encoder( |
297 | 294 | task="translation", lang=direction.target_lang, mode="target"
|
298 | 295 | )
|
299 | 296 |
|
@@ -384,7 +381,18 @@ def skip(example: dict[str, Any]) -> bool:
|
384 | 381 | seed += 1
|
385 | 382 |
|
386 | 383 | # Collate bucketed examples into a batch.
|
387 |
| - collater = Collater(pad_value=tokenizer.vocab_info.pad_idx) |
| 384 | + collater = Collater( |
| 385 | + overrides=[ |
| 386 | + CollateOptionsOverride( |
| 387 | + selector="source_indices", |
| 388 | + pad_value=source_tokenizer.vocab_info.pad_idx, |
| 389 | + ), |
| 390 | + CollateOptionsOverride( |
| 391 | + selector="target_indices", |
| 392 | + pad_value=target_tokenizer.vocab_info.pad_idx, |
| 393 | + ), |
| 394 | + ] |
| 395 | + ) |
388 | 396 |
|
389 | 397 | builder.map(collater, num_parallel_calls=npc)
|
390 | 398 |
|
|
0 commit comments