From 8dfa9b01d3202d994041dccbb80aa8d50b50081d Mon Sep 17 00:00:00 2001 From: moto Date: Thu, 27 Oct 2016 12:58:22 -0700 Subject: [PATCH] Add initializer compatibility test (#67) * Add initializer test * Fix initializers in Theano * Tweak layer and optimizer test * Add version print --- circle.yml | 1 + luchador/nn/core/__init__.py | 6 +- luchador/nn/core/theano/initializer.py | 47 +++---- tests/integration/README.md | 4 + .../integration/data/initializer/constant.yml | 12 ++ tests/integration/data/initializer/normal.yml | 13 ++ .../integration/data/initializer/uniform.yml | 14 +++ .../data/initializer/xavier_conv2d_normal.yml | 20 +++ .../initializer/xavier_conv2d_uniform.yml | 19 +++ .../data/initializer/xavier_normal.yml | 20 +++ .../data/initializer/xavier_uniform.yml | 19 +++ .../layer/batch_normalization_2d_learn.yml | 2 - .../batch_normalization_2d_not_learn.yml | 2 - .../layer/batch_normalization_4d_learn.yml | 3 - .../batch_normalization_4d_not_learn.yml | 2 - tests/integration/data/layer/conv2d_same.yml | 2 - tests/integration/data/layer/conv2d_valid.yml | 2 - .../data/layer/conv2d_without_bias.yml | 2 - tests/integration/data/layer/dense.yml | 2 - .../data/layer/dense_without_bias.yml | 2 - tests/integration/data/layer/flatten.yml | 2 - tests/integration/data/layer/relu.yml | 2 - tests/integration/data/layer/sigmoid.yml | 2 - tests/integration/data/layer/softmax.yml | 2 - tests/integration/data/layer/true_div.yml | 2 - .../run_initializer_compatibility_test.sh | 18 +++ ...run_layer_numerical_compatibility_tests.sh | 10 +- ...optimizer_numerical_compatibility_tests.sh | 6 + .../run_initializer.py | 119 ++++++++++++++++++ .../test_initializer_compatibility.sh | 40 ++++++ .../test_layer_numerical_compatibility.sh | 41 ++---- .../test_optimizer_numerical_compatibility.sh | 2 + 32 files changed, 355 insertions(+), 85 deletions(-) create mode 100644 tests/integration/data/initializer/constant.yml create mode 100644 tests/integration/data/initializer/normal.yml create mode 100644 tests/integration/data/initializer/uniform.yml create mode 100644 tests/integration/data/initializer/xavier_conv2d_normal.yml create mode 100644 tests/integration/data/initializer/xavier_conv2d_uniform.yml create mode 100644 tests/integration/data/initializer/xavier_normal.yml create mode 100644 tests/integration/data/initializer/xavier_uniform.yml create mode 100755 tests/integration/run_initializer_compatibility_test.sh create mode 100644 tests/integration/test_initializer_compatibility/run_initializer.py create mode 100755 tests/integration/test_initializer_compatibility/test_initializer_compatibility.sh diff --git a/circle.yml b/circle.yml index 2f7a70a5..6474cf21 100644 --- a/circle.yml +++ b/circle.yml @@ -35,6 +35,7 @@ test: - LUCHADOR_NN_BACKEND=tensorflow LUCHADOR_NN_CONV_FORMAT=NHWC coverage run --source luchador -a setup.py test # Integration tests - ./tests/integration/run_serialization_tests.sh + - ./tests/integration/run_initializer_compatibility_test.sh - ./tests/integration/run_layer_numerical_compatibility_tests.sh - ./tests/integration/run_optimizer_numerical_compatibility_tests.sh - LUCHADOR_NN_BACKEND=theano LUCHADOR_NN_CONV_FORMAT=NCHW ./tests/integration/run_dqn.sh diff --git a/luchador/nn/core/__init__.py b/luchador/nn/core/__init__.py index e092b850..e9646579 100644 --- a/luchador/nn/core/__init__.py +++ b/luchador/nn/core/__init__.py @@ -6,9 +6,9 @@ from .base import * # noqa: F401, F403 -logging.getLogger(__name__).info( - 'Using %s backend', luchador.get_nn_backend() -) +_LG = logging.getLogger(__name__) +_LG.info('Luchador Version: %s', luchador.__version__) +_LG.info('Luchador NN backend: %s', luchador.get_nn_backend()) if luchador.get_nn_backend() == 'tensorflow': from .tensorflow import * # noqa: F401, F403 diff --git a/luchador/nn/core/theano/initializer.py b/luchador/nn/core/theano/initializer.py index 3a04166d..a78feb70 100644 --- a/luchador/nn/core/theano/initializer.py +++ b/luchador/nn/core/theano/initializer.py @@ -7,6 +7,7 @@ import numpy as np from numpy.random import RandomState +from scipy.stats import truncnorm as tnorm from theano import config @@ -64,31 +65,24 @@ class Xavier(InitializerMixin, base_initializer.BaseXavier): def _sample(self, shape): if not len(shape) == 2: raise ValueError( - 'Xavier initializer expects the shape to have 2 elements, ' - 'e.g. [fan_in, fan_out]. Found: {}'.format(shape) + 'Xavier initializer expects the shape to have 2 elements.' + 'Found: {}'.format(shape) ) - fan_in, fan_out = shape - param = self._compute_param(fan_in, fan_out) - return self._sample_value(shape, param) - - def _compute_param(self, fan_in, fan_out): + fan_out, fan_in = shape[0], shape[1] + scale = np.sqrt(6. / (fan_in + fan_out)) if self.args['uniform']: - x = np.sqrt(6. / (fan_in + fan_out)) - return {'low': -x, 'high': x} + value = self._sample_uniform(scale, shape) else: - scale = np.sqrt(3. / (fan_in + fan_out)) - return {'loc': 0., 'scale': scale} + value = self._sample_truncated_normal(scale, shape) + return value.astype(self.args['dtype'] or config.floatX) - def _sample_value(self, shape, param): - if self.args['uniform']: - values = self._rng.uniform( - low=param['low'], high=param['high'], size=shape) - else: - values = self._rng.normal( - loc=param['loc'], scale=param['scale'], size=shape) - dtype = self.args['dtype'] or config.floatX - return values.astype(dtype) + def _sample_uniform(self, scale, shape): + return self._rng.uniform(low=-scale, high=scale, size=shape) + + def _sample_truncated_normal(self, scale, shape): + return tnorm.rvs( + -1, 1, scale=scale, size=shape, random_state=self._rng) class XavierConv2D(Xavier): @@ -97,9 +91,18 @@ class XavierConv2D(Xavier): See :any:`BaseXavierConv2D` for detail. """ def _sample(self, shape): + if not len(shape) == 4: + raise ValueError( + 'Xavier conv2d initializer expects the shape with 4 elements.' + 'Found: {}'.format(shape) + ) # theano's filter shape is # (output_channels, input_channels, filter_rows, filter_columns) fan_in = shape[1] * shape[2] * shape[3] fan_out = shape[0] * shape[2] * shape[3] - param = self._compute_param(fan_in, fan_out) - return self._sample_value(shape, param) + scale = np.sqrt(6. / (fan_in + fan_out)) + if self.args['uniform']: + value = self._sample_uniform(scale, shape) + else: + value = self._sample_truncated_normal(scale, shape) + return value.astype(self.args['dtype'] or config.floatX) diff --git a/tests/integration/README.md b/tests/integration/README.md index 0934519e..2a4a8de0 100644 --- a/tests/integration/README.md +++ b/tests/integration/README.md @@ -4,6 +4,10 @@ This directory contains the list of intergation tests. This test builds and run DQN against ALEEnvironment so as to verify that it is not broken. +* `run_initializer_compatibility_tests.sh` + +This test runs initializers and check if the distribution is correct. + * `run_layer_numerical_compatibility_tests.sh` This test compares the outputs from fixed layer configuration/parameter and input so as to ensure layers' behavior is same across backends. diff --git a/tests/integration/data/initializer/constant.yml b/tests/integration/data/initializer/constant.yml new file mode 100644 index 00000000..561323c0 --- /dev/null +++ b/tests/integration/data/initializer/constant.yml @@ -0,0 +1,12 @@ +initializer: + name: Constant + args: + value: 3.2 + +test_config: + shape: [4, 3, 2, 1] + +compare_config: + threshold: !!float 1e-4 + mean: 3.2 + std: 0 diff --git a/tests/integration/data/initializer/normal.yml b/tests/integration/data/initializer/normal.yml new file mode 100644 index 00000000..d7edcc4c --- /dev/null +++ b/tests/integration/data/initializer/normal.yml @@ -0,0 +1,13 @@ +initializer: + name: Normal + args: + mean: &mean 5.3 + stddev: &stddev 9.0 + +test_config: + shape: [32, 16, 4, 4] + +compare_config: + mean: *mean + std: *stddev + threshold: 0.05 diff --git a/tests/integration/data/initializer/uniform.yml b/tests/integration/data/initializer/uniform.yml new file mode 100644 index 00000000..ffe6b3fb --- /dev/null +++ b/tests/integration/data/initializer/uniform.yml @@ -0,0 +1,14 @@ +initializer: + name: Uniform + args: + minval: -2.0 + maxval: 6.0 + +test_config: + shape: [16, 16, 8, 8] + +compare_config: + threshold: 0.03 + mean: 2.0 + std: 2.309 # (maxval - min_val) / sqrt(12) + diff --git a/tests/integration/data/initializer/xavier_conv2d_normal.yml b/tests/integration/data/initializer/xavier_conv2d_normal.yml new file mode 100644 index 00000000..e12baf6a --- /dev/null +++ b/tests/integration/data/initializer/xavier_conv2d_normal.yml @@ -0,0 +1,20 @@ +initializer: + name: XavierConv2D + args: + uniform: False + +test_config: + shape: [64, 32, 8, 8] + +compare_config: + threshold: 0.1 + mean: 0.0 + std: 0.0168 + # Standard deviation here is not that of normal distribution, but of truncated normal distribution + # with bound of `2 * standard deviation of normal distribution`. + # To get this you can use scipy.stats.truncnorm class. + # + # fan_out, fan_in = 64 * 8 * 8, 32 * 8 * 8 + # scale = np.sqrt(3. / (fan_in + fan_out)) + # variance = truncnorm.stats(-2, 2, scale=scale, moments='v') + # std = np.sqrt(variance) diff --git a/tests/integration/data/initializer/xavier_conv2d_uniform.yml b/tests/integration/data/initializer/xavier_conv2d_uniform.yml new file mode 100644 index 00000000..561bb8aa --- /dev/null +++ b/tests/integration/data/initializer/xavier_conv2d_uniform.yml @@ -0,0 +1,19 @@ +initializer: + name: XavierConv2D + args: + uniform: True + +test_config: + shape: [64, 32, 8, 8] + +compare_config: + threshold: 0.03 + mean: 0.0 + std: 0.01804 + # Standard deviation here is not that of normal distribution, but of truncated normal distribution + # with bound of `2 * standard deviation of normal distribution`. + # To get this you can use scipy.stats.truncnorm class. + # + # fan_out, fan_in = 64 * 8 * 8, 32 * 8 * 8 + # scale = np.sqrt(6. / (fan_in + fan_out)) + # std = (scale + scale) / sqrt(12) diff --git a/tests/integration/data/initializer/xavier_normal.yml b/tests/integration/data/initializer/xavier_normal.yml new file mode 100644 index 00000000..a5168efb --- /dev/null +++ b/tests/integration/data/initializer/xavier_normal.yml @@ -0,0 +1,20 @@ +initializer: + name: Xavier + args: + uniform: False + +test_config: + shape: [1000, 100] + +compare_config: + threshold: 0.1 + mean: 0.0 + std: 0.03984 + # Standard deviation here is not that of normal distribution, but of truncated normal distribution + # with bound of `2 * standard deviation of normal distribution`. + # To get this you can use scipy.stats.tnorm class. + # + # fan_out, fan_in = 1000, 100 + # scale = np.sqrt(3. / (fan_in + fan_out)) + # variance = tnorm.stats(-2, 2, scale=scale, moments='v') + # std = np.sqrt(variance) diff --git a/tests/integration/data/initializer/xavier_uniform.yml b/tests/integration/data/initializer/xavier_uniform.yml new file mode 100644 index 00000000..8433a1c5 --- /dev/null +++ b/tests/integration/data/initializer/xavier_uniform.yml @@ -0,0 +1,19 @@ +initializer: + name: Xavier + args: + uniform: True + +test_config: + shape: [64, 32] + +compare_config: + threshold: 0.03 + mean: 0.0 + std: 0.1443 + # Standard deviation here is not that of normal distribution, but of truncated normal distribution + # with bound of `2 * standard deviation of normal distribution`. + # To get this you can use scipy.stats.truncnorm class. + # + # fan_out, fan_in = 16, 32 + # scale = np.sqrt(6. / (fan_in + fan_out)) + # std = (scale + scale) / sqrt(12) diff --git a/tests/integration/data/layer/batch_normalization_2d_learn.yml b/tests/integration/data/layer/batch_normalization_2d_learn.yml index 1d932a57..285be4a6 100644 --- a/tests/integration/data/layer/batch_normalization_2d_learn.yml +++ b/tests/integration/data/layer/batch_normalization_2d_learn.yml @@ -1,4 +1,3 @@ - run: iteration: 10 @@ -13,4 +12,3 @@ layer: input: input_randn_3x5_offset_3.h5 parameter: parameter_bn.h5 - diff --git a/tests/integration/data/layer/batch_normalization_2d_not_learn.yml b/tests/integration/data/layer/batch_normalization_2d_not_learn.yml index 67653c26..2960c83e 100644 --- a/tests/integration/data/layer/batch_normalization_2d_not_learn.yml +++ b/tests/integration/data/layer/batch_normalization_2d_not_learn.yml @@ -1,4 +1,3 @@ - run: iteration: 10 @@ -13,4 +12,3 @@ layer: input: input_randn_3x5_offset_3.h5 parameter: parameter_bn.h5 - diff --git a/tests/integration/data/layer/batch_normalization_4d_learn.yml b/tests/integration/data/layer/batch_normalization_4d_learn.yml index fa9fd4fb..285be4a6 100644 --- a/tests/integration/data/layer/batch_normalization_4d_learn.yml +++ b/tests/integration/data/layer/batch_normalization_4d_learn.yml @@ -1,4 +1,3 @@ - run: iteration: 10 @@ -10,8 +9,6 @@ layer: learn: True decay: 0.999 - input: input_randn_3x5_offset_3.h5 parameter: parameter_bn.h5 - diff --git a/tests/integration/data/layer/batch_normalization_4d_not_learn.yml b/tests/integration/data/layer/batch_normalization_4d_not_learn.yml index 67653c26..2960c83e 100644 --- a/tests/integration/data/layer/batch_normalization_4d_not_learn.yml +++ b/tests/integration/data/layer/batch_normalization_4d_not_learn.yml @@ -1,4 +1,3 @@ - run: iteration: 10 @@ -13,4 +12,3 @@ layer: input: input_randn_3x5_offset_3.h5 parameter: parameter_bn.h5 - diff --git a/tests/integration/data/layer/conv2d_same.yml b/tests/integration/data/layer/conv2d_same.yml index 38d1fe5f..6932c976 100644 --- a/tests/integration/data/layer/conv2d_same.yml +++ b/tests/integration/data/layer/conv2d_same.yml @@ -1,4 +1,3 @@ - layer: name: Conv2D args: @@ -12,4 +11,3 @@ layer: input: input_mnist_10x4x28x27.h5 parameter: parameter_randn_3x4x7x5.h5 - diff --git a/tests/integration/data/layer/conv2d_valid.yml b/tests/integration/data/layer/conv2d_valid.yml index b3c2a738..fed44ee3 100644 --- a/tests/integration/data/layer/conv2d_valid.yml +++ b/tests/integration/data/layer/conv2d_valid.yml @@ -1,4 +1,3 @@ - layer: name: Conv2D args: @@ -12,4 +11,3 @@ layer: input: input_mnist_10x4x28x27.h5 parameter: parameter_randn_3x4x7x5.h5 - diff --git a/tests/integration/data/layer/conv2d_without_bias.yml b/tests/integration/data/layer/conv2d_without_bias.yml index 851ca420..dbf55104 100644 --- a/tests/integration/data/layer/conv2d_without_bias.yml +++ b/tests/integration/data/layer/conv2d_without_bias.yml @@ -1,4 +1,3 @@ - layer: name: Conv2D args: @@ -12,4 +11,3 @@ layer: input: input_mnist_10x4x28x27.h5 parameter: parameter_randn_3x4x7x5.h5 - diff --git a/tests/integration/data/layer/dense.yml b/tests/integration/data/layer/dense.yml index 75a5144f..8d88c5f9 100644 --- a/tests/integration/data/layer/dense.yml +++ b/tests/integration/data/layer/dense.yml @@ -1,4 +1,3 @@ - layer: name: Dense args: @@ -8,4 +7,3 @@ layer: input: input_randn_5x3.h5 parameter: parameter_randn_3x7.h5 - diff --git a/tests/integration/data/layer/dense_without_bias.yml b/tests/integration/data/layer/dense_without_bias.yml index f3aca1cd..f319b608 100644 --- a/tests/integration/data/layer/dense_without_bias.yml +++ b/tests/integration/data/layer/dense_without_bias.yml @@ -1,4 +1,3 @@ - layer: name: Dense args: @@ -8,4 +7,3 @@ layer: input: input_randn_5x3.h5 parameter: parameter_randn_3x7.h5 - diff --git a/tests/integration/data/layer/flatten.yml b/tests/integration/data/layer/flatten.yml index 787f5508..3696f84b 100644 --- a/tests/integration/data/layer/flatten.yml +++ b/tests/integration/data/layer/flatten.yml @@ -1,7 +1,5 @@ - layer: name: Flatten args: {} input: input_mnist_10x4x28x27.h5 - diff --git a/tests/integration/data/layer/relu.yml b/tests/integration/data/layer/relu.yml index 832f0018..ae1bbd91 100644 --- a/tests/integration/data/layer/relu.yml +++ b/tests/integration/data/layer/relu.yml @@ -1,7 +1,5 @@ - layer: name: ReLU args: {} input: input_randn_5x3.h5 - diff --git a/tests/integration/data/layer/sigmoid.yml b/tests/integration/data/layer/sigmoid.yml index fdd67488..3ad85b33 100644 --- a/tests/integration/data/layer/sigmoid.yml +++ b/tests/integration/data/layer/sigmoid.yml @@ -1,7 +1,5 @@ - layer: name: Sigmoid args: {} input: input_randn_5x3.h5 - diff --git a/tests/integration/data/layer/softmax.yml b/tests/integration/data/layer/softmax.yml index bb406b98..d40a6ec1 100644 --- a/tests/integration/data/layer/softmax.yml +++ b/tests/integration/data/layer/softmax.yml @@ -1,7 +1,5 @@ - layer: name: Softmax args: {} input: input_randn_5x3.h5 - diff --git a/tests/integration/data/layer/true_div.yml b/tests/integration/data/layer/true_div.yml index 592c44e3..9b172d11 100644 --- a/tests/integration/data/layer/true_div.yml +++ b/tests/integration/data/layer/true_div.yml @@ -1,8 +1,6 @@ - layer: name: TrueDiv args: denom: 255 input: input_randint_1x3x5x7.h5 - diff --git a/tests/integration/run_initializer_compatibility_test.sh b/tests/integration/run_initializer_compatibility_test.sh new file mode 100755 index 00000000..b9d157ea --- /dev/null +++ b/tests/integration/run_initializer_compatibility_test.sh @@ -0,0 +1,18 @@ +#!/bin/bash +set -u + +BASE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +DATA_DIR="${BASE_DIR}/data/initializer" +TEST_DIR="${BASE_DIR}/test_initializer_compatibility" +TEST_COMMAND="${TEST_DIR}/test_initializer_compatibility.sh" + +RETURN=0 +for FILE in ${DATA_DIR}/*.yml +do + "${TEST_COMMAND}" "${FILE}" + if [[ ! $? = 0 ]]; then + RETURN=1 + fi +done + +exit ${RETURN} diff --git a/tests/integration/run_layer_numerical_compatibility_tests.sh b/tests/integration/run_layer_numerical_compatibility_tests.sh index 60c423e1..d144fcf9 100755 --- a/tests/integration/run_layer_numerical_compatibility_tests.sh +++ b/tests/integration/run_layer_numerical_compatibility_tests.sh @@ -1,12 +1,18 @@ #!/bin/bash -set -eu +set -u BASE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" DATA_DIR="${BASE_DIR}/data/layer" TEST_DIR="${BASE_DIR}/test_layer_numerical_compatibility" TEST_COMMAND="${TEST_DIR}/test_layer_numerical_compatibility.sh" +RETURN=0 for FILE in ${DATA_DIR}/*.yml do - "${TEST_COMMAND}" --config "${FILE}" + "${TEST_COMMAND}" "${FILE}" + if [[ ! $? = 0 ]]; then + RETURN=1 + fi done + +exit ${RETURN} diff --git a/tests/integration/run_optimizer_numerical_compatibility_tests.sh b/tests/integration/run_optimizer_numerical_compatibility_tests.sh index f8b52985..976e00c8 100755 --- a/tests/integration/run_optimizer_numerical_compatibility_tests.sh +++ b/tests/integration/run_optimizer_numerical_compatibility_tests.sh @@ -12,11 +12,17 @@ do done # Run each optimizer on formulae +RETURN=0 FORMULAE="$(python ${TEST_DIR}/formula.py)" for FORMULA in ${FORMULAE} do for OPTIMIZER in "${OPTIMIZERS[@]}" do ${TEST_DIR}/test_optimizer_numerical_compatibility.sh --optimizer ${OPTIMIZER} --formula ${FORMULA} + if [[ ! $? = 0 ]]; then + RETURN=1 + fi done done + +exit ${RETURN} diff --git a/tests/integration/test_initializer_compatibility/run_initializer.py b/tests/integration/test_initializer_compatibility/run_initializer.py new file mode 100644 index 00000000..d5ba469b --- /dev/null +++ b/tests/integration/test_initializer_compatibility/run_initializer.py @@ -0,0 +1,119 @@ +from __future__ import division +from __future__ import print_function +from __future__ import absolute_import + +import os + +import h5py +import numpy as np + +import luchador.util +from luchador import nn + + +def parse_command_line_args(): + from argparse import ArgumentParser as AP + ap = AP( + description=( + 'Run Initializer and check if the distribution of ' + 'the resulting values is desired one') + ) + ap.add_argument( + 'config', help='YAML file with initializer config and test config' + ) + ap.add_argument( + '--output', help='Name of output HDF5 file.' + ) + ap.add_argument( + '--key', help='Name of dataset in the output file.', default='input' + ) + return ap.parse_args() + + +def create_initializer(name, args): + return nn.get_initializer(name)(**args) + + +def transpose_needed(initializer): + return ( + initializer.__class__.__name__ == 'XavierConv2D' and + luchador.get_nn_backend() == 'tensorflow' + ) + + +def run_initializer(initializer, shape): + if transpose_needed(initializer): + # Shape is given in Theano's filter order, which is + # [#out-channel, #in-channel, height, width]. + # So as to compute fan-in and fan-out correctly in Tensorflow, + # we reorder this to + # [height, width, #in-channel, #out-channel], + shape = [shape[2], shape[3], shape[1], shape[0]] + + variable = nn.scope.get_variable( + shape=shape, name='input', initializer=initializer) + session = nn.Session() + session.initialize() + value = session.run(outputs=variable) + + if transpose_needed(initializer): + # So as to make the output comarison easy, we revert the oreder. + shape = [shape[3], shape[2], shape[0], shape[1]] + return value + + +def print_stats(*arrs): + print('{sum:>10} {max:>10} {min:>10} {mean:>10} {std:>10}' + .format(sum='sum', max='max', min='min', mean='mean', std='std')) + for arr in arrs: + sum_, max_, min_, mean = arr.sum(), arr.max(), arr.min(), arr.mean() + std = arr.std() + print('{sum:10.3E} {max:10.3E} {min:10.3E} {mean:10.3E} {std:10.3E}' + .format(sum=sum_, max=max_, min=min_, mean=mean, std=std)) + print('') + + +def is_moment_different(data, mean, std, threshold): + mean_diff = abs(mean - np.mean(data)) / (mean or 1.0) + std_diff = abs(std - np.std(data)) / (std or 1.0) + print('mean diff: {} [%]'.format(100 * mean_diff)) + print('std diff: {} [%]'.format(100 * std_diff)) + return (mean_diff > threshold) or (std_diff > threshold) + + +def check_dist(value, mean, std, threshold): + print_stats(value) + print('Given mean: {}'.format(mean)) + print('Given stddev: {}'.format(std)) + print('Checking (threshold: {} [%])'.format(100 * threshold)) + if is_moment_different(value, mean, std, threshold): + raise ValueError('Data are different') + print('Okay') + + +def save_output(filepath, data, key): + directory = os.path.dirname(filepath) + if not os.path.exists(directory): + os.makedirs(directory) + + print('Saving output value to {}'.format(filepath)) + print(' Shape {}'.format(data.shape)) + print(' Dtype {}'.format(data.dtype)) + f = h5py.File(filepath, 'w') + f.create_dataset(key, data=data) + f.close() + + +def main(): + args = parse_command_line_args() + cfg = luchador.util.load_config(args.config) + initializer = create_initializer(**cfg['initializer']) + value = run_initializer(initializer, **cfg['test_config']) + + if args.output: + save_output(args.output, value, args.key) + + check_dist(value, **cfg['compare_config']) + +if __name__ == '__main__': + main() diff --git a/tests/integration/test_initializer_compatibility/test_initializer_compatibility.sh b/tests/integration/test_initializer_compatibility/test_initializer_compatibility.sh new file mode 100755 index 00000000..671a4328 --- /dev/null +++ b/tests/integration/test_initializer_compatibility/test_initializer_compatibility.sh @@ -0,0 +1,40 @@ +#!/bin/bash +# This script runs the initialization of tensorflow and theano backend separately and write the result to files. +# Then check if the difference between the results are within threshold +set -u + +CONFIG=$1 +if [[ ! -f "${CONFIG}" ]]; then + echo "Argument must be a YAML file" + exit 1 +fi + +COUNT_INTEGRATION_COVERAGE=${COUNT_INTEGRATION_COVERAGE:-false} +if [ "${COUNT_INTEGRATION_COVERAGE}" = true ]; then + TEST_COMMAND="coverage run --source luchador -a" +else + TEST_COMMAND="python" +fi + +BASE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" +TEST_COMMAND="${TEST_COMMAND} ${BASE_DIR}/run_initializer.py ${CONFIG}" + +CONFIG_NAME=$(basename ${CONFIG%.*}) +echo "*** Checking numerical compatibility of ${CONFIG_NAME} ***" +echo "" +cat ${CONFIG} +echo "" + +RETURN=0 +echo "* Running ${CONFIG_NAME} with Theano backend" +LUCHADOR_NN_BACKEND=theano LUCHADOR_NN_CONV_FORMAT=NCHW ${TEST_COMMAND} +if [[ ! $? = 0 ]]; then RETURN=1; fi +echo "" + +echo "* Running ${CONFIG_NAME} with Tensorflow backend" +LUCHADOR_NN_BACKEND=tensorflow LUCHADOR_NN_CONV_FORMAT=NHWC ${TEST_COMMAND} +if [[ ! $? = 0 ]]; then RETURN=1; fi +echo "" + + +exit ${RETURN} diff --git a/tests/integration/test_layer_numerical_compatibility/test_layer_numerical_compatibility.sh b/tests/integration/test_layer_numerical_compatibility/test_layer_numerical_compatibility.sh index 04cbdfc6..68860f23 100755 --- a/tests/integration/test_layer_numerical_compatibility/test_layer_numerical_compatibility.sh +++ b/tests/integration/test_layer_numerical_compatibility/test_layer_numerical_compatibility.sh @@ -1,50 +1,33 @@ #!/bin/bash # This script runs the layer IO of tensorflow and theano backend separately and write the result to files. -# Then check if the difference between the results are within threshold -# -# Arguments: -# --dir: Path to the layer configuration directory. "config.yml", "parameter.h5", "input.h5" must be present +# Then check if the difference between the results are within threshold. set -eu -COUNT_INTEGRATION_COVERAGE=${COUNT_INTEGRATION_COVERAGE:-false} -BASE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" - -CONFIG= -while [[ $# -gt 0 ]] -do - key="$1" - case $key in - --config) - CONFIG="$2" - shift - ;; - *) - echo "Unexpected option ${key} was given" - exit 1 - ;; - esac - shift -done - -if [[ -z "${CONFIG}" ]]; then - echo "--config must be given" +CONFIG=${1} +if [[ ! -f "${CONFIG}" ]]; then + echo "Argument must be YAML file" exit 1 fi -LAYER_NAME="$( basename ${CONFIG%.*} )" +COUNT_INTEGRATION_COVERAGE=${COUNT_INTEGRATION_COVERAGE:-false} if [ "${COUNT_INTEGRATION_COVERAGE}" = true ]; then TEST_COMMAND="coverage run --source luchador -a" else TEST_COMMAND="python" fi + +LAYER_NAME="$( basename ${CONFIG%.*} )" +BASE_DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" && pwd )" TEST_COMMAND="${TEST_COMMAND} ${BASE_DIR}/run_layer.py ${CONFIG}" COMPARE_COMMAND="python ${BASE_DIR}/compare_result.py" -FILE1="tmp/test_layer_numerical_comparitbility/${LAYER_NAME}/theano.h5" -FILE2="tmp/test_layer_numerical_comparitbility/${LAYER_NAME}/tensorflow.h5" +FILE1="tmp/test_layer_numerical_compatibility/${LAYER_NAME}/theano.h5" +FILE2="tmp/test_layer_numerical_compatibility/${LAYER_NAME}/tensorflow.h5" echo "*** Checking numerical compatibility of ${LAYER_NAME} ***" +echo "" cat ${CONFIG} +echo "" echo "* Running ${LAYER_NAME} with Theano backend" LUCHADOR_NN_BACKEND=theano LUCHADOR_NN_CONV_FORMAT=NCHW ${TEST_COMMAND} --output ${FILE1} echo "* Running ${LAYER_NAME} with Tensorflow backend" diff --git a/tests/integration/test_optimizer_numerical_compatibility/test_optimizer_numerical_compatibility.sh b/tests/integration/test_optimizer_numerical_compatibility/test_optimizer_numerical_compatibility.sh index 480645f6..26a0e04b 100755 --- a/tests/integration/test_optimizer_numerical_compatibility/test_optimizer_numerical_compatibility.sh +++ b/tests/integration/test_optimizer_numerical_compatibility/test_optimizer_numerical_compatibility.sh @@ -60,7 +60,9 @@ TEST_COMMAND="${TEST_COMMAND} ${BASE_DIR}/run_optimizer.py ${FORMULA} ${OPTIMIZE COMPARE_COMMAND="python ${BASE_DIR}/compare_result.py" echo "*** Checking numerical compatibility of ${OPTIMIZER_NAME} on ${FORMULA} ***" +echo "" cat ${OPTIMIZER} +echo "" echo "* Running $(basename ${OPTIMIZER}) with Theano backend" LUCHADOR_NN_BACKEND=theano LUCHADOR_NN_CONV_FORMAT=NCHW ${TEST_COMMAND} --output ${FILE1} --iterations ${ITERATIONS} echo "* Running $(basename ${OPTIMIZER}) with Tensorflow backend"