|
9 | 9 | from enum import Enum |
10 | 10 | from itertools import repeat |
11 | 11 | from pathlib import Path |
12 | | -from typing import Any, Callable, Dict, Iterable, List, Optional, Set, Tuple, TypeVar, Union, cast |
| 12 | +from typing import ( |
| 13 | + Any, |
| 14 | + Callable, |
| 15 | + Dict, |
| 16 | + Iterable, |
| 17 | + List, |
| 18 | + Optional, |
| 19 | + Set, |
| 20 | + Tuple, |
| 21 | + TypeVar, |
| 22 | + Union, |
| 23 | + cast, |
| 24 | +) |
13 | 25 |
|
14 | 26 | import datasets.utils.logging as datasets_logging |
15 | 27 | import evaluate |
|
24 | 36 | from machine.scripture import ORIGINAL_VERSIFICATION, VerseRef |
25 | 37 | from sacremoses import MosesPunctNormalizer |
26 | 38 | from tokenizers import AddedToken, NormalizedString, Regex |
27 | | -from tokenizers.implementations import SentencePieceBPETokenizer, SentencePieceUnigramTokenizer |
| 39 | +from tokenizers.implementations import ( |
| 40 | + SentencePieceBPETokenizer, |
| 41 | + SentencePieceUnigramTokenizer, |
| 42 | +) |
28 | 43 | from tokenizers.normalizers import Normalizer |
29 | 44 | from torch import Tensor, TensorType, nn, optim |
30 | 45 | from torch.utils.data import Sampler |
|
73 | 88 | from ..common.corpus import Term, count_lines, get_terms |
74 | 89 | from ..common.environment import SIL_NLP_ENV |
75 | 90 | from ..common.translator import DraftGroup, TranslationGroup |
76 | | -from ..common.utils import NoiseMethod, ReplaceRandomToken, Side, create_noise_methods, get_mt_exp_dir, merge_dict |
| 91 | +from ..common.utils import ( |
| 92 | + NoiseMethod, |
| 93 | + ReplaceRandomToken, |
| 94 | + Side, |
| 95 | + create_noise_methods, |
| 96 | + get_mt_exp_dir, |
| 97 | + merge_dict, |
| 98 | +) |
77 | 99 | from .config import CheckpointType, Config, DataFile, NMTModel |
78 | 100 | from .tokenizer import NullTokenizer, Tokenizer |
79 | 101 |
|
@@ -1185,13 +1207,16 @@ def translate( |
1185 | 1207 | ckpt: Union[CheckpointType, str, int] = CheckpointType.LAST, |
1186 | 1208 | ) -> Iterable[TranslationGroup]: |
1187 | 1209 | tokenizer = self._config.get_tokenizer() |
1188 | | - if isinstance(tokenizer, (NllbTokenizer, NllbTokenizerFast)): |
1189 | | - tokenizer = PunctuationNormalizingTokenizer(tokenizer) |
1190 | | - |
1191 | 1210 | model = self._create_inference_model(ckpt, tokenizer) |
1192 | 1211 | if model.config.max_length is not None and model.config.max_length < 512: |
1193 | 1212 | model.config.max_length = 512 |
1194 | 1213 | lang_codes: Dict[str, str] = self._config.data["lang_codes"] |
| 1214 | + |
| 1215 | + # The tokenizer isn't wrapped until after calling _create_inference_model, |
| 1216 | + # because the tokenizer's input/output language codes are set there |
| 1217 | + if isinstance(tokenizer, (NllbTokenizer, NllbTokenizerFast)): |
| 1218 | + tokenizer = PunctuationNormalizingTokenizer(tokenizer) |
| 1219 | + |
1195 | 1220 | pipeline = TranslationPipeline( |
1196 | 1221 | model=model, |
1197 | 1222 | tokenizer=tokenizer, |
|
0 commit comments