diff --git a/README.md b/README.md index 5639fb87..44357d6b 100644 --- a/README.md +++ b/README.md @@ -19,11 +19,11 @@ from luchador.episode_runner import EpisodeRunner def main(env, agent, episodes, steps): # Create environment - Environment = get_env(env['name']) - env = Environment(**env['args']) + Environment = get_env(env['typename']) + env = Environment(**env['args']) # Create agent - Agent = get_agent(agent['name']) + Agent = get_agent(agent['typename']) agent = Agent(**agent['args']) agent.init(env) @@ -109,7 +109,7 @@ $ python To use this new agent, we need to create configuration file. As this agent does not take any constructor argument, configuration file is as simple as follow. ```yaml -name: MyRandomAgent +typename: MyRandomAgent args: {} ``` @@ -163,23 +163,20 @@ Network architecture can be described using a set of layer configurations, and w model_type: Sequential layer_configs: - scope: layer1 - layer: - name: Conv2D - args: - n_filters: 32 - filter_width: 8 - filter_height: 8 - strides: 4 - padding: valid + typename: Conv2D + args: + n_filters: 32 + filter_width: 8 + filter_height: 8 + strides: 4 + padding: valid - scope: layer2 - layer: - name: ReLU - args: {} + typename: ReLU + args: {} - scope: layer3 - layer: - name: Dense - args: - n_nodes: 3 + typename: Dense + args: + n_nodes: 3 ``` You can feed this configuration to `luchador.nn.util.make_model` then the function will return the coresponding network architecture. @@ -190,23 +187,20 @@ But having static parameters is sometimes inconvenient. For example, although th model_type: Sequential layer_configs: - scope: layer1 - layer: - name: Conv2D - args: - n_filters: 32 - filter_width: 8 - filter_height: 8 - strides: 4 - padding: valid + typename: Conv2D + args: + n_filters: 32 + filter_width: 8 + filter_height: 8 + strides: 4 + padding: valid - scope: layer2 - layer: - name: ReLU - args: {{}} + typename: ReLU + args: {{}} - scope: layer3 - layer: - name: Dense - args: - n_nodes: {n_actions} + typename: Dense + args: + n_nodes: {n_actions} ``` When you load this file with `luchador.nn.util.make_model('model.yml', n_actions=5)`, 5 is substituted at `{n_actions}`. Notice that `ReLU`'s `args` parameter became `{{}}` from `{}` so that it Python's `format` function will replace it to `{}`. diff --git a/example/ALEEnvironment_test.yml b/example/ALEEnvironment_test.yml index aae07ded..0f87e767 100644 --- a/example/ALEEnvironment_test.yml +++ b/example/ALEEnvironment_test.yml @@ -1,4 +1,4 @@ -name: ALEEnvironment +typename: ALEEnvironment args: rom: space_invaders diff --git a/example/ALEEnvironment_train.yml b/example/ALEEnvironment_train.yml index 47412716..6a688606 100644 --- a/example/ALEEnvironment_train.yml +++ b/example/ALEEnvironment_train.yml @@ -1,4 +1,4 @@ -name: ALEEnvironment +typename: ALEEnvironment args: rom: space_invaders diff --git a/example/CartPole_agent.yml b/example/CartPole_agent.yml index 0cc746a9..dc0deaf0 100644 --- a/example/CartPole_agent.yml +++ b/example/CartPole_agent.yml @@ -1 +1 @@ -name: CartPoleAgent +typename: CartPoleAgent diff --git a/example/CartPole_env.yml b/example/CartPole_env.yml index d53ea5c3..36576feb 100644 --- a/example/CartPole_env.yml +++ b/example/CartPole_env.yml @@ -1,4 +1,4 @@ -name: CartPole +typename: CartPole args: angle_limit: 12 # Degree distance_limit: 2.4 # meter diff --git a/example/DQNAgent_test.yml b/example/DQNAgent_test.yml index a799c201..428f2a60 100644 --- a/example/DQNAgent_test.yml +++ b/example/DQNAgent_test.yml @@ -7,7 +7,7 @@ alias: save_prefix: &save_prefix DQN_integration_test initial_parameter: &initial_parameter example/space_invaders_vanilla_dqn_99000.h5 -name: DQNAgent +typename: DQNAgent args: recorder_config: memory_size: 100 @@ -27,26 +27,27 @@ args: terminal: dtype: bool + model_config: + model_file: example/vanilla_dqn.yml + initial_parameter: *initial_parameter + input_channel: *stack + input_height: *height + input_width: *width + q_network_config: - model_config: - name: example/vanilla_dqn.yml - initial_parameter: *initial_parameter - input_channel: *stack - input_height: *height - input_width: *width q_learning_config: discount_rate: 0.99 # reward is clipped between the following min and max min_reward: -1.0 max_reward: 1.0 cost_config: - name: SSE2 + typename: SSE2 args: # error between predicted Q value and target Q value is clipped by the following min and max min_delta: -1.0 max_delta: 1.0 optimizer_config: - name: NeonRMSProp + typename: NeonRMSProp args: decay: 0.95 epsilon: 0.000001 diff --git a/example/DQNAgent_train.yml b/example/DQNAgent_train.yml index 55ad3a7e..f06110ed 100644 --- a/example/DQNAgent_train.yml +++ b/example/DQNAgent_train.yml @@ -7,7 +7,7 @@ alias: save_prefix: &save_prefix DQN initial_parameter: &initial_parameter null -name: DQNAgent +typename: DQNAgent args: recorder_config: memory_size: 1000000 @@ -27,26 +27,27 @@ args: terminal: dtype: bool + model_config: + model_file: vanilla_dqn + initial_parameter: *initial_parameter + input_channel: *stack + input_height: *height + input_width: *width + q_network_config: - model_config: - name: vanilla_dqn - initial_parameter: *initial_parameter - input_channel: *stack - input_height: *height - input_width: *width q_learning_config: discount_rate: 0.99 # reward is clipped between the following min and max min_reward: -1.0 max_reward: 1.0 cost_config: - name: SSE2 + typename: SSE2 args: # error between predicted Q value and target Q value is clipped by the following min and max min_delta: -1.0 max_delta: 1.0 optimizer_config: - name: NeonRMSProp + typename: NeonRMSProp args: decay: 0.95 epsilon: 0.000001 diff --git a/example/FlappyBirdEnv.yml b/example/FlappyBirdEnv.yml index e884adfe..ba0227cb 100644 --- a/example/FlappyBirdEnv.yml +++ b/example/FlappyBirdEnv.yml @@ -1,4 +1,4 @@ -name: FlappyBird +typename: FlappyBird args: random_seed: null diff --git a/example/MyRandomAgent.yml b/example/MyRandomAgent.yml index 20aae7dd..818c5c34 100644 --- a/example/MyRandomAgent.yml +++ b/example/MyRandomAgent.yml @@ -1 +1 @@ -name: MyRandomAgent +typename: MyRandomAgent diff --git a/example/RemoteEnv.yml b/example/RemoteEnv.yml index 4b608095..7bdf91c3 100644 --- a/example/RemoteEnv.yml +++ b/example/RemoteEnv.yml @@ -1,4 +1,4 @@ -name: RemoteEnv +typename: RemoteEnv args: host: 0.0.0.0 port: 12345 diff --git a/example/vanilla_dqn.yml b/example/vanilla_dqn.yml index 7d3b399b..73d417f1 100644 --- a/example/vanilla_dqn.yml +++ b/example/vanilla_dqn.yml @@ -5,100 +5,89 @@ input: name: state layer_configs: - scope: layer0/preprocessing - layer: - name: TrueDiv - args: - denom: 255 + typename: TrueDiv + args: + denom: 255 - scope: layer1/conv2D - layer: - name: Conv2D - args: - n_filters: 32 - filter_width: 8 - filter_height: 8 - strides: 4 - padding: valid - initializers: - bias: &initializer1 - name: Uniform - args: - # 1 / sqrt(8 * 8 * 32) = 0.022097 - maxval: 0.022 - minval: -0.022 - weight: *initializer1 + typename: Conv2D + args: + n_filters: 32 + filter_width: 8 + filter_height: 8 + strides: 4 + padding: valid + initializers: + bias: &initializer1 + typename: Uniform + args: + # 1 / sqrt(8 * 8 * 32) = 0.022097 + maxval: 0.022 + minval: -0.022 + weight: *initializer1 - scope: layer1/ReLU - layer: - name: ReLU + typename: ReLU - scope: layer2/conv2D - layer: - name: Conv2D - args: - n_filters: 64 - filter_width: 4 - filter_height: 4 - strides: 2 - padding: valid - initializers: - bias: &initializer2 - name: Uniform - args: - # 1 / sqrt(4 * 4 * 64) = 0.03125 - maxval: 0.031 - minval: -0.031 - weight: *initializer2 + typename: Conv2D + args: + n_filters: 64 + filter_width: 4 + filter_height: 4 + strides: 2 + padding: valid + initializers: + bias: &initializer2 + typename: Uniform + args: + # 1 / sqrt(4 * 4 * 64) = 0.03125 + maxval: 0.031 + minval: -0.031 + weight: *initializer2 - scope: layer2/ReLU - layer: - name: ReLU + typename: ReLU - scope: layer3/conv2D - layer: - name: Conv2D - args: - filter_width: 3 - filter_height: 3 - n_filters: 64 - strides: 1 - padding: valid - initializers: - bias: &initializer3 - name: Uniform - args: - # 1 / sqrt(3 * 3 * 64) = 0.04166 - maxval: 0.042 - minval: -0.042 - weight: *initializer3 + typename: Conv2D + args: + filter_width: 3 + filter_height: 3 + n_filters: 64 + strides: 1 + padding: valid + initializers: + bias: &initializer3 + typename: Uniform + args: + # 1 / sqrt(3 * 3 * 64) = 0.04166 + maxval: 0.042 + minval: -0.042 + weight: *initializer3 - scope: layer3/ReLU - layer: - name: ReLU + typename: ReLU - scope: layer4/flatten - layer: - name: Flatten + typename: Flatten - scope: layer5/dense - layer: - name: Dense - args: - n_nodes: 512 - initializers: - bias: &initializer5 - name: Uniform - args: - # 1 / sqrt(3136) = 0.01785 - # 3136 is expected #inputs to this layer when the input size to layer0 is 84 * 84 * 4 - maxval: 0.018 - minval: -0.018 - weight: *initializer5 + typename: Dense + args: + n_nodes: 512 + initializers: + bias: &initializer5 + typename: Uniform + args: + # 1 / sqrt(3136) = 0.01785 + # 3136 is expected #inputs to this layer when the input size to layer0 is 84 * 84 * 4 + maxval: 0.018 + minval: -0.018 + weight: *initializer5 - scope: layer5/ReLU - layer: - name: ReLU + typename: ReLU - scope: layer6/dense - layer: - name: Dense - args: - n_nodes: {n_actions} - initializers: - bias: &initializer6 - name: Uniform - args: - # 1 / sqrt(512) = 0.04419 - maxval: 0.044 - minval: -0.044 - weight: *initializer6 + typename: Dense + args: + n_nodes: {n_actions} + initializers: + bias: &initializer6 + typename: Uniform + args: + # 1 / sqrt(512) = 0.04419 + maxval: 0.044 + minval: -0.044 + weight: *initializer6 diff --git a/luchador/agent/base.py b/luchador/agent/base.py index 65db5dcb..bb6ef599 100644 --- a/luchador/agent/base.py +++ b/luchador/agent/base.py @@ -96,12 +96,12 @@ def __str__(self): } -def get_agent(name): +def get_agent(typename): """Retrieve Agent class by name Parameters ---------- - name : str + typename : str Name of Agent to retrieve Returns @@ -114,11 +114,11 @@ def get_agent(name): ValueError When Agent with the given name is not found """ - if name in _AGENT_MODULE_MAPPING: - module = 'luchador.agent.{:s}'.format(_AGENT_MODULE_MAPPING[name]) + if typename in _AGENT_MODULE_MAPPING: + module = 'luchador.agent.{:s}'.format(_AGENT_MODULE_MAPPING[typename]) importlib.import_module(module) for class_ in luchador.util.get_subclasses(BaseAgent): - if class_.__name__ == name: + if class_.__name__ == typename: return class_ - raise ValueError('Unknown Agent: {}'.format(name)) + raise ValueError('Unknown Agent: {}'.format(typename)) diff --git a/luchador/agent/dqn.py b/luchador/agent/dqn.py index b73d5ad8..42df6189 100644 --- a/luchador/agent/dqn.py +++ b/luchador/agent/dqn.py @@ -40,6 +40,16 @@ class DQNAgent(luchador.util.StoreMixin, BaseAgent): Constructor arguments for :class:`luchador.agent.recorder.TransitionRecorder` + model_config : dict + Configuration for model definition. + + name : str + The name of network model or path to model definition file. + initial_parameter : str + Path to HDF5 file which contain the initial parameter values. + input_channel, input_height, input_width : int + The shape of input to the network + q_network_config : dict Constructor arguments for :class:`luchador.agent.rl.q_learning.DeepQLearning` @@ -75,6 +85,7 @@ class DQNAgent(luchador.util.StoreMixin, BaseAgent): def __init__( self, recorder_config, + model_config, q_network_config, action_config, training_config, @@ -86,6 +97,7 @@ def __init__( super(DQNAgent, self).__init__() self._store_args( recorder_config=recorder_config, + model_config=model_config, q_network_config=q_network_config, action_config=action_config, training_config=training_config, @@ -122,15 +134,28 @@ def init(self, env): self._init_summary_writer() self._summarize_layer_params() + def _gen_model_def(self, n_actions): + cfg = self.args['model_config'] + fmt = luchador.get_nn_conv_format() + w, h, c = cfg['input_width'], cfg['input_height'], cfg['input_channel'] + shape = ( + '[null, {}, {}, {}]'.format(h, w, c) if fmt == 'NHWC' else + '[null, {}, {}, {}]'.format(c, h, w) + ) + return nn.get_model_config( + cfg['model_file'], n_actions=n_actions, input_shape=shape) + def _init_network(self, n_actions): cfg = self.args['q_network_config'] self._ql = DeepQLearning( - model_config=cfg['model_config'], q_learning_config=cfg['q_learning_config'], cost_config=cfg['cost_config'], optimizer_config=cfg['optimizer_config'], ) - self._ql.build(n_actions=n_actions) + + model_def = self._gen_model_def(n_actions) + initial_parameter = self.args['model_config']['initial_parameter'] + self._ql.build(model_def, initial_parameter) self._ql.sync_network() def _init_saver(self): diff --git a/luchador/agent/rl/q_learning.py b/luchador/agent/rl/q_learning.py index fc393677..4f6af09f 100644 --- a/luchador/agent/rl/q_learning.py +++ b/luchador/agent/rl/q_learning.py @@ -46,16 +46,6 @@ class DeepQLearning(luchador.util.StoreMixin, object): Parameters ---------- - model_config : dict - Configuration for model definition. - - name : str - The name of network model or path to model definition file. - initial_parameter : str - Path to HDF5 file which contain the initial parameter values. - input_channel, input_height, input_width : int - The shape of input to the network - q_learning_config : dict Configuration for building target Q value. @@ -90,10 +80,8 @@ class DeepQLearning(luchador.util.StoreMixin, object): """ # pylint: disable=too-many-instance-attributes def __init__( - self, model_config, q_learning_config, cost_config, - optimizer_config): + self, q_learning_config, cost_config, optimizer_config): self._store_args( - model_config=model_config, q_learning_config=q_learning_config, cost_config=cost_config, optimizer_config=optimizer_config, @@ -117,7 +105,7 @@ def __call__(self, q_network_maker): """ self.build(q_network_maker) - def build(self, n_actions): + def build(self, model_def, initial_parameter): """Build computation graph (error and sync ops) for Q learning Parameters @@ -126,7 +114,6 @@ def build(self, n_actions): The number of available actions in the environment. """ # pylint: disable=too-many-locals - model_def = self._gen_model_def(n_actions) model_0, state_0, action_value_0 = _make_model(model_def, 'pre_trans') model_1, state_1, action_value_1 = _make_model(model_def, 'post_trans') sync_op = _build_sync_op(model_0, model_1, 'sync') @@ -144,7 +131,7 @@ def build(self, n_actions): self._init_optimizer() optimize_op = self.optimizer.minimize( error, wrt=model_0.get_parameter_variables()) - self._init_session() + self._init_session(initial_parameter) self.models = { 'model_0': model_0, @@ -166,17 +153,6 @@ def build(self, n_actions): 'optimize': optimize_op, } - def _gen_model_def(self, n_actions): - cfg = self.args['model_config'] - fmt = luchador.get_nn_conv_format() - w, h, c = cfg['input_width'], cfg['input_height'], cfg['input_channel'] - shape = ( - '[null, {}, {}, {}]'.format(h, w, c) if fmt == 'NHWC' else - '[null, {}, {}, {}]'.format(c, h, w) - ) - return nn.get_model_config( - cfg['name'], n_actions=n_actions, input_shape=shape) - def _build_target_q_value(self, action_value_1, reward, terminal): config = self.args['q_learning_config'] # Clip rewrads @@ -197,22 +173,22 @@ def _build_target_q_value(self, action_value_1, reward, terminal): def _build_error(self, target_q, action_value_0, action): config = self.args['cost_config'] - sse2 = nn.get_cost(config['name'])(elementwise=True, **config['args']) - error = sse2(target_q, action_value_0) + args = config['args'] + cost = nn.get_cost(config['typename'])(elementwise=True, **args) + error = cost(target_q, action_value_0) mask = action.one_hot(n_classes=action_value_0.shape[1]) return (mask * error).mean() ########################################################################### def _init_optimizer(self): cfg = self.args['optimizer_config'] - self.optimizer = nn.get_optimizer(cfg['name'])(**cfg['args']) + self.optimizer = nn.get_optimizer(cfg['typename'])(**cfg['args']) - def _init_session(self): - cfg = self.args['model_config'] + def _init_session(self, initial_parameter=None): self.session = nn.Session() - if cfg.get('initial_parameter'): - _LG.info('Loading parameters from %s', cfg['initial_parameter']) - self.session.load_from_file(cfg['initial_parameter']) + if initial_parameter: + _LG.info('Loading parameters from %s', initial_parameter) + self.session.load_from_file(initial_parameter) else: self.session.initialize() diff --git a/luchador/command/exercise.py b/luchador/command/exercise.py index 1d7d7ad3..eb47714e 100644 --- a/luchador/command/exercise.py +++ b/luchador/command/exercise.py @@ -62,19 +62,19 @@ def _load_additional_sources(*files): def _make_agent(config_file): config = ( load_config(config_file) if config_file else - {'name': 'NoOpAgent', 'args': {}} + {'typename': 'NoOpAgent', 'args': {}} ) - return get_agent(config['name'])(**config.get('args', {})) + return get_agent(config['typename'])(**config.get('args', {})) def _make_env(config_file, host, port): config = load_config(config_file) - if config['name'] == 'RemoteEnv': + if config['typename'] == 'RemoteEnv': if port: config['args']['port'] = port if host: config['args']['host'] = host - return get_env(config['name'])(**config.get('args', {})) + return get_env(config['typename'])(**config.get('args', {})) def entry_point(args): diff --git a/luchador/command/serve.py b/luchador/command/serve.py index c4c336be..9dbfd269 100644 --- a/luchador/command/serve.py +++ b/luchador/command/serve.py @@ -27,7 +27,7 @@ def entry_point_env(args): if args.environment is None: raise ValueError('Environment config is not given') env_config = luchador.util.load_config(args.environment) - env = luchador.env.get_env(env_config['name'])(**env_config['args']) + env = luchador.env.get_env(env_config['typename'])(**env_config['args']) app = luchador.env.remote.create_env_app(env) _run_server(app, args.port) diff --git a/luchador/env/base.py b/luchador/env/base.py index 0eb4da41..53620c63 100644 --- a/luchador/env/base.py +++ b/luchador/env/base.py @@ -74,12 +74,12 @@ def step(self, action): } -def get_env(name): - """Retrieve Environment class by name +def get_env(typename): + """Retrieve Environment class by typename Parameters ---------- - name : str + typename : str Name of Environment to retrieve Returns @@ -97,12 +97,13 @@ def get_env(name): # Ubuntu 14.04. # TLS Error is avoidable only by upgrading underlying libc version, which # is not easy. So we import such environments on-demand. - if name in _ENVIRONMENT_MODULE_MAPPING: - module = 'luchador.env.{:s}'.format(_ENVIRONMENT_MODULE_MAPPING[name]) + if typename in _ENVIRONMENT_MODULE_MAPPING: + module = 'luchador.env.{:s}'.format( + _ENVIRONMENT_MODULE_MAPPING[typename]) importlib.import_module(module) for class_ in luchador.util.get_subclasses(BaseEnvironment): - if class_.__name__ == name: + if class_.__name__ == typename: return class_ - raise ValueError('Unknown Environment: {}'.format(name)) + raise ValueError('Unknown Environment: {}'.format(typename)) diff --git a/luchador/nn/base/cost.py b/luchador/nn/base/cost.py index 6a69b5dd..e6c614f6 100644 --- a/luchador/nn/base/cost.py +++ b/luchador/nn/base/cost.py @@ -52,13 +52,13 @@ def _build(self, target, prediction): pass -def get_cost(name): - """Retrieve Cost class by name +def get_cost(typename): + """Retrieve Cost class by type Parameters ---------- - name : str - Name of Cost to retrieve + typename : str + Type of Cost to retrieve Returns ------- @@ -68,12 +68,12 @@ def get_cost(name): Raises ------ ValueError - When Cost with the given name is not found + When Cost with the given type is not found """ for class_ in luchador.util.get_subclasses(BaseCost): - if class_.__name__ == name: + if class_.__name__ == typename: return class_ - raise ValueError('Unknown Cost: {}'.format(name)) + raise ValueError('Unknown Cost: {}'.format(typename)) ############################################################################### diff --git a/luchador/nn/base/initializer.py b/luchador/nn/base/initializer.py index 61f22cc2..01f66d71 100644 --- a/luchador/nn/base/initializer.py +++ b/luchador/nn/base/initializer.py @@ -48,13 +48,13 @@ def _sample(self, shape): pass -def get_initializer(name): - """Retrieve Initializer class by name +def get_initializer(typename): + """Retrieve Initializer class by type Parameters ---------- - name : str - Name of Initializer to retrieve + typename : str + Type of Initializer to retrieve Returns ------- @@ -64,12 +64,12 @@ def get_initializer(name): Raises ------ ValueError - When Initializer with the given name is not found + When Initializer with the given type is not found """ for class_ in luchador.util.get_subclasses(BaseInitializer): - if class_.__name__ == name: + if class_.__name__ == typename: return class_ - raise ValueError('Unknown Initializer: {}'.format(name)) + raise ValueError('Unknown Initializer: {}'.format(typename)) ############################################################################### diff --git a/luchador/nn/base/layer.py b/luchador/nn/base/layer.py index dfa4f01c..d1eececd 100644 --- a/luchador/nn/base/layer.py +++ b/luchador/nn/base/layer.py @@ -84,13 +84,13 @@ def _build(self, input_tensor): ) -def get_layer(name): - """Retrieve Layer class by name +def get_layer(typename): + """Retrieve Layer class by type Parameters ---------- - name : str - Name of Layer to retrieve + typename : str + Type of Layer to retrieve Returns ------- @@ -100,12 +100,12 @@ def get_layer(name): Raises ------ ValueError - When Layer with the given name is not found + When Layer with the given type is not found """ for class_ in luchador.util.get_subclasses(BaseLayer): - if class_.__name__ == name: + if class_.__name__ == typename: return class_ - raise ValueError('Unknown Layer: {}'.format(name)) + raise ValueError('Unknown Layer: {}'.format(typename)) ############################################################################### diff --git a/luchador/nn/base/optimizer.py b/luchador/nn/base/optimizer.py index 497138ec..7a76d690 100644 --- a/luchador/nn/base/optimizer.py +++ b/luchador/nn/base/optimizer.py @@ -113,13 +113,13 @@ def get_parameter_variables(self): return self.slot -def get_optimizer(name): - """Retrieve Optimizer class by name +def get_optimizer(typename): + """Retrieve Optimizer class by type Parameters ---------- - name : str - Name of Optimizer to retrieve + typename : str + Type of Optimizer to retrieve Returns ------- @@ -129,12 +129,12 @@ def get_optimizer(name): Raises ------ ValueError - When Optimizer with the given name is not found + When Optimizer with the given type is not found """ for class_ in luchador.util.get_subclasses(BaseOptimizer): - if class_.__name__ == name: + if class_.__name__ == typename: return class_ - raise ValueError('Unknown Optimizer: {}'.format(name)) + raise ValueError('Unknown Optimizer: {}'.format(typename)) ############################################################################### @@ -264,6 +264,14 @@ def __init__(self, learning_rate, class BaseAdam(BaseOptimizer): + """Adam optimizer [1]_ + + References + ---------- + .. [1] Kingma, D. Ba, J 2014 + Adam: A Method for Stochastic Optimization + https://arxiv.org/abs/1412.6980 + """ def __init__(self, learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-08, name='Adam', **kwargs): @@ -273,6 +281,14 @@ def __init__(self, learning_rate, class BaseAdamax(BaseOptimizer): + """Adam optimizer [1]_ + + References + ---------- + .. [1] Kingma, D. Ba, J 2014 + Adam: A Method for Stochastic Optimization + https://arxiv.org/abs/1412.6980 + """ def __init__(self, learning_rate, beta1=0.9, beta2=0.999, epsilon=1e-8, name='Adamax', **kwargs): diff --git a/luchador/nn/model/data/vanilla_dqn.yml b/luchador/nn/model/data/vanilla_dqn.yml index d0e8a5ba..2baeb609 100644 --- a/luchador/nn/model/data/vanilla_dqn.yml +++ b/luchador/nn/model/data/vanilla_dqn.yml @@ -1,65 +1,54 @@ model_type: Sequential input: - name: Input + typename: Input args: dtype: uint8 shape: {input_shape} name: state layer_configs: - scope: layer0/preprocessing - layer: - name: TrueDiv - args: - denom: 255 + typename: TrueDiv + args: + denom: 255 - scope: layer1/conv2D - layer: - name: Conv2D - args: - n_filters: 32 - filter_width: 8 - filter_height: 8 - strides: 4 - padding: valid + typename: Conv2D + args: + n_filters: 32 + filter_width: 8 + filter_height: 8 + strides: 4 + padding: valid - scope: layer1/ReLU - layer: - name: ReLU + typename: ReLU - scope: layer2/conv2D - layer: - name: Conv2D - args: - n_filters: 64 - filter_width: 4 - filter_height: 4 - strides: 2 - padding: valid + typename: Conv2D + args: + n_filters: 64 + filter_width: 4 + filter_height: 4 + strides: 2 + padding: valid - scope: layer2/ReLU - layer: - name: ReLU + typename: ReLU - scope: layer3/conv2D - layer: - name: Conv2D - args: - filter_width: 3 - filter_height: 3 - n_filters: 64 - strides: 1 - padding: valid + typename: Conv2D + args: + filter_width: 3 + filter_height: 3 + n_filters: 64 + strides: 1 + padding: valid - scope: layer3/ReLU - layer: - name: ReLU + typename: ReLU - scope: layer4/flatten - layer: - name: Flatten + typename: Flatten - scope: layer5/dense - layer: - name: Dense - args: - n_nodes: 512 + typename: Dense + args: + n_nodes: 512 - scope: layer5/ReLU - layer: - name: ReLU + typename: ReLU - scope: layer6/dense - layer: - name: Dense - args: - n_nodes: {n_actions} + typename: Dense + args: + n_nodes: {n_actions} diff --git a/luchador/nn/model/data/vanilla_dqn_bn.yml b/luchador/nn/model/data/vanilla_dqn_bn.yml index a9845e45..74674ec0 100644 --- a/luchador/nn/model/data/vanilla_dqn_bn.yml +++ b/luchador/nn/model/data/vanilla_dqn_bn.yml @@ -1,79 +1,78 @@ model_type: Sequential input: - dtype: uint8 - shape: {input_shape} - name: state + typename: Input + args: + dtype: uint8 + shape: {input_shape} + name: state layer_configs: - scope: layer0/preprocessing - layer: - name: TrueDiv - args: - denom: 255 + typename: TrueDiv + args: + denom: 255 - scope: layer1/conv2D - layer: - name: Conv2D - args: - n_filters: 32 - filter_width: 8 - filter_height: 8 - strides: 4 - padding: valid - with_bias: False + typename: Conv2D + args: + n_filters: 32 + filter_width: 8 + filter_height: 8 + strides: 4 + padding: valid + with_bias: False - scope: layer1/BN - layer: &BN - name: BatchNormalization - args: - learn: True - decay: 0.999 + typename: BatchNormalization + args: + learn: True + decay: 0.999 - scope: layer1/ReLU - layer: - name: ReLU + typename: ReLU - scope: layer2/conv2D - layer: - name: Conv2D - args: - n_filters: 64 - filter_width: 4 - filter_height: 4 - strides: 2 - padding: valid - with_bias: False + typename: Conv2D + args: + n_filters: 64 + filter_width: 4 + filter_height: 4 + strides: 2 + padding: valid + with_bias: False - scope: layer2/BN - layer: *BN + typename: BatchNormalization + args: + learn: True + decay: 0.999 - scope: layer2/ReLU - layer: - name: ReLU + typename: ReLU - scope: layer3/conv2D - layer: - name: Conv2D - args: - filter_width: 3 - filter_height: 3 - n_filters: 64 - strides: 1 - padding: valid - with_bias: False + typename: Conv2D + args: + filter_width: 3 + filter_height: 3 + n_filters: 64 + strides: 1 + padding: valid + with_bias: False - scope: layer3/BN - layer: *BN + typename: BatchNormalization + args: + learn: True + decay: 0.999 - scope: layer3/ReLU - layer: - name: ReLU + typename: ReLU - scope: layer4/flatten - layer: - name: Flatten + typename: Flatten - scope: layer5/dense - layer: - name: Dense - args: - n_nodes: 512 - with_bias: False + typename: Dense + args: + n_nodes: 512 + with_bias: False - scope: layer5/BN - layer: *BN + typename: BatchNormalization + args: + learn: True + decay: 0.999 - scope: layer5/ReLU - layer: - name: ReLU + typename: ReLU - scope: layer6/dense - layer: - name: Dense - args: - n_nodes: {n_actions} + typename: Dense + args: + n_nodes: {n_actions} diff --git a/luchador/nn/model/sequential.py b/luchador/nn/model/sequential.py index 7e5edbff..30c5304e 100644 --- a/luchador/nn/model/sequential.py +++ b/luchador/nn/model/sequential.py @@ -36,6 +36,11 @@ def __repr__(self): 'output': self.output, }) + def serialize(self): + ret = self.layer.serialize() + ret['scope'] = self.scope + return ret + class Sequential(BaseModel): """Network architecture which produces single output from single input @@ -243,10 +248,7 @@ def serialize(self): """ return { 'model_type': self.__class__.__name__, - 'layer_configs': [{ - 'scope': cfg.scope, - 'layer': cfg.layer.serialize(), - } for cfg in self.layer_configs] + 'layer_configs': [cfg.serialize() for cfg in self.layer_configs] } ########################################################################### @@ -271,12 +273,12 @@ def make_sequential_model(layer_configs): Resulting model """ model = Sequential() - for layer_config in layer_configs: - layer_cfg = layer_config['layer'] - if 'name' not in layer_cfg: + for config in layer_configs: + if 'typename' not in config: raise RuntimeError('Layer name is not given') - args = layer_cfg.get('args', {}) - _LG.debug(' Constructing: %s: %s', layer_cfg['name'], args) - layer = get_layer(layer_cfg['name'])(**args) - model.add_layer(layer=layer, scope=layer_config.get('scope', '')) + args = config.get('args', {}) + + _LG.debug(' Constructing: %s: %s', config['typename'], args) + layer = get_layer(config['typename'])(**args) + model.add_layer(layer=layer, scope=config.get('scope', '')) return model diff --git a/luchador/nn/model/util.py b/luchador/nn/model/util.py index f242a272..d0db23e6 100644 --- a/luchador/nn/model/util.py +++ b/luchador/nn/model/util.py @@ -30,7 +30,7 @@ def _get_input(): def _make_input(input_config): - if input_config['name'] == 'Input': + if input_config['typename'] == 'Input': return _get_input()(**input_config['args']) diff --git a/luchador/util/mixin.py b/luchador/util/mixin.py index aaea1702..d37730d0 100644 --- a/luchador/util/mixin.py +++ b/luchador/util/mixin.py @@ -133,6 +133,6 @@ def serialize(self): for key, val in self.args.items(): args[key] = val.serialize() if hasattr(val, 'serialize') else val return { - 'name': self.__class__.__name__, + 'typename': self.__class__.__name__, 'args': args } diff --git a/tests/integration/data/dqn/ALEEnvironment_train.yml b/tests/integration/data/dqn/ALEEnvironment_train.yml index 47412716..6a688606 100644 --- a/tests/integration/data/dqn/ALEEnvironment_train.yml +++ b/tests/integration/data/dqn/ALEEnvironment_train.yml @@ -1,4 +1,4 @@ -name: ALEEnvironment +typename: ALEEnvironment args: rom: space_invaders diff --git a/tests/integration/data/dqn/DQNAgent_train.yml b/tests/integration/data/dqn/DQNAgent_train.yml index 342e43b7..a6497c48 100644 --- a/tests/integration/data/dqn/DQNAgent_train.yml +++ b/tests/integration/data/dqn/DQNAgent_train.yml @@ -7,7 +7,7 @@ alias: save_prefix: &save_prefix DQN_integration_test initial_parameter: &initial_parameter tests/integration/data/dqn/dqn_integration_test_initial.h5 -name: DQNAgent +typename: DQNAgent args: recorder_config: memory_size: 100 @@ -27,26 +27,27 @@ args: terminal: dtype: bool + model_config: + model_file: vanilla_dqn + initial_parameter: *initial_parameter + input_channel: *stack + input_height: *height + input_width: *width + q_network_config: - model_config: - name: vanilla_dqn - initial_parameter: *initial_parameter - input_channel: *stack - input_height: *height - input_width: *width q_learning_config: discount_rate: 0.99 # reward is clipped between the following min and max min_reward: -1.0 max_reward: 1.0 cost_config: - name: SSE2 + typename: SSE2 args: # error between predicted Q value and target Q value is clipped by the following min and max min_delta: -1.0 max_delta: 1.0 optimizer_config: - name: RMSProp + typename: RMSProp args: decay: 0.95 epsilon: 0.000001 diff --git a/tests/integration/data/env/cartpole/agent.yml b/tests/integration/data/env/cartpole/agent.yml index 0cc746a9..dc0deaf0 100644 --- a/tests/integration/data/env/cartpole/agent.yml +++ b/tests/integration/data/env/cartpole/agent.yml @@ -1 +1 @@ -name: CartPoleAgent +typename: CartPoleAgent diff --git a/tests/integration/data/env/cartpole/env.yml b/tests/integration/data/env/cartpole/env.yml index bc5ddc11..67015163 100644 --- a/tests/integration/data/env/cartpole/env.yml +++ b/tests/integration/data/env/cartpole/env.yml @@ -1,4 +1,4 @@ -name: CartPole +typename: CartPole args: angle_limit: 12 # Degree distance_limit: 2.4 # meter diff --git a/tests/integration/data/initializer/constant.yml b/tests/integration/data/initializer/constant.yml index 561323c0..306d2daf 100644 --- a/tests/integration/data/initializer/constant.yml +++ b/tests/integration/data/initializer/constant.yml @@ -1,5 +1,5 @@ initializer: - name: Constant + typename: Constant args: value: 3.2 diff --git a/tests/integration/data/initializer/kaiming_conv2d_normal.yml b/tests/integration/data/initializer/kaiming_conv2d_normal.yml index bc1699cc..304f6312 100644 --- a/tests/integration/data/initializer/kaiming_conv2d_normal.yml +++ b/tests/integration/data/initializer/kaiming_conv2d_normal.yml @@ -1,5 +1,5 @@ initializer: - name: Kaiming + typename: Kaiming args: uniform: False diff --git a/tests/integration/data/initializer/kaiming_conv2d_uniform.yml b/tests/integration/data/initializer/kaiming_conv2d_uniform.yml index ec19facf..637cb427 100644 --- a/tests/integration/data/initializer/kaiming_conv2d_uniform.yml +++ b/tests/integration/data/initializer/kaiming_conv2d_uniform.yml @@ -1,5 +1,5 @@ initializer: - name: Kaiming + typename: Kaiming args: uniform: True diff --git a/tests/integration/data/initializer/kaiming_normal.yml b/tests/integration/data/initializer/kaiming_normal.yml index 473edb74..bbe17a8d 100644 --- a/tests/integration/data/initializer/kaiming_normal.yml +++ b/tests/integration/data/initializer/kaiming_normal.yml @@ -1,5 +1,5 @@ initializer: - name: Kaiming + typename: Kaiming args: uniform: False diff --git a/tests/integration/data/initializer/kaiming_uniform.yml b/tests/integration/data/initializer/kaiming_uniform.yml index 9968844f..e2198908 100644 --- a/tests/integration/data/initializer/kaiming_uniform.yml +++ b/tests/integration/data/initializer/kaiming_uniform.yml @@ -1,5 +1,5 @@ initializer: - name: Kaiming + typename: Kaiming args: uniform: True diff --git a/tests/integration/data/initializer/normal.yml b/tests/integration/data/initializer/normal.yml index d7edcc4c..8b9f9e61 100644 --- a/tests/integration/data/initializer/normal.yml +++ b/tests/integration/data/initializer/normal.yml @@ -1,5 +1,5 @@ initializer: - name: Normal + typename: Normal args: mean: &mean 5.3 stddev: &stddev 9.0 diff --git a/tests/integration/data/initializer/uniform.yml b/tests/integration/data/initializer/uniform.yml index ffe6b3fb..aad2be74 100644 --- a/tests/integration/data/initializer/uniform.yml +++ b/tests/integration/data/initializer/uniform.yml @@ -1,5 +1,5 @@ initializer: - name: Uniform + typename: Uniform args: minval: -2.0 maxval: 6.0 diff --git a/tests/integration/data/initializer/xavier_conv2d_normal.yml b/tests/integration/data/initializer/xavier_conv2d_normal.yml index aa18b971..9f703d0d 100644 --- a/tests/integration/data/initializer/xavier_conv2d_normal.yml +++ b/tests/integration/data/initializer/xavier_conv2d_normal.yml @@ -1,5 +1,5 @@ initializer: - name: Xavier + typename: Xavier args: uniform: False diff --git a/tests/integration/data/initializer/xavier_conv2d_uniform.yml b/tests/integration/data/initializer/xavier_conv2d_uniform.yml index 579129e4..a6135d99 100644 --- a/tests/integration/data/initializer/xavier_conv2d_uniform.yml +++ b/tests/integration/data/initializer/xavier_conv2d_uniform.yml @@ -1,5 +1,5 @@ initializer: - name: Xavier + typename: Xavier args: uniform: True diff --git a/tests/integration/data/initializer/xavier_normal.yml b/tests/integration/data/initializer/xavier_normal.yml index 0ee59e11..d8adf64e 100644 --- a/tests/integration/data/initializer/xavier_normal.yml +++ b/tests/integration/data/initializer/xavier_normal.yml @@ -1,5 +1,5 @@ initializer: - name: Xavier + typename: Xavier args: uniform: False diff --git a/tests/integration/data/initializer/xavier_uniform.yml b/tests/integration/data/initializer/xavier_uniform.yml index 2a637551..f3101fdf 100644 --- a/tests/integration/data/initializer/xavier_uniform.yml +++ b/tests/integration/data/initializer/xavier_uniform.yml @@ -1,5 +1,5 @@ initializer: - name: Xavier + typename: Xavier args: uniform: True diff --git a/tests/integration/data/layer/batch_normalization_2d_learn.yml b/tests/integration/data/layer/batch_normalization_2d_learn.yml index 285be4a6..909db718 100644 --- a/tests/integration/data/layer/batch_normalization_2d_learn.yml +++ b/tests/integration/data/layer/batch_normalization_2d_learn.yml @@ -2,7 +2,7 @@ run: iteration: 10 layer: - name: BatchNormalization + typename: BatchNormalization args: scale: 2.0 offset: 0.5 diff --git a/tests/integration/data/layer/batch_normalization_2d_not_learn.yml b/tests/integration/data/layer/batch_normalization_2d_not_learn.yml index 2960c83e..504ec7bf 100644 --- a/tests/integration/data/layer/batch_normalization_2d_not_learn.yml +++ b/tests/integration/data/layer/batch_normalization_2d_not_learn.yml @@ -2,7 +2,7 @@ run: iteration: 10 layer: - name: BatchNormalization + typename: BatchNormalization args: scale: 2.0 offset: 0.5 diff --git a/tests/integration/data/layer/batch_normalization_4d_learn.yml b/tests/integration/data/layer/batch_normalization_4d_learn.yml index 285be4a6..909db718 100644 --- a/tests/integration/data/layer/batch_normalization_4d_learn.yml +++ b/tests/integration/data/layer/batch_normalization_4d_learn.yml @@ -2,7 +2,7 @@ run: iteration: 10 layer: - name: BatchNormalization + typename: BatchNormalization args: scale: 2.0 offset: 0.5 diff --git a/tests/integration/data/layer/batch_normalization_4d_not_learn.yml b/tests/integration/data/layer/batch_normalization_4d_not_learn.yml index 2960c83e..504ec7bf 100644 --- a/tests/integration/data/layer/batch_normalization_4d_not_learn.yml +++ b/tests/integration/data/layer/batch_normalization_4d_not_learn.yml @@ -2,7 +2,7 @@ run: iteration: 10 layer: - name: BatchNormalization + typename: BatchNormalization args: scale: 2.0 offset: 0.5 diff --git a/tests/integration/data/layer/conv2d_same.yml b/tests/integration/data/layer/conv2d_same.yml index 6932c976..10989564 100644 --- a/tests/integration/data/layer/conv2d_same.yml +++ b/tests/integration/data/layer/conv2d_same.yml @@ -1,5 +1,5 @@ layer: - name: Conv2D + typename: Conv2D args: filter_height: 7 filter_width: 5 diff --git a/tests/integration/data/layer/conv2d_valid.yml b/tests/integration/data/layer/conv2d_valid.yml index fed44ee3..26c2f02c 100644 --- a/tests/integration/data/layer/conv2d_valid.yml +++ b/tests/integration/data/layer/conv2d_valid.yml @@ -1,5 +1,5 @@ layer: - name: Conv2D + typename: Conv2D args: filter_height: 7 filter_width: 5 diff --git a/tests/integration/data/layer/conv2d_without_bias.yml b/tests/integration/data/layer/conv2d_without_bias.yml index dbf55104..9ded5730 100644 --- a/tests/integration/data/layer/conv2d_without_bias.yml +++ b/tests/integration/data/layer/conv2d_without_bias.yml @@ -1,5 +1,5 @@ layer: - name: Conv2D + typename: Conv2D args: filter_height: 7 filter_width: 5 diff --git a/tests/integration/data/layer/dense.yml b/tests/integration/data/layer/dense.yml index 8d88c5f9..e9dab38f 100644 --- a/tests/integration/data/layer/dense.yml +++ b/tests/integration/data/layer/dense.yml @@ -1,5 +1,5 @@ layer: - name: Dense + typename: Dense args: n_nodes: 7 with_bias: True diff --git a/tests/integration/data/layer/dense_without_bias.yml b/tests/integration/data/layer/dense_without_bias.yml index f319b608..430d34dc 100644 --- a/tests/integration/data/layer/dense_without_bias.yml +++ b/tests/integration/data/layer/dense_without_bias.yml @@ -1,5 +1,5 @@ layer: - name: Dense + typename: Dense args: n_nodes: 7 with_bias: False diff --git a/tests/integration/data/layer/flatten.yml b/tests/integration/data/layer/flatten.yml index 19192d9e..8dec35c0 100644 --- a/tests/integration/data/layer/flatten.yml +++ b/tests/integration/data/layer/flatten.yml @@ -1,4 +1,4 @@ layer: - name: Flatten + typename: Flatten input: input_mnist_10x4x28x27.h5 diff --git a/tests/integration/data/layer/relu.yml b/tests/integration/data/layer/relu.yml index 9726c362..1be06380 100644 --- a/tests/integration/data/layer/relu.yml +++ b/tests/integration/data/layer/relu.yml @@ -1,4 +1,4 @@ layer: - name: ReLU + typename: ReLU input: input_randn_5x3.h5 diff --git a/tests/integration/data/layer/sigmoid.yml b/tests/integration/data/layer/sigmoid.yml index 23185f07..dc195a5b 100644 --- a/tests/integration/data/layer/sigmoid.yml +++ b/tests/integration/data/layer/sigmoid.yml @@ -1,4 +1,4 @@ layer: - name: Sigmoid + typename: Sigmoid input: input_randn_5x3.h5 diff --git a/tests/integration/data/layer/softmax.yml b/tests/integration/data/layer/softmax.yml index f6d5f006..ff623fc9 100644 --- a/tests/integration/data/layer/softmax.yml +++ b/tests/integration/data/layer/softmax.yml @@ -1,4 +1,4 @@ layer: - name: Softmax + typename: Softmax input: input_randn_5x3.h5 diff --git a/tests/integration/data/layer/softplus.yml b/tests/integration/data/layer/softplus.yml index b71ce69b..375bdbef 100644 --- a/tests/integration/data/layer/softplus.yml +++ b/tests/integration/data/layer/softplus.yml @@ -1,4 +1,4 @@ layer: - name: Softplus + typename: Softplus input: input_randn_5x3.h5 diff --git a/tests/integration/data/layer/tanh.yml b/tests/integration/data/layer/tanh.yml index 5040bb43..466cb50f 100644 --- a/tests/integration/data/layer/tanh.yml +++ b/tests/integration/data/layer/tanh.yml @@ -1,4 +1,4 @@ layer: - name: Tanh + typename: Tanh input: input_randn_5x3.h5 diff --git a/tests/integration/data/layer/true_div.yml b/tests/integration/data/layer/true_div.yml index 9b172d11..efab0b3f 100644 --- a/tests/integration/data/layer/true_div.yml +++ b/tests/integration/data/layer/true_div.yml @@ -1,5 +1,5 @@ layer: - name: TrueDiv + typename: TrueDiv args: denom: 255 diff --git a/tests/integration/data/model/dqn.yml b/tests/integration/data/model/dqn.yml index 275f1783..6530524c 100644 --- a/tests/integration/data/model/dqn.yml +++ b/tests/integration/data/model/dqn.yml @@ -1,71 +1,59 @@ model_type: Sequential input: - name: Input + typename: Input args: dtype: uint8 shape: {input_shape} name: state layer_configs: - scope: layer0/preprocessing - layer: - name: TrueDiv - args: - denom: 255 + typename: TrueDiv + args: + denom: 255 - scope: layer1/conv2D - layer: - name: Conv2D - args: - n_filters: 32 - filter_width: 8 - filter_height: 8 - strides: 4 - padding: valid + typename: Conv2D + args: + n_filters: 32 + filter_width: 8 + filter_height: 8 + strides: 4 + padding: valid - scope: layer1/ReLU - layer: - name: ReLU + typename: ReLU - scope: layer2/conv2D - layer: - name: Conv2D - args: - n_filters: 64 - filter_width: 4 - filter_height: 4 - strides: 2 - padding: valid - with_bias: False + typename: Conv2D + args: + n_filters: 64 + filter_width: 4 + filter_height: 4 + strides: 2 + padding: valid + with_bias: False - scope: layer2/BatchNormalization - layer: - name: BatchNormalization - args: - learn: True + typename: BatchNormalization + args: + learn: True - scope: layer2/ReLU - layer: - name: ReLU + typename: ReLU - scope: layer3/flatten - layer: - name: Flatten + typename: Flatten - scope: layer4/dense - layer: - name: Dense - args: - n_nodes: 512 - with_bias: False + typename: Dense + args: + n_nodes: 512 + with_bias: False - scope: layer4/BatchNormalization - layer: - name: BatchNormalization - args: - learn: True + typename: BatchNormalization + args: + learn: True - scope: layer4/ReLU - layer: - name: ReLU + typename: ReLU - scope: layer5/dense - layer: - name: Dense - args: - n_nodes: {n_actions} - with_bias: False + typename: Dense + args: + n_nodes: {n_actions} + with_bias: False - scope: layer5/BatchNormalization - layer: - name: BatchNormalization - args: - learn: True + typename: BatchNormalization + args: + learn: True diff --git a/tests/integration/data/optimizer/Adam.yml b/tests/integration/data/optimizer/Adam.yml index 025cc0d8..3e26f38d 100644 --- a/tests/integration/data/optimizer/Adam.yml +++ b/tests/integration/data/optimizer/Adam.yml @@ -1,4 +1,4 @@ -name: Adam +typename: Adam args: learning_rate: 0.001 beta1: 0.9 diff --git a/tests/integration/data/optimizer/Adamax.yml b/tests/integration/data/optimizer/Adamax.yml index ef37afd2..8c104de1 100644 --- a/tests/integration/data/optimizer/Adamax.yml +++ b/tests/integration/data/optimizer/Adamax.yml @@ -1,4 +1,4 @@ -name: Adamax +typename: Adamax args: learning_rate: 0.001 beta1: 0.9 diff --git a/tests/integration/data/optimizer/GravesRMSProp.yml b/tests/integration/data/optimizer/GravesRMSProp.yml index 647bfcfe..14b82ab9 100644 --- a/tests/integration/data/optimizer/GravesRMSProp.yml +++ b/tests/integration/data/optimizer/GravesRMSProp.yml @@ -1,4 +1,4 @@ -name: GravesRMSProp +typename: GravesRMSProp args: learning_rate: 0.001 decay1: 0.95 diff --git a/tests/integration/data/optimizer/NeonRMSProp.yml b/tests/integration/data/optimizer/NeonRMSProp.yml index 95792970..3fa43ff9 100644 --- a/tests/integration/data/optimizer/NeonRMSProp.yml +++ b/tests/integration/data/optimizer/NeonRMSProp.yml @@ -1,4 +1,4 @@ -name: NeonRMSProp +typename: NeonRMSProp args: learning_rate: 0.001 decay: 0.95 diff --git a/tests/integration/data/optimizer/RMSProp_with_moment.yml b/tests/integration/data/optimizer/RMSProp_with_moment.yml index 194e26c5..27cc0cf8 100644 --- a/tests/integration/data/optimizer/RMSProp_with_moment.yml +++ b/tests/integration/data/optimizer/RMSProp_with_moment.yml @@ -1,4 +1,4 @@ -name: RMSProp +typename: RMSProp args: learning_rate: 0.001 decay: 0.95 diff --git a/tests/integration/data/optimizer/RMSProp_without_moment.yml b/tests/integration/data/optimizer/RMSProp_without_moment.yml index 27c64dd8..06a2c56d 100644 --- a/tests/integration/data/optimizer/RMSProp_without_moment.yml +++ b/tests/integration/data/optimizer/RMSProp_without_moment.yml @@ -1,4 +1,4 @@ -name: RMSProp +typename: RMSProp args: learning_rate: 0.001 decay: 0.95 diff --git a/tests/integration/data/optimizer/SGD.yml b/tests/integration/data/optimizer/SGD.yml index 671f2391..3dd2451d 100644 --- a/tests/integration/data/optimizer/SGD.yml +++ b/tests/integration/data/optimizer/SGD.yml @@ -1,3 +1,3 @@ -name: SGD +typename: SGD args: learning_rate: 0.001 diff --git a/tests/integration/test_initializer_compatibility/run_initializer.py b/tests/integration/test_initializer_compatibility/run_initializer.py index c47e692c..100a7307 100644 --- a/tests/integration/test_initializer_compatibility/run_initializer.py +++ b/tests/integration/test_initializer_compatibility/run_initializer.py @@ -30,8 +30,8 @@ def _parse_command_line_args(): return ap.parse_args() -def _create_initializer(name, args): - return nn.get_initializer(name)(**args) +def _create_initializer(typename, args): + return nn.get_initializer(typename)(**args) def _transpose_needed(initializer, shape): diff --git a/tests/integration/test_layer_numerical_compatibility/run_layer.py b/tests/integration/test_layer_numerical_compatibility/run_layer.py index 5c226063..84c03ff2 100644 --- a/tests/integration/test_layer_numerical_compatibility/run_layer.py +++ b/tests/integration/test_layer_numerical_compatibility/run_layer.py @@ -93,8 +93,8 @@ def _run_forward_prop(layer, input_value, parameter_file, iteration=1): return output -def _load_layer(cfg): - return nn.get_layer(cfg['name'])(**cfg.get('args', {})) +def _create_layer(typename, args=None): + return nn.get_layer(typename)(**(args or {})) def _load_input_value(filepath): @@ -150,7 +150,7 @@ def _main(): cfg, input_file, param_file = _load_config(args.config) output = _run_forward_prop( - layer=_load_layer(cfg['layer']), + layer=_create_layer(**cfg['layer']), input_value=_load_input_value(input_file), parameter_file=param_file, **cfg.get('run', {}) diff --git a/tests/integration/test_optimizer_numerical_compatibility/run_optimizer.py b/tests/integration/test_optimizer_numerical_compatibility/run_optimizer.py index 6ded8085..80cd2f13 100644 --- a/tests/integration/test_optimizer_numerical_compatibility/run_optimizer.py +++ b/tests/integration/test_optimizer_numerical_compatibility/run_optimizer.py @@ -65,7 +65,7 @@ def _optimize(optimizer, loss, wrt, n_ite): def _load_optimizer(filepath): cfg = load_config(filepath) - return get_optimizer(cfg['name'])(**cfg.get('args', {})) + return get_optimizer(cfg['typename'])(**cfg.get('args', {})) def _save_result(filepath, result): diff --git a/tests/integration/test_serialization/serialize_model.py b/tests/integration/test_serialization/serialize_model.py index e10d4a6a..e018043d 100644 --- a/tests/integration/test_serialization/serialize_model.py +++ b/tests/integration/test_serialization/serialize_model.py @@ -55,26 +55,30 @@ def _parse_command_line_args(): def _make_optimizer(filepath): cfg = load_config(filepath) - return nn.get_optimizer(cfg['name'])(**cfg['args']) + return nn.get_optimizer(cfg['typename'])(**cfg['args']) + + +def _gen_model_def(model_file): + fmt = luchador.get_nn_conv_format() + w, h, c = WIDTH, HEIGHT, CHANNEL + shape = ( + '[null, {}, {}, {}]'.format(h, w, c) if fmt == 'NHWC' else + '[null, {}, {}, {}]'.format(c, h, w) + ) + return nn.get_model_config( + model_file, n_actions=N_ACTIONS, input_shape=shape) def _build_network(model_filepath, optimizer_filepath, initial_parameter): _LG.info('Building Q networks') dql = DeepQLearning( - model_config={ - 'name': model_filepath, - 'initial_parameter': initial_parameter, - 'input_channel': CHANNEL, - 'input_height': HEIGHT, - 'input_width': WIDTH, - }, q_learning_config={ 'discount_rate': 0.99, 'min_reward': -1.0, 'max_reward': 1.0, }, cost_config={ - 'name': 'SSE2', + 'typename': 'SSE2', 'args': { 'min_delta': -1.0, 'max_delta': 1.0 @@ -82,7 +86,8 @@ def _build_network(model_filepath, optimizer_filepath, initial_parameter): }, optimizer_config=load_config(optimizer_filepath), ) - dql.build(n_actions=N_ACTIONS) + model_def = _gen_model_def(model_filepath) + dql.build(model_def, initial_parameter) _LG.info('Syncing models') dql.sync_network() return dql diff --git a/tests/integration/test_server_client/launch_remote_env.py b/tests/integration/test_server_client/launch_remote_env.py index 82e1a535..bec29892 100644 --- a/tests/integration/test_server_client/launch_remote_env.py +++ b/tests/integration/test_server_client/launch_remote_env.py @@ -22,7 +22,7 @@ def _main(): 'http://localhost:{}/create'.format(args.man_port), json={ 'environment': { - 'name': 'ALEEnvironment', + 'typename': 'ALEEnvironment', 'args': { 'rom': 'breakout.bin', 'display_screen': True, diff --git a/tests/unit/nn/model_test.py b/tests/unit/nn/model_test.py index 747397f7..1c65f89f 100644 --- a/tests/unit/nn/model_test.py +++ b/tests/unit/nn/model_test.py @@ -1,3 +1,4 @@ +"""Test nn.model.util module""" from __future__ import absolute_import import unittest @@ -7,6 +8,7 @@ class UtilTest(unittest.TestCase): + """Test model [de]serialization""" longMessage = True maxDiff = None