Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: lisa-lab/pylearn2
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: b1f8029188a1ef26b5546b70301ce601302478ec
Choose a base ref
..
head repository: lisa-lab/pylearn2
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: bfd1d94087e71e0f9364cb2c0ec4a10da6cd2532
Choose a head ref
4 changes: 0 additions & 4 deletions pylearn2/datasets/adult.py
Original file line number Diff line number Diff line change
@@ -98,7 +98,3 @@ def adult(which_set):
X = np.concatenate(pieces, axis=1)

return DenseDesignMatrix(X=X, y=y)

if __name__ == "__main__":
adult(which_set='train')
adult(which_set='test')
3 changes: 3 additions & 0 deletions pylearn2/datasets/tests/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""
Dataset testing classes
"""
24 changes: 24 additions & 0 deletions pylearn2/datasets/tests/test_adult.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
"""
Test code for adult.py
"""
import numpy
from pylearn2.datasets.adult import adult
from pylearn2.testing.skip import skip_if_no_data


def test_adult():
"""
Tests if it will work correctly for train and test set.
"""
skip_if_no_data()
adult_train = adult(which_set='train')
assert (adult_train.X >= 0.).all()
assert adult_train.y.dtype == bool
assert adult_train.X.shape == (30162, 104)
assert adult_train.y.shape == (30162, 1)

adult_test = adult(which_set='test')
assert (adult_test.X >= 0.).all()
assert adult_test.y.dtype == bool
assert adult_test.X.shape == (15060, 103)
assert adult_test.y.shape == (15060, 1)
32 changes: 32 additions & 0 deletions pylearn2/datasets/tests/test_avicenna.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""module for testing datasets.avicenna"""
import unittest
import numpy as np
from pylearn2.datasets.avicenna import Avicenna
from pylearn2.testing.skip import skip_if_no_data


def test_avicenna():
"""test that train/valid/test sets load (when standardize=False/true)."""
skip_if_no_data()
data = Avicenna(which_set='train', standardize=False)
assert data.X.shape == (150205, 120)

data = Avicenna(which_set='valid', standardize=False)
assert data.X.shape == (4096, 120)

data = Avicenna(which_set='test', standardize=False)
assert data.X.shape == (4096, 120)

# test that train/valid/test sets load (when standardize=True).
data_train = Avicenna(which_set='train', standardize=True)
assert data.X.shape == (150205, 120)

data_valid = Avicenna(which_set='valid', standardize=True)
assert data.X.shape == (4096, 120)

data_test = Avicenna(which_set='test', standardize=True)
assert data.X.shape == (4096, 120)

dt = np.concatenate([data_train.X, data_valid.X, data_test.X], axis=0)
assert np.allclose(dt.mean(), 0)
assert np.allclose(dt.std(), 1.)
124 changes: 124 additions & 0 deletions pylearn2/datasets/tests/test_cifar100.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,124 @@
"""Test for cifar100 dataset module"""

import unittest
import numpy as np
from pylearn2.datasets.cifar100 import CIFAR100
from pylearn2.space import Conv2DSpace
from pylearn2.testing.skip import skip_if_no_data


class TestCIFAR100(unittest.TestCase):
"""
Parameters
----------
none
"""

def setUp(self):
"""Load the train and test sets; check for nan and inf."""
skip_if_no_data()
self.train_set = CIFAR100(which_set='train')
self.test_set = CIFAR100(which_set='test')
assert not np.any(np.isnan(self.train_set.X))
assert not np.any(np.isinf(self.train_set.X))
assert not np.any(np.isnan(self.test_set.X))
assert not np.any(np.isinf(self.test_set.X))

def test_adjust_for_viewer(self):
"""Test method"""
self.train_set.adjust_for_viewer(self.train_set.X)

def test_adjust_to_be_viewed_with(self):
"""Test method on train set"""
self.train_set.adjust_to_be_viewed_with(
self.train_set.X,
np.ones(self.train_set.X.shape))

def test_get_test_set(self):
"""
Check that the train and test sets'
get_test_set methods return same thing.
"""
train_test_set = self.train_set.get_test_set()
test_test_set = self.test_set.get_test_set()
assert np.all(train_test_set.X == test_test_set.X)
assert np.all(train_test_set.X == self.test_set.X)

def test_topo(self):
"""Tests that a topological batch has 4 dimensions"""
topo = self.train_set.get_batch_topo(1)
assert topo.ndim == 4

def test_topo_c01b(self):
"""
Tests that a topological batch with axes ('c',0,1,'b')
can be dimshuffled back to match the standard ('b',0,1,'c')
format.
"""
batch_size = 100
c01b_test = CIFAR100(which_set='test', axes=('c', 0, 1, 'b'))
c01b_X = c01b_test.X[0:batch_size, :]
c01b = c01b_test.get_topological_view(c01b_X)
assert c01b.shape == (3, 32, 32, batch_size)
b01c = c01b.transpose(3, 1, 2, 0)
b01c_X = self.test_set.X[0:batch_size, :]
assert c01b_X.shape == b01c_X.shape
assert np.all(c01b_X == b01c_X)
b01c_direct = self.test_set.get_topological_view(b01c_X)
assert b01c_direct.shape == b01c.shape
assert np.all(b01c_direct == b01c)

def test_iterator(self):
"""
Tests that batches returned by an iterator with topological
data_specs are the same as the ones returned by calling
get_topological_view on the dataset with the corresponding order
"""
batch_size = 100
b01c_X = self.test_set.X[0:batch_size, :]
b01c_topo = self.test_set.get_topological_view(b01c_X)
b01c_b01c_it = self.test_set.iterator(
mode='sequential',
batch_size=batch_size,
data_specs=(Conv2DSpace(shape=(32, 32),
num_channels=3,
axes=('b', 0, 1, 'c')),
'features'))
b01c_b01c = b01c_b01c_it.next()
assert np.all(b01c_topo == b01c_b01c)

c01b_test = CIFAR100(which_set='test', axes=('c', 0, 1, 'b'))
c01b_X = c01b_test.X[0:batch_size, :]
c01b_topo = c01b_test.get_topological_view(c01b_X)
c01b_c01b_it = c01b_test.iterator(
mode='sequential',
batch_size=batch_size,
data_specs=(Conv2DSpace(shape=(32, 32),
num_channels=3,
axes=('c', 0, 1, 'b')),
'features'))
c01b_c01b = c01b_c01b_it.next()
assert np.all(c01b_topo == c01b_c01b)

# Also check that samples from iterators with the same data_specs
# with Conv2DSpace do not depend on the axes of the dataset
b01c_c01b_it = self.test_set.iterator(
mode='sequential',
batch_size=batch_size,
data_specs=(Conv2DSpace(shape=(32, 32),
num_channels=3,
axes=('c', 0, 1, 'b')),
'features'))
b01c_c01b = b01c_c01b_it.next()
assert np.all(b01c_c01b == c01b_c01b)

c01b_b01c_it = c01b_test.iterator(
mode='sequential',
batch_size=batch_size,
data_specs=(Conv2DSpace(shape=(32, 32),
num_channels=3,
axes=('b', 0, 1, 'c')),
'features'))
c01b_b01c = c01b_b01c_it.next()
assert np.all(c01b_b01c == b01c_b01c)
15 changes: 15 additions & 0 deletions pylearn2/datasets/tests/test_cos_dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
"""Test code for cos dataset."""
import numpy
from pylearn2.datasets.cos_dataset import CosDataset
from pylearn2.testing.skip import skip_if_no_data


def test_cos_dataset():
"""Tests if the dataset generator yields the desired value."""
skip_if_no_data()
dataset = CosDataset()

sample_batch = dataset.get_batch_design(batch_size=10000)
assert sample_batch.shape == (10000, 2)
assert sample_batch[:, 0].min() >= dataset.min_x
assert sample_batch[:, 0].max() <= dataset.max_x
13 changes: 13 additions & 0 deletions pylearn2/datasets/tests/test_hepatitis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""module for testing datasets.hepatitis"""
import numpy as np
import pylearn2.datasets.hepatitis as hepatitis
from pylearn2.testing.skip import skip_if_no_data


def test_hepatitis():
"""test hepatitis dataset"""
skip_if_no_data()
data = hepatitis.Hepatitis()
assert data.X is not None
assert np.all(data.X != np.inf)
assert np.all(data.X != np.nan)
13 changes: 13 additions & 0 deletions pylearn2/datasets/tests/test_iris.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
"""module for testing datasets.iris"""
import numpy as np
import pylearn2.datasets.iris as iris
from pylearn2.testing.skip import skip_if_no_data


def test_iris():
"""Load iris dataset"""
skip_if_no_data()
data = iris.Iris()
assert data.X is not None
assert np.all(data.X != np.inf)
assert np.all(data.X != np.nan)
27 changes: 27 additions & 0 deletions pylearn2/datasets/tests/test_mnist_rotated.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
"""
Testing class that simply checks to see if the class is
loadable
"""
from pylearn2.datasets.mnist import MNIST_rotated_background
from pylearn2.datasets.tests.test_mnist import TestMNIST
from pylearn2.testing.skip import skip_if_no_data


class TestMNIST_rotated(TestMNIST):
"""
Parameters
----------
None
Notes
-----
Testing class that simply checks to see if the rotated mnist is
loadable
"""
def setUp(self):
"""
Attempts to load train and test
"""
skip_if_no_data()
self.train = MNIST_rotated_background(which_set='train')
self.test = MNIST_rotated_background(which_set='test')
100 changes: 100 additions & 0 deletions pylearn2/datasets/tests/test_ocr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""module for testing datasets.ocr"""
import unittest
import numpy as np
from pylearn2.datasets.ocr import OCR
from pylearn2.space import Conv2DSpace
from pylearn2.testing.skip import skip_if_no_data


class TestOCR(unittest.TestCase):
"""
Unit test of OCR dataset
Parameters
----------
None
"""
def setUp(self):
"""Load train, test, valid sets"""
skip_if_no_data()
self.train = OCR(which_set='train')
self.valid = OCR(which_set='valid')
self.test = OCR(which_set='test')

def test_topo(self):
"""Tests that a topological batch has 4 dimensions"""
topo = self.train.get_batch_topo(1)
assert topo.ndim == 4

def test_topo_c01b(self):
"""
Tests that a topological batch with axes ('c',0,1,'b')
can be dimshuffled back to match the standard ('b',0,1,'c')
format.
"""
batch_size = 100
c01b_test = OCR(which_set='test', axes=('c', 0, 1, 'b'))
c01b_X = c01b_test.X[0:batch_size, :]
c01b = c01b_test.get_topological_view(c01b_X)
assert c01b.shape == (1, 16, 8, batch_size)
b01c = c01b.transpose(3, 1, 2, 0)
b01c_X = self.test.X[0:batch_size, :]
assert c01b_X.shape == b01c_X.shape
assert np.all(c01b_X == b01c_X)
b01c_direct = self.test.get_topological_view(b01c_X)
assert b01c_direct.shape == b01c.shape
assert np.all(b01c_direct == b01c)

def test_iterator(self):
"""
Tests that batches returned by an iterator with topological
data_specs are the same as the ones returned by calling
get_topological_view on the dataset with the corresponding order
"""
batch_size = 100
b01c_X = self.test.X[0:batch_size, :]
b01c_topo = self.test.get_topological_view(b01c_X)
b01c_b01c_it = self.test.iterator(
mode='sequential',
batch_size=batch_size,
data_specs=(Conv2DSpace(shape=(16, 8),
num_channels=1,
axes=('b', 0, 1, 'c')),
'features'))
b01c_b01c = b01c_b01c_it.next()
assert np.all(b01c_topo == b01c_b01c)

c01b_test = OCR(which_set='test', axes=('c', 0, 1, 'b'))
c01b_X = c01b_test.X[0:batch_size, :]
c01b_topo = c01b_test.get_topological_view(c01b_X)
c01b_c01b_it = c01b_test.iterator(
mode='sequential',
batch_size=batch_size,
data_specs=(Conv2DSpace(shape=(16, 8),
num_channels=1,
axes=('c', 0, 1, 'b')),
'features'))
c01b_c01b = c01b_c01b_it.next()
assert np.all(c01b_topo == c01b_c01b)

# Also check that samples from iterators with the same data_specs
# with Conv2DSpace do not depend on the axes of the dataset
b01c_c01b_it = self.test.iterator(
mode='sequential',
batch_size=batch_size,
data_specs=(Conv2DSpace(shape=(16, 8),
num_channels=1,
axes=('c', 0, 1, 'b')),
'features'))
b01c_c01b = b01c_c01b_it.next()
assert np.all(b01c_c01b == c01b_c01b)

c01b_b01c_it = c01b_test.iterator(
mode='sequential',
batch_size=batch_size,
data_specs=(Conv2DSpace(shape=(16, 8),
num_channels=1,
axes=('b', 0, 1, 'c')),
'features'))
c01b_b01c = c01b_b01c_it.next()
assert np.all(c01b_b01c == b01c_b01c)
1 change: 1 addition & 0 deletions pylearn2/models/dbm/layer.py
Original file line number Diff line number Diff line change
@@ -2062,6 +2062,7 @@ def __init__(self,
bias_from_marginals = None,
beta_lr_scale = 'by_sharing',
axes = ('b', 0, 1, 'c')):
super(type(self), self).__init__()

warnings.warn("GaussianVisLayer math very faith based, need to finish working through gaussian.lyx")

11 changes: 9 additions & 2 deletions pylearn2/models/mlp.py
Original file line number Diff line number Diff line change
@@ -296,9 +296,16 @@ def set_biases(self, biases):

def get_weights_format(self):
"""
.. todo::
Returns a description of how to interpret the weights of the layer.
Returns
-------
format: tuple
Either ('v', 'h') or ('h', 'v').
('v', 'h') means a weight matrix of shape
(num visible units, num hidden units),
while ('h', 'v') means the transpose of it.
WRITEME
"""
raise NotImplementedError

Loading