Skip to content

Commit

Permalink
Replace separate SoftmaxOutput constructor parameters with Config struct
Browse files Browse the repository at this point in the history
This should make it much easier to customize one or two of the softmax
config parameters, for example:

```D
SoftmaxOutput.Config config = { preserve_shape = true };
auto s = new SoftmaxOutput(input, label, config);
```

A usage example unittest has been added.
  • Loading branch information
joseph-wakeling-sociomantic committed Dec 5, 2017
1 parent c6da7fd commit fe344ee
Show file tree
Hide file tree
Showing 2 changed files with 188 additions and 50 deletions.
25 changes: 25 additions & 0 deletions relnotes/softmaxoutput.migration.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
* `mxnet.Symbol.SoftmaxOutput`

A new nested `Config` class has been added in order to store the various
optional settings that can be used with this symbol. The constructor has
been rewritten to use an instance of this struct (initialized to default
values) rather than to require each parameter separately. This should
make it easier to instantiate `SoftmaxOutput` instances with only one or
a few custom settings, for example:

```D
SoftmaxOutput.Config config = {
grad_scale: 1.5,
preserve_shape: true
};
auto softmax_custom = new SoftmaxOutput(input, label, config);
```

The config parameter is entirely optional: instances created using the
default configuration via

```D
auto softmax_default = new SoftmaxOutput(input, label);
```

will behave exactly the same as before.
213 changes: 163 additions & 50 deletions src/mxnet/Symbol.d
Original file line number Diff line number Diff line change
Expand Up @@ -763,81 +763,171 @@ public enum SoftmaxOutputNormalization

public class SoftmaxOutput : Symbol
{
import ocean.core.Traits;

/***************************************************************************
Struct containing configuration options for the `SoftmaxOutput`
symbol
By default, all values will match the defaults used by the upstream
MXNet library. This struct only needs to be used directly if one or
more of these parameters is to be set to a non-default value.
***************************************************************************/

public struct Config
{
/***********************************************************************
Scale factor for scaling the gradient
***********************************************************************/

public float grad_scale = 1;

/***********************************************************************
Label to ignore during the backward pass, if `use_ignore` is `true`
***********************************************************************/

public float ignore_label = -1;

/***********************************************************************
If `true`, the softmax function will be computed along the second
axis, i.e. `axis == 1`
***********************************************************************/

public bool multi_output = false;

/***********************************************************************
If `true`, the `ignore_label` field can be used to specify a label
to ignore during the backward pass
***********************************************************************/

public bool use_ignore = false;

/***********************************************************************
If `true`, the softmax function will be computed along the last
axis, i.e. `axis == -1`
***********************************************************************/

public bool preserve_shape = false;

/***********************************************************************
Normalization applied to the gradient
***********************************************************************/

public SoftmaxOutputNormalization normalization =
SoftmaxOutputNormalization.batch;

/***********************************************************************
If `true`, the gradient will be multiplied elementwise by the
output gradient
***********************************************************************/

public bool out_grad = false;
}

///
unittest
{
// initialize a `Config` struct instance with only a couple of
// non-default values
Config config = {
ignore_label: 1,
use_ignore: true
};

// check values of custom fields
assert(config.ignore_label == 1.0f);
assert(config.ignore_label != Config.init.ignore_label);

assert(config.use_ignore);
assert(config.use_ignore != Config.init.use_ignore);

// check values of non-customized fields
assert(config.grad_scale == Config.init.grad_scale);
assert(config.multi_output == Config.init.multi_output);
assert(config.preserve_shape == Config.init.preserve_shape);
assert(config.normalization == Config.init.normalization);
assert(config.out_grad == Config.init.out_grad);
}


/***************************************************************************
Constructs `SoftmaxOutput` symbol
Params:
input = input symbol to apply softmax to
label = ground truth to compare against the output of softmax
normalization = normalization applied to the gradient; defaults to
batch
grad_scale = scale factor for scaling the gradient; defaults to 1
use_ignore = use ignore_label; defaults to false
ignore_label = all labels with this label will be ignored during the
backward pass; defaults to -1
multi_output = softmax applied to axis 1, if set to true; defaults
to false
preserve_shape = softmax will applied on the last axis, if true;
defaults to false
out_grad = apply weighting to output gradient
config = optional instance of Config struct containing extra
configuration parameters (if not specified, all will
be set to their default values)
***************************************************************************/

public this (Symbol input,
Symbol label,
SoftmaxOutputNormalization normalization = SoftmaxOutputNormalization.batch,
float grad_scale = 1,
bool use_ignore = false,
float ignore_label = -1,
bool multi_output = false,
bool preserve_shape = false,
bool out_grad = false)
Config config = Config.init)
in
{
assert(input !is null);
assert(label !is null);
}
body
{
char[16] buf_grad_scale = void;
cstring grad_scale_str = toNoLossString(grad_scale, buf_grad_scale);
buf_grad_scale[grad_scale_str.length] = '\0';

istring use_ignore_str = use_ignore ? "true" : "false";

char[16] buf_ignore_label = void;
cstring ignore_label_str = toNoLossString(ignore_label, buf_ignore_label);
buf_ignore_label[ignore_label_str.length] = '\0';
const istring[] softmax_normalizations = ["batch", "null", "valid"];

istring multi_output_str = multi_output ? "true" : "false";
alias typeof(Config.tupleof) ConfigTuple;

istring preserve_shape_str = preserve_shape ? "true" : "false";
istring[ConfigTuple.length] keys;
Const!(char)*[ConfigTuple.length] c_keys;
Const!(char)*[ConfigTuple.length] c_values;

istring out_grad_str = out_grad ? "true" : "false";
foreach (i, T; ConfigTuple)
{
keys[i] = FieldName!(i, Config);
c_keys[i] = keys[i].ptr;

istring[7] keys;
keys[0] = "grad_scale";
keys[1] = "ignore_label";
keys[2] = "multi_output";
keys[3] = "use_ignore";
keys[4] = "preserve_shape";
keys[5] = "normalization";
keys[6] = "out_grad";
auto field_value = config.tupleof[i];

Immut!(char)*[7] c_keys;
foreach (i, ref key; keys) c_keys[i] = key.ptr;

const istring[] softmax_normalizations = ["batch", "null", "valid"];
static if (isFloatingPointType!(T))
{
char[16] float_buf = void;
cstring float_string = toNoLossString(field_value, float_buf);
float_buf[float_string.length] = '\0';
c_values[i] = float_string.ptr;

Const!(char)*[7] c_values;
c_values[0] = grad_scale_str.ptr;
c_values[1] = ignore_label_str.ptr;
c_values[2] = multi_output_str.ptr;
c_values[3] = use_ignore_str.ptr;
c_values[4] = preserve_shape_str.ptr;
c_values[5] = softmax_normalizations[normalization].ptr;
c_values[6] = out_grad_str.ptr;
}
else static if (is(T == bool))
{
istring bool_string = field_value ? "true" : "false";
c_values[i] = bool_string.ptr;
}
else static if (is(T == SoftmaxOutputNormalization))
{
istring norm_string = softmax_normalizations[field_value];
c_values[i] = norm_string.ptr;
}
else
{
static assert("Unsupported SoftmaxOutput config type "
~ T.stringof);
}
}

super("SoftmaxOutput", c_keys, c_values);

Expand All @@ -860,6 +950,29 @@ public class SoftmaxOutput : Symbol
}
}

///
unittest
{
// set up input and label variables
scope input = new Variable("x");
scope (exit) input.freeHandle();
scope label = new Variable("y");
scope (exit) label.freeHandle();

// create SoftmaxOutput symbol using default configuration
scope softmax_default = new SoftmaxOutput(input, label);
scope (exit) softmax_default.freeHandle();

// create SoftmaxOutput symbol with custom configuration
SoftmaxOutput.Config config = {
grad_scale: 1.5,
preserve_shape: true
};
scope softmax_custom = new SoftmaxOutput(input, label, config);
scope (exit) softmax_custom.freeHandle();
}


/*******************************************************************************
A symbol representing linear regression using least squares
Expand Down

0 comments on commit fe344ee

Please sign in to comment.