Skip to content

Commit

Permalink
updated nl2cst runtime to match latest transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
kmartiny committed Dec 8, 2023
1 parent 948beaf commit 1daa859
Showing 1 changed file with 18 additions and 10 deletions.
28 changes: 18 additions & 10 deletions transformers/runtime/run_nl2cst_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from tqdm import tqdm
import traceback

from transformers import EncoderDecoderModel, BertTokenizerFast
from transformers import EncoderDecoderModel, BertTokenizerFast, GenerationConfig

from arsenal_tokenizer import PreTrainedArsenalTokenizer

Expand All @@ -28,7 +28,7 @@ def get_env(argname, default=None):
return default

def get_bool_opt(var_name : str):
if var_name in os.environ and int(get_env(var_name)) is not 0:
if var_name in os.environ and int(get_env(var_name)) != 0:
return True
else:
return False
Expand Down Expand Up @@ -82,14 +82,28 @@ def process():
data["sentences"] = data.pop("msg")


global bert2arsenal
global bert2arsenal, generation_config
if 'checkpoint_id' in data:
ckpt_id = data['checkpoint_id']
print(f"using {ckpt_id} to generate translations")
bert2arsenal = EncoderDecoderModel.from_pretrained(os.path.join(MODEL_ROOT, ckpt_id))
else:
bert2arsenal = EncoderDecoderModel.from_pretrained(MODEL_ROOT)

generation_config = GenerationConfig(
num_beams=NUM_BEAMS,
num_return_sequences=NUM_OUTPUTS,
no_repeat_ngram_size=0,
decoder_start_token_id=target_tokenizer.cls_token_id,
eos_token_id=bert2arsenal.config.eos_token_id,
pad_token_id=bert2arsenal.config.pad_token_id,
max_new_tokens=bert2arsenal.config.max_length,
output_scores=True,
return_dict_in_generate=True,
early_stopping=False,
length_penalty=1.0
)

print(f'Initializing {N_WORKERS} workers with chunksize {BATCH_SIZE}')

# Instantiate queues and worker processes
Expand Down Expand Up @@ -137,13 +151,7 @@ def process_batch(data):
generated = bert2arsenal.generate(
input_ids=input_tokens.input_ids,
attention_mask=input_tokens.attention_mask,
decoder_start_token_id=target_tokenizer.cls_token_id,
num_beams=NUM_BEAMS,
num_return_sequences=NUM_OUTPUTS,
type_forcing_vocab=type_forcing_vocab,
no_repeat_ngram_size=0, # default was 3, but this punishes desired translations -> figure out if/what setting we want to use here
output_scores = True,
return_dict_in_generate=True,
generation_config=generation_config
)

# iterate over instances in batch to prepare outputs
Expand Down

0 comments on commit 1daa859

Please sign in to comment.