22
22
23
23
24
24
def drop_block_2d (
25
- x , drop_prob : float = 0.1 , training : bool = False , block_size : int = 7 ,
26
- gamma_scale : float = 1.0 , drop_with_noise : bool = False ):
25
+ x , drop_prob : float = 0.1 , block_size : int = 7 , gamma_scale : float = 1.0 ,
26
+ with_noise : bool = False , inplace : bool = False , batchwise : bool = False ):
27
27
""" DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
28
28
29
29
DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
30
30
runs with success, but needs further validation and possibly optimization for lower runtime impact.
31
-
32
31
"""
33
- if drop_prob == 0. or not training :
34
- return x
35
- _ , _ , height , width = x .shape
36
- total_size = width * height
37
- clipped_block_size = min (block_size , min (width , height ))
32
+ B , C , H , W = x .shape
33
+ total_size = W * H
34
+ clipped_block_size = min (block_size , min (W , H ))
38
35
# seed_drop_rate, the gamma parameter
39
- seed_drop_rate = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
40
- (width - block_size + 1 ) *
41
- (height - block_size + 1 ))
36
+ gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
37
+ (W - block_size + 1 ) * (H - block_size + 1 ))
42
38
43
39
# Forces the block to be inside the feature map.
44
- w_i , h_i = torch .meshgrid (torch .arange (width ).to (x .device ), torch .arange (height ).to (x .device ))
45
- valid_block = ((w_i >= clipped_block_size // 2 ) & (w_i < width - (clipped_block_size - 1 ) // 2 )) & \
46
- ((h_i >= clipped_block_size // 2 ) & (h_i < height - (clipped_block_size - 1 ) // 2 ))
47
- valid_block = torch .reshape (valid_block , (1 , 1 , height , width )).float ()
48
-
49
- uniform_noise = torch .rand_like (x , dtype = torch .float32 )
50
- block_mask = ((2 - seed_drop_rate - valid_block + uniform_noise ) >= 1 ).float ()
40
+ w_i , h_i = torch .meshgrid (torch .arange (W ).to (x .device ), torch .arange (H ).to (x .device ))
41
+ valid_block = ((w_i >= clipped_block_size // 2 ) & (w_i < W - (clipped_block_size - 1 ) // 2 )) & \
42
+ ((h_i >= clipped_block_size // 2 ) & (h_i < H - (clipped_block_size - 1 ) // 2 ))
43
+ valid_block = torch .reshape (valid_block , (1 , 1 , H , W )).to (dtype = x .dtype )
44
+
45
+ if batchwise :
46
+ # one mask for whole batch, quite a bit faster
47
+ uniform_noise = torch .rand ((1 , C , H , W ), dtype = x .dtype , device = x .device )
48
+ else :
49
+ uniform_noise = torch .rand_like (x )
50
+ block_mask = ((2 - gamma - valid_block + uniform_noise ) >= 1 ).to (dtype = x .dtype )
51
51
block_mask = - F .max_pool2d (
52
52
- block_mask ,
53
- kernel_size = clipped_block_size , # block_size, ???
53
+ kernel_size = clipped_block_size , # block_size,
54
54
stride = 1 ,
55
55
padding = clipped_block_size // 2 )
56
56
57
- if drop_with_noise :
58
- normal_noise = torch .randn_like (x )
59
- x = x * block_mask + normal_noise * (1 - block_mask )
57
+ if with_noise :
58
+ normal_noise = torch .randn ((1 , C , H , W ), dtype = x .dtype , device = x .device ) if batchwise else torch .randn_like (x )
59
+ if inplace :
60
+ x .mul_ (block_mask ).add_ (normal_noise * (1 - block_mask ))
61
+ else :
62
+ x = x * block_mask + normal_noise * (1 - block_mask )
63
+ else :
64
+ normalize_scale = (block_mask .numel () / block_mask .to (dtype = torch .float32 ).sum ().add (1e-7 )).to (x .dtype )
65
+ if inplace :
66
+ x .mul_ (block_mask * normalize_scale )
67
+ else :
68
+ x = x * block_mask * normalize_scale
69
+ return x
70
+
71
+
72
+ def drop_block_fast_2d (
73
+ x : torch .Tensor , drop_prob : float = 0.1 , block_size : int = 7 ,
74
+ gamma_scale : float = 1.0 , with_noise : bool = False , inplace : bool = False , batchwise : bool = False ):
75
+ """ DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
76
+
77
+ DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
78
+ block mask at edges.
79
+ """
80
+ B , C , H , W = x .shape
81
+ total_size = W * H
82
+ clipped_block_size = min (block_size , min (W , H ))
83
+ gamma = gamma_scale * drop_prob * total_size / clipped_block_size ** 2 / (
84
+ (W - block_size + 1 ) * (H - block_size + 1 ))
85
+
86
+ if batchwise :
87
+ # one mask for whole batch, quite a bit faster
88
+ block_mask = torch .rand ((1 , C , H , W ), dtype = x .dtype , device = x .device ) < gamma
89
+ else :
90
+ # mask per batch element
91
+ block_mask = torch .rand_like (x ) < gamma
92
+ block_mask = F .max_pool2d (
93
+ block_mask .to (x .dtype ), kernel_size = clipped_block_size , stride = 1 , padding = clipped_block_size // 2 )
94
+
95
+ if with_noise :
96
+ normal_noise = torch .randn ((1 , C , H , W ), dtype = x .dtype , device = x .device ) if batchwise else torch .randn_like (x )
97
+ if inplace :
98
+ x .mul_ (1. - block_mask ).add_ (normal_noise * block_mask )
99
+ else :
100
+ x = x * (1. - block_mask ) + normal_noise * block_mask
60
101
else :
61
- normalize_scale = block_mask .numel () / (torch .sum (block_mask ) + 1e-7 )
62
- x = x * block_mask * normalize_scale
102
+ block_mask = 1 - block_mask
103
+ normalize_scale = (block_mask .numel () / block_mask .to (dtype = torch .float32 ).sum ().add (1e-7 )).to (dtype = x .dtype )
104
+ if inplace :
105
+ x .mul_ (block_mask * normalize_scale )
106
+ else :
107
+ x = x * block_mask * normalize_scale
63
108
return x
64
109
65
110
@@ -70,15 +115,28 @@ def __init__(self,
70
115
drop_prob = 0.1 ,
71
116
block_size = 7 ,
72
117
gamma_scale = 1.0 ,
73
- with_noise = False ):
118
+ with_noise = False ,
119
+ inplace = False ,
120
+ batchwise = False ,
121
+ fast = True ):
74
122
super (DropBlock2d , self ).__init__ ()
75
123
self .drop_prob = drop_prob
76
124
self .gamma_scale = gamma_scale
77
125
self .block_size = block_size
78
126
self .with_noise = with_noise
127
+ self .inplace = inplace
128
+ self .batchwise = batchwise
129
+ self .fast = fast # FIXME finish comparisons of fast vs not
79
130
80
131
def forward (self , x ):
81
- return drop_block_2d (x , self .drop_prob , self .training , self .block_size , self .gamma_scale , self .with_noise )
132
+ if not self .training or not self .drop_prob :
133
+ return x
134
+ if self .fast :
135
+ return drop_block_fast_2d (
136
+ x , self .drop_prob , self .block_size , self .gamma_scale , self .with_noise , self .inplace , self .batchwise )
137
+ else :
138
+ return drop_block_2d (
139
+ x , self .drop_prob , self .block_size , self .gamma_scale , self .with_noise , self .inplace , self .batchwise )
82
140
83
141
84
142
def drop_path (x , drop_prob : float = 0. , training : bool = False ):
0 commit comments