Skip to content

Commit 824bb13

Browse files
committed
Fix conv1d with stride
1 parent eb7e112 commit 824bb13

File tree

4 files changed

+34
-9
lines changed

4 files changed

+34
-9
lines changed

src/nf/nf_conv1d_layer.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ module subroutine init(self, input_shape)
6363
!! Input layer dimensions
6464
end subroutine init
6565

66-
pure module subroutine forward(self, input)
66+
module subroutine forward(self, input)
6767
!! Apply a forward pass on the `conv1d` layer.
6868
class(conv1d_layer), intent(in out) :: self
6969
!! A `conv1d_layer` instance

src/nf/nf_conv1d_layer_submodule.f90

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ module subroutine init(self, input_shape)
2727
integer, intent(in) :: input_shape(:)
2828

2929
self % channels = input_shape(1)
30-
self % width = (input_shape(2) - self % kernel_size + 1) / self % stride
30+
self % width = (input_shape(2) - self % kernel_size) / self % stride +1
31+
32+
if (mod(input_shape(2) - self % kernel_size , self % stride) /= 0) self % width = self % width + 1
3133

3234
! Output of shape: filters x width
3335
allocate(self % output(self % filters, self % width))
@@ -55,7 +57,7 @@ module subroutine init(self, input_shape)
5557

5658
end subroutine init
5759

58-
pure module subroutine forward(self, input)
60+
module subroutine forward(self, input)
5961
implicit none
6062
class(conv1d_layer), intent(in out) :: self
6163
real, intent(in) :: input(:,:)
@@ -71,7 +73,7 @@ pure module subroutine forward(self, input)
7173
! Compute the input window corresponding to output index j.
7274
! In forward: center index = j + half_window, so window = indices j to j+kernel_size-1.
7375
iws = self % stride * (j-1) + 1
74-
iwe = max(iws + self % kernel_size - 1, input_width)
76+
iwe = min(iws + self % kernel_size - 1, input_width)
7577

7678
! For each filter, compute the convolution (inner product over channels and kernel width).
7779
do concurrent (n = 1:self % filters)

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,6 @@ module function conv1d(filters, kernel_width, activation, stride) result(res)
3333
integer :: stride_tmp
3434
class(activation_function), allocatable :: activation_tmp
3535

36-
if (stride < 1) &
37-
error stop 'stride must be >= 1 in a conv1d layer'
38-
3936
res % name = 'conv1d'
4037

4138
if (present(activation)) then
@@ -52,6 +49,9 @@ module function conv1d(filters, kernel_width, activation, stride) result(res)
5249
stride_tmp = 1
5350
endif
5451

52+
if (stride_tmp < 1) &
53+
error stop 'stride must be >= 1 in a conv1d layer'
54+
5555
allocate( &
5656
res % p, &
5757
source=conv1d_layer(filters, kernel_width, activation_tmp, stride_tmp) &

test/test_conv1d_layer.f90

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ program test_conv1d_layer
5858
select type(this_layer => input_layer % p); type is(input2d_layer)
5959
call this_layer % set(sample_input)
6060
end select
61+
deallocate(sample_input)
6162

6263
call conv1d_layer % forward(input_layer)
6364
call conv1d_layer % get_output(output)
@@ -67,11 +68,33 @@ program test_conv1d_layer
6768
write(stderr, '(a)') 'conv1d layer with zero input and sigmoid function must forward to all 0.5.. failed'
6869
end if
6970

71+
! Minimal conv1d layer: 1 channel, 3x3 pixel image, stride = 3;
72+
allocate(sample_input(1, 17))
73+
sample_input = 0
74+
75+
input_layer = input(1, 17)
76+
conv1d_layer = conv(filters, kernel_size, stride = 3)
77+
call conv1d_layer % init(input_layer)
78+
79+
select type(this_layer => input_layer % p); type is(input2d_layer)
80+
call this_layer % set(sample_input)
81+
end select
82+
deallocate(sample_input)
83+
84+
call conv1d_layer % forward(input_layer)
85+
call conv1d_layer % get_output(output)
86+
87+
if (.not. all(abs(output) < tolerance)) then
88+
ok = .false.
89+
write(stderr, '(a)') 'conv1d layer with zero input and sigmoid function must forward to all 0.5.. failed'
90+
end if
91+
92+
!Final
7093
if (ok) then
7194
print '(a)', 'test_conv1d_layer: All tests passed.'
7295
else
7396
write(stderr, '(a)') 'test_conv1d_layer: One or more tests failed.'
74-
stop 1
97+
stop 2
7598
end if
76-
99+
77100
end program test_conv1d_layer

0 commit comments

Comments
 (0)