Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Right way to created Masked CNN #895

Open
ashutosh-dwivedi-e3502 opened this issue Nov 14, 2024 · 1 comment
Open

Right way to created Masked CNN #895

ashutosh-dwivedi-e3502 opened this issue Nov 14, 2024 · 1 comment
Labels
question User queries

Comments

@ashutosh-dwivedi-e3502
Copy link

I am creating masked convolution in equinox like below :

import equinox as eqx
import jax
import jax.numpy as jnp

from typing import Optional
from jaxtyping import Array, Float

class MaskedConv(eqx.Module):
    """A masked convolution module using Equinox"""
    
    in_channels: int
    out_channels: int
    mask: Float[Array, "kernel_h kernel_w"]  
    dilation: int = 1 
    conv: eqx.nn.Conv2d
    key: Optional[jax.random.PRNGKey] = None

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        mask: Float[Array, "kernel_h kernel_w"],
        dilation: int = 1,
        key: Optional[jax.random.PRNGKey] = None
    ) -> None:
        # Set the output channels, input channels, and dilation
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.dilation = dilation
        self.key = key

        # Ensure mask is a JAX array had has right dims
        self.mask = jnp.array(mask)
        assert self.mask.ndim == 2, "Mask must be a 2D array."

        # Initialize the convolution layer
        kernel_height, kernel_width = self.mask.shape

        pad = kernel_height // 2
        padding=((pad, pad), (pad, pad))

        self.conv = eqx.nn.Conv2d(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            kernel_size=(kernel_height, kernel_width),
            stride=1,
            padding=padding,
            use_bias=True,
            dilation=self.dilation,
            key=self.key
        )
        self.mask = self.mask.reshape(1, 1, kernel_height, kernel_width)

    def __call__(self, x: Float[Array, "batch h w c_in"]) -> Float[Array, "batch h w c_out"]:
        masked_weights = self.conv.weight * self.mask
        masked_conv = eqx.tree_at(
            where=lambda conv: conv.weight,
            pytree=self.conv,
            replace=masked_weights
        )
        return masked_conv(x)


class VerticalStackConvolution(eqx.Module):
    in_channels: int
    out_channels: int
    kernel_size: int
    mask_center: bool = False
    dilation: int = 1
    conv: MaskedConv

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        mask_center: bool = False,
        dilation: int = 1,
        key: Optional[jax.random.PRNGKey] = None
    ):
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.mask_center = mask_center
        self.dilation = dilation

        # Create the mask
        self.kernel_size = kernel_size
        mask = jnp.ones((self.kernel_size, self.kernel_size), dtype=jnp.float32)
        # Mask out all pixels below the center row
        mask.at[self.kernel_size // 2 + 1:, :].set(0)
        # Optionally mask out the center row
        if self.mask_center:
            mask.at[self.kernel_size // 2, :].set(0)

        # Initialize the MaskedConv module
        self.conv = MaskedConv(
            in_channels=self.in_channels,
            out_channels=self.out_channels,
            mask=mask,
            dilation=self.dilation,
            key=key
        )

    def __call__(self, x: Float[Array, "batch h w c_in"]) -> Float[Array, "batch h w c_out"]:
        return self.conv(x)

But when I compute the gradient for the center pixel using the VerticalStackConvolution I see the mask is not applied. What am I doing wrong here?

@patrick-kidger
Copy link
Owner

You want foo = foo.at[...].set(...). See the JAX sharp bits around in-place operations.

@patrick-kidger patrick-kidger added the question User queries label Nov 14, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question User queries
Projects
None yet
Development

No branches or pull requests

2 participants