-
Notifications
You must be signed in to change notification settings - Fork 64
/
Copy pathdecode_beam.py
184 lines (142 loc) · 6.32 KB
/
decode_beam.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
import operator
import torch
import torch.nn as nn
import torch.nn.functional as F
from Queue import PriorityQueue
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
SOS_token = 0
EOS_token = 1
MAX_LENGTH = 50
class DecoderRNN(nn.Module):
def __init__(self, embedding_size, hidden_size, output_size, cell_type, dropout=0.1):
'''
Illustrative decoder
'''
super(DecoderRNN, self).__init__()
self.hidden_size = hidden_size
self.cell_type = cell_type
self.embedding = nn.Embedding(num_embeddings=output_size,
embedding_dim=embedding_size,
)
self.rnn = nn.GRU(embedding_size, hidden_size, bidirectional=True, dropout=dropout, batch_first=False)
self.dropout_rate = dropout
self.out = nn.Linear(hidden_size, output_size)
def forward(self, input, hidden, not_used):
embedded = self.embedding(input).transpose(0, 1) # [B,1] -> [ 1, B, D]
embedded = F.dropout(embedded, self.dropout_rate)
output = embedded
output, hidden = self.rnn(output, hidden)
out = self.out(output.squeeze(0))
output = F.log_softmax(out, dim=1)
return output, hidden
class BeamSearchNode(object):
def __init__(self, hiddenstate, previousNode, wordId, logProb, length):
'''
:param hiddenstate:
:param previousNode:
:param wordId:
:param logProb:
:param length:
'''
self.h = hiddenstate
self.prevNode = previousNode
self.wordid = wordId
self.logp = logProb
self.leng = length
def eval(self, alpha=1.0):
reward = 0
# Add here a function for shaping a reward
return self.logp / float(self.leng - 1 + 1e-6) + alpha * reward
decoder = DecoderRNN()
def beam_decode(target_tensor, decoder_hiddens, encoder_outputs=None):
'''
:param target_tensor: target indexes tensor of shape [B, T] where B is the batch size and T is the maximum length of the output sentence
:param decoder_hidden: input tensor of shape [1, B, H] for start of the decoding
:param encoder_outputs: if you are using attention mechanism you can pass encoder outputs, [T, B, H] where T is the maximum length of input sentence
:return: decoded_batch
'''
beam_width = 10
topk = 1 # how many sentence do you want to generate
decoded_batch = []
# decoding goes sentence by sentence
for idx in range(target_tensor.size(0)):
if isinstance(decoder_hiddens, tuple): # LSTM case
decoder_hidden = (decoder_hiddens[0][:,idx, :].unsqueeze(0),decoder_hiddens[1][:,idx, :].unsqueeze(0))
else:
decoder_hidden = decoder_hiddens[:, idx, :].unsqueeze(0)
encoder_output = encoder_outputs[:,idx, :].unsqueeze(1)
# Start with the start of the sentence token
decoder_input = torch.LongTensor([[SOS_token]], device=device)
# Number of sentence to generate
endnodes = []
number_required = min((topk + 1), topk - len(endnodes))
# starting node - hidden vector, previous node, word id, logp, length
node = BeamSearchNode(decoder_hidden, None, decoder_input, 0, 1)
nodes = PriorityQueue()
# start the queue
nodes.put((-node.eval(), node))
qsize = 1
# start beam search
while True:
# give up when decoding takes too long
if qsize > 2000: break
# fetch the best node
score, n = nodes.get()
decoder_input = n.wordid
decoder_hidden = n.h
if n.wordid.item() == EOS_token and n.prevNode != None:
endnodes.append((score, n))
# if we reached maximum # of sentences required
if len(endnodes) >= number_required:
break
else:
continue
# decode for one step using decoder
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_output)
# PUT HERE REAL BEAM SEARCH OF TOP
log_prob, indexes = torch.topk(decoder_output, beam_width)
nextnodes = []
for new_k in range(beam_width):
decoded_t = indexes[0][new_k].view(1, -1)
log_p = log_prob[0][new_k].item()
node = BeamSearchNode(decoder_hidden, n, decoded_t, n.logp + log_p, n.leng + 1)
score = -node.eval()
nextnodes.append((score, node))
# put them into queue
for i in range(len(nextnodes)):
score, nn = nextnodes[i]
nodes.put((score, nn))
# increase qsize
qsize += len(nextnodes) - 1
# choose nbest paths, back trace them
if len(endnodes) == 0:
endnodes = [nodes.get() for _ in range(topk)]
utterances = []
for score, n in sorted(endnodes, key=operator.itemgetter(0)):
utterance = []
utterance.append(n.wordid)
# back trace
while n.prevNode != None:
n = n.prevNode
utterance.append(n.wordid)
utterance = utterance[::-1]
utterances.append(utterance)
decoded_batch.append(utterances)
return decoded_batch
def greedy_decode(decoder_hidden, encoder_outputs, target_tensor):
'''
:param target_tensor: target indexes tensor of shape [B, T] where B is the batch size and T is the maximum length of the output sentence
:param decoder_hidden: input tensor of shape [1, B, H] for start of the decoding
:param encoder_outputs: if you are using attention mechanism you can pass encoder outputs, [T, B, H] where T is the maximum length of input sentence
:return: decoded_batch
'''
batch_size, seq_len = target_tensor.size()
decoded_batch = torch.zeros((batch_size, MAX_LENGTH))
decoder_input = torch.LongTensor([[SOS_token] for _ in range(batch_size)], device=device)
for t in range(MAX_LENGTH):
decoder_output, decoder_hidden = decoder(decoder_input, decoder_hidden, encoder_outputs)
topv, topi = decoder_output.data.topk(1) # get candidates
topi = topi.view(-1)
decoded_batch[:, t] = topi
decoder_input = topi.detach().view(-1, 1)
return decoded_batch