-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTalkToCharacter.py
77 lines (71 loc) · 3.26 KB
/
TalkToCharacter.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
#Start by suppressing warnings - end users do not need to see these
import sys
import warnings
if not sys.warnoptions:
warnings.simplefilter("ignore")
import logging
logging.getLogger('tensorflow').disabled = True
import gpt_2_simple as gpt2
#Get the model loaded
print('Loading AI...')
run_name='sampledata'
sess = gpt2.start_tf_sess()
gpt2.load_gpt2(sess,
run_name=run_name)
maxprefix = 20 #sets size of max prefix in number of previous lines
tokensafterprefix = 50 #sets how many words to generate after the prefix
defaultname = 'Character' #the spoken of the ai character in the training data
defaultnametoken = 'CHARACTER - ' #the form of the ai's name when starting a line
defaultuser = 'User' #the name of the user in the training data
defaultusertoken = 'USER - ' #the form of the user's name when starting a line
print('Welcome to TalkToCharacter')
print('This program lets you talk to a finetuned AI that resembles a specific character')
print('Please check the readme for usage guidelines')
print('---')
print('Configure the maximum prefix length')
print('This decides how many previous lines of conversation the AI will think about')
print('A larger number will give the AI better short term memory, but will take longer to respond')
print('5 = fast, 10 = normal, 15 = slow, 20 = very slow')
maxprefix = int(input('Please choose the maximum prefix length:'))
playername = input('Please choose a name:')
print(defaultnametoken + 'Hello ' + playername + '!')
#Create stack of previous messages
previousmessages = [defaultnametoken + 'Hello '+ defaultuser + '!\n']
while True:
playertext = input(playername.upper() + ' - ')
playertext = defaultusertoken + playertext + '\n' #Make the program think player is the username
previousmessages.append(playertext)
#Need code to handle stacking up to 10 previous messages
if len(previousmessages) <= maxprefix:
prefix = ''.join(previousmessages)
prefixlinecount = len(previousmessages)
else:
prefix = ''.join(previousmessages[-maxprefix:])
prefixlinecount = maxprefix
#print(fulltext)
success = 0
while success == 0:
#Figure out how long the output needs to be
length = len(prefix.split()) + tokensafterprefix
aitext = gpt2.generate(sess,
run_name=run_name,
prefix=prefix,
length=length,
return_as_list=True
)[0]
#print('start of AI text\n' + aitext +'\n end of AI text')
#print(prefixlinecount)
#Split output into individual lines
aitext = aitext.splitlines()
#Chop off the prefix
aitext = aitext[prefixlinecount:]
#Fetch the lines where AI speaks
aitext = [i for i in aitext if i.startswith(defaultnametoken)]
#If there is at least one appropriate line, this works
if len(aitext) > 0:
success = 1
#Use the first reply, most likely to be appropriate
aitext = aitext[0]
previousmessages.append(aitext + '\n')
#Make AI use the player's name
print(aitext.replace(defaultuser,playername))