Skip to content

Commit

Permalink
feat(graphs,plots): expand support for multi-dimensional node attribu…
Browse files Browse the repository at this point in the history
…tes (#48)

* feat: expand plot_nodes to multi-dimensional attributes

* udpate changelog
  • Loading branch information
JPXKQX authored Jan 15, 2025
1 parent 2ebcc7d commit 17af5a0
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 19 deletions.
4 changes: 4 additions & 0 deletions graphs/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ Keep it human-readable, your future self will thank you!

## [Unreleased](https://github.com/ecmwf/anemoi-graphs/compare/0.4.2...HEAD)

### Added

- feat: Support for multi-dimensional node attributes in plots (#48)

## [0.4.2 - Optimisations and lat-lon](https://github.com/ecmwf/anemoi-graphs/compare/0.4.1...0.4.2) - 2024-12-19

### Added
Expand Down
40 changes: 21 additions & 19 deletions graphs/src/anemoi/graphs/plotting/interactive_html.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
import matplotlib.pyplot as plt
import numpy as np
import plotly.graph_objects as go
import torch
from matplotlib.colors import rgb2hex
from torch_geometric.data import HeteroData

Expand Down Expand Up @@ -197,25 +198,26 @@ def plot_interactive_nodes(graph: HeteroData, nodes_name: str, out_file: Optiona
for node_attr in node_attrs:
node_attr_values = graph[nodes_name][node_attr].float().numpy()

# Skip multi-dimensional attributes. Supported only: (N, 1) or (N,) tensors
if node_attr_values.ndim > 1 and node_attr_values.shape[1] > 1:
continue

node_traces[node_attr] = go.Scattergeo(
lat=node_latitudes,
lon=node_longitudes,
name=" ".join(node_attr.split("_")).capitalize(),
mode="markers",
hoverinfo="text",
marker={
"color": node_attr_values.squeeze().tolist(),
"showscale": True,
"colorscale": "RdBu",
"colorbar": {"thickness": 15, "title": node_attr, "xanchor": "left"},
"size": 5,
},
visible=False,
)
if node_attr_values.ndim == 1:
node_attr_values = torch.unsqueeze(node_attr_values, -1)

for attr_dim in range(node_attr_values.shape[1]):
suffix = "" if node_attr_values.shape[1] == 1 else f"_[{attr_dim}]"
node_traces[node_attr + suffix] = go.Scattergeo(
lat=node_latitudes,
lon=node_longitudes,
name=" ".join((node_attr + suffix).split("_")).capitalize(),
mode="markers",
hoverinfo="text",
marker={
"color": node_attr_values[:, attr_dim].squeeze().tolist(),
"showscale": True,
"colorscale": "RdBu",
"colorbar": {"thickness": 15, "title": node_attr + suffix, "xanchor": "left"},
"size": 5,
},
visible=False,
)

# Create and add slider
slider_steps = []
Expand Down

0 comments on commit 17af5a0

Please sign in to comment.