Skip to content

Commit

Permalink
feat(training_configuration_t): add type, JSON I/O
Browse files Browse the repository at this point in the history
  • Loading branch information
rouson committed Nov 1, 2023
1 parent de39ecb commit 628de78
Show file tree
Hide file tree
Showing 7 changed files with 182 additions and 8 deletions.
6 changes: 3 additions & 3 deletions src/inference_engine/hyperparameters_s.f90
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,9 @@

lines = [ &
string_t(indent // '"hyperparameters": {'), &
string_t(indent // indent // '"' // mini_batches_key // '" : ' // mini_batches_string ), &
string_t(indent // indent // '"' // learning_rate_key // '" : ' // learning_rate_string ), &
string_t(indent // indent // '"' // optimizer_key // '" : "' // self%optimizer_ // '"'), &
string_t(indent // indent // '"' // mini_batches_key // '" : ' // trim(adjustl(mini_batches_string)) // "," ), &
string_t(indent // indent // '"' // learning_rate_key // '" : ' // trim(adjustl(learning_rate_string)) // "," ), &
string_t(indent // indent // '"' // optimizer_key // '" : "' // trim(adjustl(self%optimizer_ )) // '"'), &
string_t(indent // '}') &
]
end procedure
Expand Down
6 changes: 3 additions & 3 deletions src/inference_engine/network_configuration_s.f90
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,9 @@

lines = [ &
string_t(indent // '"network configuration": {'), &
string_t(indent // indent // '"' // skip_connections_key // '" : ' // skip_connections_string ), &
string_t(indent // indent // '"' // nodes_per_layer_key // '" : [' // trim(nodes_per_layer_string) // ']' ), &
string_t(indent // indent // '"' // activation_function_key // '" : "' // self%activation_function_ // '"'), &
string_t(indent // indent // '"' // skip_connections_key // '" : ' // trim(adjustl(skip_connections_string )) // ','), &
string_t(indent // indent // '"' // nodes_per_layer_key // '" : [' // trim(adjustl(nodes_per_layer_string )) // '],'), &
string_t(indent // indent // '"' // activation_function_key // '" : "' // trim(adjustl(self%activation_function_)) // '"' ), &
string_t(indent // '}') &
]
end procedure
Expand Down
53 changes: 53 additions & 0 deletions src/inference_engine/training_configuration_m.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
module training_configuration_m
use sourcery_m, only : string_t, file_t
use hyperparameters_m, only : hyperparameters_t
use network_configuration_m, only : network_configuration_t
implicit none

private
public :: training_configuration_t

type, extends(file_t) :: training_configuration_t
private
type(hyperparameters_t) hyperparameters_
type(network_configuration_t) network_configuration_
contains
procedure :: to_json
procedure :: equals
generic :: operator(==) => equals
end type

interface training_configuration_t

module function from_components(hyperparameters, network_configuration) result(training_configuration)
implicit none
type(hyperparameters_t), intent(in) :: hyperparameters
type(network_configuration_t), intent(in) :: network_configuration
type(training_configuration_t) training_configuration
end function

module function from_file(file_object) result(training_configuration)
implicit none
type(file_t), intent(in) :: file_object
type(training_configuration_t) training_configuration
end function

end interface

interface

pure module function to_json(self) result(json_lines)
implicit none
class(training_configuration_t), intent(in) :: self
type(string_t), allocatable :: json_lines(:)
end function

elemental module function equals(lhs, rhs) result(lhs_eq_rhs)
implicit none
class(training_configuration_t), intent(in) :: lhs, rhs
logical lhs_eq_rhs
end function

end interface

end module
49 changes: 49 additions & 0 deletions src/inference_engine/training_configuration_s.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
submodule(training_configuration_m) training_configuration_s
use assert_m, only : assert
implicit none

character(len=*), parameter :: header="{", footer="}", separator = ","

contains

module procedure from_components

training_configuration%hyperparameters_ = hyperparameters
training_configuration%network_configuration_ = network_configuration
training_configuration%file_t = file_t([ &
string_t(header), &
training_configuration%hyperparameters_%to_json(), &
string_t(separator), &
training_configuration%network_configuration_%to_json(), &
string_t(footer) &
])
end procedure

module procedure from_file
integer, parameter :: hyperparameters_start=2, hyperparameters_end=6, separator_line=7 ! line numbers
integer, parameter :: net_config_start=8, net_config_end=12 ! line numbers
integer, parameter :: file_start=hyperparameters_start-1, file_end=net_config_end+1 ! line numbers

training_configuration%file_t = file_object

associate(lines => training_configuration%file_t%lines())
call assert(trim(adjustl(lines(file_start)%string()))==header,"training_configuration_s(from_file): header",lines(file_start))
training_configuration%hyperparameters_ = hyperparameters_t(lines(hyperparameters_start:hyperparameters_end))
call assert(trim(adjustl(lines(separator_line)%string()))==separator,"training_configuration_s(from_file): separator", &
lines(file_start))
training_configuration%network_configuration_= network_configuration_t(lines(net_config_start:net_config_end))
call assert(trim(adjustl(lines(file_end)%string()))==footer, "training_configuration_s(from_file): footer", lines(file_end))
end associate
end procedure

module procedure to_json
json_lines = self%lines()
end procedure

module procedure equals
lhs_eq_rhs = &
lhs%hyperparameters_ == rhs%hyperparameters_ .and. &
lhs%network_configuration_ == rhs%network_configuration_
end procedure

end submodule training_configuration_s
5 changes: 3 additions & 2 deletions src/inference_engine_m.f90
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,19 @@ module inference_engine_m
!! Specify the user-facing modules, derived types, and type parameters
use activation_strategy_m, only : activation_strategy_t
use differentiable_activation_strategy_m, only : differentiable_activation_strategy_t
use hyperparameters_m, only : hyperparameters_t
use input_output_pair_m, only : input_output_pair_t, shuffle
use inference_engine_m_, only : inference_engine_t, difference_t
use kind_parameters_m, only : rkind
use mini_batch_m, only : mini_batch_t
use NetCDF_file_m, only : NetCDF_file_t
use network_configuration_m, only : network_configuration_t
use relu_m, only : relu_t
use sigmoid_m, only : sigmoid_t
use step_m, only : step_t
use swish_m, only : swish_t
use tensor_m, only : tensor_t
use trainable_engine_m, only : trainable_engine_t
use hyperparameters_m, only : hyperparameters_t
use network_configuration_m, only : network_configuration_t
use training_configuration_m, only : training_configuration_t
implicit none
end module
3 changes: 3 additions & 0 deletions test/main.f90
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,15 @@ program main
use trainable_engine_test_m, only : trainable_engine_test_t
use hyperparameters_test_m, only : hyperparameters_test_t
use network_configuration_test_m, only : network_configuration_test_t
use training_configuration_test_m, only : training_configuration_test_t
implicit none

type(inference_engine_test_t) inference_engine_test
type(asymmetric_engine_test_t) asymmetric_engine_test
type(trainable_engine_test_t) trainable_engine_test
type(hyperparameters_test_t) hyperparameters_test
type(network_configuration_test_t) network_configuration_test
type(training_configuration_test_t) training_configuration_test
real t_start, t_finish

integer :: passes=0, tests=0
Expand All @@ -24,6 +26,7 @@ program main
call trainable_engine_test%report(passes, tests)
call hyperparameters_test%report(passes, tests)
call network_configuration_test%report(passes, tests)
call training_configuration_test%report(passes, tests)
#ifndef __INTEL_FORTRAN
block
use netCDF_file_test_m, only : netCDF_file_test_t
Expand Down
68 changes: 68 additions & 0 deletions test/training_configuration_test_m.f90
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
! Copyright (c), The Regents of the University of California
! Terms of use are as specified in LICENSE.txt
module training_configuration_test_m
!! Test training_configuration_t object I/O and construction

! External dependencies
use assert_m, only : assert
use sourcery_m, only : string_t, test_t, test_result_t, file_t
use inference_engine_m, only : training_configuration_t, hyperparameters_t, network_configuration_t

! Internal dependencies
use training_configuration_m, only : training_configuration_t

implicit none

private
public :: training_configuration_test_t

type, extends(test_t) :: training_configuration_test_t
contains
procedure, nopass :: subject
procedure, nopass :: results
end type

contains

pure function subject() result(specimen)
character(len=:), allocatable :: specimen
specimen = "A training_configuration_t object"
end function

function results() result(test_results)
type(test_result_t), allocatable :: test_results(:)

character(len=*), parameter :: longest_description = &
"component-wise construction followed by conversion to and from JSON"

associate( &
descriptions => &
[ character(len=len(longest_description)) :: &
"component-wise construction followed by conversion to and from JSON" &
], &
outcomes => &
[ construct_and_convert_to_and_from_json() &
] &
)
call assert(size(descriptions) == size(outcomes),"training_configuration_test_m(results): size(descriptions)==size(outcomes)")
test_results = test_result_t(descriptions, outcomes)
end associate

end function

function construct_and_convert_to_and_from_json() result(test_passes)
logical test_passes


associate(training_configuration => training_configuration_t( &
hyperparameters_t(mini_batches=5, learning_rate=1., optimizer = "adam"), &
network_configuration_t(skip_connections=.false., nodes_per_layer=[2,72,2], activation_function="sigmoid") &
))
associate(from_json => training_configuration_t(file_t(training_configuration%to_json())))
test_passes = training_configuration == from_json
end associate
end associate

end function

end module training_configuration_test_m

0 comments on commit 628de78

Please sign in to comment.