diff --git a/rim/rim.py b/rim/rim.py index 37717b5..aadcc6f 100644 --- a/rim/rim.py +++ b/rim/rim.py @@ -143,16 +143,17 @@ def model_score_fn(x, y, *args, **kwargs): self.model_score_fn = model_score_fn self.energy_fn = energy_fn - def initialization(self, observation) -> tuple[list[Tensor], Tensor, Tensor]: + def initialization(self, observation, *args, **kwargs) -> tuple[list[Tensor], Tensor, Tensor]: """ From an observation, initialize the parameters to be inferred, x, and the - hidden states of the recurrent neural net. - + hidden states of the recurrent neural net. + Args: observation (Tensor): The observation used for parameter initialization. - + *args, **kwargs: Additional arguments passed to approximate_inverse_fn + Returns: - tuple[list[Tensor], Tensor, Tensor]: A list for the optimization trajectories, + tuple[list[Tensor], Tensor, Tensor]: A list for the optimization trajectories, the initialized parameters x, and the initialized hidden states h. """ batch_size = observation.shape[0] @@ -161,7 +162,7 @@ def initialization(self, observation) -> tuple[list[Tensor], Tensor, Tensor]: if self.initialization_method == "zeros": x = torch.zeros((batch_size, self.C, *self.dimensions)).to(self.device) elif self.initialization_method == "approximate_inverse": - x_param = self.approximate_inverse_fn(observation) + x_param = self.approximate_inverse_fn(observation, *args, **kwargs) # Pass through args x = self.inverse_link_function(x_param) elif self.initialization_method == "model": x = torch.zeros((batch_size, self.C, *self.dimensions)).to(self.device) @@ -179,15 +180,15 @@ def initialization(self, observation) -> tuple[list[Tensor], Tensor, Tensor]: def forward(self, y: Tensor, *args, **kwargs) -> list[Tensor]: """ Perform the forward pass of the RIM optimization. - + Args: y (Tensor): The observation used for the optimization. args and kwargs: Additional arguments and keyword arguments for the score function. - + Returns: list[Tensor]: The RIM optimization trajectories, represented as a list of parameter x at every iteration of the recurrent series. """ - out, x, h = self.initialization(y) + out, x, h = self.initialization(y, *args, **kwargs) # Pass args to initialization for t in range(self.T): with torch.no_grad(): score = self.model_score_fn(x, y, *args, **kwargs) @@ -200,13 +201,13 @@ def forward(self, y: Tensor, *args, **kwargs) -> list[Tensor]: def predict(self, y: Tensor, *args, **kwargs) -> Tensor: """ Perform a prediction using the RIM optimization. - + Args: y (Tensor): The observation used for the prediction. args and kwargs: Additional arguments and keyword arguments for the score function. - + Returns: - Tensor: The predicted parameter x after the RIM optimization (in physical space). + Tensor: The predicted parameter x after the RIM optimization (in physical space). """ return self.link_function(self.forward(y, *args, **kwargs)[-1])