diff --git a/sparselm/tools.py b/sparselm/tools.py index d6103f7..98df200 100644 --- a/sparselm/tools.py +++ b/sparselm/tools.py @@ -7,7 +7,7 @@ import numpy as np -def constrain_coefficients(indices, high, low=0.0): +def constrain_coefficients(indices, high=None, low=None): """Constrain a fit method to keep coefficients within a specified range. Decorator to enforce that a fit method fitting a cluster expansion that @@ -30,8 +30,14 @@ def your_fit_method(X, y): """ indices = np.array(indices) - high = high * np.ones(len(indices)) if isinstance(high, float) else np.array(high) - low = low * np.ones(len(indices)) if isinstance(low, float) else np.array(low) + if high is not None: + high = high * np.ones(len(indices)) if isinstance(high, float) else np.array(high) + else: + high = np.inf * np.ones(len(indices)) + if low is not None: + low = low * np.ones(len(indices)) if isinstance(low, float) else np.array(low) + else: + low = -np.inf * np.ones(len(indices)) def decorate_fit_method(fit_method): """Decorate a fit method to constrain "dielectric constant".