Fix graph LAM GNN kwarg handling#688
Conversation
Co-authored-by: Hurricane <hurricane@hermes.local>
There was a problem hiding this comment.
Appreciate the fast fix and the regression tests for checkpoint loading.
cc: @joeloskarsson
joeloskarsson
left a comment
There was a problem hiding this comment.
Thanks, a few comments.
| "neural_lam.train_model.torch.load", | ||
| return_value={"hyper_parameters": {"args": args}}, | ||
| ), | ||
| patch("neural_lam.train_model.MODELS", {"hi_lam": DummyPredictor}), |
There was a problem hiding this comment.
Can we test this also for graph_lam, as that would have caught the original crash? (now that load_forecaster_module_from_checkpoint actually uses all args).
| output_clamping_upper: dict[str, float] | None = None, | ||
| g2m_gnn_type: str = "InteractionNet", | ||
| m2g_gnn_type: str = "InteractionNet", | ||
| **_kwargs: object, |
There was a problem hiding this comment.
To be consistent with rest of codebase
| **_kwargs: object, | |
| **kwargs: object, |
joeloskarsson
left a comment
There was a problem hiding this comment.
Oops, hit the wrong box.
sadamov
left a comment
There was a problem hiding this comment.
I would avoid **kwargs on the constructors here. It swallows typos in programmatic use, removes the very error that just caught this, and contradicts requirement #1 of #672 (explicit APIs). It also doesn't fix the loader.
Instead one small build_predictor helper that adds the hierarchical kwargs only for hierarchical models. Signatures stay explicit, both call sites de-duplicated, both bugs fixed. Good enough until #672 lands, without running the risk of having lingering **kwargs in the codebase. What do you think @joeloskarsson @gitcommit90?
Make the kwargs explicit like this:
def build_predictor(predictor_class, args, config, datastore):
kwargs = dict(datastore=datastore, graph_name=args.graph, ...,
g2m_gnn_type=args.g2m_gnn_type, m2g_gnn_type=args.m2g_gnn_type)
if issubclass(predictor_class, BaseHiGraphModel):
kwargs["mesh_up_gnn_type"] = args.mesh_up_gnn_type
kwargs["mesh_down_gnn_type"] = args.mesh_down_gnn_type
return predictor_class(**kwargs)
|
I am also noting that this was a clear gap in our test-suite that we should definitely fill with this PR. |
|
To me the specialized instantiation for different model classes doesn't feel so maintainable, but I can also see the potential issues with the kwargs approach. Overall the way we instantiate models in For a fix right now I don't have strong opinions, so we can go with something like @sadamov's suggestion. |
- Route training and checkpoint reload through build_predictor - Omit hierarchical mesh GNN kwargs for graph_lam - Remove GraphLAM **_kwargs swallow - Add regression test for graph_lam kwargs
Describe your changes
Fix
graph_lamstartup by allowingGraphLAMto ignore hierarchical-only GNN kwargs.Restore saved GNN type kwargs when loading checkpoints, with defaults for older checkpoints.
No new dependencies.
Issue Link
closes #686
Type of change
Checklist before requesting a review
pullwith--rebaseoption if possible).Checklist for reviewers
Each PR comes with its own improvements and flaws. The reviewer should check the following:
Author checklist after completed review
reflecting type of change (add section where missing):
Checklist for assignee