66import shutil
77from contextlib import ExitStack
88from copy import deepcopy
9+ from dataclasses import dataclass
910from enum import Enum
1011from itertools import repeat
1112from math import exp , prod
@@ -832,6 +833,20 @@ def get_token_scores(self) -> List[float]:
832833 def get_sequence_score (self ) -> List [float ]:
833834 return [output ["sequence_score" ] for output in self .outputs ]
834835
836+ @dataclass
837+ class InferenceModelParams :
838+ checkpoint : Union [CheckpointType , str , int ]
839+ src_lang : str
840+ trg_lang : str
841+
842+ def __post_init__ (self ):
843+ if not isinstance (self .checkpoint , (CheckpointType , str , int )):
844+ raise ValueError ("checkpoint must be a CheckpointType, string, or integer" )
845+ if not isinstance (self .src_lang , str ):
846+ raise ValueError ("src_lang must be a string" )
847+ if not isinstance (self .trg_lang , str ):
848+ raise ValueError ("trg_lang must be a string" )
849+
835850
836851class HuggingFaceNMTModel (NMTModel ):
837852 def __init__ (self , config : HuggingFaceConfig , mixed_precision : bool , num_devices : int ) -> None :
@@ -841,6 +856,8 @@ def __init__(self, config: HuggingFaceConfig, mixed_precision: bool, num_devices
841856 self ._dictionary : Optional [Dict [VerseRef , Set [str ]]] = None
842857 self ._is_t5 = self ._config .model_prefix in SUPPORTED_T5_MODELS
843858 self ._num_devices = num_devices
859+ self ._cached_inference_model : Optional [PreTrainedModel ] = None
860+ self ._inference_model_params : Optional [InferenceModelParams ] = None
844861
845862 def train (self ) -> None :
846863 training_args = self ._create_training_arguments ()
@@ -1218,8 +1235,13 @@ def translate(
12181235 ) -> Iterable [TranslationGroup ]:
12191236 src_lang = self ._config .data ["lang_codes" ].get (src_iso , src_iso )
12201237 trg_lang = self ._config .data ["lang_codes" ].get (trg_iso , trg_iso )
1238+ inference_model_params = InferenceModelParams (ckpt , src_lang , trg_lang )
12211239 tokenizer = self ._config .get_tokenizer ()
1222- model = self ._create_inference_model (ckpt , tokenizer , src_lang , trg_lang )
1240+ if self ._inference_model_params == inference_model_params and self ._cached_inference_model is not None :
1241+ model = self ._cached_inference_model
1242+ else :
1243+ model = self ._cached_inference_model = self ._create_inference_model (ckpt , tokenizer , src_lang , trg_lang )
1244+ self ._inference_model_params = inference_model_params
12231245 if model .config .max_length is not None and model .config .max_length < 512 :
12241246 model .config .max_length = 512
12251247
@@ -1304,6 +1326,10 @@ def get_checkpoint_path(self, ckpt: Union[CheckpointType, str, int]) -> Tuple[Pa
13041326 raise ValueError (f"Unsupported checkpoint type: { ckpt } ." )
13051327 return ckpt_path , step
13061328
1329+ def clear_cache (self ) -> None :
1330+ self ._cached_inference_model = None
1331+ self ._inference_model_params = None
1332+
13071333 def _create_training_arguments (self ) -> Seq2SeqTrainingArguments :
13081334 parser = HfArgumentParser (Seq2SeqTrainingArguments )
13091335 args : dict = {}
0 commit comments