Skip to content

Commit 3658b8a

Browse files
authored
Merge pull request #227
This merging is a refactoring of the neural-fortran code. Instead of the functions get_params and set_params, a function get_params_ptr is implemented. All of the instances of the old two functions have been replaced with the new one.
2 parents b2073fa + 226b030 commit 3658b8a

16 files changed

+151
-342
lines changed

src/nf/nf_conv1d_layer.f90

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,8 @@ module nf_conv1d_layer
3333
procedure :: backward
3434
procedure :: get_gradients_ptr
3535
procedure :: get_num_params
36-
procedure :: get_params
3736
procedure :: get_params_ptr
3837
procedure :: init
39-
procedure :: set_params
4038

4139
end type conv1d_layer
4240

@@ -89,15 +87,6 @@ pure module function get_num_params(self) result(num_params)
8987
!! Number of parameters
9088
end function get_num_params
9189

92-
module function get_params(self) result(params)
93-
!! Return the parameters (weights and biases) of this layer.
94-
!! The parameters are ordered as weights first, biases second.
95-
class(conv1d_layer), intent(in), target :: self
96-
!! A `conv1d_layer` instance
97-
real, allocatable :: params(:)
98-
!! Parameters to get
99-
end function get_params
100-
10190
module subroutine get_params_ptr(self, w_ptr, b_ptr)
10291
!! Return pointers to the parameters (weights and biases) of this layer.
10392
class(conv1d_layer), intent(in), target :: self
@@ -118,14 +107,6 @@ module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
118107
!! Pointer to the bias gradients
119108
end subroutine get_gradients_ptr
120109

121-
module subroutine set_params(self, params)
122-
!! Set the parameters of the layer.
123-
class(conv1d_layer), intent(in out) :: self
124-
!! A `conv1d_layer` instance
125-
real, intent(in) :: params(:)
126-
!! Parameters to set
127-
end subroutine set_params
128-
129110
end interface
130111

131112
end module nf_conv1d_layer

src/nf/nf_conv1d_layer_submodule.f90

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -144,14 +144,6 @@ pure module function get_num_params(self) result(num_params)
144144
num_params = product(shape(self % kernel)) + size(self % biases)
145145
end function get_num_params
146146

147-
module function get_params(self) result(params)
148-
class(conv1d_layer), intent(in), target :: self
149-
real, allocatable :: params(:)
150-
real, pointer :: w_(:) => null()
151-
w_(1:size(self % kernel)) => self % kernel
152-
params = [ w_, self % biases]
153-
end function get_params
154-
155147
module subroutine get_params_ptr(self, w_ptr, b_ptr)
156148
class(conv1d_layer), intent(in), target :: self
157149
real, pointer, intent(out) :: w_ptr(:)
@@ -168,19 +160,4 @@ module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
168160
db_ptr => self % db
169161
end subroutine get_gradients_ptr
170162

171-
module subroutine set_params(self, params)
172-
class(conv1d_layer), intent(in out) :: self
173-
real, intent(in) :: params(:)
174-
175-
if (size(params) /= self % get_num_params()) then
176-
error stop 'conv1d_layer % set_params: Number of parameters does not match'
177-
end if
178-
179-
self % kernel = reshape(params(:product(shape(self % kernel))), shape(self % kernel))
180-
associate(n => product(shape(self % kernel)))
181-
self % biases = params(n + 1 : n + self % filters)
182-
end associate
183-
184-
end subroutine set_params
185-
186163
end submodule nf_conv1d_layer_submodule

src/nf/nf_conv2d_layer.f90

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,8 @@ module nf_conv2d_layer
3434
procedure :: backward
3535
procedure :: get_gradients_ptr
3636
procedure :: get_num_params
37-
procedure :: get_params
3837
procedure :: get_params_ptr
3938
procedure :: init
40-
procedure :: set_params
4139

4240
end type conv2d_layer
4341

@@ -90,15 +88,6 @@ pure module function get_num_params(self) result(num_params)
9088
!! Number of parameters
9189
end function get_num_params
9290

93-
module function get_params(self) result(params)
94-
!! Return the parameters (weights and biases) of this layer.
95-
!! The parameters are ordered as weights first, biases second.
96-
class(conv2d_layer), intent(in), target :: self
97-
!! A `conv2d_layer` instance
98-
real, allocatable :: params(:)
99-
!! Parameters to get
100-
end function get_params
101-
10291
module subroutine get_params_ptr(self, w_ptr, b_ptr)
10392
!! Return pointers to the parameters (weights and biases) of this layer.
10493
class(conv2d_layer), intent(in), target :: self
@@ -119,14 +108,6 @@ module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
119108
!! Pointer to the bias gradients
120109
end subroutine get_gradients_ptr
121110

122-
module subroutine set_params(self, params)
123-
!! Set the parameters of the layer.
124-
class(conv2d_layer), intent(in out) :: self
125-
!! A `conv2d_layer` instance
126-
real, intent(in) :: params(:)
127-
!! Parameters to set
128-
end subroutine set_params
129-
130111
end interface
131112

132113
end module nf_conv2d_layer

src/nf/nf_conv2d_layer_submodule.f90

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -188,22 +188,6 @@ pure module function get_num_params(self) result(num_params)
188188
num_params = product(shape(self % kernel)) + size(self % biases)
189189
end function get_num_params
190190

191-
192-
module function get_params(self) result(params)
193-
class(conv2d_layer), intent(in), target :: self
194-
real, allocatable :: params(:)
195-
196-
real, pointer :: w_(:) => null()
197-
198-
w_(1:size(self % kernel)) => self % kernel
199-
200-
params = [ &
201-
w_, &
202-
self % biases &
203-
]
204-
205-
end function get_params
206-
207191

208192
module subroutine get_params_ptr(self, w_ptr, b_ptr)
209193
class(conv2d_layer), intent(in), target :: self
@@ -222,27 +206,4 @@ module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
222206
db_ptr => self % db
223207
end subroutine get_gradients_ptr
224208

225-
226-
module subroutine set_params(self, params)
227-
class(conv2d_layer), intent(in out) :: self
228-
real, intent(in) :: params(:)
229-
230-
! Check that the number of parameters is correct.
231-
if (size(params) /= self % get_num_params()) then
232-
error stop 'conv2d % set_params: Number of parameters does not match'
233-
end if
234-
235-
! Reshape the kernel.
236-
self % kernel = reshape( &
237-
params(:product(shape(self % kernel))), &
238-
shape(self % kernel) &
239-
)
240-
241-
! Reshape the biases.
242-
associate(n => product(shape(self % kernel)))
243-
self % biases = params(n + 1 : n + self % filters)
244-
end associate
245-
246-
end subroutine set_params
247-
248209
end submodule nf_conv2d_layer_submodule

src/nf/nf_dense_layer.f90

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,8 @@ module nf_dense_layer
3535
procedure :: forward
3636
procedure :: get_gradients_ptr
3737
procedure :: get_num_params
38-
procedure :: get_params
3938
procedure :: get_params_ptr
4039
procedure :: init
41-
procedure :: set_params
4240

4341
end type dense_layer
4442

@@ -88,15 +86,6 @@ pure module function get_num_params(self) result(num_params)
8886
!! Number of parameters in this layer
8987
end function get_num_params
9088

91-
module function get_params(self) result(params)
92-
!! Return the parameters (weights and biases) of this layer.
93-
!! The parameters are ordered as weights first, biases second.
94-
class(dense_layer), intent(in), target :: self
95-
!! Dense layer instance
96-
real, allocatable :: params(:)
97-
!! Parameters of this layer
98-
end function get_params
99-
10089
module subroutine get_params_ptr(self, w_ptr, b_ptr)
10190
class(dense_layer), intent(in), target :: self
10291
real, pointer, intent(out) :: w_ptr(:)

src/nf/nf_dense_layer_submodule.f90

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -61,22 +61,6 @@ pure module function get_num_params(self) result(num_params)
6161
end function get_num_params
6262

6363

64-
module function get_params(self) result(params)
65-
class(dense_layer), intent(in), target :: self
66-
real, allocatable :: params(:)
67-
68-
real, pointer :: w_(:) => null()
69-
70-
w_(1:size(self % weights)) => self % weights
71-
72-
params = [ &
73-
w_, &
74-
self % biases &
75-
]
76-
77-
end function get_params
78-
79-
8064
module subroutine get_params_ptr(self, w_ptr, b_ptr)
8165
class(dense_layer), intent(in), target :: self
8266
real, pointer, intent(out) :: w_ptr(:)
@@ -94,30 +78,6 @@ module subroutine get_gradients_ptr(self, dw_ptr, db_ptr)
9478
db_ptr => self % db
9579
end subroutine get_gradients_ptr
9680

97-
98-
module subroutine set_params(self, params)
99-
class(dense_layer), intent(in out) :: self
100-
real, intent(in), target :: params(:)
101-
102-
real, pointer :: p_(:,:) => null()
103-
104-
! check if the number of parameters is correct
105-
if (size(params) /= self % get_num_params()) then
106-
error stop 'Error: number of parameters does not match'
107-
end if
108-
109-
associate(n => self % input_size * self % output_size)
110-
! reshape the weights
111-
p_(1:self % input_size, 1:self % output_size) => params(1 : n)
112-
self % weights = p_
113-
114-
! reshape the biases
115-
self % biases = params(n + 1 : n + self % output_size)
116-
end associate
117-
118-
end subroutine set_params
119-
120-
12181
module subroutine init(self, input_shape)
12282
class(dense_layer), intent(in out) :: self
12383
integer, intent(in) :: input_shape(:)

src/nf/nf_layer.f90

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,17 @@ module function get_params(self) result(params)
160160
!! Parameters of this layer
161161
end function get_params
162162

163+
module subroutine get_params_ptr(self, w_ptr, b_ptr)
164+
!! Returns the parameters of this layer as pointers.
165+
!! This is used for layers that have weights and biases.
166+
class(layer), intent(in) :: self
167+
!! Layer instance
168+
real, pointer :: w_ptr(:)
169+
!! Pointer to weights of this layer
170+
real, pointer :: b_ptr(:)
171+
!! Pointer to biases of this layer
172+
end subroutine get_params_ptr
173+
163174
module subroutine set_params(self, params)
164175
!! Returns the parameters of this layer.
165176
class(layer), intent(in out) :: self

0 commit comments

Comments
 (0)