diff --git a/bubblewrap.py b/bubblewrap.py index 85071bf..e47ca17 100644 --- a/bubblewrap.py +++ b/bubblewrap.py @@ -141,6 +141,7 @@ def init_nodes(self): self.time_grad_Q = [] self.time_pred = [] self.entropy_list = [] + self.entropy_far = [] self.loss = [] self.t = 1 @@ -185,8 +186,10 @@ def single_e_step(self, x): if not self.go_fast: new_log_pred = self.log_pred_prob(self.B, self.A, self.alpha) self.pred.append(new_log_pred) - ent = entropy(self.A, self.alpha) - self.entropy_list.append(ent) + ent_present = entropy(self.A, self.alpha, future_distance=1) + ent_future = entropy(self.A, self.alpha, self.future_distance) + self.entropy_list.append(ent_present) + self.entropy_far.append(ent_future) if self.future_x is not None: future_B = self.logB_jax(self.future_x, self.mu, self.L, self.L_diag) pred_far = self.pred_ahead(future_B, self.A, self.alpha, self.future_distance) @@ -390,9 +393,9 @@ def pred_ahead(B, A, alpha, future_distance): AT = np.linalg.matrix_power(A,future_distance) return np.log(alpha @ AT @ np.exp(B) + 1e-16) -@jit -def entropy(A, alpha): - AT = np.linalg.matrix_power(A,1) +# @jit +def entropy(A, alpha, future_distance): + AT = np.linalg.matrix_power(A,future_distance) one = alpha @ AT return - np.sum(one.dot(np.log2(alpha @ AT)))