From c3ecf2ec308077fb9ef85cfa5920e7ea3d69429d Mon Sep 17 00:00:00 2001 From: Riccardo Orsi Date: Fri, 5 Sep 2025 10:28:28 +0200 Subject: [PATCH] Fixing segmentation error on test_flatten_layer --- test/test_flatten_layer.f90 | 217 ++++++++++++++++++++++-------------- 1 file changed, 131 insertions(+), 86 deletions(-) diff --git a/test/test_flatten_layer.f90 b/test/test_flatten_layer.f90 index 0dca4c1f..cc147d95 100644 --- a/test/test_flatten_layer.f90 +++ b/test/test_flatten_layer.f90 @@ -1,126 +1,171 @@ program test_flatten_layer use iso_fortran_env, only: stderr => error_unit - use nf, only: dense, flatten, input, layer, network - use nf_flatten_layer, only: flatten_layer - use nf_input2d_layer, only: input2d_layer - use nf_input3d_layer, only: input3d_layer + use nf, only: dense, flatten, input, layer, network + use nf_flatten_layer, only: flatten_layer + use nf_input2d_layer, only: input2d_layer + use nf_input3d_layer, only: input3d_layer implicit none - type(layer) :: test_layer, input_layer + type(layer) :: test_layer, input_layer type(network) :: net real, allocatable :: gradient_3d(:,:,:), gradient_2d(:,:) real, allocatable :: output(:) logical :: ok = .true. - ! Test 3D input + call banner('TEST FLATTEN') + + ! ---------- 3D INPUT ---------- test_layer = flatten() - if (.not. test_layer % name == 'flatten') then - ok = .false. - write(stderr, '(a)') 'flatten layer has its name set correctly.. failed' - end if + call assert_true(trim(test_layer%name) == 'flatten', & + "flatten layer has its name set correctly.. failed", ok) - if (test_layer % initialized) then - ok = .false. - write(stderr, '(a)') 'flatten layer is not initialized yet.. failed' - end if + call assert_true(.not. test_layer%initialized, & + "flatten layer is not initialized yet.. failed", ok) input_layer = input(1, 2, 2) - call test_layer % init(input_layer) - - if (.not. test_layer % initialized) then - ok = .false. - write(stderr, '(a)') 'flatten layer is now initialized.. failed' - end if - - if (.not. all(test_layer % layer_shape == [4])) then - ok = .false. - write(stderr, '(a)') 'flatten layer has an incorrect output shape.. failed' - end if - - ! Test forward pass - reshaping from 3-d to 1-d - - select type(this_layer => input_layer % p); type is(input3d_layer) - call this_layer % set(reshape(real([1, 2, 3, 4]), [1, 2, 2])) - end select + call test_layer%init(input_layer) - call test_layer % forward(input_layer) - call test_layer % get_output(output) + call assert_true(test_layer%initialized, & + "flatten layer is now initialized.. failed", ok) - if (.not. all(output == [1, 2, 3, 4])) then - ok = .false. - write(stderr, '(a)') 'flatten layer correctly propagates forward.. failed' - end if + call assert_true(all(test_layer%layer_shape == [4]), & + "flatten layer has an incorrect output shape.. failed", ok) - ! Test backward pass - reshaping from 1-d to 3-d + ! Forward 3D -> 1D + call set_input3d(input_layer, reshape(real([1,2,3,4]), [1,2,2])) + call test_layer%forward(input_layer) + call test_layer%get_output(output) - ! Calling backward() will set the values on the gradient component - ! input_layer is used only to determine shape - call test_layer % backward(input_layer, real([1, 2, 3, 4])) + call assert_true(size(output) == 4, & + "flatten forward output size (3D) mismatch.. failed", ok) + call assert_true(all(output == [1,2,3,4]), & + "flatten layer correctly propagates forward.. failed", ok) - select type(this_layer => test_layer % p); type is(flatten_layer) - gradient_3d = this_layer % gradient_3d - end select + ! Backward 1D -> 3D + call test_layer%backward(input_layer, real([1,2,3,4])) + call grab_flatten_gradients3d(test_layer, gradient_3d) - if (.not. all(gradient_3d == reshape(real([1, 2, 3, 4]), [1, 2, 2]))) then - ok = .false. - write(stderr, '(a)') 'flatten layer correctly propagates backward.. failed' - end if + call assert_true(allocated(gradient_3d), & + "gradient_3d not allocated after backward.. failed", ok) + call assert_true(all(gradient_3d == reshape(real([1,2,3,4]), [1,2,2])), & + "flatten layer correctly propagates backward.. failed", ok) - ! Test 2D input + ! ---------- 2D INPUT ---------- test_layer = flatten() input_layer = input(2, 3) - call test_layer % init(input_layer) + call test_layer%init(input_layer) - if (.not. all(test_layer % layer_shape == [6])) then - ok = .false. - write(stderr, '(a)') 'flatten layer has an incorrect output shape for 2D input.. failed' - end if + call assert_true(all(test_layer%layer_shape == [6]), & + "flatten layer has an incorrect output shape for 2D input.. failed", ok) - ! Test forward pass - reshaping from 2-d to 1-d - select type(this_layer => input_layer % p); type is(input2d_layer) - call this_layer % set(reshape(real([1, 2, 3, 4, 5, 6]), [2, 3])) - end select + ! Forward 2D -> 1D + call set_input2d(input_layer, reshape(real([1,2,3,4,5,6]), [2,3])) + call test_layer%forward(input_layer) + call test_layer%get_output(output) - call test_layer % forward(input_layer) - call test_layer % get_output(output) + call assert_true(size(output) == 6, & + "flatten forward output size (2D) mismatch.. failed", ok) + call assert_true(all(output == [1,2,3,4,5,6]), & + "flatten layer correctly propagates forward for 2D input.. failed", ok) - if (.not. all(output == [1, 2, 3, 4, 5, 6])) then - ok = .false. - write(stderr, '(a)') 'flatten layer correctly propagates forward for 2D input.. failed' - end if + ! Backward 1D -> 2D + call test_layer%backward(input_layer, real([1,2,3,4,5,6])) + call grab_flatten_gradients2d(test_layer, gradient_2d) - ! Test backward pass - reshaping from 1-d to 2-d - call test_layer % backward(input_layer, real([1, 2, 3, 4, 5, 6])) + call assert_true(allocated(gradient_2d), & + "gradient_2d not allocated after backward.. failed", ok) + call assert_true(all(gradient_2d == reshape(real([1,2,3,4,5,6]), [2,3])), & + "flatten layer correctly propagates backward for 2D input.. failed", ok) - select type(this_layer => test_layer % p); type is(flatten_layer) - gradient_2d = this_layer % gradient_2d - end select + ! ---------- CHAIN TO DENSE ---------- + net = network([ input(1,28,28), flatten(), dense(10) ]) - if (.not. all(gradient_2d == reshape(real([1, 2, 3, 4, 5, 6]), [2, 3]))) then - ok = .false. - write(stderr, '(a)') 'flatten layer correctly propagates backward for 2D input.. failed' - end if - - net = network([ & - input(1, 28, 28), & - flatten(), & - dense(10) & - ]) - - ! Test that the output layer receives 784 elements in the input - if (.not. all(net % layers(3) % input_layer_shape == [784])) then - ok = .false. - write(stderr, '(a)') 'flatten layer correctly chains input3d to dense.. failed' - end if + call assert_true(all(net%layers(3)%input_layer_shape == [784]), & + "flatten layer correctly chains input3d to dense.. failed", ok) if (ok) then print '(a)', 'test_flatten_layer: All tests passed.' else - write(stderr, '(a)') 'test_flatten_layer: One or more tests failed.' + write(stderr,'(a)') 'test_flatten_layer: One or more tests failed.' stop 1 end if +contains + + subroutine banner(s) + character(*), intent(in) :: s + print '(a)', repeat('-', 8)//' '//trim(s)//' '//repeat('-', 8) + end subroutine banner + + subroutine assert_true(cond, msg, ok) + logical, intent(in) :: cond + character(*), intent(in) :: msg + logical, intent(inout) :: ok + if (.not. cond) then + ok = .false. + write(stderr,'(a)') trim(msg) + end if + end subroutine assert_true + + subroutine set_input3d(lay, x) + type(layer), intent(inout) :: lay + real, intent(in) :: x(:,:,:) + select type(p => lay%p) + type is (input3d_layer) + call p%set(x) + class default + call bail("expected input3d_layer in set_input3d") + end select + end subroutine set_input3d + + subroutine set_input2d(lay, x) + type(layer), intent(inout) :: lay + real, intent(in) :: x(:,:) + select type(p => lay%p) + type is (input2d_layer) + call p%set(x) + class default + call bail("expected input2d_layer in set_input2d") + end select + end subroutine set_input2d + + subroutine grab_flatten_gradients3d(lay, g) + type(layer), intent(in) :: lay + real, allocatable, intent(out) :: g(:,:,:) + select type(p => lay%p) + type is (flatten_layer) + if (allocated(p%gradient_3d)) then + g = p%gradient_3d + else + call bail("flatten_layer%gradient_3d is not allocated") + end if + class default + call bail("expected flatten_layer in grab_flatten_gradients3d") + end select + end subroutine grab_flatten_gradients3d + + subroutine grab_flatten_gradients2d(lay, g) + type(layer), intent(in) :: lay + real, allocatable, intent(out) :: g(:,:) + select type(p => lay%p) + type is (flatten_layer) + if (allocated(p%gradient_2d)) then + g = p%gradient_2d + else + call bail("flatten_layer%gradient_2d is not allocated") + end if + class default + call bail("expected flatten_layer in grab_flatten_gradients2d") + end select + end subroutine grab_flatten_gradients2d + + subroutine bail(msg) + character(*), intent(in) :: msg + write(stderr,'(a)') trim(msg) + error stop 2 + end subroutine bail + end program test_flatten_layer