Skip to content
This repository was archived by the owner on Dec 6, 2023. It is now read-only.

Commit 3fc0fc5

Browse files
committed
improve coverage and support older numpy tests
1 parent 93fefc6 commit 3fc0fc5

File tree

2 files changed

+51
-8
lines changed

2 files changed

+51
-8
lines changed

polylearn/tests/test_adagrad.py

Lines changed: 20 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
from nose.tools import assert_less_equal
22

33
import numpy as np
4-
from numpy.testing import assert_array_almost_equal, assert_raises_regex
4+
from numpy.testing import assert_array_almost_equal, assert_raises
5+
6+
try:
7+
from numpy.testing import assert_raises_regex
8+
has_assert_raises_regex = True
9+
except ImportError:
10+
has_assert_raises_regex = False
511

612
import scipy.sparse as sp
713

@@ -191,11 +197,18 @@ def test_predict_sensible_error():
191197
fit_linear=False, fit_lower=None,
192198
max_iter=3, random_state=0)
193199
reg.fit(X, y)
194-
assert_raises_regex(ValueError,
195-
"Incompatible dimensions",
196-
reg.predict,
197-
X[:, :2])
198-
reg.P_ = np.transpose(reg.P_, [1, 2, 0])
199-
assert_raises_regex(ValueError, "wrong order", reg.predict, X)
200+
if has_assert_raises_regex:
201+
assert_raises_regex(ValueError,
202+
"Incompatible dimensions",
203+
reg.predict,
204+
X[:, :2])
205+
reg.P_ = np.transpose(reg.P_, [1, 2, 0])
206+
assert_raises_regex(ValueError, "wrong order", reg.predict, X)
207+
else:
208+
# if assert_raises_regex is not available, use looser test
209+
assert_raises(ValueError, reg.predict, X[:, :2])
210+
reg.P_ = np.transpose(reg.P_, [1, 2, 0])
211+
assert_raises(ValueError, reg.predict, X)
212+
200213

201214

polylearn/tests/test_factorization_machine.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from nose.tools import assert_less_equal, assert_equal
77

88
import numpy as np
9-
from numpy.testing import assert_array_almost_equal
9+
from numpy.testing import assert_array_almost_equal, assert_raises
1010

1111
from sklearn.metrics import mean_squared_error
1212
from sklearn.utils.testing import assert_warns_message
@@ -340,3 +340,33 @@ def check_warm_start(degree):
340340
def test_warm_start():
341341
yield check_warm_start, 2
342342
yield check_warm_start, 3
343+
344+
345+
def test_lambdas():
346+
"""Check that +/-1 lambdas lead to better train error for even degree."""
347+
y = _poly_predict(X, P, lams, kernel="anova", degree=2)
348+
349+
est = FactorizationMachineRegressor(degree=2, n_components=5,
350+
fit_linear=False, fit_lower=None,
351+
beta=0.1, random_state=0)
352+
y_pred_ones = est.fit(X, y).predict(X)
353+
err_ones = mean_squared_error(y, y_pred_ones)
354+
355+
est.set_params(init_lambdas='random_signs')
356+
y_pred_signs = est.fit(X, y).predict(X)
357+
err_signs = mean_squared_error(y, y_pred_signs)
358+
359+
assert_less_equal(err_signs, err_ones)
360+
361+
362+
def test_unsupported_errors():
363+
y = _poly_predict(X, P, lams, kernel="anova", degree=2)
364+
365+
est = FactorizationMachineRegressor(degree=10, n_components=5,
366+
fit_linear=False, fit_lower=None,
367+
beta=0.1, random_state=0)
368+
369+
assert_raises(NotImplementedError, est.fit, X, y)
370+
371+
est.set_params(solver='adagrad', init_lambdas='random_signs')
372+
assert_raises(NotImplementedError, est.fit, X, y)

0 commit comments

Comments
 (0)