@@ -640,6 +640,8 @@ end function get_num_params
640
640
module function get_params (self ) result(params)
641
641
class(layer), intent (in ) :: self
642
642
real , allocatable :: params(:)
643
+ real , pointer :: w_ptr(:)
644
+ real , pointer :: b_ptr(:)
643
645
644
646
select type (this_layer = > self % p)
645
647
type is (input1d_layer)
@@ -649,15 +651,27 @@ module function get_params(self) result(params)
649
651
type is (input3d_layer)
650
652
! No parameters to get.
651
653
type is (dense_layer)
652
- params = this_layer % get_params()
654
+ call this_layer % get_params_ptr(w_ptr, b_ptr)
655
+ allocate (params(size (w_ptr) + size (b_ptr)))
656
+ params(1 :size (w_ptr)) = w_ptr
657
+ params(size (w_ptr)+ 1 :) = b_ptr
653
658
type is (dropout_layer)
654
659
! No parameters to get.
655
660
type is (conv1d_layer)
656
- params = this_layer % get_params()
661
+ call this_layer % get_params_ptr(w_ptr, b_ptr)
662
+ allocate (params(size (w_ptr) + size (b_ptr)))
663
+ params(1 :size (w_ptr)) = w_ptr
664
+ params(size (w_ptr)+ 1 :) = b_ptr
657
665
type is (conv2d_layer)
658
- params = this_layer % get_params()
666
+ call this_layer % get_params_ptr(w_ptr, b_ptr)
667
+ allocate (params(size (w_ptr) + size (b_ptr)))
668
+ params(1 :size (w_ptr)) = w_ptr
669
+ params(size (w_ptr)+ 1 :) = b_ptr
659
670
type is (locally_connected2d_layer)
660
- params = this_layer % get_params()
671
+ call this_layer % get_params_ptr(w_ptr, b_ptr)
672
+ allocate (params(size (w_ptr) + size (b_ptr)))
673
+ params(1 :size (w_ptr)) = w_ptr
674
+ params(size (w_ptr)+ 1 :) = b_ptr
661
675
type is (maxpool1d_layer)
662
676
! No parameters to get.
663
677
type is (maxpool2d_layer)
@@ -669,7 +683,10 @@ module function get_params(self) result(params)
669
683
type is (reshape3d_layer)
670
684
! No parameters to get.
671
685
type is (linear2d_layer)
672
- params = this_layer % get_params()
686
+ call this_layer % get_params_ptr(w_ptr, b_ptr)
687
+ allocate (params(size (w_ptr) + size (b_ptr)))
688
+ params(1 :size (w_ptr)) = w_ptr
689
+ params(size (w_ptr)+ 1 :) = b_ptr
673
690
type is (self_attention_layer)
674
691
params = this_layer % get_params()
675
692
type is (embedding_layer)
0 commit comments