From e6e60847a3857164d405f0e305c140b875feeda5 Mon Sep 17 00:00:00 2001 From: Aobo Yang Date: Sat, 1 Dec 2018 22:34:07 -0500 Subject: [PATCH 1/6] log time during training --- src/seq2seq.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/seq2seq.py b/src/seq2seq.py index 50ac7c5..379ed39 100644 --- a/src/seq2seq.py +++ b/src/seq2seq.py @@ -5,6 +5,7 @@ import os import math import random +from datetime import datetime import torch import config @@ -13,6 +14,10 @@ USE_CUDA = torch.cuda.is_available() device = torch.device("cuda" if USE_CUDA else "cpu") +def training_log(string): + time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + print(f'{time}\t{string}') + # Inverse sigmoid decay def teacher_forcing_rate(idx): k_factor = config.TF_RATE_DECAY_FACTOR @@ -130,7 +135,7 @@ def trainIters(word_map, person_map, pairs, encoder, decoder, encoder_optimizer, # Print progress if iteration % config.PRINT_EVERY == 0: print_loss_avg = print_loss / config.PRINT_EVERY - print("Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}".format(iteration, iteration / n_iteration * 100, print_loss_avg)) + training_log('Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}'.format(iteration, iteration / n_iteration * 100, print_loss_avg)) print_loss = 0 # Save checkpoint @@ -139,7 +144,9 @@ def trainIters(word_map, person_map, pairs, encoder, decoder, encoder_optimizer, if not os.path.exists(directory): os.makedirs(directory) - filepath = os.path.join(directory, f'{config.TRAIN_MODE}_{iteration}.tar') + filename = f'{config.TRAIN_MODE}_{iteration}.tar' + filepath = os.path.join(directory, filename) + torch.save({ 'iteration': iteration, 'en': encoder.state_dict(), @@ -152,3 +159,6 @@ def trainIters(word_map, person_map, pairs, encoder, decoder, encoder_optimizer, 'embedding': embedding.state_dict(), 'persona': personas.state_dict(), }, filepath) + + training_log('Save checkpoin {filename}') + From ef1b9b42a10dcc6aa099a813f27e3f8c027b1c13 Mon Sep 17 00:00:00 2001 From: Aobo Yang Date: Sun, 2 Dec 2018 00:11:04 -0500 Subject: [PATCH 2/6] replace seq2seq by trainer --- src/chatbot.py | 9 +-- src/seq2seq.py | 164 -------------------------------------------- src/trainer.py | 180 +++++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 185 insertions(+), 168 deletions(-) delete mode 100644 src/seq2seq.py create mode 100644 src/trainer.py diff --git a/src/chatbot.py b/src/chatbot.py index 3a164bf..cf1b8f9 100644 --- a/src/chatbot.py +++ b/src/chatbot.py @@ -12,7 +12,7 @@ from search_decoder import GreedySearchDecoder, BeamSearchDecoder from seq_encoder import EncoderRNN from seq_decoder_persona import DecoderRNN -from seq2seq import trainIters +from trainer import Trainer from evaluate import evaluateInput from embedding_map import EmbeddingMap @@ -149,9 +149,10 @@ def train(pairs, encoder, decoder, embedding, personas, word_map, person_map, ch # Run training iterations iteration += 1 - print(f'Starting Training from iteration {iteration}!') - trainIters(word_map, person_map, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, - embedding, personas, config.N_ITER, iteration) + + trainer = Trainer(encoder, decoder, word_map, person_map, embedding, personas, encoder_optimizer, decoder_optimizer) + + trainer.train(pairs, config.N_ITER, config.BATCH_SIZE, iteration) def chat(encoder, decoder, word_map, speaker_id): diff --git a/src/seq2seq.py b/src/seq2seq.py deleted file mode 100644 index 379ed39..0000000 --- a/src/seq2seq.py +++ /dev/null @@ -1,164 +0,0 @@ -""" -Train seq2seq -""" - -import os -import math -import random -from datetime import datetime -import torch -import config - -from data_util import batch2TrainData, data_2_indexes - -USE_CUDA = torch.cuda.is_available() -device = torch.device("cuda" if USE_CUDA else "cpu") - -def training_log(string): - time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') - print(f'{time}\t{string}') - -# Inverse sigmoid decay -def teacher_forcing_rate(idx): - k_factor = config.TF_RATE_DECAY_FACTOR - rate = k_factor / (k_factor + math.exp(idx / k_factor)) - return rate - -def maskNLLLoss(inp, target, mask): - # Calculate our loss based on our decoder’s output tensor, the target tensor, - # and a binary mask tensor describing the padding of the target tensor. - n_total = mask.sum().float() - cross_entropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1))) - loss = cross_entropy.masked_select(mask).mean() - loss = loss.to(device) - return loss, n_total.mean() - - -def train(word_map, input_variable, lengths, target_variable, mask, max_target_len, speaker_variable, - encoder, decoder, encoder_optimizer, decoder_optimizer, batch_size, iteration): - # Zero gradients - encoder_optimizer.zero_grad() - decoder_optimizer.zero_grad() - - # Set device options - input_variable = input_variable.to(device) - lengths = lengths.to(device) - target_variable = target_variable.to(device) - mask = mask.to(device) - speaker_variable = speaker_variable.to(device) - - # Initialize variables - loss = 0. - print_loss = [] - n_totals = 0. - - # Forward pass through encoder - encoder_outputs, encoder_hidden = encoder(input_variable, lengths) - - # Create initial decoder input - sos = word_map.get_index(config.SPECIAL_WORD_EMBEDDING_TOKENS['SOS']) - decoder_input = torch.LongTensor([[sos for _ in range(batch_size)]]) - decoder_input = decoder_input.to(device) - - # Set initial decoder hidden state to the encoder's final hidden state - if config.RNN_TYPE == 'LSTM': - decoder_hidden = (encoder_hidden[0][:decoder.n_layers], # hidden state - encoder_hidden[1][:decoder.n_layers]) # cell state - else: - decoder_hidden = encoder_hidden[:decoder.n_layers] - - # Forward batch of sequences through decoder one time step at a time - for t in range(max_target_len): - decoder_output, decoder_hidden = decoder(decoder_input, speaker_variable, decoder_hidden, encoder_outputs) - - if random.random() < teacher_forcing_rate(iteration): - # Teacher forcing: next input is current target - decoder_input = target_variable[t].view(1, -1) - else: - # No teacher forcing: next input is decoder's own current output - _, topi = decoder_output.topk(1) - decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]]) - decoder_input = decoder_input.to(device) - - # Calculate and accumulate loss - mask_loss, n_total = maskNLLLoss(decoder_output, target_variable[t], mask[t]) - loss += mask_loss - print_loss.append(mask_loss.item() * n_total) - n_totals += n_total - - - # Perform backpropagation - loss.backward() - - # Clip gradients: gradients are modified in place - _ = torch.nn.utils.clip_grad_norm_(encoder.parameters(), config.CLIP) - _ = torch.nn.utils.clip_grad_norm_(decoder.parameters(), config.CLIP) - - # Adjust model weights - encoder_optimizer.step() - decoder_optimizer.step() - - return sum(print_loss)/n_totals - - -def trainIters(word_map, person_map, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer, - embedding, personas, n_iteration, start_iteration): - """ - When we save our model, we save a tarball containing the encoder and decoder state_dicts (parameters), - the optimizers’ state_dicts, the loss, the iteration, etc. - After loading a checkpoint, we will be able to use the model parameters to run inference, - or we can continue training right where we left off. - """ - # convert sentence & speaker name to indexes - index_pair = [data_2_indexes(pair, word_map, person_map) for pair in pairs] - - batch_size = config.BATCH_SIZE - - # Load batches for each iteration - training_batches = [batch2TrainData([random.choice(index_pair) for _ in range(batch_size)], word_map) - for _ in range(n_iteration)] - - # Initializations - print_loss = 0 - - # Training loop - for iteration in range(start_iteration, n_iteration + 1): - training_batch = training_batches[iteration - 1] - # extract fields from batch - input_variable, lengths, target_variable, mask, max_target_len, speaker_variable = training_batch - - # run a training iteration with batch - loss = train(word_map, input_variable, lengths, target_variable, mask, max_target_len, speaker_variable, - encoder, decoder, encoder_optimizer, decoder_optimizer, batch_size, iteration) - print_loss += loss - - # Print progress - if iteration % config.PRINT_EVERY == 0: - print_loss_avg = print_loss / config.PRINT_EVERY - training_log('Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}'.format(iteration, iteration / n_iteration * 100, print_loss_avg)) - print_loss = 0 - - # Save checkpoint - if iteration % config.SAVE_EVERY == 0: - directory = os.path.join(config.SAVE_DIR, config.MODEL_NAME, '{}-{}_{}'.format(config.ENCODER_N_LAYERS, config.DECODER_N_LAYERS, config.HIDDEN_SIZE)) - if not os.path.exists(directory): - os.makedirs(directory) - - filename = f'{config.TRAIN_MODE}_{iteration}.tar' - filepath = os.path.join(directory, filename) - - torch.save({ - 'iteration': iteration, - 'en': encoder.state_dict(), - 'de': decoder.state_dict(), - 'en_opt': encoder_optimizer.state_dict(), - 'de_opt': decoder_optimizer.state_dict(), - 'loss': loss, - 'word_map_dict': word_map.__dict__, - 'person_map_dict': person_map.__dict__, - 'embedding': embedding.state_dict(), - 'persona': personas.state_dict(), - }, filepath) - - training_log('Save checkpoin {filename}') - diff --git a/src/trainer.py b/src/trainer.py new file mode 100644 index 0000000..1fac73c --- /dev/null +++ b/src/trainer.py @@ -0,0 +1,180 @@ +""" +Train seq2seq +""" + +import os +import math +import random +from datetime import datetime +import torch +import config + +from data_util import batch2TrainData, data_2_indexes + +USE_CUDA = torch.cuda.is_available() +DEVICE = torch.device("cuda" if USE_CUDA else "cpu") + +# Inverse sigmoid decay +def teacher_forcing_rate(idx): + k_factor = config.TF_RATE_DECAY_FACTOR + rate = k_factor / (k_factor + math.exp(idx / k_factor)) + return rate + +# TODO consider use nn.CrossEntropyLoss +def mask_nll_loss(inp, target, mask): + # Calculate our loss based on our decoder’s output tensor, the target tensor, + # and a binary mask tensor describing the padding of the target tensor. + n_total = mask.sum().float() + cross_entropy = -torch.log(torch.gather(inp, 1, target.view(-1, 1))) + loss = cross_entropy.masked_select(mask).mean() + loss = loss.to(DEVICE) + return loss, n_total.mean() + +class Trainer: + '''Trainer to train the seq2seq model''' + + def __init__(self, encoder, decoder, word_map, person_map, embedding, personas, encoder_optimizer, decoder_optimizer): + self.encoder = encoder + self.decoder = decoder + self.word_map = word_map + self.person_map = person_map + self.embedding = embedding + self.personas = personas + self.encoder_optimizer = encoder_optimizer + self.decoder_optimizer = decoder_optimizer + + def log(self, string): + time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') + print(f'{time}\t{string}') + + + def train_batch(self, training_batch, tf_rate=1): + # extract fields from batch + input_variable, lengths, target_variable, mask, max_target_len, speaker_variable = training_batch + + # Zero gradients + self.encoder_optimizer.zero_grad() + self.decoder_optimizer.zero_grad() + + # Set DEVICE options + input_variable = input_variable.to(DEVICE) + lengths = lengths.to(DEVICE) + target_variable = target_variable.to(DEVICE) + mask = mask.to(DEVICE) + speaker_variable = speaker_variable.to(DEVICE) + + # Initialize variables + loss = 0. + print_loss = [] + n_totals = 0. + + # Forward pass through encoder + encoder_outputs, encoder_hidden = self.encoder(input_variable, lengths) + + # Create initial decoder input + sos = self.word_map.get_index(config.SPECIAL_WORD_EMBEDDING_TOKENS['SOS']) + + batch_size = input_variable.size(1) + decoder_input = torch.LongTensor([[sos for _ in range(batch_size)]]) + decoder_input = decoder_input.to(DEVICE) + + decoder_layers = self.decoder.n_layers + # Set initial decoder hidden state to the encoder's final hidden state + if config.RNN_TYPE == 'LSTM': + decoder_hidden = (encoder_hidden[0][:decoder_layers], # hidden state + encoder_hidden[1][:decoder_layers]) # cell state + else: + decoder_hidden = encoder_hidden[:decoder_layers] + + # Forward batch of sequences through decoder one time step at a time + for t in range(max_target_len): + decoder_output, decoder_hidden = self.decoder(decoder_input, speaker_variable, decoder_hidden, encoder_outputs) + + if random.random() < tf_rate: + # Teacher forcing: next input is current target + decoder_input = target_variable[t].view(1, -1) + else: + # No teacher forcing: next input is decoder's own current output + _, topi = decoder_output.topk(1) + decoder_input = torch.LongTensor([[topi[i][0] for i in range(batch_size)]]) + decoder_input = decoder_input.to(DEVICE) + + # Calculate and accumulate loss + mask_loss, n_total = mask_nll_loss(decoder_output, target_variable[t], mask[t]) + loss += mask_loss + print_loss.append(mask_loss.item() * n_total) + n_totals += n_total + + + # Perform backpropagation + loss.backward() + + # Clip gradients: gradients are modified in place + _ = torch.nn.utils.clip_grad_norm_(self.encoder.parameters(), config.CLIP) + _ = torch.nn.utils.clip_grad_norm_(self.decoder.parameters(), config.CLIP) + + # Adjust model weights + self.encoder_optimizer.step() + self.decoder_optimizer.step() + + return sum(print_loss)/n_totals + + + def train(self, pairs, n_iteration, batch_size=1, start_iteration=1): + """ + When we save our model, we save a tarball containing the encoder and decoder state_dicts (parameters), + the optimizers’ state_dicts, the loss, the iteration, etc. + After loading a checkpoint, we will be able to use the model parameters to run inference, + or we can continue training right where we left off. + """ + # convert sentence & speaker name to indexes + index_pair = [data_2_indexes(pair, self.word_map, self.person_map) for pair in pairs] + + batch_size = config.BATCH_SIZE + + # Load batches for each iteration + training_batches = [batch2TrainData([random.choice(index_pair) for _ in range(batch_size)], self.word_map) + for _ in range(n_iteration)] + + # Initializations + print_loss = 0 + + # Training loop + for iteration in range(start_iteration, n_iteration + 1): + training_batch = training_batches[iteration - 1] + + tf_rate = teacher_forcing_rate(iteration) + # run a training iteration with batch + loss = self.train_batch(training_batch, tf_rate) + print_loss += loss + + # Print progress + if iteration % config.PRINT_EVERY == 0: + print_loss_avg = print_loss / config.PRINT_EVERY + self.log('Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}'.format(iteration, iteration / n_iteration * 100, print_loss_avg)) + print_loss = 0 + + # Save checkpoint + if iteration % config.SAVE_EVERY == 0: + directory = os.path.join(config.SAVE_DIR, config.MODEL_NAME, '{}-{}_{}'.format(config.ENCODER_N_LAYERS, config.DECODER_N_LAYERS, config.HIDDEN_SIZE)) + if not os.path.exists(directory): + os.makedirs(directory) + + filename = f'{config.TRAIN_MODE}_{iteration}.tar' + filepath = os.path.join(directory, filename) + + torch.save({ + 'iteration': iteration, + 'loss': loss, + 'en': self.encoder.state_dict(), + 'de': self.decoder.state_dict(), + 'en_opt': self.encoder_optimizer.state_dict(), + 'de_opt': self.decoder_optimizer.state_dict(), + 'word_map_dict': self.word_map.__dict__, + 'person_map_dict': self.person_map.__dict__, + 'embedding': self.embedding.state_dict(), + 'persona': self.personas.state_dict(), + }, filepath) + + self.log(f'Save checkpoin {filename}') + From aa8555f173cc80e26ed319332129fd60ff65dc59 Mon Sep 17 00:00:00 2001 From: Aobo Yang Date: Sun, 2 Dec 2018 00:59:53 -0500 Subject: [PATCH 3/6] trainer to init optim & load checkpoint --- src/chatbot.py | 19 +++---------------- src/trainer.py | 38 +++++++++++++++++++++++++++++++++----- 2 files changed, 36 insertions(+), 21 deletions(-) diff --git a/src/chatbot.py b/src/chatbot.py index cf1b8f9..3593cdd 100644 --- a/src/chatbot.py +++ b/src/chatbot.py @@ -5,7 +5,6 @@ import os import argparse import torch -from torch import optim import config from data_util import trimRareWords, load_pairs @@ -135,24 +134,12 @@ def train(pairs, encoder, decoder, embedding, personas, word_map, person_map, ch encoder.train() decoder.train() - # Initialize optimizers - print('Building optimizers ...') - encoder_optimizer = optim.Adam(encoder.parameters(), lr=config.LR) - decoder_optimizer = optim.Adam(decoder.parameters(), lr=config.LR * config.DECODER_LR) - - iteration = 0 + trainer = Trainer(encoder, decoder, word_map, person_map, embedding, personas) if checkpoint: - encoder_optimizer.load_state_dict(checkpoint['en_opt']) - decoder_optimizer.load_state_dict(checkpoint['de_opt']) - iteration = checkpoint['iteration'] - - # Run training iterations - iteration += 1 - - trainer = Trainer(encoder, decoder, word_map, person_map, embedding, personas, encoder_optimizer, decoder_optimizer) + trainer.load(checkpoint) - trainer.train(pairs, config.N_ITER, config.BATCH_SIZE, iteration) + trainer.train(pairs, config.N_ITER, config.BATCH_SIZE) def chat(encoder, decoder, word_map, speaker_id): diff --git a/src/trainer.py b/src/trainer.py index 1fac73c..04a6e32 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -7,6 +7,7 @@ import random from datetime import datetime import torch +from torch import optim import config from data_util import batch2TrainData, data_2_indexes @@ -33,22 +34,44 @@ def mask_nll_loss(inp, target, mask): class Trainer: '''Trainer to train the seq2seq model''' - def __init__(self, encoder, decoder, word_map, person_map, embedding, personas, encoder_optimizer, decoder_optimizer): + def __init__(self, encoder, decoder, word_map, person_map, embedding, personas): self.encoder = encoder self.decoder = decoder self.word_map = word_map self.person_map = person_map self.embedding = embedding self.personas = personas - self.encoder_optimizer = encoder_optimizer - self.decoder_optimizer = decoder_optimizer + + self.encoder_optimizer = optim.Adam(encoder.parameters(), lr=config.LR) + self.decoder_optimizer = optim.Adam(decoder.parameters(), lr=config.LR * config.DECODER_LR) + + # trained iteration + self.trained_iteration = 0 def log(self, string): + '''formatted log output for training''' + time = datetime.now().strftime('%Y-%m-%d %H:%M:%S') print(f'{time}\t{string}') + def load(self, checkpoint): + '''load checkpoint''' + + self.trained_iteration = checkpoint['iteration'] + + self.encoder_optimizer.load_state_dict(checkpoint['en_opt']) + self.decoder_optimizer.load_state_dict(checkpoint['de_opt']) + def train_batch(self, training_batch, tf_rate=1): + ''' + train a batch of any batch size + + Inputs: + training_batch: train data batch created by batch2TrainData + tf_rate: teacher forcing rate, the smaller the rate the higher the scheduled sampling + ''' + # extract fields from batch input_variable, lengths, target_variable, mask, max_target_len, speaker_variable = training_batch @@ -117,10 +140,10 @@ def train_batch(self, training_batch, tf_rate=1): self.encoder_optimizer.step() self.decoder_optimizer.step() - return sum(print_loss)/n_totals + return sum(print_loss) / n_totals - def train(self, pairs, n_iteration, batch_size=1, start_iteration=1): + def train(self, pairs, n_iteration, batch_size=1): """ When we save our model, we save a tarball containing the encoder and decoder state_dicts (parameters), the optimizers’ state_dicts, the loss, the iteration, etc. @@ -140,6 +163,9 @@ def train(self, pairs, n_iteration, batch_size=1, start_iteration=1): print_loss = 0 # Training loop + start_iteration = self.trained_iteration + 1 + + self.log(f'Start training from iteration {start_iteration} to {n_iteration}...') for iteration in range(start_iteration, n_iteration + 1): training_batch = training_batches[iteration - 1] @@ -148,6 +174,8 @@ def train(self, pairs, n_iteration, batch_size=1, start_iteration=1): loss = self.train_batch(training_batch, tf_rate) print_loss += loss + self.trained_iteration = iteration + # Print progress if iteration % config.PRINT_EVERY == 0: print_loss_avg = print_loss / config.PRINT_EVERY From 2d0d64568590d62f6d9d394748bf149d0ef6f1c4 Mon Sep 17 00:00:00 2001 From: Aobo Yang Date: Sun, 2 Dec 2018 01:09:58 -0500 Subject: [PATCH 4/6] log tf rate --- src/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/trainer.py b/src/trainer.py index 04a6e32..5101ce7 100644 --- a/src/trainer.py +++ b/src/trainer.py @@ -179,7 +179,7 @@ def train(self, pairs, n_iteration, batch_size=1): # Print progress if iteration % config.PRINT_EVERY == 0: print_loss_avg = print_loss / config.PRINT_EVERY - self.log('Iteration: {}; Percent complete: {:.1f}%; Average loss: {:.4f}'.format(iteration, iteration / n_iteration * 100, print_loss_avg)) + self.log('Iter: {}; Percent: {:.1f}%; Avg loss: {:.4f}; TF rate: {:.4f}'.format(iteration, iteration / n_iteration * 100, print_loss_avg, tf_rate)) print_loss = 0 # Save checkpoint From 46c95212569171bf5d3599422f477dc44a41a39a Mon Sep 17 00:00:00 2001 From: Aobo Yang Date: Sun, 2 Dec 2018 02:00:04 -0500 Subject: [PATCH 5/6] do not filter corpus by person --- src/data_util.py | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/data_util.py b/src/data_util.py index fe8d65b..f11660a 100644 --- a/src/data_util.py +++ b/src/data_util.py @@ -38,14 +38,13 @@ def readVocs(datafile): pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines] return pairs +def filter_pair(pair): + ''' + Returns True iff both sentences in a pair 'p' are under the MAX_LENGTH threshold + ''' -# Returns True iff both sentences in a pair 'p' are under the MAX_LENGTH threshold -def filter_pair(p): - # Input sequences need to preserve the last word for EOS token - if len(p) == 3 and config.USE_PERSONA: - return len(p[0].split(' ')) < config.MAX_LENGTH and len(p[1].split(' ')) < config.MAX_LENGTH and len(p[2]) > 1 - elif len(p) == 2 and config.USE_PERSONA is not True: - return len(p[0].split(' ')) < config.MAX_LENGTH and len(p[1].split(' ')) < config.MAX_LENGTH + if len(pair) <= 3 and len(pair) >= 2: + return len(pair[0].split(' ')) < config.MAX_LENGTH and len(pair[1].split(' ')) < config.MAX_LENGTH else: return False From 9494740fa39d96a0f39a1fe0e89b33e4578b8031 Mon Sep 17 00:00:00 2001 From: Aobo Yang Date: Sun, 2 Dec 2018 02:01:32 -0500 Subject: [PATCH 6/6] default general persona to data --- src/data_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/data_util.py b/src/data_util.py index f11660a..6cae3d0 100644 --- a/src/data_util.py +++ b/src/data_util.py @@ -167,7 +167,7 @@ def batch2TrainData(pair_batch, word_map): def data_2_indexes(pair, word_map, person_map): - speaker = pair[2] if len(pair) == 3 and config.USE_PERSONA else config.NONE_PERSONA + speaker = pair[2] if len(pair) == 3 else config.NONE_PERSONA return [ indexes_from_sentence(pair[0], word_map),