|
8 | 8 | use nf_input3d_layer, only: input3d_layer
|
9 | 9 | use nf_maxpool2d_layer, only: maxpool2d_layer
|
10 | 10 | use nf_reshape_layer, only: reshape3d_layer
|
| 11 | + use nf_optimizers, only: optimizer_base_type |
11 | 12 |
|
12 | 13 | contains
|
13 | 14 |
|
@@ -382,15 +383,54 @@ module subroutine set_params(self, params)
|
382 | 383 | end subroutine set_params
|
383 | 384 |
|
384 | 385 |
|
385 |
| - impure elemental module subroutine update(self, learning_rate) |
| 386 | + impure elemental module subroutine update(self, optimizer, batch_size) |
386 | 387 | class(layer), intent(in out) :: self
|
387 |
| - real, intent(in) :: learning_rate |
| 388 | + class(optimizer_base_type), intent(in) :: optimizer |
| 389 | + integer, intent(in), optional :: batch_size |
| 390 | + integer :: batch_size_ |
| 391 | + |
| 392 | + batch_size_ = 1 |
| 393 | + if (present(batch_size)) batch_size_ = batch_size |
| 394 | + |
| 395 | + select type (this_layer => self % p) |
| 396 | + type is (dense_layer) |
| 397 | + |
| 398 | + ! Sum weight and bias gradients across images, if any |
| 399 | + call co_sum(this_layer % dw) |
| 400 | + call co_sum(this_layer % db) |
| 401 | + |
| 402 | + call optimizer % minimize( & |
| 403 | + this_layer % weights, & |
| 404 | + this_layer % dw / batch_size_ & |
| 405 | + ) |
| 406 | + call optimizer % minimize( & |
| 407 | + this_layer % biases, & |
| 408 | + this_layer % db / batch_size_ & |
| 409 | + ) |
| 410 | + |
| 411 | + ! Reset gradients. |
| 412 | + this_layer % dw = 0 |
| 413 | + this_layer % db = 0 |
| 414 | + |
| 415 | + type is (conv2d_layer) |
| 416 | + |
| 417 | + ! Sum weight and bias gradients across images, if any |
| 418 | + call co_sum(this_layer % dw) |
| 419 | + call co_sum(this_layer % db) |
| 420 | + |
| 421 | + call optimizer % minimize( & |
| 422 | + this_layer % kernel, & |
| 423 | + this_layer % dw / batch_size_ & |
| 424 | + ) |
| 425 | + call optimizer % minimize( & |
| 426 | + this_layer % biases, & |
| 427 | + this_layer % db / batch_size_ & |
| 428 | + ) |
| 429 | + |
| 430 | + ! Reset gradients. |
| 431 | + this_layer % dw = 0 |
| 432 | + this_layer % db = 0 |
388 | 433 |
|
389 |
| - select type(this_layer => self % p) |
390 |
| - type is(dense_layer) |
391 |
| - call this_layer % update(learning_rate) |
392 |
| - type is(conv2d_layer) |
393 |
| - call this_layer % update(learning_rate) |
394 | 434 | end select
|
395 | 435 |
|
396 | 436 | end subroutine update
|
|
0 commit comments