Skip to content

Fix graph LAM GNN kwarg handling#688

Open
gitcommit90 wants to merge 2 commits into
mllam:mainfrom
gitcommit90:fix/issue-686-gnn-kwargs
Open

Fix graph LAM GNN kwarg handling#688
gitcommit90 wants to merge 2 commits into
mllam:mainfrom
gitcommit90:fix/issue-686-gnn-kwargs

Conversation

@gitcommit90

Copy link
Copy Markdown

Describe your changes

Fix graph_lam startup by allowing GraphLAM to ignore hierarchical-only GNN kwargs.

Restore saved GNN type kwargs when loading checkpoints, with defaults for older checkpoints.

No new dependencies.

Issue Link

closes #686

Type of change

  • 🐛 Bug fix (non-breaking change that fixes an issue)
  • ✨ New feature (non-breaking change that adds functionality)
  • 💥 Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • 📖 Documentation (Addition or improvements to documentation)

Checklist before requesting a review

  • My branch is up-to-date with the target branch - if not update your fork with the changes from the target branch (use pull with --rebase option if possible).
  • I have performed a self-review of my code
  • For any new/modified functions/classes I have added docstrings that clearly describe its purpose, expected inputs and returned values
  • I have placed in-line comments to clarify the intent of any hard-to-understand passages of my code
  • I have updated the README to cover introduced code changes
  • I have added tests that prove my fix is effective or that my feature works
  • I have given the PR a name that clearly describes the change, written in imperative form (context).
  • I have requested a reviewer and an assignee (assignee is responsible for merging). This applies only if you have write access to the repo, otherwise feel free to tag a maintainer to add a reviewer and assignee.

Checklist for reviewers

Each PR comes with its own improvements and flaws. The reviewer should check the following:

  • the code is readable
  • the code is well tested
  • the code is documented (including return types and parameters)
  • the code is easy to maintain

Author checklist after completed review

  • I have added a line to the CHANGELOG describing this change, in a section
    reflecting type of change (add section where missing):
    • added: when you have added new functionality
    • changed: when default behaviour of the code has been changed
    • fixes: when your contribution fixes a bug
    • maintenance: when your contribution is relates to repo maintenance, e.g. CI/CD or documentation

Checklist for assignee

  • PR is up to date with the base branch
  • the tests pass
  • (if the PR is not just maintenance/bugfix) the PR is assigned to the next milestone. If it is not, propose it for a future milestone.
  • author has added an entry to the changelog (and designated the change as added, changed, fixed or maintenance)

Co-authored-by: Hurricane <hurricane@hermes.local>

@GiGiKoneti GiGiKoneti left a comment

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Appreciate the fast fix and the regression tests for checkpoint loading.
cc: @joeloskarsson

@sadamov sadamov requested a review from joeloskarsson June 29, 2026 08:03

@joeloskarsson joeloskarsson left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, a few comments.

"neural_lam.train_model.torch.load",
return_value={"hyper_parameters": {"args": args}},
),
patch("neural_lam.train_model.MODELS", {"hi_lam": DummyPredictor}),

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we test this also for graph_lam, as that would have caught the original crash? (now that load_forecaster_module_from_checkpoint actually uses all args).

output_clamping_upper: dict[str, float] | None = None,
g2m_gnn_type: str = "InteractionNet",
m2g_gnn_type: str = "InteractionNet",
**_kwargs: object,

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To be consistent with rest of codebase

Suggested change
**_kwargs: object,
**kwargs: object,

Comment thread neural_lam/models/step_predictors/graph/graph_lam.py Outdated

@joeloskarsson joeloskarsson left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops, hit the wrong box.

@sadamov sadamov self-requested a review June 30, 2026 08:23

@sadamov sadamov left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would avoid **kwargs on the constructors here. It swallows typos in programmatic use, removes the very error that just caught this, and contradicts requirement #1 of #672 (explicit APIs). It also doesn't fix the loader.

Instead one small build_predictor helper that adds the hierarchical kwargs only for hierarchical models. Signatures stay explicit, both call sites de-duplicated, both bugs fixed. Good enough until #672 lands, without running the risk of having lingering **kwargs in the codebase. What do you think @joeloskarsson @gitcommit90?

Make the kwargs explicit like this:

def build_predictor(predictor_class, args, config, datastore):
    kwargs = dict(datastore=datastore, graph_name=args.graph, ...,
                  g2m_gnn_type=args.g2m_gnn_type, m2g_gnn_type=args.m2g_gnn_type)
    if issubclass(predictor_class, BaseHiGraphModel):
        kwargs["mesh_up_gnn_type"] = args.mesh_up_gnn_type
        kwargs["mesh_down_gnn_type"] = args.mesh_down_gnn_type
    return predictor_class(**kwargs)

@sadamov

sadamov commented Jun 30, 2026

Copy link
Copy Markdown
Collaborator

I am also noting that this was a clear gap in our test-suite that we should definitely fill with this PR.

@joeloskarsson

Copy link
Copy Markdown
Collaborator

To me the specialized instantiation for different model classes doesn't feel so maintainable, but I can also see the potential issues with the kwargs approach. Overall the way we instantiate models in train_model.py is no good, but this should probably be solved with #672.

For a fix right now I don't have strong opinions, so we can go with something like @sadamov's suggestion.

- Route training and checkpoint reload through build_predictor
- Omit hierarchical mesh GNN kwargs for graph_lam
- Remove GraphLAM **_kwargs swallow
- Add regression test for graph_lam kwargs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] TypeError: GraphLAM.__init__() got an unexpected keyword argument 'mesh_up_gnn_type'

4 participants