Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding DPT to SMP #1073

Open
vedantdalimkar opened this issue Feb 23, 2025 · 6 comments
Open

Adding DPT to SMP #1073

vedantdalimkar opened this issue Feb 23, 2025 · 6 comments

Comments

@vedantdalimkar
Copy link
Contributor

Dense Prediction Transformer was introduced in the paper - Vision Transformers for Dense Prediction

It was also used in UniMatch V2: Pushing the Limit of Semi-Supervised Semantic Segmentation, which is the current SOTA in semi supervised semantic segmentation.

I think this would be a good addition to the library. I would like to contribute this architecture to the library, if its fine.

@qubvel
Copy link
Collaborator

qubvel commented Feb 24, 2025

Hey @vedantdalimkar sounds super cool 👍 I would appreciate the contribution!

The most simple way to start is to clone library, copy-paste /decoders/unet as /decoders/dpt and start drafting your code in decoder.py file. See similar recent PR like #941 and #906

@vedantdalimkar
Copy link
Contributor Author

vedantdalimkar commented Feb 25, 2025

Hey @qubvel. DPT uses VIT-base, VIT-large models as encoders. However, those encoders are not supported by SMP as of now. Would adding them be outside the scope of the library? If not, let me know. I can help with extending the library so that VIT based encoders are supported.

My 2 cents - adding VIT based encoders would be valuable for the community as it would allow the users to use SOTA pretraining method weights (like DINO and DINO v2) for segmentation

@qubvel
Copy link
Collaborator

qubvel commented Feb 25, 2025

Hey @vedantdalimkar, ViT encoders are supported via timm, so you can just use tu-vit..., for example, for checkpoint
timm/vit_base_patch16_224.augreg2_in21k_ft_in1k, the encoder name would be tu-vit_base_patch16_224.augreg2_in21k_ft_in1k.

ViT encoders are not listed because they are not compatible with models such as Unet, but you can still instantiate them and use them for newer models.

@vedantdalimkar
Copy link
Contributor Author

vedantdalimkar commented Feb 26, 2025

Hey @qubvel, I think ViT encoders can't be used with SMP as they don't follow the required downsampling pattern. I tried loading the following timm model - tu-vit_base_patch16_224.augreg2_in21k_ft_in1k using the segmentation_models_pytorch.encoders.get_encoder method and got an error.

The error traceback -

Cell In[3], [line 1](vscode-notebook-cell:?execution_count=3&line=1)
----> [1](vscode-notebook-cell:?execution_count=3&line=1) model = get_encoder(name = "tu-vit_base_patch16_224.augreg2_in21k_ft_in1k")

File c:\Open_source\smp\segmentation_models.pytorch\segmentation_models_pytorch\encoders\__init__.py:86, in get_encoder(name, in_channels, depth, weights, output_stride, **kwargs)
     [84](file:///C:/Open_source/smp/segmentation_models.pytorch/segmentation_models_pytorch/encoders/__init__.py:84) if name.startswith("tu-"):
     [85](file:///C:/Open_source/smp/segmentation_models.pytorch/segmentation_models_pytorch/encoders/__init__.py:85)     name = name[3:]
---> [86](file:///C:/Open_source/smp/segmentation_models.pytorch/segmentation_models_pytorch/encoders/__init__.py:86)     encoder = TimmUniversalEncoder(
     [87](file:///C:/Open_source/smp/segmentation_models.pytorch/segmentation_models_pytorch/encoders/__init__.py:87)         name=name,
     [88](file:///C:/Open_source/smp/segmentation_models.pytorch/segmentation_models_pytorch/encoders/__init__.py:88)         in_channels=in_channels,
     [89](file:///C:/Open_source/smp/segmentation_models.pytorch/segmentation_models_pytorch/encoders/__init__.py:89)         depth=depth,
     [90](file:///C:/Open_source/smp/segmentation_models.pytorch/segmentation_models_pytorch/encoders/__init__.py:90)         output_stride=output_stride,
     [91](file:///C:/Open_source/smp/segmentation_models.pytorch/segmentation_models_pytorch/encoders/__init__.py:91)         pretrained=weights is not None,
     [92](file:///C:/Open_source/smp/segmentation_models.pytorch/segmentation_models_pytorch/encoders/__init__.py:92)         **kwargs,
     [93](file:///C:/Open_source/smp/segmentation_models.pytorch/segmentation_models_pytorch/encoders/__init__.py:93)     )
     [94](file:///C:/Open_source/smp/segmentation_models.pytorch/segmentation_models_pytorch/encoders/__init__.py:94)     return encoder
     [96](file:///C:/Open_source/smp/segmentation_models.pytorch/segmentation_models_pytorch/encoders/__init__.py:96) if name not in encoders:

File c:\Open_source\smp\segmentation_models.pytorch\segmentation_models_pytorch\encoders\timm_universal.py:121, in TimmUniversalEncoder.__init__(self, name, pretrained, in_channels, depth, output_stride, **kwargs)
    [119](file:///C:/Open_source/smp/segmentation_models.pytorch/segmentation_models_pytorch/encoders/timm_universal.py:119)     self._is_vgg_style = True
    [120](file:///C:/Open_source/smp/segmentation_models.pytorch/segmentation_models_pytorch/encoders/timm_universal.py:120) else:
--> [121](file:///C:/Open_source/smp/segmentation_models.pytorch/segmentation_models_pytorch/encoders/timm_universal.py:121)     raise ValueError("Unsupported model downsampling pattern.")
    [123](file:///C:/Open_source/smp/segmentation_models.pytorch/segmentation_models_pytorch/encoders/timm_universal.py:123) if self._is_transformer_style:
    [124](file:///C:/Open_source/smp/segmentation_models.pytorch/segmentation_models_pytorch/encoders/timm_universal.py:124)     # Transformer-like models (start at scale 4)
    [125](file:///C:/Open_source/smp/segmentation_models.pytorch/segmentation_models_pytorch/encoders/timm_universal.py:125)     if "tresnet" in name:
    [126](file:///C:/Open_source/smp/segmentation_models.pytorch/segmentation_models_pytorch/encoders/timm_universal.py:126)         # 'tresnet' models start feature extraction at stage 1,
    [127](file:///C:/Open_source/smp/segmentation_models.pytorch/segmentation_models_pytorch/encoders/timm_universal.py:127)         # so out_indices=(1, 2, 3, 4) for depth=5.

ValueError: Unsupported model downsampling pattern.

So, if this is not an error from my side - How should I go about extending support for the ViT encoders? Should I make a new module timm_vit.py inside the encoders folder or should I modify the TimmUniversalEncoder class in smp.encoders.timm_universal such that loading ViT encoders is enabled?

@qubvel
Copy link
Collaborator

qubvel commented Feb 26, 2025

Ahh, you are right, my bad! Let's have another encoder then (e.g. TimmViTLikeEncoder), I would be safer and will not break other models.

@vedantdalimkar
Copy link
Contributor Author

Hey @qubvel, I have raised an initial PR for this issue. It would be great if you could please check it out and let me know if it requires any changes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

2 participants