Skip to content

Commit

Permalink
Merge pull request #37 from jiahao42/pytorch
Browse files Browse the repository at this point in the history
Update telegram bots
  • Loading branch information
jiahao42 authored Dec 2, 2018
2 parents 7d4177e + 63dddd5 commit 4afb2b1
Show file tree
Hide file tree
Showing 7 changed files with 82 additions and 36 deletions.
21 changes: 21 additions & 0 deletions src/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from seq2seq import trainIters
from evaluate import evaluateInput
from embedding_map import EmbeddingMap
import telegram

DIR_PATH = os.path.dirname(__file__)
USE_CUDA = torch.cuda.is_available()
Expand Down Expand Up @@ -193,5 +194,25 @@ def main():
else:
print('Invalid speaker. Possible speakers:', person_map.tokens)

def telegram_init(speaker_name):
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint')
args = parser.parse_args()
config.USE_PERSONA = True
encoder, decoder, embedding, personas, word_map, person_map, _ = build_model(args.checkpoint)
if person_map.has(speaker_name):
print('Selected speaker:', speaker_name)
speaker_id = person_map.get_index(speaker_name)
# Set dropout layers to eval mode
encoder.eval()
decoder.eval()
# Initialize search module
if config.BEAM_SEARCH_ON:
searcher = BeamSearchDecoder(encoder, decoder)
else:
searcher = GreedySearchDecoder(encoder, decoder)
return searcher, word_map, speaker_id


if __name__ =='__main__':
main()
4 changes: 2 additions & 2 deletions src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,14 @@
# Configure models - chat relevant
BEAM_SEARCH_ON = True # use Beam Search or Greedy Search
BEAM_WIDTH = 10 # size of beam
BEAM_CANDIDATE_NUM = 3 # number of sentences generated by beam search
BEAM_CANDIDATE_NUM = 2 # number of sentences generated by beam search

# Configure models - training relevant
RNN_TYPE = 'LSTM' # use LSTM or GRU as RNN
ATTN_MODEL = 'dot' # type of the attention model: dot/general/concat
TRAIN_EMBEDDING = True # whether to update the word embeddding during training
HIDDEN_SIZE = 300 # size of the word embedding & number of hidden units in GRU
PERSONA_SIZE = 100 # size of the persona embedding
PERSONA_SIZE = 30 # size of the persona embedding
ENCODER_N_LAYERS = 2 # number of layers in bi-GRU encoder
DECODER_N_LAYERS = 2 # number of layers in GRU decoder
ENCODER_DROPOUT_RATE = 0.1 # dropout rate in bi-GRU encoder
Expand Down
2 changes: 1 addition & 1 deletion src/data_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def normalizeString(s):
# give a leading & ending spaces to punctuations
s = re.sub(r'([.!?,])', r' \1 ', s)
# purge unrecognized token with space
s = re.sub(r'[^a-z.!?,]+', r' ', s)
s = re.sub(r'[^0-9a-z.!?,]+', r' ', s)
# squeeze multiple spaces
s = re.sub(r'([ ]+)', r' ', s)
# remove extra leading & ending space
Expand Down
9 changes: 7 additions & 2 deletions src/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
import config
from data_util import normalizeString, indexes_from_sentence
import re
import datetime

USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
Expand Down Expand Up @@ -55,5 +57,8 @@ def evaluateExample(sentence, searcher, word_map, speaker_id):
# evaluate sentence
output_words = evaluate(searcher, word_map, input_sentence, speaker_id)
output_words = [x for x in output_words if x not in config.SPECIAL_WORD_EMBEDDING_TOKENS.values()]

print('Bot:', ' '.join(output_words))
output_words[0] = output_words[0].capitalize()
res = ' '.join(output_words)
res = re.sub(r'\s+([.!?,])', r'\1', res)
print('Timestamp: {:%Y-%m-%d %H:%M:%S}> '.format(datetime.datetime.now()) + res)
return res
4 changes: 2 additions & 2 deletions src/search_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,8 @@ def forward(self, input_seq, input_length, speaker_id, sos, eos):
n = n.prevNode
tokens.append(n.wordid)
scores.append(n.logp)
all_tokens.append(tokens[::-1][1:])
all_scores.append(scores[::-1][1:])
all_tokens.append(tokens[::-1])
all_scores.append(scores[::-1])
idx = self.random_pick(len(all_tokens))
return all_tokens[idx], all_scores[idx]

Expand Down
29 changes: 0 additions & 29 deletions src/telegram.py

This file was deleted.

49 changes: 49 additions & 0 deletions src/telegram_bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from telegram.ext import Updater, CommandHandler, MessageHandler, Filters
from chatbot import telegram_init
from evaluate import evaluateExample

speakers = {
'cartman': '771434365:AAE8bpXZ6QMizyZxmvr6Mke6pOUFW0suN9E',
'stan': '797666745:AAG5L9qHQgQIERXFf47mD1QdatLEfXH0p2c',
'kyle': '798154398:AAGHAUZ8l-Zi_bMr8KrRwkZ1svju1_fW2S8',
'randy': '764465736:AAHeOYQMHStGNPZ1gNaPl2dZEiuorvBi_fI',
'<none>': '773092951:AAHMltKlernAmXHvO_TQ3B6mzkY1mv61rQc'
}


class TeleBot:
def __init__(self, speaker, token):
self.speaker = speaker
self.token = token
def run(self):
searcher, word_map, speaker_id = telegram_init(self.speaker)
updater = Updater(token=self.token)
dispatcher = updater.dispatcher
def start(bot, update):
bot.send_message(chat_id=update.message.chat_id, text=evaluateExample(
"",
searcher,
word_map,
speaker_id))
start_handler = CommandHandler('start', start)
dispatcher.add_handler(start_handler)
def response(bot, update):
bot.send_message(
chat_id=update.message.chat_id,
text=evaluateExample(
update.message.text,
searcher,
word_map,
speaker_id))
message_handler = MessageHandler(Filters.text, response)
dispatcher.add_handler(message_handler)
updater.start_polling()
print(f'Telegram bot {self.speaker} has started')

if __name__ == '__main__':
for speaker, token in speakers.items():
bot = TeleBot(speaker, token)
bot.run()



0 comments on commit 4afb2b1

Please sign in to comment.