Skip to content

Commit 31465ac

Browse files
authored
Refactor vocabulary info in recipes (#1094)
1 parent 5d0b40a commit 31465ac

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

87 files changed

+1008
-944
lines changed

doc/source/basics/cli.rst

+4-4
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ or add, delete values:
7979
fairseq2 lm instruction_finetune <OUTPUT_DIR> --config del:common.metric_recorders.tensorboard
8080
8181
# Add a configuration key
82-
fairseq2 lm instruction_finetune <OUTPUT_DIR> --config add:common.metric_recorders.tensorboard="{enabled: true}"
82+
fairseq2 lm instruction_finetune <OUTPUT_DIR> --config set:common.metric_recorders.tensorboard="{enabled: true}"
8383
8484
.. note::
8585

@@ -88,12 +88,12 @@ or add, delete values:
8888
3. Adding and Removing Values
8989
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
9090

91-
Use ``add:`` and ``del:`` directives for more advanced configuration:
91+
Use ``set:`` and ``del:`` directives for more advanced configuration:
9292

9393
.. code-block:: bash
9494
9595
# Add a new configuration value
96-
fairseq2 lm instruction_finetune <OUTPUT_DIR> --config add:new_param=value
96+
fairseq2 lm instruction_finetune <OUTPUT_DIR> --config set:new_param=value
9797
9898
# Remove a configuration value
9999
fairseq2 lm instruction_finetune <OUTPUT_DIR> --config del:unwanted_param
@@ -110,7 +110,7 @@ You can combine all these methods, with later values taking precedence:
110110
--config-file override.yaml \
111111
--config max_num_tokens=512 \
112112
optimizer_config.lr=4e-5 \
113-
add:custom_param=value
113+
set:custom_param=value
114114
115115
Asset Management
116116
----------------

src/fairseq2/assets/cards/models/llama.yaml

-4
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,6 @@ use_v2_tokenizer: true
7575

7676
name: llama3_instruct
7777
base: llama3
78-
model_config:
79-
vocab_info:
80-
_set_:
81-
eos_idx: 128009 # EOT (end-of-turn)
8278
use_eot: true # instruct tokenizer to use EOT instead of EOS
8379

8480
---

src/fairseq2/assets/cards/models/s2t_transformer.yaml

+3-8
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@ name: s2t_transformer_mustc_asr_de_s
88
model_family: s2t_transformer
99
model_arch: small
1010
model_config:
11-
target_vocab_info:
12-
size: 5000
11+
target_vocab_size: 5000
1312
task: transcription
1413
target_langs: [en]
1514
checkpoint: "https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_asr_transformer_s.pt"
@@ -23,9 +22,7 @@ name: s2t_transformer_mustc_asr_es_s
2322
model_family: s2t_transformer
2423
model_arch: small
2524
model_config:
26-
target_vocab_info:
27-
_set_:
28-
size: 5000
25+
target_vocab_size: 5000
2926
task: transcription
3027
target_langs: [en]
3128
checkpoint: "https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_es_asr_transformer_s.pt"
@@ -51,9 +48,7 @@ name: s2t_transformer_mustc_st_de_s
5148
model_family: s2t_transformer
5249
model_arch: small
5350
model_config:
54-
target_vocab_info:
55-
_set_:
56-
size: 8000
51+
target_vocab_size: 8000
5752
task: translation
5853
target_langs: [de]
5954
checkpoint: "https://dl.fbaipublicfiles.com/fairseq/s2t/mustc_de_st_transformer_s.pt"

src/fairseq2/checkpoint/_metadata_provider.py

+2-9
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@
2828

2929
class CheckpointMetadataSaver(ABC):
3030
@abstractmethod
31-
def save(
32-
self, model_family: str, model_config: object, tokenizer_name: str | None = None
33-
) -> None: ...
31+
def save(self, model_family: str, model_config: object) -> None: ...
3432

3533

3634
@final
@@ -52,9 +50,7 @@ def __init__(
5250
self._file_system = file_system
5351
self._yaml_dumper = yaml_dumper
5452

55-
def save(
56-
self, model_family: str, model_config: object, tokenizer_name: str | None = None
57-
) -> None:
53+
def save(self, model_family: str, model_config: object) -> None:
5854
if self._gangs.root.rank == 0:
5955
unstructured_config = unstructure(model_config)
6056

@@ -66,9 +62,6 @@ def save(
6662
},
6763
}
6864

69-
if tokenizer_name is not None:
70-
metadata["tokenizer_ref"] = tokenizer_name
71-
7265
if self._gangs.tp.size != 1:
7366
metadata["num_shards"] = self._gangs.tp.size
7467

src/fairseq2/cli/commands/chatbot.py

+29-18
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,12 @@
3838
setup_reference_model,
3939
setup_torch,
4040
)
41-
from fairseq2.recipes.config import CommonSection, GangSection, ReferenceModelSection
41+
from fairseq2.recipes.config import (
42+
CommonSection,
43+
GangSection,
44+
ReferenceModelSection,
45+
TextTokenizerSection,
46+
)
4247
from fairseq2.typing import CPU
4348
from fairseq2.utils.rng import RngBag
4449

@@ -113,33 +118,33 @@ def run(
113118

114119
view = CliChatbotView(args.model_name, console)
115120

116-
args.gang = GangSection(
117-
tensor_parallel_size=args.tensor_parallel_size, timeout=999
118-
)
121+
set_torch_distributed_variables(context, args.cluster)
119122

120-
args.model = ReferenceModelSection(name=args.model_name)
123+
common_section = CommonSection()
121124

122-
args.common = CommonSection()
125+
setup_torch(context, common_section, output_dir=None)
123126

124-
set_torch_distributed_variables(context, args.cluster)
125-
126-
setup_torch(context, args)
127+
gang_section = GangSection(
128+
tensor_parallel_size=args.tensor_parallel_size, timeout=999
129+
)
127130

128131
try:
129-
gangs = setup_gangs(context, args)
132+
gangs = setup_gangs(context, gang_section)
130133
except RecipeError as ex:
131134
raise CliCommandError(
132135
"The chatbot setup has failed. See the nested exception for details."
133136
) from ex
134137

135138
if gangs.dp.size > 1:
136-
log.warning("Using redundant data parallelism which may reduce throughput. It is recommended to use one device per model (shard).") # fmt: skip
139+
log.warning("Using redundant data parallelism which may reduce throughput.") # fmt: skip
140+
141+
model_section = ReferenceModelSection(name=args.model_name)
137142

138143
try:
139144
model = setup_reference_model(
140145
DecoderModel,
141146
context,
142-
args.model_name,
147+
model_section,
143148
gangs,
144149
args.dtype,
145150
mp=False,
@@ -152,19 +157,25 @@ def run(
152157

153158
module = cast(DecoderModel, model.module)
154159

155-
sampler = TopPSampler(p=args.top_p)
156-
157-
generator = SamplingSequenceGenerator(
158-
module, sampler, temperature=args.temperature, max_gen_len=args.max_gen_len
159-
)
160+
tokenizer_section = TextTokenizerSection(name=args.model_name)
160161

161162
try:
162-
tokenizer = load_text_tokenizer(context, args)
163+
tokenizer = load_text_tokenizer(context, tokenizer_section)
163164
except RecipeError as ex:
164165
raise CliCommandError(
165166
"The chatbot setup has failed. See the nested exception for details."
166167
) from ex
167168

169+
sampler = TopPSampler(p=args.top_p)
170+
171+
generator = SamplingSequenceGenerator(
172+
module,
173+
tokenizer.vocab_info,
174+
sampler,
175+
temperature=args.temperature,
176+
max_gen_len=args.max_gen_len,
177+
)
178+
168179
card = context.asset_store.retrieve_card(args.model_name)
169180

170181
family = card.field("model_family").as_(str)

src/fairseq2/data/text/tokenizers/nllb.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
SentencePieceDecoder,
2424
SentencePieceEncoder,
2525
SentencePieceModel,
26-
vocab_info_from_sentencepiece,
26+
get_sentencepiece_vocabulary_info,
2727
)
2828
from fairseq2.typing import Device
2929

@@ -63,7 +63,7 @@ def __init__(self, path: Path, langs: Sequence[str], default_lang: str) -> None:
6363

6464
self._default_lang = default_lang
6565

66-
self._vocab_info = vocab_info_from_sentencepiece(self._model)
66+
self._vocab_info = get_sentencepiece_vocabulary_info(self._model)
6767

6868
@override
6969
def create_encoder(

src/fairseq2/data/text/tokenizers/s2t_transformer.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
SentencePieceDecoder,
2323
SentencePieceEncoder,
2424
SentencePieceModel,
25-
vocab_info_from_sentencepiece,
25+
get_sentencepiece_vocabulary_info,
2626
)
2727
from fairseq2.typing import Device
2828

@@ -62,7 +62,7 @@ def __init__(
6262
self._target_langs = target_langs
6363
self._default_target_lang = default_target_lang
6464

65-
self._vocab_info = vocab_info_from_sentencepiece(self._model)
65+
self._vocab_info = get_sentencepiece_vocabulary_info(self._model)
6666

6767
@override
6868
def create_encoder(

src/fairseq2/data/text/tokenizers/sentencepiece.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ class BasicSentencePieceTokenizer(TextTokenizer):
125125
def __init__(self, path: Path) -> None:
126126
self._model = SentencePieceModel(path)
127127

128-
self._vocab_info = vocab_info_from_sentencepiece(self._model)
128+
self._vocab_info = get_sentencepiece_vocabulary_info(self._model)
129129

130130
@override
131131
def create_encoder(
@@ -216,7 +216,7 @@ class RawSentencePieceTokenizer(TextTokenizer):
216216
def __init__(self, path: Path) -> None:
217217
self._model = SentencePieceModel(path)
218218

219-
self._vocab_info = vocab_info_from_sentencepiece(self._model)
219+
self._vocab_info = get_sentencepiece_vocabulary_info(self._model)
220220

221221
@override
222222
def create_encoder(
@@ -277,7 +277,7 @@ def load_raw_sentencepiece_tokenizer(path: Path, card: AssetCard) -> TextTokeniz
277277
) from ex
278278

279279

280-
def vocab_info_from_sentencepiece(model: SentencePieceModel) -> VocabularyInfo:
280+
def get_sentencepiece_vocabulary_info(model: SentencePieceModel) -> VocabularyInfo:
281281
"""Return the vocabulary information of ``model``."""
282282
return VocabularyInfo(
283283
model.vocabulary_size,

src/fairseq2/datasets/_utils.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def _load_files_and_weights(
5555
manifest_file = path.joinpath("MANIFEST")
5656

5757
try:
58-
with manifest_file.open() as fp:
58+
with manifest_file.open(encoding="utf-8") as fp:
5959
content = list(fp)
6060
except FileNotFoundError:
6161
content = None

src/fairseq2/datasets/asr.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -289,7 +289,7 @@ def _retrieve_data_directory(self, split: str) -> Path:
289289
manifest_file = self._manifest_dir.joinpath(f"{split}.tsv")
290290

291291
try:
292-
with manifest_file.open() as fp:
292+
with manifest_file.open(encoding="utf-8") as fp:
293293
line = fp.readline().rstrip()
294294
except OSError as ex:
295295
raise DataReadError(

src/fairseq2/datasets/instruction.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -414,7 +414,7 @@ def _read_jsonl(self, path: Path, tokenizer: TextTokenizer) -> DataPipelineBuild
414414
lines = []
415415

416416
# TODO(balioglu): Do in C++.
417-
with path.open() as fp:
417+
with path.open(encoding="utf-8") as fp:
418418
for line in fp:
419419
lines.append(line)
420420

src/fairseq2/datasets/parallel_text.py

+28-20
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
from typing_extensions import override
1515

1616
from fairseq2.data import (
17+
CollateOptionsOverride,
1718
Collater,
1819
DataPipeline,
1920
DataPipelineBuilder,
@@ -77,28 +78,23 @@ class ParallelTextDataset(ABC):
7778
def create_reader(
7879
self,
7980
split: str,
80-
tokenizer: TextTokenizer,
81+
source_tokenizer: TextTokenizer,
82+
target_tokenizer: TextTokenizer,
8183
gang: Gang,
8284
min_seq_len: int,
8385
max_seq_len: int,
8486
options: ParallelTextReadOptions | None = None,
8587
) -> DataReader[Seq2SeqBatch]:
8688
"""Create a dataset reader.
8789
88-
:param split:
89-
The split to read.
90-
:param tokenizer:
91-
The tokenizer to encode text.
92-
:param gang:
93-
The gang over which to shard the dataset.
94-
:param min_seq_len:
95-
The minimum sequence length of each example. Examples shorter than
96-
this value will be dropped.
97-
:param max_seq_len:
98-
The maximum sequence length of each example. Examples longer than
99-
this value will be dropped.
100-
:param options:
101-
The read options.
90+
:param split: The split to read.
91+
:param source_tokenizer: The tokenizer to encode source text.
92+
:param gang: The gang over which to shard the dataset.
93+
:param min_seq_len: The minimum sequence length of each example.
94+
Examples shorter than this value will be dropped.
95+
:param max_seq_len: The maximum sequence length of each example.
96+
Examples longer than this value will be dropped.
97+
:param options: The read options.
10298
"""
10399

104100
@abstractmethod
@@ -171,7 +167,7 @@ def from_path(cls, path: Path, name: str) -> GenericParallelTextDataset:
171167
manifest_file = path.joinpath(split).joinpath("MANIFEST")
172168

173169
try:
174-
with manifest_file.open() as fp:
170+
with manifest_file.open(encoding="utf-8") as fp:
175171
content = list(fp)
176172
except OSError as ex:
177173
raise DatasetLoadError(
@@ -252,7 +248,8 @@ def value_error() -> ValueError:
252248
def create_reader(
253249
self,
254250
split: str,
255-
tokenizer: TextTokenizer,
251+
source_tokenizer: TextTokenizer,
252+
target_tokenizer: TextTokenizer,
256253
gang: Gang,
257254
min_seq_len: int,
258255
max_seq_len: int,
@@ -289,11 +286,11 @@ def create_reader(
289286
if direction.origin:
290287
source_mode = f"{source_mode}_{direction.origin}"
291288

292-
source_encoder = tokenizer.create_encoder(
289+
source_encoder = source_tokenizer.create_encoder(
293290
task="translation", lang=direction.source_lang, mode=source_mode
294291
)
295292

296-
target_encoder = tokenizer.create_encoder(
293+
target_encoder = target_tokenizer.create_encoder(
297294
task="translation", lang=direction.target_lang, mode="target"
298295
)
299296

@@ -384,7 +381,18 @@ def skip(example: dict[str, Any]) -> bool:
384381
seed += 1
385382

386383
# Collate bucketed examples into a batch.
387-
collater = Collater(pad_value=tokenizer.vocab_info.pad_idx)
384+
collater = Collater(
385+
overrides=[
386+
CollateOptionsOverride(
387+
selector="source_indices",
388+
pad_value=source_tokenizer.vocab_info.pad_idx,
389+
),
390+
CollateOptionsOverride(
391+
selector="target_indices",
392+
pad_value=target_tokenizer.vocab_info.pad_idx,
393+
),
394+
]
395+
)
388396

389397
builder.map(collater, num_parallel_calls=npc)
390398

src/fairseq2/datasets/preference.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -377,7 +377,7 @@ def _read_jsonl(self, path: Path, tokenizer: TextTokenizer) -> DataPipelineBuild
377377
lines = []
378378

379379
# TODO(balioglu): Do in C++.
380-
with path.open() as fp:
380+
with path.open(encoding="utf-8") as fp:
381381
for line in fp:
382382
lines.append(line)
383383

0 commit comments

Comments
 (0)