forked from k2-fsa/icefall
-
Notifications
You must be signed in to change notification settings - Fork 0
/
model.py
477 lines (415 loc) · 15.3 KB
/
model.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
# Copyright 2021 Xiaomi Corp. (authors: Fangjun Kuang
# Mingshuang Luo)
#
# See ../../../../LICENSE for clarification regarding multiple authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Optional
import torch
import torch.nn as nn
from torch import Tensor
class TdnnLiGRU(nn.Module):
def __init__(
self, num_features: int, num_classes: int, subsampling_factor: int = 3
) -> None:
"""
Args:
num_features:
The input dimension of the model.
num_classes:
The output dimension of the model.
subsampling_factor:
It reduces the number of output frames by this factor.
"""
super().__init__()
self.num_features = num_features
self.num_classes = num_classes
self.subsampling_factor = subsampling_factor
self.tdnn = nn.Sequential(
nn.Conv1d(
in_channels=num_features,
out_channels=512,
kernel_size=3,
stride=1,
padding=1,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=512, affine=False),
nn.Conv1d(
in_channels=512,
out_channels=512,
kernel_size=3,
stride=1,
padding=1,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=512, affine=False),
nn.Conv1d(
in_channels=512,
out_channels=512,
kernel_size=3,
stride=1,
padding=1,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=512, affine=False),
nn.Conv1d(
in_channels=512,
out_channels=512,
kernel_size=3,
stride=self.subsampling_factor, # stride: subsampling_factor!
padding=1,
),
nn.ReLU(inplace=True),
nn.BatchNorm1d(num_features=512, affine=False),
)
self.ligrus = nn.ModuleList(
[
LiGRU(
input_shape=[None, None, 512],
hidden_size=512,
num_layers=1,
bidirectional=True,
)
for _ in range(4)
]
)
self.linears = nn.ModuleList(
[nn.Linear(in_features=1024, out_features=512) for _ in range(4)]
)
self.bnorms = nn.ModuleList(
[nn.BatchNorm1d(num_features=512, affine=False) for _ in range(4)]
)
self.dropout = nn.Dropout(0.2)
self.linear = nn.Linear(in_features=512, out_features=self.num_classes)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Args:
x:
Its shape is [N, C, T]
Returns:
The output tensor has shape [N, T, C]
"""
x = self.tdnn(x)
x = x.permute(0, 2, 1)
for ligru, linear, bnorm in zip(self.ligrus, self.linears, self.bnorms):
x_new, _ = ligru(x)
x_new = linear(x_new)
x_new = bnorm(x_new.permute(0, 2, 1)).permute(0, 2, 1)
# (N, T, C) -> (N, C, T) -> (N, T, C)
x_new = self.dropout(x_new)
x = x_new + x # skip connections
x = self.linear(x)
x = nn.functional.log_softmax(x, dim=-1)
return x
class LiGRU(torch.nn.Module):
"""This function implements a Light GRU (liGRU).
This LiGRU model is from speechbrain, please see
https://github.com/speechbrain/speechbrain/blob/develop/speechbrain/nnet/RNN.py
LiGRU is single-gate GRU model based on batch-norm + relu
activations + recurrent dropout. For more info see:
"M. Ravanelli, P. Brakel, M. Omologo, Y. Bengio,
Light Gated Recurrent Units for Speech Recognition,
in IEEE Transactions on Emerging Topics in Computational Intelligence,
2018" (https://arxiv.org/abs/1803.10225)
This is a custm RNN and to speed it up it must be compiled with
the torch just-in-time compiler (jit) right before using it.
You can compile it with:
compiled_model = torch.jit.script(model)
It accepts in input tensors formatted as (batch, time, fea).
In the case of 4d inputs like (batch, time, fea, channel) the tensor is
flattened as (batch, time, fea*channel).
Arguments
---------
hidden_size : int
Number of output neurons (i.e, the dimensionality of the output).
values (i.e, time and frequency kernel sizes respectively).
input_shape : tuple
The shape of an example input.
nonlinearity : str
Type of nonlinearity (tanh, relu).
normalization : str
Type of normalization for the ligru model (batchnorm, layernorm).
Every string different from batchnorm and layernorm will result
in no normalization.
num_layers : int
Number of layers to employ in the RNN architecture.
bias : bool
If True, the additive bias b is adopted.
dropout : float
It is the dropout factor (must be between 0 and 1).
bidirectional : bool
If True, a bidirectional model that scans the sequence both
right-to-left and left-to-right is used.
Example
-------
>>> inp_tensor = torch.rand([4, 10, 20])
>>> net = LiGRU(input_shape=inp_tensor.shape, hidden_size=5)
>>> out_tensor, _ = net(inp_tensor)
>>>
torch.Size([4, 10, 5])
"""
def __init__(
self,
hidden_size,
input_shape,
nonlinearity="relu",
normalization="batchnorm",
num_layers=1,
bias=True,
dropout=0.0,
bidirectional=False,
):
super().__init__()
self.hidden_size = hidden_size
self.nonlinearity = nonlinearity
self.num_layers = num_layers
self.normalization = normalization
self.bias = bias
self.dropout = dropout
self.bidirectional = bidirectional
self.reshape = False
# Computing the feature dimensionality
if len(input_shape) > 3:
self.reshape = True
self.fea_dim = float(torch.prod(torch.tensor(input_shape[2:])))
self.batch_size = input_shape[0]
self.rnn = self._init_layers()
def _init_layers(self):
"""Initializes the layers of the liGRU."""
rnn = torch.nn.ModuleList([])
current_dim = self.fea_dim
for i in range(self.num_layers):
rnn_lay = LiGRU_Layer(
current_dim,
self.hidden_size,
self.num_layers,
self.batch_size,
dropout=self.dropout,
nonlinearity=self.nonlinearity,
normalization=self.normalization,
bidirectional=self.bidirectional,
)
rnn.append(rnn_lay)
if self.bidirectional:
current_dim = self.hidden_size * 2
else:
current_dim = self.hidden_size
return rnn
def forward(self, x, hx: Optional[Tensor] = None):
"""Returns the output of the liGRU.
Arguments
---------
x : torch.Tensor
The input tensor.
hx : torch.Tensor
Starting hidden state.
"""
# Reshaping input tensors for 4d inputs
if self.reshape:
if x.ndim == 4:
x = x.reshape(x.shape[0], x.shape[1], x.shape[2] * x.shape[3])
# run ligru
output, hh = self._forward_ligru(x, hx=hx)
return output, hh
def _forward_ligru(self, x, hx: Optional[Tensor]):
"""Returns the output of the vanilla liGRU.
Arguments
---------
x : torch.Tensor
Input tensor.
hx : torch.Tensor
"""
h = []
if hx is not None:
if self.bidirectional:
hx = hx.reshape(self.num_layers, self.batch_size * 2, self.hidden_size)
# Processing the different layers
for i, ligru_lay in enumerate(self.rnn):
if hx is not None:
x = ligru_lay(x, hx=hx[i])
else:
x = ligru_lay(x, hx=None)
h.append(x[:, -1, :])
h = torch.stack(h, dim=1)
if self.bidirectional:
h = h.reshape(h.shape[1] * 2, h.shape[0], self.hidden_size)
else:
h = h.transpose(0, 1)
return x, h
class LiGRU_Layer(torch.nn.Module):
"""This function implements Light-Gated Recurrent Units (ligru) layer.
Arguments
---------
input_size : int
Feature dimensionality of the input tensors.
batch_size : int
Batch size of the input tensors.
hidden_size : int
Number of output neurons.
num_layers : int
Number of layers to employ in the RNN architecture.
nonlinearity : str
Type of nonlinearity (tanh, relu).
normalization : str
Type of normalization (batchnorm, layernorm).
Every string different from batchnorm and layernorm will result
in no normalization.
dropout : float
It is the dropout factor (must be between 0 and 1).
bidirectional : bool
if True, a bidirectional model that scans the sequence both
right-to-left and left-to-right is used.
"""
def __init__(
self,
input_size,
hidden_size,
num_layers,
batch_size,
dropout=0.0,
nonlinearity="relu",
normalization="batchnorm",
bidirectional=False,
):
super(LiGRU_Layer, self).__init__()
self.hidden_size = int(hidden_size)
self.input_size = int(input_size)
self.batch_size = batch_size
self.bidirectional = bidirectional
self.dropout = dropout
self.drop = torch.nn.Dropout(p=self.dropout, inplace=False)
self.N_drop_masks = 16000
self.drop_mask_cnt = 0
self.drop_mask_te = torch.tensor([1.0]).float()
self.w = nn.Linear(self.input_size, 2 * self.hidden_size, bias=False)
self.u = nn.Linear(self.hidden_size, 2 * self.hidden_size, bias=False)
# Initializing batch norm
self.normalize = False
if normalization == "batchnorm":
self.norm = nn.BatchNorm1d(2 * self.hidden_size, momentum=0.05)
self.normalize = True
elif normalization == "layernorm":
self.norm = torch.nn.LayerNorm(2 * self.hidden_size)
self.normalize = True
else:
# Normalization is disabled here. self.norm is only formally
# initialized to avoid jit issues.
self.norm = torch.nn.LayerNorm(2 * self.hidden_size)
self.normalize = True
# Initial state
self.register_buffer("h_init", torch.zeros(1, self.hidden_size))
# Setting the activation function
if nonlinearity == "tanh":
self.act = torch.nn.Tanh()
elif nonlinearity == "sin":
self.act = torch.sin
elif nonlinearity == "leaky_relu":
self.act = torch.nn.LeakyReLU()
else:
self.act = torch.nn.ReLU()
def forward(self, x, hx: Optional[Tensor] = None):
# type: (Tensor, Optional[Tensor]) -> Tensor # noqa F821
"""Returns the output of the liGRU layer.
Arguments
---------
x : torch.Tensor
Input tensor.
"""
if self.bidirectional:
x_flip = x.flip(1)
x = torch.cat([x, x_flip], dim=0)
# Change batch size if needed
self._change_batch_size(x)
# Feed-forward affine transformations (all steps in parallel)
w = self.w(x)
# Apply batch normalization
if self.normalize:
w_bn = self.norm(w.reshape(w.shape[0] * w.shape[1], w.shape[2]))
w = w_bn.reshape(w.shape[0], w.shape[1], w.shape[2])
# Processing time steps
if hx is not None:
h = self._ligru_cell(w, hx)
else:
h = self._ligru_cell(w, self.h_init)
if self.bidirectional:
h_f, h_b = h.chunk(2, dim=0)
h_b = h_b.flip(1)
h = torch.cat([h_f, h_b], dim=2)
return h
def _ligru_cell(self, w, ht):
"""Returns the hidden states for each time step.
Arguments
---------
wx : torch.Tensor
Linearly transformed input.
"""
hiddens = []
# Sampling dropout mask
drop_mask = self._sample_drop_mask(w)
# Loop over time axis
for k in range(w.shape[1]):
gates = w[:, k] + self.u(ht)
at, zt = gates.chunk(2, 1)
zt = torch.sigmoid(zt)
hcand = self.act(at) * drop_mask
ht = zt * ht + (1 - zt) * hcand
hiddens.append(ht)
# Stacking hidden states
h = torch.stack(hiddens, dim=1)
return h
def _init_drop(self, batch_size):
"""Initializes the recurrent dropout operation. To speed it up,
the dropout masks are sampled in advance.
"""
self.N_drop_masks = 16000
self.drop_mask_cnt = 0
self.register_buffer(
"drop_masks",
self.drop(torch.ones(self.N_drop_masks, self.hidden_size)).data,
)
self.register_buffer("drop_mask_te", torch.tensor([1.0]).float())
def _sample_drop_mask(self, w):
"""Selects one of the pre-defined dropout masks"""
if self.training:
# Sample new masks when needed
if self.drop_mask_cnt + self.batch_size > self.N_drop_masks:
self.drop_mask_cnt = 0
self.drop_masks = self.drop(
torch.ones(self.N_drop_masks, self.hidden_size, device=w.device)
).data
# Sampling the mask
left_boundary = self.drop_mask_cnt
right_boundary = self.drop_mask_cnt + self.batch_size
drop_mask = self.drop_masks[left_boundary:right_boundary]
self.drop_mask_cnt = self.drop_mask_cnt + self.batch_size
else:
self.drop_mask_te = self.drop_mask_te.to(w.device)
drop_mask = self.drop_mask_te
return drop_mask
def _change_batch_size(self, x):
"""This function changes the batch size when it is different from
the one detected in the initialization method. This might happen in
the case of multi-gpu or when we have different batch sizes in train
and test. We also update the h_int and drop masks.
"""
if self.batch_size != x.shape[0]:
self.batch_size = x.shape[0]
if self.training:
self.drop_masks = self.drop(
torch.ones(
self.N_drop_masks,
self.hidden_size,
device=x.device,
)
).data