Skip to content

Commit ab2f103

Browse files
authored
Prevent model from being loaded multiple times if translate method is called multiple times (#801)
* cache model during translate step * add clear_cache; add dataclass; rename variables to be more precise * add clear_cache to NMTModel interface * fix typo from rebase
1 parent 6d64de6 commit ab2f103

File tree

3 files changed

+206
-172
lines changed

3 files changed

+206
-172
lines changed

silnlp/nmt/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ def translate(
9191
@abstractmethod
9292
def get_checkpoint_path(self, ckpt: Union[CheckpointType, str, int]) -> Tuple[Path, int]: ...
9393

94+
@abstractmethod
95+
def clear_cache(self) -> None: ...
96+
9497

9598
class Config(ABC):
9699
def __init__(self, exp_dir: Path, config: dict) -> None:

silnlp/nmt/hugging_face_config.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import shutil
77
from contextlib import ExitStack
88
from copy import deepcopy
9+
from dataclasses import dataclass
910
from enum import Enum
1011
from itertools import repeat
1112
from math import exp, prod
@@ -832,6 +833,20 @@ def get_token_scores(self) -> List[float]:
832833
def get_sequence_score(self) -> List[float]:
833834
return [output["sequence_score"] for output in self.outputs]
834835

836+
@dataclass
837+
class InferenceModelParams:
838+
checkpoint: Union[CheckpointType, str, int]
839+
src_lang: str
840+
trg_lang: str
841+
842+
def __post_init__(self):
843+
if not isinstance(self.checkpoint, (CheckpointType, str, int)):
844+
raise ValueError("checkpoint must be a CheckpointType, string, or integer")
845+
if not isinstance(self.src_lang, str):
846+
raise ValueError("src_lang must be a string")
847+
if not isinstance(self.trg_lang, str):
848+
raise ValueError("trg_lang must be a string")
849+
835850

836851
class HuggingFaceNMTModel(NMTModel):
837852
def __init__(self, config: HuggingFaceConfig, mixed_precision: bool, num_devices: int) -> None:
@@ -841,6 +856,8 @@ def __init__(self, config: HuggingFaceConfig, mixed_precision: bool, num_devices
841856
self._dictionary: Optional[Dict[VerseRef, Set[str]]] = None
842857
self._is_t5 = self._config.model_prefix in SUPPORTED_T5_MODELS
843858
self._num_devices = num_devices
859+
self._cached_inference_model: Optional[PreTrainedModel] = None
860+
self._inference_model_params: Optional[InferenceModelParams] = None
844861

845862
def train(self) -> None:
846863
training_args = self._create_training_arguments()
@@ -1218,8 +1235,13 @@ def translate(
12181235
) -> Iterable[TranslationGroup]:
12191236
src_lang = self._config.data["lang_codes"].get(src_iso, src_iso)
12201237
trg_lang = self._config.data["lang_codes"].get(trg_iso, trg_iso)
1238+
inference_model_params = InferenceModelParams(ckpt, src_lang, trg_lang)
12211239
tokenizer = self._config.get_tokenizer()
1222-
model = self._create_inference_model(ckpt, tokenizer, src_lang, trg_lang)
1240+
if self._inference_model_params == inference_model_params and self._cached_inference_model is not None:
1241+
model = self._cached_inference_model
1242+
else:
1243+
model = self._cached_inference_model = self._create_inference_model(ckpt, tokenizer, src_lang, trg_lang)
1244+
self._inference_model_params = inference_model_params
12231245
if model.config.max_length is not None and model.config.max_length < 512:
12241246
model.config.max_length = 512
12251247

@@ -1304,6 +1326,10 @@ def get_checkpoint_path(self, ckpt: Union[CheckpointType, str, int]) -> Tuple[Pa
13041326
raise ValueError(f"Unsupported checkpoint type: {ckpt}.")
13051327
return ckpt_path, step
13061328

1329+
def clear_cache(self) -> None:
1330+
self._cached_inference_model = None
1331+
self._inference_model_params = None
1332+
13071333
def _create_training_arguments(self) -> Seq2SeqTrainingArguments:
13081334
parser = HfArgumentParser(Seq2SeqTrainingArguments)
13091335
args: dict = {}

0 commit comments

Comments
 (0)