Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 131 additions & 86 deletions test/test_flatten_layer.f90
Original file line number Diff line number Diff line change
@@ -1,126 +1,171 @@
program test_flatten_layer

use iso_fortran_env, only: stderr => error_unit
use nf, only: dense, flatten, input, layer, network
use nf_flatten_layer, only: flatten_layer
use nf_input2d_layer, only: input2d_layer
use nf_input3d_layer, only: input3d_layer
use nf, only: dense, flatten, input, layer, network
use nf_flatten_layer, only: flatten_layer
use nf_input2d_layer, only: input2d_layer
use nf_input3d_layer, only: input3d_layer

implicit none

type(layer) :: test_layer, input_layer
type(layer) :: test_layer, input_layer
type(network) :: net
real, allocatable :: gradient_3d(:,:,:), gradient_2d(:,:)
real, allocatable :: output(:)
logical :: ok = .true.

! Test 3D input
call banner('TEST FLATTEN')

! ---------- 3D INPUT ----------
test_layer = flatten()

if (.not. test_layer % name == 'flatten') then
ok = .false.
write(stderr, '(a)') 'flatten layer has its name set correctly.. failed'
end if
call assert_true(trim(test_layer%name) == 'flatten', &
"flatten layer has its name set correctly.. failed", ok)

if (test_layer % initialized) then
ok = .false.
write(stderr, '(a)') 'flatten layer is not initialized yet.. failed'
end if
call assert_true(.not. test_layer%initialized, &
"flatten layer is not initialized yet.. failed", ok)

input_layer = input(1, 2, 2)
call test_layer % init(input_layer)

if (.not. test_layer % initialized) then
ok = .false.
write(stderr, '(a)') 'flatten layer is now initialized.. failed'
end if

if (.not. all(test_layer % layer_shape == [4])) then
ok = .false.
write(stderr, '(a)') 'flatten layer has an incorrect output shape.. failed'
end if

! Test forward pass - reshaping from 3-d to 1-d

select type(this_layer => input_layer % p); type is(input3d_layer)
call this_layer % set(reshape(real([1, 2, 3, 4]), [1, 2, 2]))
end select
call test_layer%init(input_layer)

call test_layer % forward(input_layer)
call test_layer % get_output(output)
call assert_true(test_layer%initialized, &
"flatten layer is now initialized.. failed", ok)

if (.not. all(output == [1, 2, 3, 4])) then
ok = .false.
write(stderr, '(a)') 'flatten layer correctly propagates forward.. failed'
end if
call assert_true(all(test_layer%layer_shape == [4]), &
"flatten layer has an incorrect output shape.. failed", ok)

! Test backward pass - reshaping from 1-d to 3-d
! Forward 3D -> 1D
call set_input3d(input_layer, reshape(real([1,2,3,4]), [1,2,2]))
call test_layer%forward(input_layer)
call test_layer%get_output(output)

! Calling backward() will set the values on the gradient component
! input_layer is used only to determine shape
call test_layer % backward(input_layer, real([1, 2, 3, 4]))
call assert_true(size(output) == 4, &
"flatten forward output size (3D) mismatch.. failed", ok)
call assert_true(all(output == [1,2,3,4]), &
"flatten layer correctly propagates forward.. failed", ok)

select type(this_layer => test_layer % p); type is(flatten_layer)
gradient_3d = this_layer % gradient_3d
end select
! Backward 1D -> 3D
call test_layer%backward(input_layer, real([1,2,3,4]))
call grab_flatten_gradients3d(test_layer, gradient_3d)

if (.not. all(gradient_3d == reshape(real([1, 2, 3, 4]), [1, 2, 2]))) then
ok = .false.
write(stderr, '(a)') 'flatten layer correctly propagates backward.. failed'
end if
call assert_true(allocated(gradient_3d), &
"gradient_3d not allocated after backward.. failed", ok)
call assert_true(all(gradient_3d == reshape(real([1,2,3,4]), [1,2,2])), &
"flatten layer correctly propagates backward.. failed", ok)

! Test 2D input
! ---------- 2D INPUT ----------
test_layer = flatten()
input_layer = input(2, 3)
call test_layer % init(input_layer)
call test_layer%init(input_layer)

if (.not. all(test_layer % layer_shape == [6])) then
ok = .false.
write(stderr, '(a)') 'flatten layer has an incorrect output shape for 2D input.. failed'
end if
call assert_true(all(test_layer%layer_shape == [6]), &
"flatten layer has an incorrect output shape for 2D input.. failed", ok)

! Test forward pass - reshaping from 2-d to 1-d
select type(this_layer => input_layer % p); type is(input2d_layer)
call this_layer % set(reshape(real([1, 2, 3, 4, 5, 6]), [2, 3]))
end select
! Forward 2D -> 1D
call set_input2d(input_layer, reshape(real([1,2,3,4,5,6]), [2,3]))
call test_layer%forward(input_layer)
call test_layer%get_output(output)

call test_layer % forward(input_layer)
call test_layer % get_output(output)
call assert_true(size(output) == 6, &
"flatten forward output size (2D) mismatch.. failed", ok)
call assert_true(all(output == [1,2,3,4,5,6]), &
"flatten layer correctly propagates forward for 2D input.. failed", ok)

if (.not. all(output == [1, 2, 3, 4, 5, 6])) then
ok = .false.
write(stderr, '(a)') 'flatten layer correctly propagates forward for 2D input.. failed'
end if
! Backward 1D -> 2D
call test_layer%backward(input_layer, real([1,2,3,4,5,6]))
call grab_flatten_gradients2d(test_layer, gradient_2d)

! Test backward pass - reshaping from 1-d to 2-d
call test_layer % backward(input_layer, real([1, 2, 3, 4, 5, 6]))
call assert_true(allocated(gradient_2d), &
"gradient_2d not allocated after backward.. failed", ok)
call assert_true(all(gradient_2d == reshape(real([1,2,3,4,5,6]), [2,3])), &
"flatten layer correctly propagates backward for 2D input.. failed", ok)

select type(this_layer => test_layer % p); type is(flatten_layer)
gradient_2d = this_layer % gradient_2d
end select
! ---------- CHAIN TO DENSE ----------
net = network([ input(1,28,28), flatten(), dense(10) ])

if (.not. all(gradient_2d == reshape(real([1, 2, 3, 4, 5, 6]), [2, 3]))) then
ok = .false.
write(stderr, '(a)') 'flatten layer correctly propagates backward for 2D input.. failed'
end if

net = network([ &
input(1, 28, 28), &
flatten(), &
dense(10) &
])

! Test that the output layer receives 784 elements in the input
if (.not. all(net % layers(3) % input_layer_shape == [784])) then
ok = .false.
write(stderr, '(a)') 'flatten layer correctly chains input3d to dense.. failed'
end if
call assert_true(all(net%layers(3)%input_layer_shape == [784]), &
"flatten layer correctly chains input3d to dense.. failed", ok)

if (ok) then
print '(a)', 'test_flatten_layer: All tests passed.'
else
write(stderr, '(a)') 'test_flatten_layer: One or more tests failed.'
write(stderr,'(a)') 'test_flatten_layer: One or more tests failed.'
stop 1
end if

contains

subroutine banner(s)
character(*), intent(in) :: s
print '(a)', repeat('-', 8)//' '//trim(s)//' '//repeat('-', 8)
end subroutine banner

subroutine assert_true(cond, msg, ok)
logical, intent(in) :: cond
character(*), intent(in) :: msg
logical, intent(inout) :: ok
if (.not. cond) then
ok = .false.
write(stderr,'(a)') trim(msg)
end if
end subroutine assert_true

subroutine set_input3d(lay, x)
type(layer), intent(inout) :: lay
real, intent(in) :: x(:,:,:)
select type(p => lay%p)
type is (input3d_layer)
call p%set(x)
class default
call bail("expected input3d_layer in set_input3d")
end select
end subroutine set_input3d

subroutine set_input2d(lay, x)
type(layer), intent(inout) :: lay
real, intent(in) :: x(:,:)
select type(p => lay%p)
type is (input2d_layer)
call p%set(x)
class default
call bail("expected input2d_layer in set_input2d")
end select
end subroutine set_input2d

subroutine grab_flatten_gradients3d(lay, g)
type(layer), intent(in) :: lay
real, allocatable, intent(out) :: g(:,:,:)
select type(p => lay%p)
type is (flatten_layer)
if (allocated(p%gradient_3d)) then
g = p%gradient_3d
else
call bail("flatten_layer%gradient_3d is not allocated")
end if
class default
call bail("expected flatten_layer in grab_flatten_gradients3d")
end select
end subroutine grab_flatten_gradients3d

subroutine grab_flatten_gradients2d(lay, g)
type(layer), intent(in) :: lay
real, allocatable, intent(out) :: g(:,:)
select type(p => lay%p)
type is (flatten_layer)
if (allocated(p%gradient_2d)) then
g = p%gradient_2d
else
call bail("flatten_layer%gradient_2d is not allocated")
end if
class default
call bail("expected flatten_layer in grab_flatten_gradients2d")
end select
end subroutine grab_flatten_gradients2d

subroutine bail(msg)
character(*), intent(in) :: msg
write(stderr,'(a)') trim(msg)
error stop 2
end subroutine bail

end program test_flatten_layer