diff --git a/transformers/runtime/run_nl2cst_transformers.py b/transformers/runtime/run_nl2cst_transformers.py index ec5dd59..60a012c 100644 --- a/transformers/runtime/run_nl2cst_transformers.py +++ b/transformers/runtime/run_nl2cst_transformers.py @@ -9,7 +9,7 @@ from tqdm import tqdm import traceback -from transformers import EncoderDecoderModel, BertTokenizerFast +from transformers import EncoderDecoderModel, BertTokenizerFast, GenerationConfig from arsenal_tokenizer import PreTrainedArsenalTokenizer @@ -28,7 +28,7 @@ def get_env(argname, default=None): return default def get_bool_opt(var_name : str): - if var_name in os.environ and int(get_env(var_name)) is not 0: + if var_name in os.environ and int(get_env(var_name)) != 0: return True else: return False @@ -82,7 +82,7 @@ def process(): data["sentences"] = data.pop("msg") - global bert2arsenal + global bert2arsenal, generation_config if 'checkpoint_id' in data: ckpt_id = data['checkpoint_id'] print(f"using {ckpt_id} to generate translations") @@ -90,6 +90,20 @@ def process(): else: bert2arsenal = EncoderDecoderModel.from_pretrained(MODEL_ROOT) + generation_config = GenerationConfig( + num_beams=NUM_BEAMS, + num_return_sequences=NUM_OUTPUTS, + no_repeat_ngram_size=0, + decoder_start_token_id=target_tokenizer.cls_token_id, + eos_token_id=bert2arsenal.config.eos_token_id, + pad_token_id=bert2arsenal.config.pad_token_id, + max_new_tokens=bert2arsenal.config.max_length, + output_scores=True, + return_dict_in_generate=True, + early_stopping=False, + length_penalty=1.0 + ) + print(f'Initializing {N_WORKERS} workers with chunksize {BATCH_SIZE}') # Instantiate queues and worker processes @@ -137,13 +151,7 @@ def process_batch(data): generated = bert2arsenal.generate( input_ids=input_tokens.input_ids, attention_mask=input_tokens.attention_mask, - decoder_start_token_id=target_tokenizer.cls_token_id, - num_beams=NUM_BEAMS, - num_return_sequences=NUM_OUTPUTS, - type_forcing_vocab=type_forcing_vocab, - no_repeat_ngram_size=0, # default was 3, but this punishes desired translations -> figure out if/what setting we want to use here - output_scores = True, - return_dict_in_generate=True, + generation_config=generation_config ) # iterate over instances in batch to prepare outputs