Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 13 additions & 12 deletions rim/rim.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,17 @@ def model_score_fn(x, y, *args, **kwargs):
self.model_score_fn = model_score_fn
self.energy_fn = energy_fn

def initialization(self, observation) -> tuple[list[Tensor], Tensor, Tensor]:
def initialization(self, observation, *args, **kwargs) -> tuple[list[Tensor], Tensor, Tensor]:
"""
From an observation, initialize the parameters to be inferred, x, and the
hidden states of the recurrent neural net.
hidden states of the recurrent neural net.

Args:
observation (Tensor): The observation used for parameter initialization.

*args, **kwargs: Additional arguments passed to approximate_inverse_fn

Returns:
tuple[list[Tensor], Tensor, Tensor]: A list for the optimization trajectories,
tuple[list[Tensor], Tensor, Tensor]: A list for the optimization trajectories,
the initialized parameters x, and the initialized hidden states h.
"""
batch_size = observation.shape[0]
Expand All @@ -161,7 +162,7 @@ def initialization(self, observation) -> tuple[list[Tensor], Tensor, Tensor]:
if self.initialization_method == "zeros":
x = torch.zeros((batch_size, self.C, *self.dimensions)).to(self.device)
elif self.initialization_method == "approximate_inverse":
x_param = self.approximate_inverse_fn(observation)
x_param = self.approximate_inverse_fn(observation, *args, **kwargs) # Pass through args
x = self.inverse_link_function(x_param)
elif self.initialization_method == "model":
x = torch.zeros((batch_size, self.C, *self.dimensions)).to(self.device)
Expand All @@ -179,15 +180,15 @@ def initialization(self, observation) -> tuple[list[Tensor], Tensor, Tensor]:
def forward(self, y: Tensor, *args, **kwargs) -> list[Tensor]:
"""
Perform the forward pass of the RIM optimization.

Args:
y (Tensor): The observation used for the optimization.
args and kwargs: Additional arguments and keyword arguments for the score function.

Returns:
list[Tensor]: The RIM optimization trajectories, represented as a list of parameter x at every iteration of the recurrent series.
"""
out, x, h = self.initialization(y)
out, x, h = self.initialization(y, *args, **kwargs) # Pass args to initialization
for t in range(self.T):
with torch.no_grad():
score = self.model_score_fn(x, y, *args, **kwargs)
Expand All @@ -200,13 +201,13 @@ def forward(self, y: Tensor, *args, **kwargs) -> list[Tensor]:
def predict(self, y: Tensor, *args, **kwargs) -> Tensor:
"""
Perform a prediction using the RIM optimization.

Args:
y (Tensor): The observation used for the prediction.
args and kwargs: Additional arguments and keyword arguments for the score function.

Returns:
Tensor: The predicted parameter x after the RIM optimization (in physical space).
Tensor: The predicted parameter x after the RIM optimization (in physical space).
"""
return self.link_function(self.forward(y, *args, **kwargs)[-1])

Expand Down