Skip to content

Commit

Permalink
added channel mixer
Browse files Browse the repository at this point in the history
  • Loading branch information
ssmmnn11 committed Jan 29, 2025
1 parent ffbe713 commit e5c7b15
Show file tree
Hide file tree
Showing 4 changed files with 693 additions and 32 deletions.
229 changes: 226 additions & 3 deletions models/src/anemoi/models/layers/block.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,9 +472,6 @@ def __init__(
**kwargs,
)

self.lin_x0_skip = nn.Linear(in_channels, out_channels) # to match x_skip.shape[1] shape with out shape
self.lin_x1_skip = nn.Linear(in_channels, out_channels) # to match x_skip.shape[1] shape with out shape

self.layer_norm2 = nn.LayerNorm(in_channels)

def forward(
Expand Down Expand Up @@ -645,3 +642,229 @@ def forward(
nodes_new = self.node_dst_mlp(out) + out

return nodes_new, edge_attr


class GraphTransformerBaseBlockAttention(BaseBlock, ABC):
"""Message passing block with MLPs for node embeddings."""

def __init__(
self,
in_channels: int,
out_channels: int,
edge_dim: int,
num_heads: int = 16,
bias: bool = True,
num_chunks: int = 1,
**kwargs,
) -> None:
"""Initialize GraphTransformerBlock.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
edge_dim : int,
Edge dimension
num_heads : int,
Number of heads
bias : bool, by default True,
Add bias or not
"""
super().__init__(**kwargs)

self.out_channels_conv = out_channels // num_heads
self.num_heads = num_heads

self.num_chunks = num_chunks

self.lin_key = nn.Linear(in_channels, num_heads * self.out_channels_conv)
self.lin_query = nn.Linear(in_channels, num_heads * self.out_channels_conv)
self.lin_value = nn.Linear(in_channels, num_heads * self.out_channels_conv)
self.lin_self = nn.Linear(in_channels, num_heads * self.out_channels_conv, bias=bias)
self.lin_edge = nn.Linear(edge_dim, num_heads * self.out_channels_conv) # , bias=False)

self.conv = GraphTransformerConv(out_channels=self.out_channels_conv)

self.projection = nn.Linear(out_channels, out_channels)

self.layer_norm1 = nn.LayerNorm(in_channels)

def shard_qkve_heads(
self,
query: Tensor,
key: Tensor,
value: Tensor,
edges: Tensor,
shapes: tuple,
batch_size: int,
model_comm_group: Optional[ProcessGroup] = None,
) -> tuple[Tensor, Tensor, Tensor, Tensor]:
"""Shards qkv and edges along head dimension."""
shape_src_nodes, shape_dst_nodes, shape_edges = shapes

query, key, value, edges = (
einops.rearrange(
t,
"(batch grid) (heads vars) -> batch heads grid vars",
heads=self.num_heads,
vars=self.out_channels_conv,
batch=batch_size,
)
for t in (query, key, value, edges)
)
query = shard_heads(query, shapes=shape_dst_nodes, mgroup=model_comm_group)
key = shard_heads(key, shapes=shape_src_nodes, mgroup=model_comm_group)
value = shard_heads(value, shapes=shape_src_nodes, mgroup=model_comm_group)
edges = shard_heads(edges, shapes=shape_edges, mgroup=model_comm_group)

query, key, value, edges = (
einops.rearrange(t, "batch heads grid vars -> (batch grid) heads vars") for t in (query, key, value, edges)
)

return query, key, value, edges

def shard_output_seq(
self,
out: Tensor,
shapes: tuple,
batch_size: int,
model_comm_group: Optional[ProcessGroup] = None,
) -> Tensor:
"""Shards Tensor sequence dimension."""
shape_dst_nodes = shapes[1]

out = einops.rearrange(out, "(batch grid) heads vars -> (batch grid) (heads vars)", batch=batch_size)
out = shard_tensor(out, dim=0, shapes=shape_dst_nodes, mgroup=model_comm_group)

return out

@abstractmethod
def forward(
self,
x: OptPairTensor,
edge_attr: Tensor,
edge_index: Adj,
shapes: tuple,
batch_size: int,
model_comm_group: Optional[ProcessGroup] = None,
size: Optional[Size] = None,
): ...


class GraphTransformerMapperBlockAttention(GraphTransformerBaseBlockAttention):
"""Graph Transformer Block for node embeddings."""

def __init__(
self,
in_channels: int,
hidden_dim: int, # not used
out_channels: int,
edge_dim: int,
num_heads: int = 16,
bias: bool = True,
activation: str = "GELU", # not used
num_chunks: int = 1,
update_src_nodes: bool = False, # not used
**kwargs,
) -> None:
"""Initialize GraphTransformerBlock.
Parameters
----------
in_channels : int
Number of input channels.
out_channels : int
Number of output channels.
edge_dim : int,
Edge dimension
num_heads : int,
Number of heads
bias : bool, by default True,
Add bias or not
activation : str, optional
Activation function, by default "GELU"
update_src_nodes: bool, by default False
Update src if src and dst nodes are given
"""
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
edge_dim=edge_dim,
num_heads=num_heads,
bias=bias,
num_chunks=num_chunks,
**kwargs,
)

self.layer_norm2 = nn.LayerNorm(in_channels)

def forward(
self,
x: OptPairTensor,
edge_attr: Tensor,
edge_index: Adj,
shapes: tuple,
batch_size: int,
model_comm_group: Optional[ProcessGroup] = None,
size: Optional[Size] = None,
):
x_skip = x # todo: check if this is correct

x = (
self.layer_norm1(x[0]),
self.layer_norm2(x[1]),
)

x_r = self.lin_self(x[1])
query = self.lin_query(x[1])
key = self.lin_key(x[0])
value = self.lin_value(x[0])
edges = self.lin_edge(edge_attr)

# sync node sharded q k v
query = sync_tensor(query, 0, shapes[1], model_comm_group)
key = sync_tensor(key, 0, shapes[0], model_comm_group)
value = sync_tensor(value, 0, shapes[0], model_comm_group)

# expand head dimension
query = query.view(-1, self.num_heads, self.out_channels_conv)
key = key.view(-1, self.num_heads, self.out_channels_conv)
value = value.view(-1, self.num_heads, self.out_channels_conv)
edges = edges.view(-1, self.num_heads, self.out_channels_conv)

if model_comm_group is not None:
assert (
model_comm_group.size() == 1 or batch_size == 1
), "Only batch size of 1 is supported when model is sharded across GPUs"

num_chunks = self.num_chunks if self.training else NUM_CHUNKS_INFERENCE

if num_chunks > 1:
# split 1-hop edges into chunks, compute self.conv chunk-wise and aggregate
edge_attr_list, edge_index_list = sort_edges_1hop_chunks(
num_nodes=size, edge_attr=edges, edge_index=edge_index, num_chunks=num_chunks
)
out = torch.zeros((query.shape[0], self.num_heads, self.out_channels_conv), device=query.device)
for i in range(num_chunks):
out += self.conv(
query=query,
key=key,
value=value,
edge_attr=edge_attr_list[i],
edge_index=edge_index_list[i],
size=size,
)
else:
out = self.conv(query=query, key=key, value=value, edge_attr=edges, edge_index=edge_index, size=size)

# go back to original shape and shard nodes again
out = self.shard_output_seq(out, shapes, batch_size, model_comm_group)

# compute out = self.projection(out + x_r) in chunks:
out = torch.cat([self.projection(chunk) for chunk in torch.tensor_split(out + x_r, num_chunks, dim=0)], dim=0)

out = out + x_skip[1]

return (x_skip[0], out), edge_attr
Loading

0 comments on commit e5c7b15

Please sign in to comment.