From 09c5bf7b8a1b9c64fb9669110ab4f1d8f2dbb106 Mon Sep 17 00:00:00 2001 From: Jiahao Cai Date: Sat, 1 Dec 2018 17:22:51 -0500 Subject: [PATCH 1/8] update telegram bots --- src/chatbot.py | 18 +++++++++++++++++ src/config.py | 2 +- src/evaluate.py | 5 +++-- src/telegram.py | 29 ---------------------------- src/telegram_bot.py | 47 +++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 69 insertions(+), 32 deletions(-) delete mode 100644 src/telegram.py create mode 100644 src/telegram_bot.py diff --git a/src/chatbot.py b/src/chatbot.py index 84c9c80..0c37067 100644 --- a/src/chatbot.py +++ b/src/chatbot.py @@ -15,6 +15,7 @@ from seq2seq import trainIters from evaluate import evaluateInput from embedding_map import EmbeddingMap +import telegram DIR_PATH = os.path.dirname(__file__) USE_CUDA = torch.cuda.is_available() @@ -194,5 +195,22 @@ def main(): else: print('Invalid speaker. Possible speakers:', person_map.tokens) +def telegram_init(speaker_name): + config.USE_PERSONA = True + encoder, decoder, embedding, personas, word_map, person_map, _, _ = build_model(load_checkpoint=True) + if person_map.has(speaker_name): + print('Selected speaker:', speaker_name) + speaker_id = person_map.get_index(speaker_name) + # Set dropout layers to eval mode + encoder.eval() + decoder.eval() + # Initialize search module + if config.BEAM_SEARCH_ON: + searcher = BeamSearchDecoder(encoder, decoder) + else: + searcher = GreedySearchDecoder(encoder, decoder) + return searcher, word_map, speaker_id + + if __name__ =='__main__': main() diff --git a/src/config.py b/src/config.py index 5859834..db86639 100644 --- a/src/config.py +++ b/src/config.py @@ -33,7 +33,7 @@ ATTN_MODEL = 'dot' # type of the attention model: dot/general/concat TRAIN_EMBEDDING = True # whether to update the word embeddding during training HIDDEN_SIZE = 300 # size of the word embedding & number of hidden units in GRU -PERSONA_SIZE = 100 # size of the persona embedding +PERSONA_SIZE = 30 # size of the persona embedding ENCODER_N_LAYERS = 2 # number of layers in bi-GRU encoder DECODER_N_LAYERS = 2 # number of layers in GRU decoder ENCODER_DROPOUT_RATE = 0.1 # dropout rate in bi-GRU encoder diff --git a/src/evaluate.py b/src/evaluate.py index febd1ed..353722e 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -55,5 +55,6 @@ def evaluateExample(sentence, searcher, word_map, speaker_id): # evaluate sentence output_words = evaluate(searcher, word_map, input_sentence, speaker_id) output_words = [x for x in output_words if x not in config.SPECIAL_WORD_EMBEDDING_TOKENS.values()] - - print('Bot:', ' '.join(output_words)) + res = ' '.join(output_words) + print('Bot:', res) + return res \ No newline at end of file diff --git a/src/telegram.py b/src/telegram.py deleted file mode 100644 index 539ad17..0000000 --- a/src/telegram.py +++ /dev/null @@ -1,29 +0,0 @@ -from telegram.ext import Updater, CommandHandler, MessageHandler, Filters - -speakers = { - 'cartman': '771434365:AAE8bpXZ6QMizyZxmvr6Mke6pOUFW0suN9E', - 'stan': '797666745:AAG5L9qHQgQIERXFf47mD1QdatLEfXH0p2c', - 'kyle': '798154398:AAGHAUZ8l-Zi_bMr8KrRwkZ1svju1_fW2S8', - 'randy': '764465736:AAHeOYQMHStGNPZ1gNaPl2dZEiuorvBi_fI', - 'none': '773092951:AAHMltKlernAmXHvO_TQ3B6mzkY1mv61rQc' -} - -updater = Updater(token='771434365:AAE8bpXZ6QMizyZxmvr6Mke6pOUFW0suN9E') - -dispatcher = updater.dispatcher - -def start(bot, update): - bot.send_message(chat_id=update.message.chat_id, text="I'm Cartman!") - -def response(bot, update): - bot.send_message(chat_id=update.message.chat_id, text=update.message.text) - -start_handler = CommandHandler('start', start) -dispatcher.add_handler(start_handler) - -message_handler = MessageHandler(Filters.text, response) -dispatcher.add_handler(message_handler) - -updater.start_polling() - -print('Telegram bot has started') diff --git a/src/telegram_bot.py b/src/telegram_bot.py new file mode 100644 index 0000000..53c2df6 --- /dev/null +++ b/src/telegram_bot.py @@ -0,0 +1,47 @@ +from telegram.ext import Updater, CommandHandler, MessageHandler, Filters +from chatbot import telegram_init +from evaluate import evaluateExample + +speakers = { + 'cartman': '771434365:AAE8bpXZ6QMizyZxmvr6Mke6pOUFW0suN9E', + 'stan': '797666745:AAG5L9qHQgQIERXFf47mD1QdatLEfXH0p2c', + 'kyle': '798154398:AAGHAUZ8l-Zi_bMr8KrRwkZ1svju1_fW2S8', + 'randy': '764465736:AAHeOYQMHStGNPZ1gNaPl2dZEiuorvBi_fI', + # 'none': '773092951:AAHMltKlernAmXHvO_TQ3B6mzkY1mv61rQc' +} + + +class TeleBot: + def __init__(self, speaker, token): + self.speaker = speaker + self.token = token + def run(self): + searcher, word_map, speaker_id = telegram_init(self.speaker) + updater = Updater(token=self.token) + dispatcher = updater.dispatcher + def start(bot, update): + bot.send_message(chat_id=update.message.chat_id, text=f"Hi, I'm {self.speaker.capitalize()}!") + start_handler = CommandHandler('start', start) + dispatcher.add_handler(start_handler) + def response(bot, update): + bot.send_message( + chat_id=update.message.chat_id, + text=evaluateExample( + update.message.text, + searcher, + word_map, + speaker_id)) + message_handler = MessageHandler(Filters.text, response) + dispatcher.add_handler(message_handler) + updater.start_polling() + print(f'Telegram bot {self.speaker} has started') + +if __name__ == '__main__': + for speaker, token in speakers.items(): + bot = TeleBot(speaker, token) + bot.run() + +# evaluateExample(sentence, searcher, word_map, speaker_id) + + + From 3babcdae6faa2070bf6f5e2e216f6cf56256db9c Mon Sep 17 00:00:00 2001 From: Jiahao Cai Date: Sat, 1 Dec 2018 17:32:28 -0500 Subject: [PATCH 2/8] fix merged conflict --- src/chatbot.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/chatbot.py b/src/chatbot.py index 70357ec..7592924 100644 --- a/src/chatbot.py +++ b/src/chatbot.py @@ -195,8 +195,11 @@ def main(): print('Invalid speaker. Possible speakers:', person_map.tokens) def telegram_init(speaker_name): + parser = argparse.ArgumentParser() + parser.add_argument('--checkpoint') + args = parser.parse_args() config.USE_PERSONA = True - encoder, decoder, embedding, personas, word_map, person_map, _, _ = build_model(load_checkpoint=True) + encoder, decoder, embedding, personas, word_map, person_map, _ = build_model(args.checkpoint) if person_map.has(speaker_name): print('Selected speaker:', speaker_name) speaker_id = person_map.get_index(speaker_name) From 18da9ab93d8be62179be95f87f3a11445f26ab5c Mon Sep 17 00:00:00 2001 From: Jiahao Cai Date: Sat, 1 Dec 2018 19:50:23 -0500 Subject: [PATCH 3/8] add none, fix response --- src/evaluate.py | 3 +++ src/telegram_bot.py | 4 +--- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/evaluate.py b/src/evaluate.py index 353722e..7f458e5 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -5,6 +5,7 @@ import torch import config from data_util import normalizeString, indexes_from_sentence +import re USE_CUDA = torch.cuda.is_available() device = torch.device("cuda" if USE_CUDA else "cpu") @@ -55,6 +56,8 @@ def evaluateExample(sentence, searcher, word_map, speaker_id): # evaluate sentence output_words = evaluate(searcher, word_map, input_sentence, speaker_id) output_words = [x for x in output_words if x not in config.SPECIAL_WORD_EMBEDDING_TOKENS.values()] + output_words[0] = output_words[0].capitalize() res = ' '.join(output_words) + res = re.sub(r'\s+([.!?,])', r'\1', res) print('Bot:', res) return res \ No newline at end of file diff --git a/src/telegram_bot.py b/src/telegram_bot.py index 53c2df6..53166e2 100644 --- a/src/telegram_bot.py +++ b/src/telegram_bot.py @@ -7,7 +7,7 @@ 'stan': '797666745:AAG5L9qHQgQIERXFf47mD1QdatLEfXH0p2c', 'kyle': '798154398:AAGHAUZ8l-Zi_bMr8KrRwkZ1svju1_fW2S8', 'randy': '764465736:AAHeOYQMHStGNPZ1gNaPl2dZEiuorvBi_fI', - # 'none': '773092951:AAHMltKlernAmXHvO_TQ3B6mzkY1mv61rQc' + '': '773092951:AAHMltKlernAmXHvO_TQ3B6mzkY1mv61rQc' } @@ -41,7 +41,5 @@ def response(bot, update): bot = TeleBot(speaker, token) bot.run() -# evaluateExample(sentence, searcher, word_map, speaker_id) - From 033f1aaba5ae978d9763d0ae34e309402782fbd5 Mon Sep 17 00:00:00 2001 From: Jiahao Cai Date: Sat, 1 Dec 2018 19:54:39 -0500 Subject: [PATCH 4/8] add timestamp to logs --- src/evaluate.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/evaluate.py b/src/evaluate.py index 7f458e5..0b23f28 100644 --- a/src/evaluate.py +++ b/src/evaluate.py @@ -6,6 +6,7 @@ import config from data_util import normalizeString, indexes_from_sentence import re +import datetime USE_CUDA = torch.cuda.is_available() device = torch.device("cuda" if USE_CUDA else "cpu") @@ -59,5 +60,5 @@ def evaluateExample(sentence, searcher, word_map, speaker_id): output_words[0] = output_words[0].capitalize() res = ' '.join(output_words) res = re.sub(r'\s+([.!?,])', r'\1', res) - print('Bot:', res) + print('Timestamp: {:%Y-%m-%d %H:%M:%S}> '.format(datetime.datetime.now()) + res) return res \ No newline at end of file From 570562c88453ffd3d349b4d85b35db9eccb7172b Mon Sep 17 00:00:00 2001 From: Jiahao Cai Date: Sat, 1 Dec 2018 20:02:26 -0500 Subject: [PATCH 5/8] add empty str --- src/telegram_bot.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/telegram_bot.py b/src/telegram_bot.py index 53166e2..7c2227f 100644 --- a/src/telegram_bot.py +++ b/src/telegram_bot.py @@ -20,7 +20,11 @@ def run(self): updater = Updater(token=self.token) dispatcher = updater.dispatcher def start(bot, update): - bot.send_message(chat_id=update.message.chat_id, text=f"Hi, I'm {self.speaker.capitalize()}!") + bot.send_message(chat_id=update.message.chat_id, text=evaluateExample( + "", + searcher, + word_map, + speaker_id)) start_handler = CommandHandler('start', start) dispatcher.add_handler(start_handler) def response(bot, update): From 89aa005c7c6b3ddd5ff4d31ab40aabd217583e55 Mon Sep 17 00:00:00 2001 From: Jiahao Cai Date: Sat, 1 Dec 2018 20:27:09 -0500 Subject: [PATCH 6/8] fix lost token --- src/search_decoder.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/search_decoder.py b/src/search_decoder.py index 6884baa..5c74e5f 100644 --- a/src/search_decoder.py +++ b/src/search_decoder.py @@ -177,8 +177,8 @@ def forward(self, input_seq, input_length, speaker_id, sos, eos): n = n.prevNode tokens.append(n.wordid) scores.append(n.logp) - all_tokens.append(tokens[::-1][1:]) - all_scores.append(scores[::-1][1:]) + all_tokens.append(tokens[::-1]) + all_scores.append(scores[::-1]) idx = self.random_pick(len(all_tokens)) return all_tokens[idx], all_scores[idx] From 99906156ac823c3f74f345d1b20d1004c05faa41 Mon Sep 17 00:00:00 2001 From: Jiahao Cai Date: Sat, 1 Dec 2018 20:44:29 -0500 Subject: [PATCH 7/8] fix regex 0-9 bug --- src/data_util.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/data_util.py b/src/data_util.py index fe8d65b..443d240 100644 --- a/src/data_util.py +++ b/src/data_util.py @@ -12,7 +12,7 @@ def normalizeString(s): # give a leading & ending spaces to punctuations s = re.sub(r'([.!?,])', r' \1 ', s) # purge unrecognized token with space - s = re.sub(r'[^a-z.!?,]+', r' ', s) + s = re.sub(r'[^0-9a-z.!?,]+', r' ', s) # squeeze multiple spaces s = re.sub(r'([ ]+)', r' ', s) # remove extra leading & ending space From 63dddd5be854508fa9c8153c61f2a483520929fc Mon Sep 17 00:00:00 2001 From: Jiahao Cai Date: Sat, 1 Dec 2018 22:37:15 -0500 Subject: [PATCH 8/8] change beam candidate num --- src/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/config.py b/src/config.py index ae6b148..7d7fdca 100644 --- a/src/config.py +++ b/src/config.py @@ -22,7 +22,7 @@ # Configure models - chat relevant BEAM_SEARCH_ON = True # use Beam Search or Greedy Search BEAM_WIDTH = 10 # size of beam -BEAM_CANDIDATE_NUM = 3 # number of sentences generated by beam search +BEAM_CANDIDATE_NUM = 2 # number of sentences generated by beam search # Configure models - training relevant RNN_TYPE = 'LSTM' # use LSTM or GRU as RNN