Skip to content

Conditional Sequence Generative Adversarial Network trained with policy gradient, Implementation in Tensorflow

License

Notifications You must be signed in to change notification settings

andi611/Conditional-SeqGAN-Tensorflow

Folders and files

NameName
Last commit message
Last commit date

Latest commit

416bdfc · Dec 12, 2018

History

18 Commits
Sep 21, 2018
Sep 21, 2018
Sep 21, 2018
Sep 21, 2018
Sep 21, 2018
Sep 21, 2018
Dec 12, 2018
Oct 4, 2018

Repository files navigation

Daisy: Dialog Analogous Intellectual System

Machine Learning: Chatbot

Conditional Sequence Generative Adversarial Network Trained with Policy Gradient, Implementation in Tensorflow

Requirements:

  • Tensorflow r1.5.1
  • Python 3.6
  • Numpy 1.13.3
  • tempfile (Optional)
  • gtts (Optional)
  • pygame (Optional)

Introduction

Apply Policy Gradient to Generative Adversarial Nets to improve ChatBot dialog generation quality.

  • Generator: seq2seq model with attention and embedding
  • Discriminator: hierarchical encoder and two class classifier
  • Policy Gradient Update: use reward computed by the discriminator to update the generator
  • Monte Carlo Rollout: to compute reward for every generation step
  • Teacher Forcing: add maximum likelihood estimation training to GAN training
  • sampled softmax: to reduce computation complexity from regular softmax
  • model checkpoint supported: training early stopping and save model with best loss

Preperation

  1. Download corpus from:
1) https://www.cs.cornell.edu/%7Ecristian/Cornell_Movie-Dialogs_Corpus.html (official site)
2) https://drive.google.com/file/d/13TvKXXrKVg9X7IEayu7eYpRbxfZ5GDE0/view?usp=sharing (backup link)
  1. extract the .zip file, and save them to the path below:
config.corpus_path = '../cornell_movie_dialog_corpus'

This default path can be modified by changing the '--corpus_dir' option in 'config.py'.

  1. Run Preprocessing on the raw corpus file:
python3 preprocess_data.py

This saves the pre-processed files to the path below:

config.data_dir = '../data'

This default path can be modified by changing the '--data_dir' option in 'config.py'.

  1. After running pre-processing, 5 files will be created:
1. processed_corpus.txt : Encoder input for training
2. word2idx.pkl : word to idx dictionary mapping saved by pickle
3. idx2word.pkl : idx to word dictionary mapping saved by pickle
4. train_encode.pkl: model-ready training data: tokenized and index labeled text ready to be fed to the encoder
5. train_decode.pkl: model-ready training data: tokenized and index labeled text ready to be fed to the decoder

Training

  1. To run MLE pre train on the Generator model:
python3 run.py --pre_train --force_save

(optional: '--force_save' saves model every epoch, otherwise the default model checkpoint is activated and only the model with minimum loss will be saved.)

  1. To run generative adversarial training with policy gradient:
python3 run.py --gan_train --force_save

(optional: '--force_save' saves model every epoch, otherwise the default model checkpoint is activated and only the model with minimum loss will be saved.)

Inference

  1. To chat with the pre-trained model:
python3 run.py --pre_chat --speak

(optional: '--speak' enables audio response)

  1. To chat with the GAN-trained model:
python3 run.py --gan_chat --speak

(optional: '--speak' enables audio response)

  1. To evaluate the BLEU score of the pre-trained model:
python3 run.py --pre_evaluate
  1. To evaluate the BLEU score of the GAN-trained model:
python3 run.py --gan_evaluate

Trained Models

  1. To use exsiting models, place model files under the 'model' directory. Trained Models (pre-trained / gan-trained) can be download from:
https://drive.google.com/drive/folders/1C0A53qqQpdGyKIEB7HVb6BqlBD5gvEJy?usp=sharing

Other

  1. Source code:
1. configuration.py
2. data_loader.py
3. discriminator.py
4. generator.py
5. train.py
6. test.py
7. run.py
8. tf_seq2seq_model.py

About

Conditional Sequence Generative Adversarial Network trained with policy gradient, Implementation in Tensorflow

Topics

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published