Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Code callbacks and optional console output suppression while training #160

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion aitextgen/aitextgen.py
Original file line number Diff line number Diff line change
@@ -578,6 +578,9 @@ def train(
freeze_layers: bool = False,
num_layers_freeze: int = None,
use_deepspeed: bool = False,
print_generated: bool = True,
print_saved: bool = True,
callbacks: dict = {},
**kwargs,
) -> None:
"""
@@ -611,6 +614,15 @@ def train(
:param run_id: Run identifier; used for save_gdrive
:param progress_bar_refresh_rate: How often to update
the progress bar while training.
:param print_generated: Whether to print generated sample texts. If this is set to
False, sample texts will still be generated, but will not be displayed in the
console - useful for applications where console output isn't necessary.
:param print_saved: Whether to print a message when the model is being saved
:param callbacks: A dictionary containing callbacks for training events. Supported
callbacks are 'on_train_start' and 'on_train_end' with no arguments, 'on_batch_end'
with arguments (current_steps, total_steps, current_loss, avg_loss, trainer),
'on_sample_text_generated' with argument (texts), a list of the generated text,
and 'on_model_saved' with arguments (current_steps, max_steps, output_dir).
"""

if not os.path.exists(output_dir):
@@ -726,6 +738,9 @@ def train(
progress_bar_refresh_rate,
freeze_layers,
num_layers_freeze,
print_generated,
print_saved,
callbacks
)
],
plugins=deepspeed_plugin,
@@ -751,7 +766,8 @@ def train(
trainer = pl.Trainer(**train_params)
trainer.fit(train_model)

logger.info(f"Saving trained model pytorch_model.bin to /{output_dir}")
if print_saved:
logger.info(f"Saving trained model pytorch_model.bin to /{output_dir}")

self.model.save_pretrained(output_dir)

4 changes: 2 additions & 2 deletions aitextgen/tokenizers.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from tokenizers import ByteLevelBPETokenizer
from typing import Union, List

from os.path import join

def train_tokenizer(
files: Union[str, List[str]],
@@ -53,6 +53,6 @@ def train_tokenizer(
)

if serialize:
tokenizer.save(f"{prefix}.tokenizer.json")
tokenizer.save(join(save_path, f"{prefix}.tokenizer.json"))
else:
tokenizer.save_model(save_path, prefix)
37 changes: 30 additions & 7 deletions aitextgen/train.py
Original file line number Diff line number Diff line change
@@ -99,6 +99,9 @@ def __init__(
progress_bar_refresh_rate,
train_transformers_only,
num_layers_freeze,
print_generated,
print_saved,
callbacks
):
super().__init__()
self.enabled = True
@@ -115,6 +118,9 @@ def __init__(
self.progress_bar_refresh_rate = progress_bar_refresh_rate
self.train_transformers_only = train_transformers_only
self.num_layers_freeze = num_layers_freeze
self.print_generated = print_generated
self.print_saved = print_saved
self.callbacks = callbacks

@property
def save_every_check(self):
@@ -138,10 +144,16 @@ def on_train_start(self, trainer, pl_module):
)
self.freeze_layers(pl_module)

# Call the on_train_start callback, if it exists
self.callbacks.get('on_train_start', lambda: None)()

def on_train_end(self, trainer, pl_module):
self.main_progress_bar.close()
self.unfreeze_layers(pl_module)

# Call the on_train_end callback, if it exists
self.callbacks.get('on_train_end', lambda: None)()

def on_batch_end(self, trainer, pl_module):
super().on_batch_end(trainer, pl_module)

@@ -161,6 +173,10 @@ def on_batch_end(self, trainer, pl_module):

desc = f"Loss: {current_loss:.3f} — Avg: {avg_loss:.3f}"

# Call the on_batch_end callback, if it exists
trainer: pl.Trainer
self.callbacks.get('on_batch_end', lambda steps, max, curr, avg, trainer: None)(self.steps, trainer.max_steps, current_loss, avg_loss, trainer)

if self.steps % self.progress_bar_refresh_rate == 0:
if self.gpu:
# via pytorch-lightning's get_gpu_memory_map()
@@ -205,9 +221,10 @@ def on_batch_end(self, trainer, pl_module):
self.freeze_layers(pl_module)

def generate_sample_text(self, trainer, pl_module):
self.main_progress_bar.write(
f"\033[1m{self.steps:,} steps reached: generating sample texts.\033[0m"
)
if self.print_generated:
self.main_progress_bar.write(
f"\033[1m{self.steps:,} steps reached: generating sample texts.\033[0m"
)

gen_length_max = getattr(
pl_module.model.config, "n_positions", None
@@ -229,15 +246,19 @@ def generate_sample_text(self, trainer, pl_module):

gen_texts = pl_module.tokenizer.batch_decode(outputs, skip_special_tokens=True)

for text in gen_texts:
if self.print_generated:
for text in gen_texts:
self.main_progress_bar.write("=" * 10)
self.main_progress_bar.write(text)

self.main_progress_bar.write("=" * 10)
self.main_progress_bar.write(text)

self.main_progress_bar.write("=" * 10)
# Call the on_sample_text_generated callback, if it exists
self.callbacks.get('on_sample_text_generated', lambda texts: None)(gen_texts)

def save_pytorch_model(self, trainer, pl_module, tpu=False):

if self.enabled:
if self.enabled and self.print_saved:
self.main_progress_bar.write(
f"\033[1m{self.steps:,} steps reached: saving model to /{self.output_dir}\033[0m"
)
@@ -254,6 +275,8 @@ def save_pytorch_model(self, trainer, pl_module, tpu=False):
os.path.join("/content/drive/My Drive/", self.run_id, pt_file),
)

self.callbacks.get('on_model_saved', lambda current_steps, max_steps, output: None)(self.steps, self.trainer.max_steps, self.output_dir)

def average_loss(self, current_loss, prev_avg_loss, smoothing):
if prev_avg_loss is None:
return current_loss