diff --git a/monai/networks/nets/__init__.py b/monai/networks/nets/__init__.py index c1917e5293..0d02f95efa 100644 --- a/monai/networks/nets/__init__.py +++ b/monai/networks/nets/__init__.py @@ -144,3 +144,4 @@ from .vnet import VNet from .voxelmorph import VoxelMorph, VoxelMorphUNet from .vqvae import VQVAE +from .u_mamba import UMambaUNet diff --git a/monai/networks/nets/u_mamba.py b/monai/networks/nets/u_mamba.py new file mode 100644 index 0000000000..4697ad4a66 --- /dev/null +++ b/monai/networks/nets/u_mamba.py @@ -0,0 +1,110 @@ +# Copyright (c) MONAI Consortium +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# http://www.apache.org/licenses/LICENSE-2.0 +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F + +# Simple placeholder for the SSM (Mamba-like block) +class SSMBlock(nn.Module): + def __init__(self, dim): + super().__init__() + self.linear1 = nn.Linear(dim, dim) + self.linear2 = nn.Linear(dim, dim) + + def forward(self, x): + # x: (B, L, C) + return self.linear2(torch.silu(self.linear1(x))) + +class UMambaBlock(nn.Module): + def __init__(self, in_channels, hidden_channels): + super().__init__() + self.conv_res1 = nn.Sequential( + nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1), + nn.InstanceNorm3d(in_channels), + nn.LeakyReLU(), + ) + self.conv_res2 = nn.Sequential( + nn.Conv3d(in_channels, in_channels, kernel_size=3, padding=1), + nn.InstanceNorm3d(in_channels), + nn.LeakyReLU(), + ) + + self.layernorm = nn.LayerNorm(hidden_channels) + self.linear1 = nn.Linear(in_channels, hidden_channels) + self.linear2 = nn.Linear(hidden_channels, in_channels) + self.conv1d = nn.Conv1d(hidden_channels, hidden_channels, kernel_size=3, padding=1) + self.ssm = SSMBlock(hidden_channels) + + def forward(self, x): + # x: (B, C, H, W, D) + residual = x + x = self.conv_res1(x) + x = self.conv_res2(x) + residual + + B, C, H, W, D = x.shape + x_flat = x.view(B, C, -1).permute(0, 2, 1) # (B, L, C) + x_norm = self.layernorm(x_flat) + x_proj = self.linear1(x_norm) + + x_silu = torch.silu(x_proj) + x_ssm = self.ssm(x_silu) + x_conv1d = self.conv1d(x_proj.permute(0, 2, 1)).permute(0, 2, 1) + + x_combined = torch.silu(x_conv1d) * torch.silu(x_ssm) + x_out = self.linear2(x_combined) + x_out = x_out.permute(0, 2, 1).view(B, C, H, W, D) + + return x + x_out # Residual connection + +class ResidualBlock(nn.Module): + def __init__(self, channels): + super().__init__() + self.block = nn.Sequential( + nn.Conv3d(channels, channels, kernel_size=3, padding=1), + nn.BatchNorm3d(channels), + nn.ReLU(), + nn.Conv3d(channels, channels, kernel_size=3, padding=1), + nn.BatchNorm3d(channels), + ) + + def forward(self, x): + return F.relu(x + self.block(x)) + +class UMambaUNet(nn.Module): + def __init__(self, in_channels=1, out_channels=1, base_channels=32): + super().__init__() + self.enc1 = UMambaBlock(in_channels, base_channels) + self.down1 = nn.Conv3d(base_channels, base_channels*2, kernel_size=3, stride=2, padding=1) + + self.enc2 = UMambaBlock(base_channels*2, base_channels*2) + self.down2 = nn.Conv3d(base_channels*2, base_channels*4, kernel_size=3, stride=2, padding=1) + + self.bottleneck = UMambaBlock(base_channels*4, base_channels*4) + + self.up2 = nn.ConvTranspose3d(base_channels*4, base_channels*2, kernel_size=2, stride=2) + self.dec2 = ResidualBlock(base_channels*4) + + self.up1 = nn.ConvTranspose3d(base_channels*2, base_channels, kernel_size=2, stride=2) + self.dec1 = ResidualBlock(base_channels*2) + + self.final = nn.Conv3d(base_channels, out_channels, kernel_size=1) + + def forward(self, x): + x1 = self.enc1(x) + x2 = self.enc2(self.down1(x1)) + x3 = self.bottleneck(self.down2(x2)) + + x = self.up2(x3) + x = self.dec2(torch.cat([x, x2], dim=1)) + x = self.up1(x) + x = self.dec1(torch.cat([x, x1], dim=1)) + return self.final(x) diff --git a/tests/test_networks_u_mamba.py b/tests/test_networks_u_mamba.py new file mode 100644 index 0000000000..02f75226da --- /dev/null +++ b/tests/test_networks_u_mamba.py @@ -0,0 +1,22 @@ +import unittest +import torch +from monai.networks.nets import UMambaUNet + +class TestUMamba(unittest.TestCase): + def test_forward_shape(self): + # Set up input dimensions and model + input_tensor = torch.randn(2, 1, 16, 64, 64) + model = UMambaUNet(in_channels=1, out_channels=2) + output = model(input_tensor) + self.assertEqual(output.shape, (2, 2, 16, 64, 64)) + + def test_script(self): + # Test JIT scripting if supported + model = UMambaUNet(in_channels=1, out_channels=2) + scripted = torch.jit.script(model) + x = torch.randn(1, 1, 64, 64) + out = scripted(x) + self.assertEqual(out.shape, (1, 2, 64, 64)) + +if __name__ == "__main__": + unittest.main()