Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update telegram bots #37

Merged
merged 9 commits into from
Dec 2, 2018
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions src/chatbot.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from seq2seq import trainIters
from evaluate import evaluateInput
from embedding_map import EmbeddingMap
import telegram

DIR_PATH = os.path.dirname(__file__)
USE_CUDA = torch.cuda.is_available()
Expand Down Expand Up @@ -193,5 +194,25 @@ def main():
else:
print('Invalid speaker. Possible speakers:', person_map.tokens)

def telegram_init(speaker_name):
parser = argparse.ArgumentParser()
parser.add_argument('--checkpoint')
args = parser.parse_args()
config.USE_PERSONA = True
encoder, decoder, embedding, personas, word_map, person_map, _ = build_model(args.checkpoint)
if person_map.has(speaker_name):
print('Selected speaker:', speaker_name)
speaker_id = person_map.get_index(speaker_name)
# Set dropout layers to eval mode
encoder.eval()
decoder.eval()
# Initialize search module
if config.BEAM_SEARCH_ON:
searcher = BeamSearchDecoder(encoder, decoder)
else:
searcher = GreedySearchDecoder(encoder, decoder)
return searcher, word_map, speaker_id


if __name__ =='__main__':
main()
2 changes: 1 addition & 1 deletion src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
ATTN_MODEL = 'dot' # type of the attention model: dot/general/concat
TRAIN_EMBEDDING = True # whether to update the word embeddding during training
HIDDEN_SIZE = 300 # size of the word embedding & number of hidden units in GRU
PERSONA_SIZE = 100 # size of the persona embedding
PERSONA_SIZE = 30 # size of the persona embedding
ENCODER_N_LAYERS = 2 # number of layers in bi-GRU encoder
DECODER_N_LAYERS = 2 # number of layers in GRU decoder
ENCODER_DROPOUT_RATE = 0.1 # dropout rate in bi-GRU encoder
Expand Down
9 changes: 7 additions & 2 deletions src/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
import torch
import config
from data_util import normalizeString, indexes_from_sentence
import re
import datetime

USE_CUDA = torch.cuda.is_available()
device = torch.device("cuda" if USE_CUDA else "cpu")
Expand Down Expand Up @@ -55,5 +57,8 @@ def evaluateExample(sentence, searcher, word_map, speaker_id):
# evaluate sentence
output_words = evaluate(searcher, word_map, input_sentence, speaker_id)
output_words = [x for x in output_words if x not in config.SPECIAL_WORD_EMBEDDING_TOKENS.values()]

print('Bot:', ' '.join(output_words))
output_words[0] = output_words[0].capitalize()
res = ' '.join(output_words)
res = re.sub(r'\s+([.!?,])', r'\1', res)
print('Timestamp: {:%Y-%m-%d %H:%M:%S}> '.format(datetime.datetime.now()) + res)
return res
29 changes: 0 additions & 29 deletions src/telegram.py

This file was deleted.

49 changes: 49 additions & 0 deletions src/telegram_bot.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
from telegram.ext import Updater, CommandHandler, MessageHandler, Filters
from chatbot import telegram_init
from evaluate import evaluateExample

speakers = {
'cartman': '771434365:AAE8bpXZ6QMizyZxmvr6Mke6pOUFW0suN9E',
'stan': '797666745:AAG5L9qHQgQIERXFf47mD1QdatLEfXH0p2c',
'kyle': '798154398:AAGHAUZ8l-Zi_bMr8KrRwkZ1svju1_fW2S8',
'randy': '764465736:AAHeOYQMHStGNPZ1gNaPl2dZEiuorvBi_fI',
'<none>': '773092951:AAHMltKlernAmXHvO_TQ3B6mzkY1mv61rQc'
}


class TeleBot:
def __init__(self, speaker, token):
self.speaker = speaker
self.token = token
def run(self):
searcher, word_map, speaker_id = telegram_init(self.speaker)
updater = Updater(token=self.token)
dispatcher = updater.dispatcher
def start(bot, update):
bot.send_message(chat_id=update.message.chat_id, text=evaluateExample(
"",
searcher,
word_map,
speaker_id))
start_handler = CommandHandler('start', start)
dispatcher.add_handler(start_handler)
def response(bot, update):
bot.send_message(
chat_id=update.message.chat_id,
text=evaluateExample(
update.message.text,
searcher,
word_map,
speaker_id))
message_handler = MessageHandler(Filters.text, response)
dispatcher.add_handler(message_handler)
updater.start_polling()
print(f'Telegram bot {self.speaker} has started')

if __name__ == '__main__':
for speaker, token in speakers.items():
bot = TeleBot(speaker, token)
bot.run()