feat: add latent encoder/decoder infrastructure for Graph-EFM port#648
feat: add latent encoder/decoder infrastructure for Graph-EFM port#648Sir-Sloth-The-Lazy wants to merge 30 commits into
Conversation
|
Pinging @joeloskarsson ! |
joeloskarsson
left a comment
There was a problem hiding this comment.
Had a look over the encoders now, decoders still TODO :)
| ): | ||
| super().__init__(latent_dim, output_dist) | ||
|
|
||
| self.g2m_gnn = PropagationNet( |
There was a problem hiding this comment.
I think also here the gnn type should be set with the corresponding argparse flag.
There was a problem hiding this comment.
I've switched it to resolve the GNN via get_gnn_class() and added a g2m_gnn_type parameter wired to the existing --g2m_gnn_type flag, matching the deterministic models. Two things I'd like your opinion on before I finalize:
Default value. The original prob_model_lam encoder used PropagationNet for the g2m step, but --g2m_gnn_type defaults to InteractionNet. I've defaulted the new parameter to InteractionNet for consistency with the flag and the rest of the codebase, which does change the encoder's default behavior vs. the original port. Are you happy with that, or would you prefer keeping PropagationNet as the default here?
Shared vs. separate flag. Wiring this to --g2m_gnn_type means the latent encoder's g2m GNN and the step predictor's g2m GNN share a single flag and can't be configured independently. In the original Graph-EFM these differed (predictor g2m = InteractionNet, encoder g2m = PropagationNet). Do you think sharing the one flag is fine, or should the encoder get its own flag (e.g. --latent_g2m_gnn_type) to preserve that independence?
There was a problem hiding this comment.
Hmm, these are all very good questions. I feel like we are now hitting the point where the way we configure the GNN types is quite limited. But I am reluctant to here introduce some more complex system for configuring this.
In general all configurations are fine, the only constraint is that the mesh up-edges in the hierarchical encoders should really use propnets (they need to push information into the latent Z) and the up-edges in the decoder should really use inet (otherwise there is no use of Z at initialization).
I think the best solution might be to hard-code the GNN types mentioned above (this only affects the mesh-up edges), and then reuse the existing flags for all choices of which type to use for g2m/m2g. We can keep the InteractionNet default also for these encoders/decoder, to be consistent with the rest of the codebase.
There was a problem hiding this comment.
I've implemented this in 1396357:
Hard-coded GNN types (no longer parameters):
HiGraphLatentEncodermesh-up edges:PropagationNetas you said, these need to push grid information up into the latent readout.HiGraphLatentDecodermesh-up edges:InteractionNetwith a PropagationNet the residual path would carry the aggregated messages instead of the top-level receiver rep (= Z), leaving Z unused at initialization.HiGraphLatentDecodermesh-down edges:PropagationNetI hard-coded these too, by the symmetric argument: the downward pass has to push Z from the top level down through the hierarchy so it reaches the grid output. This matches the original Graph-EFM. Let me know if you'd rather keep these configurable since your comment only strictly required fixing the up-edges.
I added short comments at each hard-coded site explaining the constraint, so future readers don't "helpfully" parameterize them again.
Everything else reuses the existing flags with the codebase-standard defaults:
- g2m in both encoders and the decoders, and m2g in the decoders, remain configurable via
g2m_gnn_type/m2g_gnn_type, now all defaulting toInteractionNetfor consistency with the deterministic models (this changes the flat encoder's g2m and the decoders' m2g defaults vs. the original port, per your suggestion). GraphEFMacceptsg2m_gnn_type/m2g_gnn_typeand passes them through to the prior, encoder and decoder, so they'll wire directly to the existing--g2m_gnn_type/--m2g_gnn_typeflags once the model is registered. There's nomesh_up_gnn_type/mesh_down_gnn_typeonGraphEFMsince those are now fixed; the existing flags keep affecting Hi-LAM only.
There was a problem hiding this comment.
Great, that's perfect! Can we also expand on the description for these flags in train_model.py so that the user understand what the different flags impacts and not. Sufficient to add a note for mesh_up_gnn_type/mesh_down_gnn_type that these do not apply to the probabilistic Graph-EFM model.
There was a problem hiding this comment.
Don't forget this thread, still think we should add this to train_model.py
|
On it ! |
|
I have pushed the latest changes , please have a look whenever you have time |
joeloskarsson
left a comment
There was a problem hiding this comment.
Good work with fixing the earlier things! I now did a complete readthrough of all of this, which resulted in a number of comments, some small, some more substantial. There are also a few that require more discussion, so do expect us to have some back and forth around these.
As you are fixing these things, please leave me some comments very briefly explaining the fix and link to the commit that fixes it (a good reason to fix one thing in each commit :)).
| ): | ||
| super().__init__(latent_dim, output_dist) | ||
|
|
||
| self.g2m_gnn = PropagationNet( |
There was a problem hiding this comment.
Hmm, these are all very good questions. I feel like we are now hitting the point where the way we configure the GNN types is quite limited. But I am reluctant to here introduce some more complex system for configuring this.
In general all configurations are fine, the only constraint is that the mesh up-edges in the hierarchical encoders should really use propnets (they need to push information into the latent Z) and the up-edges in the decoder should really use inet (otherwise there is no use of Z at initialization).
I think the best solution might be to hard-code the GNN types mentioned above (this only affects the mesh-up edges), and then reuse the existing flags for all choices of which type to use for g2m/m2g. We can keep the InteractionNet default also for these encoders/decoder, to be consistent with the rest of the codebase.
joeloskarsson
left a comment
There was a problem hiding this comment.
Good job with all the nice fixes! I went over some of the previous comments and resolved most of them. Still have some left, but submitting these comments for now just so you can see them.
joeloskarsson
left a comment
There was a problem hiding this comment.
Now I have gone over everything, and resolved or commented on things needing further attention 😄This is coming along super nicely I think, and especially nice that we are finding some shortcomings of both the earlier graph-efm code and in the repo, and able to resolve much of that.
Most things left are smaller fixes. The bigger question is how this will fit with the larger Module/Forecaster framework. This should be the priority to hammer out now. I have also not yet looked much at the methods in GraphEFM that deal with loss computation, simply because I expect these might change a bit once we figure out the details of the probabilistic model training interface.
|
Thank you @joeloskarsson for the review 🤩. I would be happy to take the work of utils refactor. I will be opening another PR soon solving that issue |
Adds neural_lam/models/latent/ with the encoder and decoder submodules needed by the probabilistic GraphEFM model (issue mllam#62). Ported from the prob_model_lam branch with adaptations for the current main architecture: - constants.GRID_STATE_DIM replaced by a num_state_vars constructor arg - interaction_net imports updated to neural_lam.gnn_layers - GraphLatentDecoder.processor unified with the other four GNN-seq constructions to use utils.make_gnn_seq (handles processor_layers=0) - HiGraph{Encoder,Decoder} guard against single-level meshes where the latent variable would be silently ignored - ConstantLatentEncoder docstring documents the N(1,1) vs N(0,1) discrepancy with the prob_model_lam CLI help (open question upstream) Also adds to neural_lam/utils.py: - IdentityModule: pass-through nn.Module for multi-arg sequential GNNs - make_gnn_seq: builds a pyg.nn.Sequential of InteractionNets, or an IdentityModule when num_gnn_layers=0; lazy-imports gnn_layers to avoid the existing gnn_layers -> utils circular dependency 17 tests in tests/test_latent_modules.py cover output shapes, distribution properties, backprop to every parameter, 2- and 3-level hierarchical graphs, intra_level_layers=0, and the single-level guard.
Co-authored-by: Joel Oskarsson <joel.oskarsson@outlook.com>
Make GNN types configurable and tidy up the latent modules per PR review: - make_gnn_seq: accept a gnn_type arg (resolved via get_gnn_class) so it is not limited to InteractionNet, and make it strict (raise on num_gnn_layers < 1) instead of silently returning an IdentityModule; callers now own the no-op (identity) case explicitly. - graph/hi encoders and decoders: expose g2m/m2g/mesh_up/mesh_down gnn_type parameters wired to get_gnn_class, with defaults matching prob_model_lam. - graph encoder/decoder: rename processor_layers -> m2m_layers (and the self.processor attribute -> self.m2m_gnns); "processor" was misleading in an encoder/decoder context. - ConstantLatentEncoder: return zeros instead of ones so the static prior is mean 0 (fixes the prob_model_lam mean-1 bug; matches its own CLI help). - tests: update for the renamed arg and strict make_gnn_seq, add coverage for the flat zero-m2m identity path, and assert the constant prior is N(0, 1).
Port prob_model_lam's GraphEFM single-step half onto the StepPredictor interface, reusing the latent encoder/decoder infra. The predictor owns its conditional prior, variational encoder, and latent decoder, plus the per-step ELBO pieces (compute_step_loss) and sampling helpers; rollout, ELBO assembly, ensemble logic, and logging stay outside it. - forward is source's predict_step (prior rsample -> decode -> sampled next state); no rescaling/clamping - loss_fn and interior_mask are threaded parameters, not predictor state; compute_step_loss takes compute_kl (kl_term=None when off) - per_var_std mirrors ForecasterModule's formula, hence the config arg - one class for flat + hierarchical meshes, resolved from self.hierarchical - not registered in MODELS yet (needs config / no mesh_aggr); config-aware assembly deferred to the ensemble-forecaster PR Adds tests/test_graph_efm_predictor.py covering forward shapes, output_std, compute_step_loss + KL toggle, differentiability, member stochasticity, sample_obs_noise, and the per_var_std formula (flat + hierarchical).
…s for the rest Per review discussion: the architecturally constrained edge sets in the hierarchical latent modules get fixed GNN types instead of parameters: - HiGraphLatentEncoder mesh-up: PropagationNet (must push grid info up into the latent readout) - HiGraphLatentDecoder mesh-up: InteractionNet (PropagationNet residual would bypass Z at the top level, leaving it unused at initialization) - HiGraphLatentDecoder mesh-down: PropagationNet (must push Z down the hierarchy to reach the grid output) All remaining choices (g2m/m2g) stay configurable and default to InteractionNet for consistency with the rest of the codebase. GraphEFM now accepts g2m_gnn_type/m2g_gnn_type and passes them through to the prior, encoder and decoder, ready for wiring to the existing argparse flags.
Upstream main added an interrogate pre-commit hook requiring 100% docstring coverage, which failed on this branch's CI after merging. - Remove the branch's pre-reorganization duplicates (forecaster.py, ar_forecaster.py, step_predictor.py, forecaster_module.py); main carries the same code under models/forecasters/, models/ step_predictors/ and models/module.py, and all imports already go through the new layout. - Add the missing module and __init__ docstrings (numpy style) in the latent modules, GraphEFM and utils.IdentityModule.forward.
Remove references to the original prob_model_lam implementation and other work meta-information from docstrings and comments, per review. Docstrings now describe what each class/function does; usage context is left to call sites.
Add proper Parameters/Returns sections following the numpydoc convention, per review.
…dentityModule When m2m_layers / intra_level_layers is 0, the latent modules now set the corresponding GNN attribute to None and skip the update in the forward pass, instead of routing representations through a no-op IdentityModule. This makes it clear from the forward code that no processing happens in that case. IdentityModule is removed from utils. The hierarchical up/down loops index levels explicitly to accommodate the conditional; outputs are unchanged (verified bit-identical against the previous implementation).
…base class Use the base class summary and expand on it with the constant-specific behavior, per review.
Co-authored-by: Joel Oskarsson <joel.oskarsson@outlook.com>
Describe the role of each representation in the message passing (sender/receiver, where the latent enters, purpose of the residual grid rep) in BaseGraphLatentDecoder and the inheriting decoders, per review.
Replace the single GraphEFM class that resolved flat vs hierarchical at construction with an explicit class per graph type, per review: - BaseGraphEFM: graph-type independent setup (graph loading, grid and grid-mesh edge embedders, per-variable std) and all shared behavior (forward, compute_step_loss, estimate_likelihood, sampling helpers). Validates the loaded graph against the subclass's requires_hierarchical and exposes an embedd_mesh hook used by embedd_all. - GraphEFM: hierarchical mesh graphs; builds per-level mesh embedders and HiGraphLatentEncoder/HiGraphLatentDecoder modules. - GraphEFMMS: flat (e.g. multi-scale) mesh graphs; builds the flat mesh embedders and GraphLatentEncoder/GraphLatentDecoder modules. learn_prior remains a constructor flag on both subclasses. Tests select the class per graph type and cover the graph-type mismatch error.
Rename the layer-count parameters to say what the layers are, matching the latent module parameter names: prior/encoder/decoder_intra_level_ layers on GraphEFM (hierarchical) and prior/encoder/decoder_m2m_layers on GraphEFMMS (flat), per review.
Make the GraphEFM and GraphEFMMS __init__ docstrings self-contained with the full parameter list, instead of pointing at the base class for the shared ones, per review.
Sampling uncorrelated Gaussian observation noise per grid node is not useful in practice, so the option is removed entirely rather than left to tempt users, per review. forward now returns the decoder mean directly (the prediction is stochastic only through the latent sample) and the trivial sample_next_state helper is dropped.
BaseGraphModel and BaseGraphEFM duplicated their graph loading, buffer registration and grid-input-dim computation. Factor these into two utils helpers used by both: - utils.load_and_register_graph(module, datastore, graph_name): loads the graph and registers its tensors/BufferLists on the module, returning whether it is hierarchical. - utils.grid_input_dim(datastore, grid_static_dim, num_past_forcing_ steps, num_future_forcing_steps): the total grid input dimensionality. This keeps the two model families' grid-feature setup in one place (e.g. for a future boundary-forcing input) without coupling their differing forward passes or submodule sets via inheritance.
Rename embedd_all -> embedd_grid_and_graph (embeds the grid for states up to t-1 plus the full graph) and embedd_current -> embedd_grid_with_target (embeds the grid including the target state, for the encoder), per review.
In forward, prev_state is already X_t, so pass it directly to the decoder instead of aliasing it to last_state, per review.
Co-authored-by: Joel Oskarsson <joel.oskarsson@outlook.com>
d0ef03a to
99d86c8
Compare
Move the constant-prior branch (identical across subclasses) into BaseGraphEFM.build_prior and delegate the graph-specific learnable prior to build_learnable_prior, which GraphEFM and GraphEFMMS implement.
Spell out the multi-scale variant's name instead of the EFMMS abbreviation, keeping the GraphEFM prefix shared with the hierarchical variant.
…_dim The static feature count is already available from the datastore via get_num_data_vars(category="static"), so drop the redundant grid_static_dim argument and query it inside the function.
get_num_data_vars("static") can report a nonzero count even when the
datastore provides no static dataarray, in which case the grid static
buffer is empty. Mirror the buffer construction by treating a None static
dataarray as zero static features, fixing the no-static-features case.
joeloskarsson
left a comment
There was a problem hiding this comment.
Nice work on this! Only one small thing with the build_prior setup, and a reminders about the two unresolved threads from above.
The thing left now is just the loss/likelihood part. But I think that will have to wait until we have started to settle #685. I will look at that next.
| # inert -- accepted for interface parity with other StepPredictors. | ||
| self.prepare_clamping_params(datastore) | ||
|
|
||
| def build_prior( |
There was a problem hiding this comment.
Do we actually need this function? It seems to be called identically by both subclasses, so could we not just do the prior setup immediately in the base class constructor, calling self.build_learnable_prior just as here, which will be the part handled by the subclasses.
| ): | ||
| super().__init__(latent_dim, output_dist) | ||
|
|
||
| self.g2m_gnn = PropagationNet( |
There was a problem hiding this comment.
Don't forget this thread, still think we should add this to train_model.py
Describe your changes
Adds
neural_lam/models/latent/the encoder and decoder submodules that the probabilistic Graph-EFM model needs. This is infrastructure-only: no model uses these classes yet. They are consumed by the upcomingGraphEFMPredictor(StepPredictorsubclass) which will close #62.New modules in
neural_lam/models/latent/:base_encoder.pyBaseLatentEncoder: abstract base; handles isotropic / diagonal Gaussian outputbase_decoder.pyBaseGraphLatentDecoder: abstract base; residual grid MLP + latent embedder + param mapconstant_encoder.pyConstantLatentEncoder: input-independent prior (used whenlearn_prior=False)graph_encoder.pyGraphLatentEncoder: flat graph: grid → mesh via PropagationNet + InteractionNet stackgraph_decoder.pyGraphLatentDecoder: flat graph: grid + latent → grid via g2m / processor / m2ghi_graph_encoder.pyHiGraphLatentEncoder: hierarchical mesh: propagates up to top level, reads out latent disthi_graph_decoder.pyHiGraphLatentDecoder: hierarchical mesh: up + latent fusion + down pass back to gridAdaptations from
prob_model_lamfor the currentmainarchitecture:constants.GRID_STATE_DIM(removed) →num_state_varsconstructor arg on all decodersfrom neural_lam.interaction_net import ...→from neural_lam.gnn_layers import ...GraphLatentDecoder.processorunified with the other four GNN-seq constructions to useutils.make_gnn_seq, which also handlesprocessor_layers=0gracefullyHiGraph{Encoder,Decoder}raiseValueErrorfor single-level meshes (where the latent would be silently ignored); points users to the flat variantsAlso adds to
neural_lam/utils.py:IdentityModulepass-throughnn.Modulefor multi-argpyg.nn.Sequentialpipelinesmake_gnn_seqbuilds apyg.nn.SequentialofInteractionNetlayers, orIdentityModulewhennum_gnn_layers=0; lazy-importsgnn_layersto avoid the existinggnn_layers → utilscircular dependencyOpen question flagged in
constant_encoder.pydocstring: the static prior returnsNormal(mean=1, std=1)faithful toprob_model_lambut the--learn_priorCLI help on that branch describes it as "mean 0". One of the two is wrong; will raise it separately with @joeloskarsson.Dependencies:
PropagationNetandInteractionNetare already onmain(AddsPropagationNetGNN layer and makes it optionally usable in existing deterministic models #507 merged).Issue Link
Partially addresses #62 (prerequisite for the
GraphEFMPredictorPR).Type of change
Checklist before requesting a review
Checklist for reviewers
Author checklist after completed review
Checklist for assignee