Skip to content

feat: add latent encoder/decoder infrastructure for Graph-EFM port#648

Draft
Sir-Sloth-The-Lazy wants to merge 30 commits into
mllam:mainfrom
Sir-Sloth-The-Lazy:feat/latent-encoder-decoder-infra
Draft

feat: add latent encoder/decoder infrastructure for Graph-EFM port#648
Sir-Sloth-The-Lazy wants to merge 30 commits into
mllam:mainfrom
Sir-Sloth-The-Lazy:feat/latent-encoder-decoder-infra

Conversation

@Sir-Sloth-The-Lazy

@Sir-Sloth-The-Lazy Sir-Sloth-The-Lazy commented May 27, 2026

Copy link
Copy Markdown
Contributor

Describe your changes

Adds neural_lam/models/latent/ the encoder and decoder submodules that the probabilistic Graph-EFM model needs. This is infrastructure-only: no model uses these classes yet. They are consumed by the upcoming GraphEFMPredictor (StepPredictor subclass) which will close #62.

New modules in neural_lam/models/latent/:

File What it adds
base_encoder.py BaseLatentEncoder : abstract base; handles isotropic / diagonal Gaussian output
base_decoder.py BaseGraphLatentDecoder : abstract base; residual grid MLP + latent embedder + param map
constant_encoder.py ConstantLatentEncoder : input-independent prior (used when learn_prior=False)
graph_encoder.py GraphLatentEncoder : flat graph: grid → mesh via PropagationNet + InteractionNet stack
graph_decoder.py GraphLatentDecoder : flat graph: grid + latent → grid via g2m / processor / m2g
hi_graph_encoder.py HiGraphLatentEncoder : hierarchical mesh: propagates up to top level, reads out latent dist
hi_graph_decoder.py HiGraphLatentDecoder : hierarchical mesh: up + latent fusion + down pass back to grid

Adaptations from prob_model_lam for the current main architecture:

  • constants.GRID_STATE_DIM (removed) → num_state_vars constructor arg on all decoders
  • from neural_lam.interaction_net import ...from neural_lam.gnn_layers import ...
  • GraphLatentDecoder.processor unified with the other four GNN-seq constructions to use utils.make_gnn_seq, which also handles processor_layers=0 gracefully
  • HiGraph{Encoder,Decoder} raise ValueError for single-level meshes (where the latent would be silently ignored); points users to the flat variants

Also adds to neural_lam/utils.py:

  • IdentityModule pass-through nn.Module for multi-arg pyg.nn.Sequential pipelines
  • make_gnn_seq builds a pyg.nn.Sequential of InteractionNet layers, or IdentityModule when num_gnn_layers=0; lazy-imports gnn_layers to avoid the existing gnn_layers → utils circular dependency

Open question flagged in constant_encoder.py docstring: the static prior returns Normal(mean=1, std=1) faithful to prob_model_lam but the --learn_prior CLI help on that branch describes it as "mean 0". One of the two is wrong; will raise it separately with @joeloskarsson.

Dependencies:

Issue Link

Partially addresses #62 (prerequisite for the GraphEFMPredictor PR).

Type of change

  • 🐛 Bug fix (non-breaking change that fixes an issue)
  • ✨ New feature (non-breaking change that adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📖 Documentation (Addition or improvements to documentation)

Checklist before requesting a review

  • My branch is up-to-date with the target branch
  • I have performed a self-review of my code
  • For any new/modified functions/classes I have added docstrings that clearly describe its purpose, expected inputs and returned values
  • I have placed in-line comments to clarify the intent of any hard-to-understand passages of my code
  • I have updated the README to cover introduced code changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have given the PR a name that clearly describes the change, written in imperative form
  • I have requested a reviewer and an assignee

Checklist for reviewers

  • the code is readable
  • the code is well tested
  • the code is documented (including return types and parameters)
  • the code is easy to maintain

Author checklist after completed review

  • I have added a line to the CHANGELOG describing this change

Checklist for assignee

  • PR is up to date with the base branch
  • the tests pass
  • the PR is assigned to the next milestone
  • author has added an entry to the changelog

@Sir-Sloth-The-Lazy

Copy link
Copy Markdown
Contributor Author

Pinging @joeloskarsson !

Comment thread neural_lam/models/forecaster_module.py Outdated
Comment thread neural_lam/models/latent/__init__.py
@joeloskarsson joeloskarsson added this to the v0.7.0 (proposed) milestone Jun 4, 2026

@joeloskarsson joeloskarsson left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Had a look over the encoders now, decoders still TODO :)

):
super().__init__(latent_dim, output_dist)

self.g2m_gnn = PropagationNet(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think also here the gnn type should be set with the corresponding argparse flag.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've switched it to resolve the GNN via get_gnn_class() and added a g2m_gnn_type parameter wired to the existing --g2m_gnn_type flag, matching the deterministic models. Two things I'd like your opinion on before I finalize:

Default value. The original prob_model_lam encoder used PropagationNet for the g2m step, but --g2m_gnn_type defaults to InteractionNet. I've defaulted the new parameter to InteractionNet for consistency with the flag and the rest of the codebase, which does change the encoder's default behavior vs. the original port. Are you happy with that, or would you prefer keeping PropagationNet as the default here?

Shared vs. separate flag. Wiring this to --g2m_gnn_type means the latent encoder's g2m GNN and the step predictor's g2m GNN share a single flag and can't be configured independently. In the original Graph-EFM these differed (predictor g2m = InteractionNet, encoder g2m = PropagationNet). Do you think sharing the one flag is fine, or should the encoder get its own flag (e.g. --latent_g2m_gnn_type) to preserve that independence?

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, these are all very good questions. I feel like we are now hitting the point where the way we configure the GNN types is quite limited. But I am reluctant to here introduce some more complex system for configuring this.
In general all configurations are fine, the only constraint is that the mesh up-edges in the hierarchical encoders should really use propnets (they need to push information into the latent Z) and the up-edges in the decoder should really use inet (otherwise there is no use of Z at initialization).

I think the best solution might be to hard-code the GNN types mentioned above (this only affects the mesh-up edges), and then reuse the existing flags for all choices of which type to use for g2m/m2g. We can keep the InteractionNet default also for these encoders/decoder, to be consistent with the rest of the codebase.

@Sir-Sloth-The-Lazy Sir-Sloth-The-Lazy Jun 12, 2026

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've implemented this in 1396357:

Hard-coded GNN types (no longer parameters):

  • HiGraphLatentEncoder mesh-up edges: PropagationNet as you said, these need to push grid information up into the latent readout.
  • HiGraphLatentDecoder mesh-up edges: InteractionNet with a PropagationNet the residual path would carry the aggregated messages instead of the top-level receiver rep (= Z), leaving Z unused at initialization.
  • HiGraphLatentDecoder mesh-down edges: PropagationNet I hard-coded these too, by the symmetric argument: the downward pass has to push Z from the top level down through the hierarchy so it reaches the grid output. This matches the original Graph-EFM. Let me know if you'd rather keep these configurable since your comment only strictly required fixing the up-edges.
    I added short comments at each hard-coded site explaining the constraint, so future readers don't "helpfully" parameterize them again.

Everything else reuses the existing flags with the codebase-standard defaults:

  • g2m in both encoders and the decoders, and m2g in the decoders, remain configurable via g2m_gnn_type/m2g_gnn_type, now all defaulting to InteractionNet for consistency with the deterministic models (this changes the flat encoder's g2m and the decoders' m2g defaults vs. the original port, per your suggestion).
  • GraphEFM accepts g2m_gnn_type/m2g_gnn_type and passes them through to the prior, encoder and decoder, so they'll wire directly to the existing --g2m_gnn_type/--m2g_gnn_type flags once the model is registered. There's no mesh_up_gnn_type/mesh_down_gnn_type on GraphEFM since those are now fixed; the existing flags keep affecting Hi-LAM only.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great, that's perfect! Can we also expand on the description for these flags in train_model.py so that the user understand what the different flags impacts and not. Sufficient to add a note for mesh_up_gnn_type/mesh_down_gnn_type that these do not apply to the probabilistic Graph-EFM model.

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget this thread, still think we should add this to train_model.py

Comment thread neural_lam/models/latent/graph_encoder.py Outdated
Comment thread neural_lam/models/latent/constant_encoder.py Outdated
Comment thread neural_lam/models/latent/constant_encoder.py Outdated
Comment thread neural_lam/utils.py Outdated
Comment thread neural_lam/utils.py Outdated
Comment thread neural_lam/models/latent/hi_graph_encoder.py Outdated
@joeloskarsson joeloskarsson self-assigned this Jun 4, 2026
@Sir-Sloth-The-Lazy

Sir-Sloth-The-Lazy commented Jun 5, 2026

Copy link
Copy Markdown
Contributor Author

On it !

@Sir-Sloth-The-Lazy

Copy link
Copy Markdown
Contributor Author

I have pushed the latest changes , please have a look whenever you have time

@joeloskarsson joeloskarsson left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good work with fixing the earlier things! I now did a complete readthrough of all of this, which resulted in a number of comments, some small, some more substantial. There are also a few that require more discussion, so do expect us to have some back and forth around these.

As you are fixing these things, please leave me some comments very briefly explaining the fix and link to the commit that fixes it (a good reason to fix one thing in each commit :)).

Comment thread neural_lam/models/forecaster_module.py Outdated
):
super().__init__(latent_dim, output_dist)

self.g2m_gnn = PropagationNet(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, these are all very good questions. I feel like we are now hitting the point where the way we configure the GNN types is quite limited. But I am reluctant to here introduce some more complex system for configuring this.
In general all configurations are fine, the only constraint is that the mesh up-edges in the hierarchical encoders should really use propnets (they need to push information into the latent Z) and the up-edges in the decoder should really use inet (otherwise there is no use of Z at initialization).

I think the best solution might be to hard-code the GNN types mentioned above (this only affects the mesh-up edges), and then reuse the existing flags for all choices of which type to use for g2m/m2g. We can keep the InteractionNet default also for these encoders/decoder, to be consistent with the rest of the codebase.

Comment thread neural_lam/models/latent/constant_encoder.py Outdated
Comment thread neural_lam/utils.py
Comment thread neural_lam/models/latent/graph_encoder.py Outdated
Comment thread neural_lam/models/step_predictors/graph/graph_efm.py Outdated
Comment thread neural_lam/models/step_predictors/graph/graph_efm.py Outdated
Comment thread neural_lam/models/step_predictors/graph/graph_efm.py Outdated
Comment thread neural_lam/models/step_predictors/graph/graph_efm.py Outdated
Comment thread neural_lam/models/step_predictors/graph/graph_efm.py

@joeloskarsson joeloskarsson left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good job with all the nice fixes! I went over some of the previous comments and resolved most of them. Still have some left, but submitting these comments for now just so you can see them.

Comment thread neural_lam/models/step_predictors/graph/graph_efm.py Outdated
Comment thread neural_lam/models/step_predictors/graph/graph_efm.py Outdated
Comment thread neural_lam/models/step_predictors/graph/graph_efm.py Outdated
Comment thread neural_lam/models/step_predictors/graph/graph_efm.py Outdated
Comment thread neural_lam/models/step_predictors/graph/graph_efm.py
Comment thread neural_lam/models/step_predictors/graph/graph_efm.py
Comment thread neural_lam/utils.py Outdated

@joeloskarsson joeloskarsson left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now I have gone over everything, and resolved or commented on things needing further attention 😄This is coming along super nicely I think, and especially nice that we are finding some shortcomings of both the earlier graph-efm code and in the repo, and able to resolve much of that.

Most things left are smaller fixes. The bigger question is how this will fit with the larger Module/Forecaster framework. This should be the priority to hammer out now. I have also not yet looked much at the methods in GraphEFM that deal with loss computation, simply because I expect these might change a bit once we figure out the details of the probabilistic model training interface.

Comment thread neural_lam/utils.py Outdated
Comment thread neural_lam/models/step_predictors/graph/graph_efm.py Outdated
@Sir-Sloth-The-Lazy

Sir-Sloth-The-Lazy commented Jun 18, 2026

Copy link
Copy Markdown
Contributor Author

Thank you @joeloskarsson for the review 🤩. I would be happy to take the work of utils refactor. I will be opening another PR soon solving that issue

Sir-Sloth-The-Lazy and others added 11 commits June 21, 2026 17:08
Adds neural_lam/models/latent/ with the encoder and decoder submodules
needed by the probabilistic GraphEFM model (issue mllam#62). Ported from the
prob_model_lam branch with adaptations for the current main architecture:

- constants.GRID_STATE_DIM replaced by a num_state_vars constructor arg
- interaction_net imports updated to neural_lam.gnn_layers
- GraphLatentDecoder.processor unified with the other four GNN-seq
  constructions to use utils.make_gnn_seq (handles processor_layers=0)
- HiGraph{Encoder,Decoder} guard against single-level meshes where the
  latent variable would be silently ignored
- ConstantLatentEncoder docstring documents the N(1,1) vs N(0,1)
  discrepancy with the prob_model_lam CLI help (open question upstream)

Also adds to neural_lam/utils.py:
- IdentityModule: pass-through nn.Module for multi-arg sequential GNNs
- make_gnn_seq: builds a pyg.nn.Sequential of InteractionNets, or an
  IdentityModule when num_gnn_layers=0; lazy-imports gnn_layers to
  avoid the existing gnn_layers -> utils circular dependency

17 tests in tests/test_latent_modules.py cover output shapes,
distribution properties, backprop to every parameter, 2- and 3-level
hierarchical graphs, intra_level_layers=0, and the single-level guard.
Co-authored-by: Joel Oskarsson <joel.oskarsson@outlook.com>
Make GNN types configurable and tidy up the latent modules per PR review:

- make_gnn_seq: accept a gnn_type arg (resolved via get_gnn_class) so it is
  not limited to InteractionNet, and make it strict (raise on
  num_gnn_layers < 1) instead of silently returning an IdentityModule;
  callers now own the no-op (identity) case explicitly.
- graph/hi encoders and decoders: expose g2m/m2g/mesh_up/mesh_down gnn_type
  parameters wired to get_gnn_class, with defaults matching prob_model_lam.
- graph encoder/decoder: rename processor_layers -> m2m_layers (and the
  self.processor attribute -> self.m2m_gnns); "processor" was misleading in
  an encoder/decoder context.
- ConstantLatentEncoder: return zeros instead of ones so the static prior is
  mean 0 (fixes the prob_model_lam mean-1 bug; matches its own CLI help).
- tests: update for the renamed arg and strict make_gnn_seq, add coverage for
  the flat zero-m2m identity path, and assert the constant prior is N(0, 1).
Port prob_model_lam's GraphEFM single-step half onto the StepPredictor
interface, reusing the latent encoder/decoder infra. The predictor owns
its conditional prior, variational encoder, and latent decoder, plus the
per-step ELBO pieces (compute_step_loss) and sampling helpers; rollout,
ELBO assembly, ensemble logic, and logging stay outside it.

- forward is source's predict_step (prior rsample -> decode -> sampled
  next state); no rescaling/clamping
- loss_fn and interior_mask are threaded parameters, not predictor state;
  compute_step_loss takes compute_kl (kl_term=None when off)
- per_var_std mirrors ForecasterModule's formula, hence the config arg
- one class for flat + hierarchical meshes, resolved from self.hierarchical
- not registered in MODELS yet (needs config / no mesh_aggr); config-aware
  assembly deferred to the ensemble-forecaster PR

Adds tests/test_graph_efm_predictor.py covering forward shapes, output_std,
compute_step_loss + KL toggle, differentiability, member stochasticity,
sample_obs_noise, and the per_var_std formula (flat + hierarchical).
…s for the rest

Per review discussion: the architecturally constrained edge sets in the
hierarchical latent modules get fixed GNN types instead of parameters:
- HiGraphLatentEncoder mesh-up: PropagationNet (must push grid info up
  into the latent readout)
- HiGraphLatentDecoder mesh-up: InteractionNet (PropagationNet residual
  would bypass Z at the top level, leaving it unused at initialization)
- HiGraphLatentDecoder mesh-down: PropagationNet (must push Z down the
  hierarchy to reach the grid output)

All remaining choices (g2m/m2g) stay configurable and default to
InteractionNet for consistency with the rest of the codebase. GraphEFM
now accepts g2m_gnn_type/m2g_gnn_type and passes them through to the
prior, encoder and decoder, ready for wiring to the existing argparse
flags.
Upstream main added an interrogate pre-commit hook requiring 100%
docstring coverage, which failed on this branch's CI after merging.

- Remove the branch's pre-reorganization duplicates (forecaster.py,
  ar_forecaster.py, step_predictor.py, forecaster_module.py); main
  carries the same code under models/forecasters/, models/
  step_predictors/ and models/module.py, and all imports already go
  through the new layout.
- Add the missing module and __init__ docstrings (numpy style) in the
  latent modules, GraphEFM and utils.IdentityModule.forward.
Remove references to the original prob_model_lam implementation and
other work meta-information from docstrings and comments, per review.
Docstrings now describe what each class/function does; usage context
is left to call sites.
Add proper Parameters/Returns sections following the numpydoc
convention, per review.
…dentityModule

When m2m_layers / intra_level_layers is 0, the latent modules now set
the corresponding GNN attribute to None and skip the update in the
forward pass, instead of routing representations through a no-op
IdentityModule. This makes it clear from the forward code that no
processing happens in that case. IdentityModule is removed from utils.

The hierarchical up/down loops index levels explicitly to accommodate
the conditional; outputs are unchanged (verified bit-identical against
the previous implementation).
…base class

Use the base class summary and expand on it with the constant-specific
behavior, per review.
Sir-Sloth-The-Lazy and others added 12 commits June 21, 2026 17:08
Co-authored-by: Joel Oskarsson <joel.oskarsson@outlook.com>
Describe the role of each representation in the message passing
(sender/receiver, where the latent enters, purpose of the residual
grid rep) in BaseGraphLatentDecoder and the inheriting decoders, per
review.
Replace the single GraphEFM class that resolved flat vs hierarchical at
construction with an explicit class per graph type, per review:

- BaseGraphEFM: graph-type independent setup (graph loading, grid and
  grid-mesh edge embedders, per-variable std) and all shared behavior
  (forward, compute_step_loss, estimate_likelihood, sampling helpers).
  Validates the loaded graph against the subclass's
  requires_hierarchical and exposes an embedd_mesh hook used by
  embedd_all.
- GraphEFM: hierarchical mesh graphs; builds per-level mesh embedders
  and HiGraphLatentEncoder/HiGraphLatentDecoder modules.
- GraphEFMMS: flat (e.g. multi-scale) mesh graphs; builds the flat mesh
  embedders and GraphLatentEncoder/GraphLatentDecoder modules.

learn_prior remains a constructor flag on both subclasses. Tests select
the class per graph type and cover the graph-type mismatch error.
Rename the layer-count parameters to say what the layers are, matching
the latent module parameter names: prior/encoder/decoder_intra_level_
layers on GraphEFM (hierarchical) and prior/encoder/decoder_m2m_layers
on GraphEFMMS (flat), per review.
Make the GraphEFM and GraphEFMMS __init__ docstrings self-contained
with the full parameter list, instead of pointing at the base class for
the shared ones, per review.
Sampling uncorrelated Gaussian observation noise per grid node is not
useful in practice, so the option is removed entirely rather than left
to tempt users, per review. forward now returns the decoder mean
directly (the prediction is stochastic only through the latent sample)
and the trivial sample_next_state helper is dropped.
BaseGraphModel and BaseGraphEFM duplicated their graph loading,
buffer registration and grid-input-dim computation. Factor these into
two utils helpers used by both:

- utils.load_and_register_graph(module, datastore, graph_name): loads
  the graph and registers its tensors/BufferLists on the module,
  returning whether it is hierarchical.
- utils.grid_input_dim(datastore, grid_static_dim, num_past_forcing_
  steps, num_future_forcing_steps): the total grid input dimensionality.

This keeps the two model families' grid-feature setup in one place
(e.g. for a future boundary-forcing input) without coupling their
differing forward passes or submodule sets via inheritance.
Rename embedd_all -> embedd_grid_and_graph (embeds the grid for states
up to t-1 plus the full graph) and embedd_current ->
embedd_grid_with_target (embeds the grid including the target state, for
the encoder), per review.
In forward, prev_state is already X_t, so pass it directly to the
decoder instead of aliasing it to last_state, per review.
Co-authored-by: Joel Oskarsson <joel.oskarsson@outlook.com>
Move the constant-prior branch (identical across subclasses) into
BaseGraphEFM.build_prior and delegate the graph-specific learnable prior
to build_learnable_prior, which GraphEFM and GraphEFMMS implement.
Spell out the multi-scale variant's name instead of the EFMMS
abbreviation, keeping the GraphEFM prefix shared with the hierarchical
variant.
…_dim

The static feature count is already available from the datastore via
get_num_data_vars(category="static"), so drop the redundant
grid_static_dim argument and query it inside the function.
get_num_data_vars("static") can report a nonzero count even when the
datastore provides no static dataarray, in which case the grid static
buffer is empty. Mirror the buffer construction by treating a None static
dataarray as zero static features, fixing the no-static-features case.

@joeloskarsson joeloskarsson left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice work on this! Only one small thing with the build_prior setup, and a reminders about the two unresolved threads from above.

The thing left now is just the loss/likelihood part. But I think that will have to wait until we have started to settle #685. I will look at that next.

# inert -- accepted for interface parity with other StepPredictors.
self.prepare_clamping_params(datastore)

def build_prior(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually need this function? It seems to be called identically by both subclasses, so could we not just do the prior setup immediately in the base class constructor, calling self.build_learnable_prior just as here, which will be the part handled by the subclasses.

):
super().__init__(latent_dim, output_dist)

self.g2m_gnn = PropagationNet(

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't forget this thread, still think we should add this to train_model.py

Comment thread neural_lam/utils.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Merge Graph-EFM model from prob_model_lam branch

2 participants