Skip to content

Commit

Permalink
Merge pull request #38 from aobo-y/aobo-pytorch
Browse files Browse the repository at this point in the history
implement Trainer
  • Loading branch information
aobo-y authored Dec 2, 2018
2 parents 4afb2b1 + 9494740 commit da1c171
Show file tree
Hide file tree
Showing 4 changed files with 220 additions and 179 deletions.
22 changes: 5 additions & 17 deletions src/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,13 @@
import os
import argparse
import torch
from torch import optim

import config
from data_util import trimRareWords, load_pairs
from search_decoder import GreedySearchDecoder, BeamSearchDecoder
from seq_encoder import EncoderRNN
from seq_decoder_persona import DecoderRNN
from seq2seq import trainIters
from trainer import Trainer
from evaluate import evaluateInput
from embedding_map import EmbeddingMap
import telegram
Expand Down Expand Up @@ -136,23 +135,12 @@ def train(pairs, encoder, decoder, embedding, personas, word_map, person_map, ch
encoder.train()
decoder.train()

# Initialize optimizers
print('Building optimizers ...')
encoder_optimizer = optim.Adam(encoder.parameters(), lr=config.LR)
decoder_optimizer = optim.Adam(decoder.parameters(), lr=config.LR * config.DECODER_LR)

iteration = 0
trainer = Trainer(encoder, decoder, word_map, person_map, embedding, personas)

if checkpoint:
encoder_optimizer.load_state_dict(checkpoint['en_opt'])
decoder_optimizer.load_state_dict(checkpoint['de_opt'])
iteration = checkpoint['iteration']

# Run training iterations
iteration += 1
print(f'Starting Training from iteration {iteration}!')
trainIters(word_map, person_map, pairs, encoder, decoder, encoder_optimizer, decoder_optimizer,
embedding, personas, config.N_ITER, iteration)
trainer.load(checkpoint)

trainer.train(pairs, config.N_ITER, config.BATCH_SIZE)


def chat(encoder, decoder, word_map, speaker_id):
Expand Down
15 changes: 7 additions & 8 deletions src/data_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,13 @@ def readVocs(datafile):
pairs = [[normalizeString(s) for s in l.split('\t')] for l in lines]
return pairs

def filter_pair(pair):
'''
Returns True iff both sentences in a pair 'p' are under the MAX_LENGTH threshold
'''

# Returns True iff both sentences in a pair 'p' are under the MAX_LENGTH threshold
def filter_pair(p):
# Input sequences need to preserve the last word for EOS token
if len(p) == 3 and config.USE_PERSONA:
return len(p[0].split(' ')) < config.MAX_LENGTH and len(p[1].split(' ')) < config.MAX_LENGTH and len(p[2]) > 1
elif len(p) == 2 and config.USE_PERSONA is not True:
return len(p[0].split(' ')) < config.MAX_LENGTH and len(p[1].split(' ')) < config.MAX_LENGTH
if len(pair) <= 3 and len(pair) >= 2:
return len(pair[0].split(' ')) < config.MAX_LENGTH and len(pair[1].split(' ')) < config.MAX_LENGTH
else:
return False

Expand Down Expand Up @@ -168,7 +167,7 @@ def batch2TrainData(pair_batch, word_map):


def data_2_indexes(pair, word_map, person_map):
speaker = pair[2] if len(pair) == 3 and config.USE_PERSONA else config.NONE_PERSONA
speaker = pair[2] if len(pair) == 3 else config.NONE_PERSONA

return [
indexes_from_sentence(pair[0], word_map),
Expand Down
154 changes: 0 additions & 154 deletions src/seq2seq.py

This file was deleted.

Loading

0 comments on commit da1c171

Please sign in to comment.