-
Notifications
You must be signed in to change notification settings - Fork 39
/
models.py
executable file
·147 lines (124 loc) · 7.31 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
from torch import nn
import torchvision
from torch.nn.utils.weight_norm import weight_norm
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
class Attention(nn.Module):
"""
Attention Network.
"""
def __init__(self, features_dim, decoder_dim, attention_dim, dropout=0.5):
"""
:param features_dim: feature size of encoded images
:param decoder_dim: size of decoder's RNN
:param attention_dim: size of the attention network
"""
super(Attention, self).__init__()
self.features_att = weight_norm(nn.Linear(features_dim, attention_dim)) # linear layer to transform encoded image
self.decoder_att = weight_norm(nn.Linear(decoder_dim, attention_dim)) # linear layer to transform decoder's output
self.full_att = weight_norm(nn.Linear(attention_dim, 1)) # linear layer to calculate values to be softmax-ed
self.relu = nn.ReLU()
self.dropout = nn.Dropout(p=dropout)
self.softmax = nn.Softmax(dim=1) # softmax layer to calculate weights
def forward(self, image_features, decoder_hidden):
"""
Forward propagation.
:param image_features: encoded images, a tensor of dimension (batch_size, 36, features_dim)
:param decoder_hidden: previous decoder output, a tensor of dimension (batch_size, decoder_dim)
:return: attention weighted encoding, weights
"""
att1 = self.features_att(image_features) # (batch_size, 36, attention_dim)
att2 = self.decoder_att(decoder_hidden) # (batch_size, attention_dim)
att = self.full_att(self.dropout(self.relu(att1 + att2.unsqueeze(1)))).squeeze(2) # (batch_size, 36)
alpha = self.softmax(att) # (batch_size, 36)
attention_weighted_encoding = (image_features * alpha.unsqueeze(2)).sum(dim=1) # (batch_size, features_dim)
return attention_weighted_encoding
class DecoderWithAttention(nn.Module):
"""
Decoder.
"""
def __init__(self, attention_dim, embed_dim, decoder_dim, vocab_size, features_dim=2048, dropout=0.5):
"""
:param attention_dim: size of attention network
:param embed_dim: embedding size
:param decoder_dim: size of decoder's RNN
:param vocab_size: size of vocabulary
:param features_dim: feature size of encoded images
:param dropout: dropout
"""
super(DecoderWithAttention, self).__init__()
self.features_dim = features_dim
self.attention_dim = attention_dim
self.embed_dim = embed_dim
self.decoder_dim = decoder_dim
self.vocab_size = vocab_size
self.dropout = dropout
self.attention = Attention(features_dim, decoder_dim, attention_dim) # attention network
self.embedding = nn.Embedding(vocab_size, embed_dim) # embedding layer
self.dropout = nn.Dropout(p=self.dropout)
self.top_down_attention = nn.LSTMCell(embed_dim + features_dim + decoder_dim, decoder_dim, bias=True) # top down attention LSTMCell
self.language_model = nn.LSTMCell(features_dim + decoder_dim, decoder_dim, bias=True) # language model LSTMCell
self.fc1 = weight_norm(nn.Linear(decoder_dim, vocab_size))
self.fc = weight_norm(nn.Linear(decoder_dim, vocab_size)) # linear layer to find scores over vocabulary
self.init_weights() # initialize some layers with the uniform distribution
def init_weights(self):
"""
Initializes some parameters with values from the uniform distribution, for easier convergence.
"""
self.embedding.weight.data.uniform_(-0.1, 0.1)
self.fc.bias.data.fill_(0)
self.fc.weight.data.uniform_(-0.1, 0.1)
def init_hidden_state(self,batch_size):
"""
Creates the initial hidden and cell states for the decoder's LSTM based on the encoded images.
:param encoder_out: encoded images, a tensor of dimension (batch_size, num_pixels, encoder_dim)
:return: hidden state, cell state
"""
h = torch.zeros(batch_size,self.decoder_dim).to(device) # (batch_size, decoder_dim)
c = torch.zeros(batch_size,self.decoder_dim).to(device)
return h, c
def forward(self, image_features, encoded_captions, caption_lengths):
"""
Forward propagation.
:param encoder_out: encoded images, a tensor of dimension (batch_size, enc_image_size, enc_image_size, encoder_dim)
:param encoded_captions: encoded captions, a tensor of dimension (batch_size, max_caption_length)
:param caption_lengths: caption lengths, a tensor of dimension (batch_size, 1)
:return: scores for vocabulary, sorted encoded captions, decode lengths, weights, sort indices
"""
batch_size = image_features.size(0)
vocab_size = self.vocab_size
# Flatten image
image_features_mean = image_features.mean(1).to(device) # (batch_size, num_pixels, encoder_dim)
# Sort input data by decreasing lengths; why? apparent below
caption_lengths, sort_ind = caption_lengths.squeeze(1).sort(dim=0, descending=True)
image_features = image_features[sort_ind]
image_features_mean = image_features_mean[sort_ind]
encoded_captions = encoded_captions[sort_ind]
# Embedding
embeddings = self.embedding(encoded_captions) # (batch_size, max_caption_length, embed_dim)
# Initialize LSTM state
h1, c1 = self.init_hidden_state(batch_size) # (batch_size, decoder_dim)
h2, c2 = self.init_hidden_state(batch_size) # (batch_size, decoder_dim)
# We won't decode at the <end> position, since we've finished generating as soon as we generate <end>
# So, decoding lengths are actual lengths - 1
decode_lengths = (caption_lengths - 1).tolist()
# Create tensors to hold word predicion scores
predictions = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
predictions1 = torch.zeros(batch_size, max(decode_lengths), vocab_size).to(device)
# At each time-step, pass the language model's previous hidden state, the mean pooled bottom up features and
# word embeddings to the top down attention model. Then pass the hidden state of the top down model and the bottom up
# features to the attention block. The attention weighed bottom up features and hidden state of the top down attention model
# are then passed to the language model
for t in range(max(decode_lengths)):
batch_size_t = sum([l > t for l in decode_lengths])
h1,c1 = self.top_down_attention(
torch.cat([h2[:batch_size_t],image_features_mean[:batch_size_t],embeddings[:batch_size_t, t, :]], dim=1),(h1[:batch_size_t], c1[:batch_size_t]))
attention_weighted_encoding = self.attention(image_features[:batch_size_t],h1[:batch_size_t])
preds1 = self.fc1(self.dropout(h1))
h2,c2 = self.language_model(
torch.cat([attention_weighted_encoding[:batch_size_t],h1[:batch_size_t]], dim=1),
(h2[:batch_size_t], c2[:batch_size_t]))
preds = self.fc(self.dropout(h2)) # (batch_size_t, vocab_size)
predictions[:batch_size_t, t, :] = preds
predictions1[:batch_size_t, t, :] = preds1
return predictions, predictions1,encoded_captions, decode_lengths, sort_ind