From 98979e42b9325742509dd08903f8f49f89462abf Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Fri, 8 Mar 2019 15:42:20 +0100 Subject: [PATCH] WIP: allow the loss to be a tuple in the Learner1D --- adaptive/learner/learner1D.py | 39 +++++++++++++++++++++++++++++------ 1 file changed, 33 insertions(+), 6 deletions(-) diff --git a/adaptive/learner/learner1D.py b/adaptive/learner/learner1D.py index 27af1c184..228d969a3 100644 --- a/adaptive/learner/learner1D.py +++ b/adaptive/learner/learner1D.py @@ -69,6 +69,28 @@ def _wrapped(loss_per_interval): return _wrapped +def loss_returns(return_type, return_length): + def _wrapped(loss_per_interval): + loss_per_interval.return_type = return_type + loss_per_interval.return_length = return_length + return loss_per_interval + return _wrapped + + +def inf_format(return_type, return_len=None): + is_iterable = hasattr(return_type, '__iter__') + if is_iterable: + return return_type(return_len * [np.inf]) + else: + return return_type(np.inf) + + +def ensure_tuple(x): + if not isinstance(x, Iterable): + x = (x,) + return x + + @uses_nth_neighbors(0) def uniform_loss(xs, ys): """Loss function that samples the domain uniformly. @@ -287,7 +309,8 @@ def npoints(self): def loss(self, real=True): losses = self.losses if real else self.losses_combined if not losses: - return np.inf + return inf_format(self.loss_per_interval.return_type, + self.loss_per_interval.return_length) max_interval, max_loss = losses.peekitem(0) return max_loss @@ -325,7 +348,7 @@ def _get_loss_in_interval(self, x_left, x_right): ys_scaled = tuple(self._scale_y(y) for y in ys) # we need to compute the loss for this interval - return self.loss_per_interval(xs_scaled, ys_scaled) + return ensure_tuple(self.loss_per_interval(xs_scaled, ys_scaled)) def _update_interpolated_loss_in_interval(self, x_left, x_right): if x_left is None or x_right is None: @@ -379,13 +402,17 @@ def _update_losses(self, x, real=True): left_loss_is_unknown = ((x_left is None) or (not real and x_right is None)) if (a is not None) and left_loss_is_unknown: - self.losses_combined[a, x] = float('inf') + self.losses_combined[a, x] = inf_format( + self.loss_per_interval.return_type, + self.loss_per_interval.return_length) # (no real point right of x) or (no real point left of b) right_loss_is_unknown = ((x_right is None) or (not real and x_left is None)) if (b is not None) and right_loss_is_unknown: - self.losses_combined[x, b] = float('inf') + self.losses_combined[x, b] = inf_format( + self.loss_per_interval.return_type, + self.loss_per_interval.return_length) @staticmethod def _find_neighbors(x, neighbors): @@ -660,8 +687,8 @@ def _set_data(self, data): def loss_manager(x_scale): def sort_key(ival, loss): - loss, ival = finite_loss(ival, loss, x_scale) - return -loss, ival + loss = [-finite_loss(ival, l, x_scale)[0] for l in loss] + return loss, ival sorted_dict = sortedcollections.ItemSortedDict(sort_key) return sorted_dict