Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 30 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,31 @@
# machine-translation
Replication of the "Attention Is All you Need" machine translation model using AttentionSmithy.
Replication of the "[Attention Is All you Need](https://arxiv.org/abs/1706.03762)" machine translation model using [AttentionSmithy](https://github.com/xomicsdatascience/AttentionSmithy), a package for creating transformer models.

# Main Files
## scripts/0_data_prep.py
This file downloads the WMT-14 German-English dataset and processes it for loading into the model. This is also where the train/val/test split occurs.

Each dataset (train/val/test) consists of two files, one for English (en) and one for German (de), matched by line index. For example, line 5 of `train_en.txt` is the English translation of line 5 of `train_de.txt`, which consists of German text.

The loaded dataset consists of sentences. This script converts those sentences into tokens, then adds them as a comma-delimited line to the relevant file.

## scripts/1_train_model.py
Has much in common with the file `scripts/model_script_for_nas.py`, which was specific to use with a neural architecture search (NAS). This script assembles and trains the machine translation model. There are several arguments to be used with the script - below is an example usage.

`python 1_train_model.py --loss_type custom --label_smoothing 0.9 --embed_dim 512 --dim_feedforward 2048 --number_of_layers=6`

## src/machine_translation/MachineTranslationModel.py
The code for the model used in machine translation. It was written using pytorch lightning for readability, and thus outlines the construction of the model, the forward pass process, and how that looks for training and validation steps.

## src/machine_translation/data/MachineTranslationDataModule.py
The code for preparing the data module used in training and validating the machine translation model. It is made to be used with the pytorch lightning Trainer class, as called in model training scripts.

# Additional Files for interested readers
## scripts/run_nas.py
This code runs a neural architecture search (NAS). The code is based on the [Multi-Objective NAS with Ax](https://pytorch.org/tutorials/intermediate/ax_multiobjective_nas_tutorial.html) tutorial, and calls the `scripts/model_script_for_nas.py` in each pass with new parameters selected during the search.

## src/machine_translation/data/LineIndexDataset.py
This code is used to extract specific lines from train, val or test datasets when forming a batch. Using this class allows the user to reference data efficiently without holding the entire dataset in memory.

## src/machine_translation/data/LengthBatchSampler.py
This code groups samples together by context window length for efficient training. A similar strategy was employed in the original Attention Is All You Need paper.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ dependencies = [
'scikit-learn>=1.1.2',
'pandas>=1.2.2',
'sacrebleu>=2.4.3',
'wandb>=0.19.0',
'wandb>=0.19.2',
]

[tool.setuptools.packages.find]
Expand Down
6 changes: 3 additions & 3 deletions scripts/1_train_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pytorch_lightning.callbacks import ModelCheckpoint
from machine_translation import MachineTranslationModel
from machine_translation.data import MachineTranslationDataModule
from attention_smithy.numeric_embeddings import SinusoidalPositionEmbedding, NumericEmbeddingFacade
from attention_smithy.numeric_embeddings import SinusoidalPositionEmbedding, NumericEmbeddingManager
from attention_smithy.components import MultiheadAttention, FeedForwardNetwork
from attention_smithy.attention import StandardAttentionMethod
from attention_smithy.utils import seed_everything
Expand Down Expand Up @@ -64,7 +64,7 @@ def train_model(
)

sinusoidal_position_embedding = SinusoidalPositionEmbedding(embed_dim)
numeric_embedding_facade = NumericEmbeddingFacade(sinusoidal_position=sinusoidal_position_embedding)
numeric_embedding_manager = NumericEmbeddingManager(sinusoidal_position=sinusoidal_position_embedding)
generic_attention = MultiheadAttention(embedding_dimension = embed_dim, number_of_heads = num_heads, attention_method = StandardAttentionMethod(dropout))
decoder_self_attention = MultiheadAttention(embedding_dimension = embed_dim, number_of_heads = num_heads, attention_method = StandardAttentionMethod(dropout, is_causal_masking=True))
feedforward_network = FeedForwardNetwork(embed_dim, dim_feedforward, 'relu', dropout)
Expand All @@ -76,7 +76,7 @@ def train_model(
decoder_self_attention=decoder_self_attention,
decoder_cross_attention=generic_attention,
feedforward_network=feedforward_network,
numeric_embedding_facade=numeric_embedding_facade,
numeric_embedding_manager=numeric_embedding_manager,
tgt_padding_token=data_module.en_pad_token,
embedding_dimension=embed_dim,
num_encoder_layers=number_of_layers,
Expand Down
66 changes: 36 additions & 30 deletions scripts/1_train_model__command_lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@

from machine_translation import MachineTranslationModel
from machine_translation.data import MachineTranslationDataModule
from attention_smithy.numeric_embeddings import SinusoidalPositionEmbedding, LearnedPositionEmbedding, RotaryPositionEmbedding, ALiBiPositionEmbedding, NumericEmbeddingFacade, NoAddEmbedding, PassthroughEmbedding
from attention_smithy.components import MultiheadAttention, FeedForwardNetwork
from attention_smithy.attention import StandardAttentionMethod
from attention_smithy.utils import seed_everything
from attention_smithy.utils import seed_everything, get_available_gpu_count
from attention_smithy.generators import GeneratorContext
from transformers import AutoTokenizer
from sacrebleu.metrics import BLEU
Expand All @@ -24,53 +21,61 @@ def run_training_job(parsed_args):
seed_everything(parsed_args.random_seed)
torch.set_float32_matmul_precision('medium')

num_gpus = get_available_gpu_count()
effective_batch_size = parsed_args.batch_size
per_gpu_batch_size = effective_batch_size // num_gpus if num_gpus > 1 else effective_batch_size

data_module = MachineTranslationDataModule(
en_filepath_suffix='_en.txt',
de_filepath_suffix='_de.txt',
maximum_length=parsed_args.maximum_length,
batch_size=parsed_args.batch_size,
batch_size=per_gpu_batch_size,
num_training_samples=parsed_args.num_training_samples,
)
data_module.setup()

run_name_prefix = f'sinusoid-{parsed_args.sinusoidal_position}_learned-{parsed_args.learned_position}_rotary-{parsed_args.rotary_position}_alibi-{parsed_args.alibi_position}_dropout-{parsed_args.dropout}_activation-{parsed_args.activation}'
logger = WandbLogger(project='NAS optimized vs. original', name=run_name_prefix)

# Create strategies config for multi-GPU training
strategy = 'ddp' if num_gpus > 1 else 'auto'

bleu_callback = BleuScoreValidationCallback()

trainer = pl.Trainer(
max_epochs=40,
max_epochs=1,
logger=logger,
log_every_n_steps=500,
callbacks=[
bleu_callback,
],
log_every_n_steps=500,
strategy=strategy,
accelerator='auto',
devices='auto'
)

sinusoidal_position_embedding = SinusoidalPositionEmbedding(parsed_args.embedding_dimension) if parsed_args.sinusoidal_position else NoAddEmbedding()
learned_position_embedding = LearnedPositionEmbedding(max_sequence_length=3_000, embedding_dimension=parsed_args.embedding_dimension) if parsed_args.learned_position else NoAddEmbedding()
rotary_position_embedding = RotaryPositionEmbedding(parsed_args.embedding_dimension // parsed_args.number_of_heads) if parsed_args.rotary_position else PassthroughEmbedding()
alibi_position_embedding = ALiBiPositionEmbedding(parsed_args.number_of_heads) if parsed_args.alibi_position else NoAddEmbedding()
model_kwargs = {
'embedding_dimension': parsed_args.embedding_dimension,
'number_of_heads': parsed_args.number_of_heads,
'dropout': parsed_args.dropout,
'activation': parsed_args.activation,
'feedforward_dimension': parsed_args.feedforward_dimension,
'num_encoder_layers': parsed_args.number_of_layers,
'num_decoder_layers': parsed_args.number_of_layers,
'scheduler_warmup_steps': parsed_args.scheduler_warmup_steps,
'loss_type': parsed_args.loss_type,
'label_smoothing': parsed_args.label_smoothing,
'use_sinusoidal': parsed_args.sinusoidal_position,
'use_learned': parsed_args.learned_position,
'use_rotary': parsed_args.rotary_position,
'use_alibi': parsed_args.alibi_position,
}

numeric_embedding_facade = NumericEmbeddingFacade(sinusoidal_position=sinusoidal_position_embedding, learned_position=learned_position_embedding, rotary_position=rotary_position_embedding, alibi_position=alibi_position_embedding)
generic_attention = MultiheadAttention(embedding_dimension= parsed_args.embedding_dimension, number_of_heads= parsed_args.number_of_heads, attention_method= StandardAttentionMethod(parsed_args.dropout))
decoder_self_attention = MultiheadAttention(embedding_dimension= parsed_args.embedding_dimension, number_of_heads= parsed_args.number_of_heads, attention_method= StandardAttentionMethod(parsed_args.dropout, is_causal_masking=True))
feedforward_network = FeedForwardNetwork(parsed_args.embedding_dimension, parsed_args.feedforward_dimension, parsed_args.activation, parsed_args.dropout)
model = MachineTranslationModel(
src_vocab_size=data_module.de_vocab_size,
tgt_vocab_size=data_module.en_vocab_size,
encoder_self_attention=generic_attention,
decoder_self_attention=decoder_self_attention,
decoder_cross_attention=generic_attention,
feedforward_network=feedforward_network,
numeric_embedding_facade=numeric_embedding_facade,
tgt_padding_token=data_module.en_pad_token,
embedding_dimension=parsed_args.embedding_dimension,
num_encoder_layers=parsed_args.number_of_layers,
num_decoder_layers=parsed_args.number_of_layers,
scheduler_warmup_steps = parsed_args.scheduler_warmup_steps,
loss_type= parsed_args.loss_type,
label_smoothing = parsed_args.label_smoothing,
**model_kwargs
)

trainer.fit(model, data_module)
Expand All @@ -79,9 +84,10 @@ def run_training_job(parsed_args):
bleu_score = bleu_callback.bleu_score
return bleu_score


class BleuScoreValidationCallback(pl.Callback):
def __init__(self):
self.generator = GeneratorContext(method='beam_batch')
self.generator = GeneratorContext(method='beam')
self.de_tokenizer = AutoTokenizer.from_pretrained('bert-base-german-cased')
self.en_tokenizer = AutoTokenizer.from_pretrained('bert-base-uncased')

Expand Down Expand Up @@ -136,8 +142,8 @@ def on_train_epoch_end(self, trainer, pl_module, **kwargs):
self.bleu_score = bleu_score.score

def parse_args():
parser = argparse.ArgumentParser(description="generformer-nas")
parser.add_argument("--log_path", type=str, required=True, help="dir to place tensorboard logs from all trials")
parser = argparse.ArgumentParser(description="machine-translation")
parser.add_argument("--log_path", type=str, required=True, help="dir to place logs from all trials")
parser.add_argument('--sinusoidal_position', action='store_true', default=False)
parser.add_argument('--rotary_position', action='store_true', default=False)
parser.add_argument('--alibi_position', action='store_true', default=False)
Expand All @@ -148,7 +154,7 @@ def parse_args():
parser.add_argument('--label_smoothing', type=float, default=0.9, help='Label smoothing value')
parser.add_argument('--scheduler_warmup_steps', type=int, default=4000, help='Number of warmup steps for scheduler')
parser.add_argument('--maximum_length', type=int, default=100, help='Maximum sequence length')
parser.add_argument('--batch_size', type=int, default=64, help='Batch size')
parser.add_argument('--batch_size', type=int, default=128, help='Batch size')
parser.add_argument('--embedding_dimension', type=int, default=512, help='Embedding dimension. Original model used 512')
parser.add_argument('--number_of_heads', type=int, default=8, help='Number of attention heads. Original model used 8')
parser.add_argument('--feedforward_dimension', type=int, default=2048, help='Feedforward dimension. Original model used 2048.')
Expand Down
6 changes: 3 additions & 3 deletions scripts/2_evaluate_model__small.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from pytorch_lightning.callbacks import ModelCheckpoint
from machine_translation import MachineTranslationModel
from machine_translation.data import MachineTranslationDataModule
from attention_smithy.numeric_embeddings import SinusoidalPositionEmbedding, NumericEmbeddingFacade
from attention_smithy.numeric_embeddings import SinusoidalPositionEmbedding, NumericEmbeddingManager
from attention_smithy.components import MultiheadAttention, FeedForwardNetwork
from attention_smithy.attention import StandardAttentionMethod
from attention_smithy.utils import seed_everything
Expand Down Expand Up @@ -108,7 +108,7 @@ def on_train_epoch_end(self, pl_module, val_dataloader, **kwargs):
print(bleu_score)

sinusoidal_position_embedding = SinusoidalPositionEmbedding(embed_dim)
numeric_embedding_facade = NumericEmbeddingFacade(sinusoidal_position=sinusoidal_position_embedding)
numeric_embedding_manager = NumericEmbeddingManager(sinusoidal_position=sinusoidal_position_embedding)
generic_attention = MultiheadAttention(embedding_dimension = embed_dim, number_of_heads = num_heads, attention_method = StandardAttentionMethod(dropout))
decoder_self_attention = MultiheadAttention(embedding_dimension = embed_dim, number_of_heads = num_heads, attention_method = StandardAttentionMethod(dropout, is_causal_masking=True))
feedforward_network = FeedForwardNetwork(embed_dim, dim_feedforward, 'relu', dropout)
Expand All @@ -122,7 +122,7 @@ def on_train_epoch_end(self, pl_module, val_dataloader, **kwargs):
decoder_self_attention=decoder_self_attention,
decoder_cross_attention=generic_attention,
feedforward_network=feedforward_network,
numeric_embedding_facade=numeric_embedding_facade,
numeric_embedding_manager=numeric_embedding_manager,
tgt_padding_token=data_module.en_pad_token,
embedding_dimension=embed_dim,
num_encoder_layers=number_of_layers,
Expand Down
6 changes: 3 additions & 3 deletions scripts/model_script_for_nas.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def write(self, buf):

from machine_translation import MachineTranslationModel
from machine_translation.data import MachineTranslationDataModule
from attention_smithy.numeric_embeddings import SinusoidalPositionEmbedding, LearnedPositionEmbedding, RotaryPositionEmbedding, ALiBiPositionEmbedding, NumericEmbeddingFacade, NoAddEmbedding, PassthroughEmbedding
from attention_smithy.numeric_embeddings import SinusoidalPositionEmbedding, LearnedPositionEmbedding, RotaryPositionEmbedding, ALiBiPositionEmbedding, NumericEmbeddingManager, NoAddEmbedding, PassthroughEmbedding
from attention_smithy.components import MultiheadAttention, FeedForwardNetwork
from attention_smithy.attention import StandardAttentionMethod
from attention_smithy.utils import seed_everything
Expand Down Expand Up @@ -84,7 +84,7 @@ def run_training_job(parsed_args):
rotary_position_embedding = RotaryPositionEmbedding(parsed_args.embedding_dimension // parsed_args.number_of_heads) if parsed_args.rotary_position else PassthroughEmbedding()
alibi_position_embedding = ALiBiPositionEmbedding(parsed_args.number_of_heads) if parsed_args.alibi_position else NoAddEmbedding()

numeric_embedding_facade = NumericEmbeddingFacade(sinusoidal_position=sinusoidal_position_embedding, learned_position=learned_position_embedding, rotary_position=rotary_position_embedding, alibi_position=alibi_position_embedding)
numeric_embedding_manager = NumericEmbeddingManager(sinusoidal_position=sinusoidal_position_embedding, learned_position=learned_position_embedding, rotary_position=rotary_position_embedding, alibi_position=alibi_position_embedding)
generic_attention = MultiheadAttention(embedding_dimension= parsed_args.embedding_dimension, number_of_heads= parsed_args.number_of_heads, attention_method= StandardAttentionMethod(parsed_args.dropout))
decoder_self_attention = MultiheadAttention(embedding_dimension= parsed_args.embedding_dimension, number_of_heads= parsed_args.number_of_heads, attention_method= StandardAttentionMethod(parsed_args.dropout, is_causal_masking=True))
feedforward_network = FeedForwardNetwork(parsed_args.embedding_dimension, parsed_args.feedforward_dimension, parsed_args.activation, parsed_args.dropout)
Expand All @@ -95,7 +95,7 @@ def run_training_job(parsed_args):
decoder_self_attention=decoder_self_attention,
decoder_cross_attention=generic_attention,
feedforward_network=feedforward_network,
numeric_embedding_facade=numeric_embedding_facade,
numeric_embedding_manager=numeric_embedding_manager,
tgt_padding_token=data_module.en_pad_token,
embedding_dimension=parsed_args.embedding_dimension,
num_encoder_layers=parsed_args.number_of_layers,
Expand Down
Loading