Skip to content

Commit 82f9bcf

Browse files
author
Paul Masurel
committed
1549: Added quantile regression as possible final MLP Layer.
closes #1549
1 parent a303ec1 commit 82f9bcf

File tree

2 files changed

+83
-2
lines changed

2 files changed

+83
-2
lines changed

pylearn2/models/mlp.py

+46
Original file line numberDiff line numberDiff line change
@@ -3824,6 +3824,52 @@ def L1WeightDecay(*args, **kwargs):
38243824
return _L1WD(*args, **kwargs)
38253825

38263826

3827+
class QuantileRegression(Linear):
3828+
"""
3829+
A linear layer for quantile regression.
3830+
3831+
A QuantileRegression (http://en.wikipedia.org/wiki/Quantile_regression)
3832+
is a linear layer that uses a specific cost that makes it possible to get
3833+
an estimator of a specific percentile of a posterior distribution.
3834+
3835+
Parameters
3836+
----------
3837+
layer_name: str
3838+
The layer name
3839+
percentile: float (0 < percentile < 1)
3840+
Percentile being estimated.
3841+
3842+
"""
3843+
def __init__(self,
3844+
layer_name,
3845+
percentile=0.2,
3846+
**kargs):
3847+
Linear.__init__(self, 1, layer_name, **kargs)
3848+
self.percentile = percentile
3849+
3850+
@wraps(Layer.get_layer_monitoring_channels)
3851+
def get_layer_monitoring_channels(self,
3852+
state_below=None,
3853+
state=None,
3854+
targets=None):
3855+
rval = Linear.get_layer_monitoring_channels(
3856+
self,
3857+
state_below,
3858+
state,
3859+
targets)
3860+
assert isinstance(rval, OrderedDict)
3861+
if targets:
3862+
rval['qcost'] = (T.abs_(targets - state) * (0.5 +
3863+
(self.percentile - 0.5) *
3864+
T.sgn(targets - state))).mean()
3865+
return rval
3866+
3867+
@wraps(Layer.cost_matrix)
3868+
def cost_matrix(self, Y, Y_hat):
3869+
return T.abs_(Y - Y_hat) * (0.5 + (self.percentile - 0.5) *
3870+
T.sgn(Y - Y_hat))
3871+
3872+
38273873
class LinearGaussian(Linear):
38283874

38293875
"""

pylearn2/models/tests/test_mlp.py

+37-2
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@
2424
exhaustive_dropout_average,
2525
sampled_dropout_average, CompositeLayer,
2626
max_pool, mean_pool, pool_dnn,
27-
SigmoidConvNonlinearity, ConvElemwise)
27+
SigmoidConvNonlinearity, ConvElemwise,
28+
QuantileRegression)
2829
from pylearn2.space import VectorSpace, CompositeSpace, Conv2DSpace
2930
from pylearn2.utils import is_iterable, sharedX
3031
from pylearn2.expr.nnet import pseudoinverse_softmax_numpy
3132

32-
3333
class IdentityLayer(Linear):
3434
dropout_input_mask_value = -np.inf
3535

@@ -1389,3 +1389,38 @@ def test_pooling_with_anon_variable():
13891389
image_shape=im_shp, try_dnn=False)
13901390
pool_1 = mean_pool(X_sym, pool_shape=shp, pool_stride=strd,
13911391
image_shape=im_shp)
1392+
1393+
1394+
def test_quantile_regression():
1395+
"""
1396+
Create a VectorSpacesDataset with two inputs (features0 and features1)
1397+
and train an MLP which takes both inputs for 1 epoch.
1398+
"""
1399+
np.random.seed(2)
1400+
nb_rows = 1000
1401+
X = np.random.normal(size=(nb_rows, 2)).astype(theano.config.floatX)
1402+
noise = np.random.rand(nb_rows, 1) # X[:, 0:1] *
1403+
coeffs = np.array([[3.], [4.]])
1404+
y_0 = np.dot(X, coeffs)
1405+
y = y_0 + noise
1406+
dataset = DenseDesignMatrix(X=X, y=y)
1407+
for percentile in [0.22, 0.5, 0.65]:
1408+
mlp = MLP(
1409+
nvis=2,
1410+
layers=[
1411+
QuantileRegression('quantile_regression_layer',
1412+
init_bias=0.0,
1413+
percentile=percentile,
1414+
irange=0.1)
1415+
]
1416+
)
1417+
train = Train(dataset, mlp, SGD(0.05, batch_size=100))
1418+
train.algorithm.termination_criterion = EpochCounter(100)
1419+
train.main_loop()
1420+
inputs = mlp.get_input_space().make_theano_batch()
1421+
outputs = mlp.fprop(inputs)
1422+
theano.function([inputs], outputs, allow_input_downcast=True)(X)
1423+
layers = mlp.layers
1424+
layer = layers[0]
1425+
assert np.allclose(layer.get_weights(), coeffs, rtol=0.05)
1426+
assert np.allclose(layer.get_biases(), np.array(percentile), rtol=0.05)

0 commit comments

Comments
 (0)