Skip to content

Commit e34a157

Browse files
author
Ben King
committed
Wait to wrap the tokenizer with Moses punctuation normalizer until it's been fully initialized
1 parent 2669e2f commit e34a157

File tree

1 file changed

+29
-6
lines changed

1 file changed

+29
-6
lines changed

silnlp/nmt/hugging_face_config.py

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,19 @@
99
from enum import Enum
1010
from itertools import repeat
1111
from pathlib import Path
12-
from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union, cast
12+
from typing import (
13+
Any,
14+
Callable,
15+
Dict,
16+
Iterable,
17+
List,
18+
Optional,
19+
Set,
20+
Tuple,
21+
TypeVar,
22+
Union,
23+
cast,
24+
)
1325

1426
import datasets.utils.logging as datasets_logging
1527
import evaluate
@@ -24,7 +36,10 @@
2436
from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef
2537
from sacremoses import MosesPunctNormalizer
2638
from tokenizers import AddedToken, NormalizedString, Regex
27-
from tokenizers.implementations import SentencePieceBPETokenizer, SentencePieceUnigramTokenizer
39+
from tokenizers.implementations import (
40+
SentencePieceBPETokenizer,
41+
SentencePieceUnigramTokenizer,
42+
)
2843
from tokenizers.normalizers import Normalizer
2944
from torch import Tensor, TensorType, nn, optim
3045
from torch.utils.data import Sampler
@@ -73,7 +88,14 @@
7388
from ..common.corpus import Term, count_lines, get_terms
7489
from ..common.environment import SIL_NLP_ENV, download_if_s3_paths
7590
from ..common.translator import DraftGroup, TranslationGroup
76-
from ..common.utils import NoiseMethod, ReplaceRandomToken, Side, create_noise_methods, get_mt_exp_dir, merge_dict
91+
from ..common.utils import (
92+
NoiseMethod,
93+
ReplaceRandomToken,
94+
Side,
95+
create_noise_methods,
96+
get_mt_exp_dir,
97+
merge_dict,
98+
)
7799
from .config import CheckpointType, Config, DataFile, NMTModel
78100
from .tokenizer import NullTokenizer, Tokenizer
79101

@@ -1185,13 +1207,14 @@ def translate(
11851207
ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST,
11861208
) -> Iterable[TranslationGroup]:
11871209
tokenizer = self._config.get_tokenizer()
1188-
if isinstance(tokenizer, (NllbTokenizer, NllbTokenizerFast)):
1189-
tokenizer = PunctuationNormalizingTokenizer(tokenizer)
1190-
11911210
model = self._create_inference_model(ckpt, tokenizer)
11921211
if model.config.max_length is not None and model.config.max_length < 512:
11931212
model.config.max_length = 512
11941213
lang_codes: Dict[str, str] = self._config.data["lang_codes"]
1214+
1215+
if isinstance(tokenizer, (NllbTokenizer, NllbTokenizerFast)):
1216+
tokenizer = PunctuationNormalizingTokenizer(tokenizer)
1217+
11951218
pipeline = TranslationPipeline(
11961219
model=model,
11971220
tokenizer=tokenizer,

0 commit comments

Comments
 (0)