diff --git a/mindone/transformers/__init__.py b/mindone/transformers/__init__.py index 2cda1be9ad..a4eb304032 100644 --- a/mindone/transformers/__init__.py +++ b/mindone/transformers/__init__.py @@ -483,6 +483,7 @@ from .models.emu3 import Emu3ForCausalLM, Emu3ForConditionalGeneration, Emu3PreTrainedModel, Emu3TextModel, Emu3VQVAE from .models.encodec import EncodecModel, EncodecPreTrainedModel from .models.encoder_decoder import EncoderDecoderModel +from .models.eomt import EomtForUniversalSegmentation, EomtImageProcessor, EomtImageProcessorFast, EomtPreTrainedModel from .models.ernie import ( ErnieForCausalLM, ErnieForMaskedLM, @@ -1331,6 +1332,7 @@ TapasPreTrainedModel, ) from .models.textnet import TextNetBackbone, TextNetForImageClassification, TextNetModel, TextNetPreTrainedModel +from .models.timesfm import TimesFmModel, TimesFmModelForPrediction, TimesFmPreTrainedModel from .models.timesformer import TimesformerForVideoClassification, TimesformerModel, TimesformerPreTrainedModel from .models.trocr import TrOCRForCausalLM, TrOCRPreTrainedModel from .models.tvp import TvpForVideoGrounding, TvpModel, TvpPreTrainedModel diff --git a/mindone/transformers/mindspore_adapter/utils.py b/mindone/transformers/mindspore_adapter/utils.py index 9ba2f68980..fbd2706a4b 100644 --- a/mindone/transformers/mindspore_adapter/utils.py +++ b/mindone/transformers/mindspore_adapter/utils.py @@ -29,7 +29,10 @@ "bool": ms.bool_, } - +_MIN_INT8 = ms.tensor(np.iinfo(np.int8).min, dtype=ms.int8) +_MIN_INT16 = ms.tensor(np.iinfo(np.int16).min, dtype=ms.int16) +_MIN_INT32 = ms.tensor(np.iinfo(np.int32).min, dtype=ms.int32) +_MIN_INT64 = ms.tensor(np.iinfo(np.int64).min, dtype=ms.int64) _MIN_FP16 = ms.tensor(np.finfo(np.float16).min, dtype=ms.float16) _MIN_FP32 = ms.tensor(np.finfo(np.float32).min, dtype=ms.float32) _MIN_FP64 = ms.tensor(np.finfo(np.float64).min, dtype=ms.float64) @@ -41,6 +44,10 @@ _DTYPE_2_MIN = { + ms.int8: _MIN_INT8, + ms.int16: _MIN_INT16, + ms.int32: _MIN_INT32, + ms.int64: _MIN_INT64, ms.float16: _MIN_FP16, ms.float32: _MIN_FP32, ms.float64: _MIN_FP64, diff --git a/mindone/transformers/models/__init__.py b/mindone/transformers/models/__init__.py index 28e7ced270..76d789fc15 100644 --- a/mindone/transformers/models/__init__.py +++ b/mindone/transformers/models/__init__.py @@ -72,6 +72,7 @@ emu3, encodec, encoder_decoder, + eomt, ernie, esm, falcon, @@ -224,6 +225,7 @@ table_transformer, tapas, textnet, + timesfm, timesformer, trocr, tvp, diff --git a/mindone/transformers/models/auto/configuration_auto.py b/mindone/transformers/models/auto/configuration_auto.py index b1d25744a3..4ccf288024 100644 --- a/mindone/transformers/models/auto/configuration_auto.py +++ b/mindone/transformers/models/auto/configuration_auto.py @@ -95,6 +95,7 @@ ("emu3", "Emu3Config"), ("encodec", "EncodecConfig"), ("encoder-decoder", "EncoderDecoderConfig"), + ("eomt", "EomtConfig"), ("esm", "EsmConfig"), ("falcon", "FalconConfig"), ("falcon_mamba", "FalconMambaConfig"), @@ -253,6 +254,7 @@ ("table-transformer", "TableTransformerConfig"), ("tapas", "TapasConfig"), ("textnet", "TextNetConfig"), + ("timesfm", "TimesFmConfig"), ("timesformer", "TimesformerConfig"), ("trocr", "TrOCRConfig"), ("tvp", "TvpConfig"), @@ -363,6 +365,7 @@ ("emu3", "Emu3"), ("encodec", "Encodec"), ("encoder-decoder", "Encoder decoder"), + ("eomt", "EoMT"), ("esm", "ESM"), ("falcon", "Falcon"), ("falcon_mamba", "FalconMamba"), @@ -525,6 +528,7 @@ ("table-transformer", "Table Transformer"), ("tapas", "TAPAS"), ("textnet", "TextNet"), + ("timesfm", "TimesFm"), ("timesformer", "TimeSformer"), ("trocr", "TrOCR"), ("tvp", "TVP"), diff --git a/mindone/transformers/models/auto/image_processing_auto.py b/mindone/transformers/models/auto/image_processing_auto.py index baea350d79..c1c3688572 100644 --- a/mindone/transformers/models/auto/image_processing_auto.py +++ b/mindone/transformers/models/auto/image_processing_auto.py @@ -63,6 +63,7 @@ ("dinov2", ("BitImageProcessor",)), ("dpt", ("DPTImageProcessor",)), ("efficientnet", ("EfficientNetImageProcessor",)), + ("eomt", ("EomtImageProcessor", "EomtImageProcessorFast")), ("flava", ("FlavaImageProcessor",)), ("llava_next", ("LlavaNextImageProcessor",)), ("llava_next_video", ("LlavaNextVideoImageProcessor",)), diff --git a/mindone/transformers/models/auto/modeling_auto.py b/mindone/transformers/models/auto/modeling_auto.py index c89e50bc47..f49f84ea06 100644 --- a/mindone/transformers/models/auto/modeling_auto.py +++ b/mindone/transformers/models/auto/modeling_auto.py @@ -230,6 +230,7 @@ ("table-transformer", "TableTransformerModel"), ("tapas", "TapasModel"), ("textnet", "TextNetModel"), + ("timesfm", "TimesFmModel"), ("timesformer", "TimesformerModel"), ("tvp", "TvpModel"), ("udop", "UdopModel"), @@ -653,6 +654,7 @@ [ # Model for Universal Segmentation mapping ("detr", "DetrForSegmentation"), + ("eomt", "EomtForUniversalSegmentation"), ("mask2former", "Mask2FormerForUniversalSegmentation"), ("maskformer", "MaskFormerForInstanceSegmentation"), ("oneformer", "OneFormerForUniversalSegmentation"), @@ -1268,6 +1270,12 @@ MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING_NAMES = OrderedDict() +MODEL_FOR_TIME_SERIES_PREDICTION_MAPPING_NAMES = OrderedDict( + [ + ("timesfm", "TimesFmModelForPrediction"), + ] +) + MODEL_FOR_IMAGE_TO_IMAGE_MAPPING_NAMES = OrderedDict( [ ("swin2sr", "Swin2SRForImageSuperResolution"), diff --git a/mindone/transformers/models/eomt/__init__.py b/mindone/transformers/models/eomt/__init__.py new file mode 100644 index 0000000000..9bea4de8ae --- /dev/null +++ b/mindone/transformers/models/eomt/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .image_processing_eomt import * +from .image_processing_eomt_fast import * +from .modeling_eomt import * diff --git a/mindone/transformers/models/eomt/image_processing_eomt.py b/mindone/transformers/models/eomt/image_processing_eomt.py new file mode 100644 index 0000000000..913ed4e810 --- /dev/null +++ b/mindone/transformers/models/eomt/image_processing_eomt.py @@ -0,0 +1,971 @@ +# coding=utf-8 +# Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Image processor class for EoMT.""" + +import math +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BaseImageProcessor, BatchFeature, get_size_dict +from ...image_transforms import PaddingMode, pad, resize +from ...image_utils import ( + ChannelDimension, + ImageInput, + PILImageResampling, + get_image_size, + infer_channel_dimension_format, + make_flat_list_of_images, + make_list_of_images, + to_numpy_array, + valid_images, + validate_preprocess_arguments, +) +from ...utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + TensorType, + filter_out_non_signature_kwargs, + is_mindspore_available, + logging, +) + +logger = logging.get_logger(__name__) + +if is_mindspore_available(): + import mindspore as ms + import mindspore.mint.nn.functional as F + from mindspore import mint + + +# Adapted from transformers.models.maskformer.image_processing_maskformer.convert_segmentation_map_to_binary_masks +def convert_segmentation_map_to_binary_masks( + segmentation_map: "np.ndarray", + instance_id_to_semantic_id: Optional[dict[int, int]] = None, + ignore_index: Optional[int] = None, +): + if ignore_index is not None: + segmentation_map = np.where(segmentation_map == 0, ignore_index, segmentation_map - 1) + + # Get unique ids (class or instance ids based on input) + all_labels = np.unique(segmentation_map) + + # Drop background label if applicable + if ignore_index is not None: + all_labels = all_labels[all_labels != ignore_index] + + # Generate a binary mask for each object instance + binary_masks = [(segmentation_map == i) for i in all_labels] + + # Stack the binary masks + if binary_masks: + binary_masks = np.stack(binary_masks, axis=0) + else: + binary_masks = np.zeros((0, *segmentation_map.shape)) + + # Convert instance ids to class ids + if instance_id_to_semantic_id is not None: + labels = np.zeros(all_labels.shape[0]) + + for label in all_labels: + class_id = instance_id_to_semantic_id[label + 1 if ignore_index is not None else label] + labels[all_labels == label] = class_id - 1 if ignore_index is not None else class_id + else: + labels = all_labels + + return binary_masks.astype(np.float32), labels.astype(np.int64) + + +def get_size_with_aspect_ratio(image_size, size, max_size=None) -> tuple[int, int]: + """ + Computes the output image size given the input image size and the desired output size. + + Args: + image_size (`tuple[int, int]`): + The input image size. + size (`int`): + The desired output size. + max_size (`int`, *optional*): + The maximum allowed output size. + """ + height, width = image_size + raw_size = None + if max_size is not None: + min_original_size = float(min((height, width))) + max_original_size = float(max((height, width))) + if max_original_size / min_original_size * size > max_size: + raw_size = max_size * min_original_size / max_original_size + size = int(round(raw_size)) + + if (height <= width and height == size) or (width <= height and width == size): + oh, ow = height, width + elif width < height: + ow = size + if max_size is not None and raw_size is not None: + oh = round(raw_size * height / width) + else: + oh = round(size * height / width) + else: + oh = size + if max_size is not None and raw_size is not None: + ow = round(raw_size * width / height) + else: + ow = round(size * width / height) + + return (oh, ow) + + +# Copied from transformers.models.detr.image_processing_detr.remove_low_and_no_objects +def remove_low_and_no_objects(masks, scores, labels, object_mask_threshold, num_labels): + """ + Binarize the given masks using `object_mask_threshold`, it returns the associated values of `masks`, `scores` and + `labels`. + + Args: + masks (`ms.Tensor`): + A tensor of shape `(num_queries, height, width)`. + scores (`ms.Tensor`): + A tensor of shape `(num_queries)`. + labels (`ms.Tensor`): + A tensor of shape `(num_queries)`. + object_mask_threshold (`float`): + A number between 0 and 1 used to binarize the masks. + Raises: + `ValueError`: Raised when the first dimension doesn't match in all input tensors. + Returns: + `tuple[`ms.Tensor`, `ms.Tensor`, `ms.Tensor`]`: The `masks`, `scores` and `labels` without the region + < `object_mask_threshold`. + """ + if not (masks.shape[0] == scores.shape[0] == labels.shape[0]): + raise ValueError("mask, scores and labels must have the same shape!") + + to_keep = labels.ne(num_labels) & (scores > object_mask_threshold) + + return masks[to_keep], scores[to_keep], labels[to_keep] + + +def check_segment_validity(mask_labels, mask_probs, k, mask_threshold=0.5, overlap_mask_area_threshold=0.8): + # Get the mask associated with the k class + mask_k = mask_labels == k + mask_k_area = mask_k.sum() + + # Compute the area of all the stuff in query k + original_mask = mask_probs[k] >= mask_threshold + original_area = original_mask.sum() + + final_mask = mask_k & original_mask + final_mask_area = final_mask.sum() + + mask_exists = mask_k_area > 0 and original_area > 0 and final_mask_area > 0 + + if mask_exists: + area_ratio = mask_k_area / original_area + if not area_ratio.item() > overlap_mask_area_threshold: + mask_exists = False + + return mask_exists, final_mask + + +def compute_segments( + mask_probs, + pred_scores, + pred_labels, + stuff_classes, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + target_size: Optional[tuple[int, int]] = None, +): + height = mask_probs.shape[1] if target_size is None else target_size[0] + width = mask_probs.shape[2] if target_size is None else target_size[1] + + segmentation = mint.zeros((height, width), dtype=ms.int64) - 1 + segments: list[dict] = [] + + # Compute per-pixel assignment based on weighted mask scores + mask_probs = mask_probs.sigmoid() + mask_labels = (pred_scores[:, None, None] * mask_probs).argmax(0) + + # Keep track of instances of each class + current_segment_id = 0 + stuff_memory_list: dict[str, int] = {} + + for k in range(pred_labels.shape[0]): + pred_class = pred_labels[k].item() + + # Check if mask exists and large enough to be a segment + mask_exists, final_mask = check_segment_validity( + mask_labels, mask_probs, k, mask_threshold, overlap_mask_area_threshold + ) + + if not mask_exists: + continue + + if stuff_classes and pred_class in stuff_classes: + if pred_class in stuff_memory_list: + segmentation[final_mask] = stuff_memory_list[pred_class] + continue + else: + stuff_memory_list[pred_class] = current_segment_id + + segmentation[final_mask] = current_segment_id + segment_score = round(pred_scores[k].item(), 6) + segments.append( + { + "id": current_segment_id, + "label_id": pred_class, + "score": segment_score, + } + ) + current_segment_id += 1 + return segmentation, segments + + +def get_target_size(size_dict: dict[str, int]) -> tuple[int, int]: + """Returns the height and width from a size dict.""" + target_height = size_dict["shortest_edge"] + target_width = size_dict.get("longest_edge", None) or target_height + + return target_height, target_width + + +class EomtImageProcessor(BaseImageProcessor): + r""" + Constructs a EoMT image processor. The image processor can be used to prepare image(s) and optional targets + for the model. + + This image processor inherits from [`BaseImageProcessor`] which contains most of the main methods. Users should + refer to this superclass for more information regarding those methods. + + Args: + do_resize (`bool`, *optional*, defaults to `True`): + Whether to resize the input to a certain `size`. + size (`int`, *optional*, defaults to 640): + Resize the input to the given size. Only has an effect if `do_resize` is set to `True`. If size is a + sequence like `(width, height)`, output size will be matched to this. If size is an int, smaller edge of + the image will be matched to this number. i.e, if `height > width`, then image will be rescaled to `(size * + height / width, size)`. + resample (`int`, *optional*, defaults to `Resampling.BILINEAR`): + An optional resampling filter. This can be one of `PIL.Image.Resampling.NEAREST`, + `PIL.Image.Resampling.BOX`, `PIL.Image.Resampling.BILINEAR`, `PIL.Image.Resampling.HAMMING`, + `PIL.Image.Resampling.BICUBIC` or `PIL.Image.Resampling.LANCZOS`. Only has an effect if `do_resize` is set + to `True`. + do_rescale (`bool`, *optional*, defaults to `True`): + Whether to rescale the input to a certain `scale`. + rescale_factor (`float`, *optional*, defaults to `1/ 255`): + Rescale the input by the given factor. Only has an effect if `do_rescale` is set to `True`. + do_normalize (`bool`, *optional*, defaults to `True`): + Whether or not to normalize the input with mean and standard deviation. + do_split_image (`bool`, *optional*, defaults to `False`): + Whether to split the input images into overlapping patches for semantic segmentation. If set to `True`, the + input images will be split into patches of size `size["shortest_edge"]` with an overlap between patches. + Otherwise, the input images will be padded to the target size. + do_pad (`bool`, *optional*, defaults to `False`): + Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest + number of patches in the batch. Padding will be applied to the bottom and right with zeros. + image_mean (`int`, *optional*, defaults to `[0.485, 0.456, 0.406]`): + The sequence of means for each channel, to be used when normalizing images. Defaults to the ImageNet mean. + image_std (`int`, *optional*, defaults to `[0.229, 0.224, 0.225]`): + The sequence of standard deviations for each channel, to be used when normalizing images. Defaults to the + ImageNet std. + ignore_index (`int`, *optional*): + Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels + denoted with 0 (background) will be replaced with `ignore_index`. + num_labels (`int`, *optional*): + The number of labels in the segmentation map. + """ + + model_input_names = ["pixel_values"] + + def __init__( + self, + do_resize: bool = True, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = PILImageResampling.BILINEAR, + do_rescale: bool = True, + rescale_factor: float = 1 / 255, + do_normalize: bool = True, + do_split_image: bool = False, + do_pad: bool = False, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + ignore_index: Optional[int] = None, + num_labels: Optional[int] = None, + **kwargs, + ): + super().__init__(**kwargs) + + size = size if size is not None else {"shortest_edge": 640, "longest_edge": 640} + size = get_size_dict(size, default_to_square=False) + + self.do_resize = do_resize + self.size = size + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.do_split_image = do_split_image + self.do_pad = do_pad + self.image_mean = image_mean if image_mean is not None else IMAGENET_DEFAULT_MEAN + self.image_std = image_std if image_std is not None else IMAGENET_DEFAULT_STD + self.ignore_index = ignore_index + self.num_labels = num_labels + + def resize( + self, + image: np.ndarray, + size: dict, + resample: PILImageResampling = PILImageResampling.BILINEAR, + data_format=None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + **kwargs, + ) -> np.ndarray: + """ + Resize an image. The shortest edge of the image is resized to size["shortest_edge"], with the longest edge + resized to keep the input aspect ratio. + + Args: + image (`np.ndarray`): + Image to resize. + resample (`PILImageResampling`, *optional*, defaults to `PILImageResampling.BICUBIC`): + Resampling filter to use when resiizing the image. + data_format (`str` or `ChannelDimension`, *optional*): + The channel dimension format of the image. If not provided, it will be the same as the input image. + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + """ + image_size = get_image_size(image) + output_size = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"]) + + image = resize( + image=image, + size=output_size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + return_numpy=True, + **kwargs, + ) + + return image + + def _split_image(self, image: ImageInput, size: dict, image_index: int) -> tuple[list, list]: + """Slices an image into overlapping patches for semantic segmentation.""" + + patches, patch_offsets = [], [] + + image_size = get_image_size(image) + patch_size = size["shortest_edge"] + + longer_side = max(image_size) + num_patches = math.ceil(longer_side / patch_size) + total_overlap = num_patches * patch_size - longer_side + overlap_per_patch = total_overlap / (num_patches - 1) if num_patches > 1 else 0 + + for i in range(num_patches): + start = int(i * (patch_size - overlap_per_patch)) + end = start + patch_size + + if image_size[0] > image_size[1]: + patch = image[:, start:end, :] + else: + patch = image[:, :, start:end] + + patches.append(patch) + patch_offsets.append([image_index, start, end]) + + return patches, patch_offsets + + def _pad(self, image: ImageInput, size: dict) -> np.ndarray: + """Pads the image to the target size using zero padding.""" + height, width = get_image_size(image) + + target_height, target_width = get_target_size(size) + pad_h = max(0, target_height - height) + pad_w = max(0, target_width - width) + + padding = ((0, pad_h), (0, pad_w)) + + # Channel axis is last; default padding format is compatible + padded_image = pad(image=image, padding=padding, mode=PaddingMode.CONSTANT, constant_values=0.0) + return padded_image + + def _preprocess_images( + self, + images: ImageInput, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = None, + do_split_image: Optional[bool] = None, + do_pad: Optional[bool] = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + data_format: Optional[Union[str, ChannelDimension]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a batch of images.""" + images = [to_numpy_array(image) for image in images] + + if do_resize: + images = [ + self.resize( + image, + size=size, + resample=resample, + data_format=data_format, + input_data_format=input_data_format, + ) + for image in images + ] + + processed_images, patch_offsets = [], [] + + if do_split_image: + for idx, img in enumerate(images): + patches, offsets = self._split_image(img, size, idx) + processed_images.extend(patches) + patch_offsets.extend(offsets) + + images = processed_images + + if do_pad: + images = [self._pad(img, size) for img in images] + + if do_rescale: + images = [self.rescale(img, scale=rescale_factor, input_data_format=input_data_format) for img in images] + + if do_normalize: + images = [ + self.normalize( + image, + mean=image_mean, + std=image_std, + input_data_format=input_data_format, + ) + for image in images + ] + + return images, patch_offsets + + def _preprocess_mask( + self, + segmentation_map: ImageInput, + do_resize: Optional[bool] = False, + do_pad: Optional[bool] = False, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = None, + data_format: Union[str, ChannelDimension] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> np.ndarray: + """Preprocesses a single mask.""" + # Add channel dimension if missing - needed for certain transformations + if segmentation_map.ndim == 2: + added_channel_dim = True + segmentation_map = segmentation_map[None, ...] + input_data_format = ChannelDimension.FIRST + else: + added_channel_dim = False + if input_data_format is None: + input_data_format = infer_channel_dimension_format(segmentation_map) + + if do_resize: + segmentation_map = self.resize( + segmentation_map, + size=size, + resample=resample, + data_format=data_format, + ) + + if do_pad: + segmentation_map = self._pad(segmentation_map, size) + + # Remove extra channel dimension if added for processing + if added_channel_dim: + segmentation_map = segmentation_map.squeeze(0) + return ms.Tensor.from_numpy(segmentation_map) + + @filter_out_non_signature_kwargs() + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[Union[list[dict[int, int]], dict[int, int]]] = None, + instance_id_to_semantic_id: Optional[dict[int, int]] = None, + do_split_image: Optional[bool] = None, + do_resize: Optional[bool] = None, + size: Optional[dict[str, int]] = None, + resample: PILImageResampling = None, + do_rescale: Optional[bool] = None, + rescale_factor: Optional[float] = None, + do_normalize: Optional[bool] = None, + do_pad: Optional[bool] = None, + image_mean: Optional[Union[float, list[float]]] = None, + image_std: Optional[Union[float, list[float]]] = None, + ignore_index: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + data_format: Union[str, ChannelDimension] = ChannelDimension.FIRST, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ) -> BatchFeature: + """ + Preprocesses images or a batch of images. + + Args: + images (`ImageInput`): + Image or batch of images to preprocess. + segmentation_maps (`ImageInput`, *optional*): + The corresponding semantic segmentation maps with the pixel-wise annotations. + instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*): + A mapping between object instance ids and class ids. + do_split_image (`bool`, *optional*, defaults to `self.do_split_image`): + Whether to split the input images into overlapping patches for semantic segmentation. + do_resize (`bool`, *optional*, defaults to `self.do_resize`): + Whether to resize the input images. + size (`dict[str, int]`, *optional*, defaults to `self.size`): + Target size as a dictionary with `"shortest_edge"` and `"longest_edge"` keys. + resample (`PILImageResampling`, *optional*, defaults to `self.resample`): + Resampling filter to use when resizing. + do_rescale (`bool`, *optional*, defaults to `self.do_rescale`): + Whether to rescale the input images by `rescale_factor`. + rescale_factor (`float`, *optional*, defaults to `self.rescale_factor`): + Factor to scale image pixel values. + do_normalize (`bool`, *optional*, defaults to `self.do_normalize`): + Whether to normalize the input images. + do_pad (`bool`, *optional*, defaults to `False`): + Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest + number of patches in the batch. Padding will be applied to the bottom and right with zeros. + image_mean (`float` or `list[float]`, *optional*, defaults to `self.image_mean`): + Mean for normalization. Single value or list for each channel. + image_std (`float` or `list[float]`, *optional*, defaults to `self.image_std`): + Standard deviation for normalization. Single value or list for each channel. + ignore_index (`int`, *optional*): + Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels + denoted with 0 (background) will be replaced with `ignore_index`. + return_tensors (`str` or `TensorType`, *optional*): + The type of tensors to return. Can be `"ms"` or `"np"`. + data_format (`ChannelDimension` or `str`, *optional*, defaults to `ChannelDimension.FIRST`): + Channel format of the output image. Either `"channels_first"` or `"channels_last"`. + input_data_format (`ChannelDimension` or `str`, *optional*): + Channel format of the input image. + """ + + do_split_image = do_split_image if do_split_image is not None else self.do_split_image + do_resize = do_resize if do_resize is not None else self.do_resize + size = size if size is not None else self.size + size = get_size_dict(size, default_to_square=False) + resample = resample if resample is not None else self.resample + do_rescale = do_rescale if do_rescale is not None else self.do_rescale + rescale_factor = rescale_factor if rescale_factor is not None else self.rescale_factor + do_normalize = do_normalize if do_normalize is not None else self.do_normalize + do_pad = do_pad if do_pad is not None else self.do_pad + image_mean = image_mean if image_mean is not None else self.image_mean + image_std = image_std if image_std is not None else self.image_std + ignore_index = ignore_index if ignore_index is not None else self.ignore_index + + images = make_flat_list_of_images(images) + + if not valid_images(images): + raise ValueError( + "Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " + "ms.Tensor, tf.Tensor or jax.ndarray." + ) + + validate_preprocess_arguments( + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + do_resize=do_resize, + size=size, + resample=resample, + ) + + pixel_values_list, patch_offsets = self._preprocess_images( + images=images, + do_resize=do_resize, + size=size, + resample=resample, + do_split_image=do_split_image, + do_pad=do_pad, + do_rescale=do_rescale, + rescale_factor=rescale_factor, + do_normalize=do_normalize, + image_mean=image_mean, + image_std=image_std, + data_format=data_format, + input_data_format=input_data_format, + ) + + if segmentation_maps is not None: + segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2) + segmentation_maps = [to_numpy_array(mask) for mask in segmentation_maps] + + segmentation_maps = [ + self._preprocess_mask( + segmentation_map, + do_resize=do_resize, + do_pad=do_pad, + size=size, + resample=PILImageResampling.NEAREST, + data_format=data_format, + input_data_format=input_data_format, + ) + for segmentation_map in segmentation_maps + ] + + encoded_inputs = self.encode_inputs( + pixel_values_list, + segmentation_maps, + instance_id_to_semantic_id, + ignore_index, + return_tensors, + input_data_format=data_format, + ) + + if do_split_image and patch_offsets: + encoded_inputs["patch_offsets"] = [ms.tensor(offsets) for offsets in patch_offsets] + + return encoded_inputs + + def encode_inputs( + self, + pixel_values_list: list[ImageInput], + segmentation_maps: ImageInput = None, + instance_id_to_semantic_id: Optional[Union[list[dict[int, int]], dict[int, int]]] = None, + ignore_index: Optional[int] = None, + return_tensors: Optional[Union[str, TensorType]] = None, + input_data_format: Optional[Union[str, ChannelDimension]] = None, + ): + """ + Pad images up to the largest image in a batch and create a corresponding `pixel_mask`. + + EoMT addresses semantic segmentation with a mask classification paradigm, thus input segmentation maps + will be converted to lists of binary masks and their respective labels. Let's see an example, assuming + `segmentation_maps = [[2,6,7,9]]`, the output will contain `mask_labels = + [[1,0,0,0],[0,1,0,0],[0,0,1,0],[0,0,0,1]]` (four binary masks) and `class_labels = [2,6,7,9]`, the labels for + each mask. + + Args: + pixel_values_list (`list[ImageInput]`): + list of images (pixel values) to be padded. Each image should be a tensor of shape `(channels, height, + width)`. + + segmentation_maps (`ImageInput`, *optional*): + The corresponding semantic segmentation maps with the pixel-wise annotations. + + (`bool`, *optional*, defaults to `True`): + Whether or not to pad images up to the largest image in a batch and create a pixel mask. + + If left to the default, will return a pixel mask that is: + + - 1 for pixels that are real (i.e. **not masked**), + - 0 for pixels that are padding (i.e. **masked**). + + instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*): + A mapping between object instance ids and class ids. If passed, `segmentation_maps` is treated as an + instance segmentation map where each pixel represents an instance id. Can be provided as a single + dictionary with a global/dataset-level mapping or as a list of dictionaries (one per image), to map + instance ids in each image separately. + + return_tensors (`str` or [`~file_utils.TensorType`], *optional*): + If set, will return tensors instead of NumPy arrays. If set to `'ms'`, return MindSpore `ms.Tensor` + objects. + + input_data_format (`ChannelDimension` or `str`, *optional*): + The channel dimension format of the input image. If not provided, it will be inferred. + + Returns: + [`BatchFeature`]: A [`BatchFeature`] with the following fields: + + - **pixel_values** -- Pixel values to be fed to a model. + - **mask_labels** -- Optional list of mask labels of shape `(labels, height, width)` to be fed to a model + (when `annotations` are provided). + - **class_labels** -- Optional list of class labels of shape `(labels)` to be fed to a model (when + `annotations` are provided). They identify the labels of `mask_labels`, e.g. the label of + `mask_labels[i][j]` if `class_labels[i][j]`. + """ + ignore_index = self.ignore_index if ignore_index is None else ignore_index + + pixel_values_list = [to_numpy_array(pixel_values) for pixel_values in pixel_values_list] + + if input_data_format is None: + input_data_format = infer_channel_dimension_format(pixel_values_list[0]) + + encoded_inputs = BatchFeature({"pixel_values": pixel_values_list}, tensor_type=return_tensors) + + if segmentation_maps is not None: + mask_labels = [] + class_labels = [] + # Convert to list of binary masks and labels + for idx, segmentation_map in enumerate(segmentation_maps): + segmentation_map = to_numpy_array(segmentation_map) + if isinstance(instance_id_to_semantic_id, list): + instance_id = instance_id_to_semantic_id[idx] + else: + instance_id = instance_id_to_semantic_id + # Use instance2class_id mapping per image + masks, classes = convert_segmentation_map_to_binary_masks( + segmentation_map, + instance_id, + ignore_index=ignore_index, + ) + + mask_labels.append(ms.Tensor.from_numpy(masks)) + class_labels.append(ms.Tensor.from_numpy(classes)) + + # we cannot batch them since they don't share a common class size + encoded_inputs["mask_labels"] = mask_labels + encoded_inputs["class_labels"] = class_labels + + return encoded_inputs + + def merge_image_patches( + self, + segmentation_logits: ms.Tensor, + patch_offsets: list[tuple[int, int, int]], + target_sizes: list[tuple[int, int]], + size: dict[str, int], + ) -> list[ms.Tensor]: + """ + Reconstructs full-size semantic segmentation logits from patch predictions. + + Args: + segmentation_logits (`ms.Tensor`): + A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits + for each image patch. + patch_offsets (`list[tuple[int, int, int]]`): + A list of tuples where each tuple contains: + - `image_index` (int): Index of the original image this patch belongs to. + - `start` (int): Start pixel index of the patch along the long dimension (height or width). + - `end` (int): End pixel index of the patch along the long dimension. + target_sizes (`list[tuple[int, int]]`): + list of original (height, width) dimensions for each image before preprocessing. + size (`dict[str, int]`): + A size dict which was used to resize. + """ + num_classes = segmentation_logits.shape[1] + aggregated_logits = [] + patch_counts = [] + + for image_size in target_sizes: + height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"]) + aggregated_logits.append(mint.zeros((num_classes, height, width))) + patch_counts.append(mint.zeros((num_classes, height, width))) + + # Stitch patches back into full-sized logit maps + for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets): + if target_sizes[image_idx][0] > target_sizes[image_idx][1]: + aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx] + patch_counts[image_idx][:, patch_start:patch_end, :] += 1 + else: + aggregated_logits[image_idx][:, :, patch_start:patch_end] += segmentation_logits[patch_idx] + patch_counts[image_idx][:, :, patch_start:patch_end] += 1 + + # Normalize and resize logits to original image size + reconstructed_logits = [] + for idx, (logit_sum, count) in enumerate(zip(aggregated_logits, patch_counts)): + averaged_logits = logit_sum / count.clamp(min=1) + resized_logits = F.interpolate( + averaged_logits[None, ...], + size=target_sizes[idx], + mode="bilinear", + align_corners=False, + )[0] + + reconstructed_logits.append(resized_logits) + + return reconstructed_logits + + def unpad_image( + self, + segmentation_logits: ms.Tensor, + target_sizes: list[tuple[int, int]], + size: dict[str, int], + ) -> list[ms.Tensor]: + """Restores panoptic segmentation logits to their original image resolutions.""" + + resized_logits = [] + + for idx, original_size in enumerate(target_sizes): + target_height, target_width = get_size_with_aspect_ratio( + original_size, size["shortest_edge"], size["longest_edge"] + ) + cropped_logits = segmentation_logits[idx][:, :target_height, :target_width] + upsampled_logits = F.interpolate( + cropped_logits[None, ...], size=original_size, mode="bilinear", align_corners=False + )[0] + resized_logits.append(upsampled_logits) + return resized_logits + + def post_process_semantic_segmentation( + self, + outputs, + target_sizes: list[tuple[int, int]], + size: Optional[dict[str, int]] = None, + ) -> np.ndarray: + """Post-processes model outputs into final semantic segmentation prediction.""" + + size = size if size is not None else self.size + + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + patch_offsets = outputs.patch_offsets + + output_size = get_target_size(size) + masks_queries_logits = F.interpolate( + masks_queries_logits, + size=output_size, + mode="bilinear", + ) + + # Remove the null class `[..., :-1]` + masks_classes = class_queries_logits.softmax(axis=-1)[..., :-1] + masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + segmentation_logits = mint.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) + + output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size) + + preds = [logit.argmax(dim=0) for logit in output_logits] + return preds + + def post_process_panoptic_segmentation( + self, + outputs, + target_sizes: list[tuple[int, int]], + threshold: float = 0.8, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + stuff_classes: Optional[list[int]] = None, + size: Optional[dict[str, int]] = None, + ): + """Post-processes model outputs into final panoptic segmentation prediction.""" + + size = size if size is not None else self.size + + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + + batch_size = class_queries_logits.shape[0] + num_labels = class_queries_logits.shape[-1] - 1 + + output_size = get_target_size(size) + masks_queries_logits = F.interpolate( + masks_queries_logits, + size=output_size, + mode="bilinear", + ) + + mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size) + pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(axis=-1).max(-1) + + results: list = [] + + for i in range(batch_size): + mask_probs, pred_scores, pred_labels = remove_low_and_no_objects( + mask_probs_batch[i], pred_scores_batch[i], pred_labels_batch[i], threshold, num_labels + ) + + # No mask found + if mask_probs.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_probs.shape[1:] + segmentation = mint.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + segmentation, segments = compute_segments( + mask_probs=mask_probs, + pred_scores=pred_scores, + pred_labels=pred_labels, + stuff_classes=stuff_classes, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + target_size=target_sizes[i] if target_sizes is not None else None, + ) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + @filter_out_non_signature_kwargs() + def post_process_instance_segmentation( + self, + outputs, + target_sizes: list[tuple[int, int]], + threshold: float = 0.5, + size: Optional[dict[str, int]] = None, + ): + """Post-processes model outputs into Instance Segmentation Predictions.""" + + size = size if size is not None else self.size + + class_queries_logits = outputs.class_queries_logits + masks_queries_logits = outputs.masks_queries_logits + + output_size = get_target_size(size) + masks_queries_logits = F.interpolate( + masks_queries_logits, + size=output_size, + mode="bilinear", + ) + + mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size) + + batch_size = class_queries_logits.shape[0] + num_queries = class_queries_logits.shape[-2] + + results = [] + + for i in range(batch_size): + mask_pred = mask_probs_batch[i] + mask_class = class_queries_logits[i] + + # Remove the null class `[..., :-1]` + scores, pred_classes = mask_class.softmax(axis=-1)[..., :-1].max(-1) + pred_masks = (mask_pred > 0).float() + + # Calculate average mask prob + mask_scores = (mask_pred.sigmoid().flatten(1) * pred_masks.flatten(1)).sum(1) / ( + pred_masks.flatten(1).sum(1) + 1e-6 + ) + pred_scores = scores * mask_scores + + segmentation = mint.zeros(target_sizes[i]) - 1 + + instance_maps, segments = [], [] + current_segment_id = 0 + for j in range(num_queries): + score = pred_scores[j].item() + + if not mint.all(pred_masks[j] == 0) and score >= threshold: + segmentation[pred_masks[j] == 1] = current_segment_id + segments.append( + { + "id": current_segment_id, + "label_id": pred_classes[j].item(), + "score": round(score, 6), + } + ) + current_segment_id += 1 + instance_maps.append(pred_masks[j]) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + +__all__ = ["EomtImageProcessor"] diff --git a/mindone/transformers/models/eomt/image_processing_eomt_fast.py b/mindone/transformers/models/eomt/image_processing_eomt_fast.py new file mode 100644 index 0000000000..584f3abb43 --- /dev/null +++ b/mindone/transformers/models/eomt/image_processing_eomt_fast.py @@ -0,0 +1,543 @@ +# coding=utf-8 +# Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Fast Image processor class for EoMT.""" + +import math +from typing import Optional, Union + +import numpy as np + +from ...image_processing_utils import BatchFeature +from ...image_processing_utils_fast import ( + BaseImageProcessorFast, + DefaultFastImageProcessorKwargs, + group_images_by_shape, + reorder_images, +) +from ...image_utils import ( + IMAGENET_DEFAULT_MEAN, + IMAGENET_DEFAULT_STD, + ChannelDimension, + ImageInput, + PILImageResampling, + SizeDict, + pil_mindspore_interpolation_mapping, +) +from ...processing_utils import Unpack +from ...utils import TensorType, filter_out_non_signature_kwargs, is_mindspore_available +from .image_processing_eomt import ( + compute_segments, + convert_segmentation_map_to_binary_masks, + get_size_with_aspect_ratio, + remove_low_and_no_objects, +) + +if is_mindspore_available(): + import mindspore as ms + from mindspore import mint + from mindspore.dataset.vision import Inter as InterpolationMode + + +class EomtImageProcessorFastKwargs(DefaultFastImageProcessorKwargs): + """ + do_split_image (`bool`, *optional*, defaults to `False`): + Whether to split the input images into overlapping patches for semantic segmentation. If set to `True`, the + input images will be split into patches of size `size["shortest_edge"]` with an overlap between patches. + Otherwise, the input images will be padded to the target size. + do_pad (`bool`, *optional*, defaults to `False`): + Whether to pad the image. If `True`, will pad the patch dimension of the images in the batch to the largest + number of patches in the batch. Padding will be applied to the bottom and right with zeros. + ignore_index (`int`, *optional*): + Label to be assigned to background pixels in segmentation maps. If provided, segmentation map pixels + denoted with 0 (background) will be replaced with `ignore_index`. + """ + + do_split_image: bool + do_pad: bool + ignore_index: Optional[int] = None + + +def get_target_size(size_dict: dict[str, int]) -> tuple[int, int]: + """Returns the height and width from a size dict.""" + target_height = size_dict["shortest_edge"] + target_width = size_dict["longest_edge"] or target_height + + return target_height, target_width + + +def reorder_patches_and_offsets( + patches: list[ms.Tensor], offsets: list[list[int]] +) -> tuple[list[ms.Tensor], list[list[int]]]: + """Sorts patches and offsets according to the original image index.""" + + combined = list(zip(offsets, patches)) + combined.sort(key=lambda x: x[0][0]) + sorted_offsets, sorted_patches = zip(*combined) + + return list(sorted_patches), list(sorted_offsets) + + +class EomtImageProcessorFast(BaseImageProcessorFast): + resample = PILImageResampling.BILINEAR + image_mean = IMAGENET_DEFAULT_MEAN + image_std = IMAGENET_DEFAULT_STD + size = {"shortest_edge": 640, "longest_edge": 640} + default_to_square = False + do_resize = True + do_rescale = True + do_normalize = True + do_split_image = False + do_pad = False + ignore_index = None + valid_kwargs = EomtImageProcessorFastKwargs + + def __init__(self, **kwargs: Unpack[EomtImageProcessorFastKwargs]): + super().__init__(**kwargs) + + def _split_image(self, images: ms.Tensor, size: dict, image_indices: int) -> tuple[list, list]: + """Slices an image into overlapping patches for semantic segmentation.""" + + patches, patch_offsets = [], [] + + _, _, height, width = images.shape + patch_size = size["shortest_edge"] + + longer_side = max(height, width) + num_patches = math.ceil(longer_side / patch_size) + total_overlap = num_patches * patch_size - longer_side + overlap_per_patch = total_overlap / (num_patches - 1) if num_patches > 1 else 0 + + for i in range(num_patches): + start = int(i * (patch_size - overlap_per_patch)) + end = start + patch_size + + if height > width: + batch_patch = images[:, :, start:end, :] + else: + batch_patch = images[:, :, :, start:end] + + for batch_idx, single in enumerate(mint.unbind(batch_patch, dim=0)): + patches.append(single) + patch_offsets.append([image_indices[batch_idx], start, end]) + + return patches, patch_offsets + + def _pad(self, images: ms.Tensor, size: dict) -> ms.Tensor: + """Pads the image to the target size using zero padding.""" + _, _, height, width = images.shape + + target_height, target_width = get_target_size(size) + pad_h = max(0, target_height - height) + pad_w = max(0, target_width - width) + padding = (0, pad_w, 0, pad_h) + + padded_images = mint.nn.functional.pad(images, padding, mode="constant", value=0.0) + return padded_images + + def preprocess( + self, + images: ImageInput, + segmentation_maps: Optional[list[ms.Tensor]] = None, + instance_id_to_semantic_id: Optional[dict[int, int]] = None, + **kwargs: Unpack[EomtImageProcessorFastKwargs], + ) -> BatchFeature: + r""" + segmentation_maps (`ImageInput`, *optional*): + The segmentation maps to preprocess for corresponding images. + instance_id_to_semantic_id (`list[dict[int, int]]` or `dict[int, int]`, *optional*): + A mapping between object instance ids and class ids. + """ + return super().preprocess(images, segmentation_maps, instance_id_to_semantic_id, **kwargs) + + def _preprocess_image_like_inputs( + self, + images: ImageInput, + segmentation_maps: Optional[ImageInput], + instance_id_to_semantic_id: Optional[dict[int, int]], + do_convert_rgb: bool, + input_data_format: ChannelDimension, + **kwargs: Unpack[EomtImageProcessorFastKwargs], + ) -> BatchFeature: + """ + Preprocess image-like inputs. + """ + images = self._prepare_image_like_inputs( + images=images, do_convert_rgb=do_convert_rgb, input_data_format=input_data_format + ) + ignore_index = kwargs.pop("ignore_index", None) + images_kwargs = kwargs.copy() + processed_images, patch_offsets = self._preprocess(images, **images_kwargs) + outputs = BatchFeature({"pixel_values": processed_images}) + + if segmentation_maps is not None: + processed_segmentation_maps = self._prepare_image_like_inputs( + images=segmentation_maps, + expected_ndims=2, + do_convert_rgb=False, + input_data_format=ChannelDimension.FIRST, + ) + + segmentation_maps_kwargs = kwargs.copy() + segmentation_maps_kwargs.update( + { + "do_normalize": False, + "do_rescale": False, + # Nearest interpolation is used for segmentation maps instead of BILINEAR. + "interpolation": pil_mindspore_interpolation_mapping[PILImageResampling.NEAREST], + } + ) + + processed_segmentation_maps, _ = self._preprocess( + images=processed_segmentation_maps, **segmentation_maps_kwargs + ) + processed_segmentation_maps = processed_segmentation_maps.squeeze(1).to(ms.int64) + # Convert to list of binary masks and labels + mask_labels, class_labels = [], [] + for idx, segmentation_map in enumerate(processed_segmentation_maps): + if isinstance(instance_id_to_semantic_id, list): + instance_id = instance_id_to_semantic_id[idx] + else: + instance_id = instance_id_to_semantic_id + # Use instance2class_id mapping per image + masks, classes = convert_segmentation_map_to_binary_masks( + segmentation_map, + instance_id, + ignore_index=ignore_index, + ) + + mask_labels.append(ms.Tensor.from_numpy(masks)) + class_labels.append(ms.Tensor.from_numpy(classes)) + + # we cannot batch them since they don't share a common class size + outputs["mask_labels"] = mask_labels + outputs["class_labels"] = class_labels + + if patch_offsets: + outputs["patch_offsets"] = [ms.tensor(offsets) for offsets in patch_offsets] + + return outputs + + def _preprocess( + self, + images: list["ms.Tensor"], + do_resize: bool, + size: SizeDict, + interpolation: Optional["InterpolationMode"], + do_rescale: bool, + rescale_factor: float, + do_normalize: bool, + do_split_image: bool, + do_pad: bool, + image_mean: Optional[Union[float, list[float]]], + image_std: Optional[Union[float, list[float]]], + return_tensors: Optional[Union[str, TensorType]], + **kwargs, + ): + """Preprocesses the input images and masks if provided.""" + processed_images, patch_offsets = [], [] + + grouped_images, grouped_images_index = group_images_by_shape(images) + resized_images_grouped = {} + + for shape, stacked_images in grouped_images.items(): + if do_resize: + stacked_images_updated = [] + for i in range(len(stacked_images)): + stacked_images_updated.append( + self.resize( + image=stacked_images[i], + size=size, + interpolation=interpolation, + ) + ) + stacked_images_updated = mint.stack(stacked_images_updated) + resized_images_grouped[shape] = stacked_images_updated + images = reorder_images(resized_images_grouped, grouped_images_index) + + # Group images by size for batched resizing, Needed in case do_resize is False. + grouped_images, grouped_images_index = group_images_by_shape(images) + processed_images_grouped = {} + + for shape, stacked_images in grouped_images.items(): + original_indices = [ + original_idx for original_idx, (img_shape, _) in grouped_images_index.items() if img_shape == shape + ] + + if do_split_image: + patches, offsets = self._split_image(stacked_images, size, original_indices) + processed_images.extend(patches) + patch_offsets.extend(offsets) + + if do_pad: + stacked_images = self._pad(stacked_images, size) + processed_images_grouped[shape] = stacked_images + + if do_split_image: + images, patch_offsets = reorder_patches_and_offsets(processed_images, patch_offsets) + + if do_pad: + images = reorder_images(processed_images_grouped, grouped_images_index) + + grouped_images, grouped_images_index = group_images_by_shape(images) + processed_images_grouped = {} + + for shape, stacked_images in grouped_images.items(): + stacked_images = self.rescale_and_normalize( + stacked_images, do_rescale, rescale_factor, do_normalize, image_mean, image_std + ) + processed_images_grouped[shape] = stacked_images + images = reorder_images(processed_images_grouped, grouped_images_index) + + processed_images = mint.stack(images, dim=0) if return_tensors else images + + return processed_images, patch_offsets + + def merge_image_patches( + self, + segmentation_logits: ms.Tensor, + patch_offsets: list[tuple[int, int, int]], + target_sizes: list[tuple[int, int]], + size: dict[str, int], + ) -> list[ms.Tensor]: + """ + Reconstructs full-size semantic segmentation logits from patch predictions. + + Args: + segmentation_logits (`ms.Tensor`): + A tensor of shape `(num_patches, num_classes, patch_height, patch_width)` representing predicted logits + for each image patch. + patch_offsets (`list[tuple[int, int, int]]`): + A list of tuples where each tuple contains: + - `image_index` (int): Index of the original image this patch belongs to. + - `start` (int): Start pixel index of the patch along the long dimension (height or width). + - `end` (int): End pixel index of the patch along the long dimension. + target_sizes (`list[tuple[int, int]]`): + list of original (height, width) dimensions for each image before preprocessing. + size (`dict[str, int]`): + A size dict which was used to resize. + """ + num_classes = segmentation_logits.shape[1] + aggregated_logits = [] + patch_counts = [] + + for image_size in target_sizes: + height, width = get_size_with_aspect_ratio(image_size, size["shortest_edge"], size["longest_edge"]) + aggregated_logits.append(mint.zeros((num_classes, height, width))) + patch_counts.append(mint.zeros((num_classes, height, width))) + + # Stitch patches back into full-sized logit maps + for patch_idx, (image_idx, patch_start, patch_end) in enumerate(patch_offsets): + if target_sizes[image_idx][0] > target_sizes[image_idx][1]: + aggregated_logits[image_idx][:, patch_start:patch_end, :] += segmentation_logits[patch_idx] + patch_counts[image_idx][:, patch_start:patch_end, :] += 1 + else: + aggregated_logits[image_idx][:, :, patch_start:patch_end] += segmentation_logits[patch_idx] + patch_counts[image_idx][:, :, patch_start:patch_end] += 1 + + # Normalize and resize logits to original image size + reconstructed_logits = [] + for idx, (logit_sum, count) in enumerate(zip(aggregated_logits, patch_counts)): + averaged_logits = logit_sum / count.clamp(min=1) + resized_logits = mint.nn.functional.interpolate( + averaged_logits[None, ...], + size=target_sizes[idx], + mode="bilinear", + align_corners=False, + )[0] + + reconstructed_logits.append(resized_logits) + + return reconstructed_logits + + def unpad_image( + self, + segmentation_logits: ms.Tensor, + target_sizes: list[tuple[int, int]], + size: dict[str, int], + ) -> list[ms.Tensor]: + """Restores panoptic segmentation logits to their original image resolutions.""" + + resized_logits = [] + + for idx, original_size in enumerate(target_sizes): + target_height, target_width = get_size_with_aspect_ratio( + original_size, size["shortest_edge"], size["longest_edge"] + ) + cropped_logits = segmentation_logits[idx][:, :target_height, :target_width] + upsampled_logits = mint.nn.functional.interpolate( + cropped_logits[None, ...], size=original_size, mode="bilinear", align_corners=False + )[0] + resized_logits.append(upsampled_logits) + return resized_logits + + def post_process_semantic_segmentation( + self, + outputs, + target_sizes: list[tuple[int, int]], + size: Optional[dict[str, int]] = None, + ) -> np.ndarray: + """Post-processes model outputs into final semantic segmentation prediction.""" + + size = size if size is not None else self.size + + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + patch_offsets = outputs.patch_offsets + + output_size = get_target_size(size) + masks_queries_logits = mint.nn.functional.interpolate( + masks_queries_logits, + size=output_size, + mode="bilinear", + ) + + # Remove the null class `[..., :-1]` + masks_classes = class_queries_logits.softmax(axis=-1)[..., :-1] + masks_probs = masks_queries_logits.sigmoid() # [batch_size, num_queries, height, width] + + segmentation_logits = mint.einsum("bqc, bqhw -> bchw", masks_classes, masks_probs) + + output_logits = self.merge_image_patches(segmentation_logits, patch_offsets, target_sizes, size) + + preds = [logit.argmax(dim=0) for logit in output_logits] + return preds + + def post_process_panoptic_segmentation( + self, + outputs, + target_sizes: list[tuple[int, int]], + threshold: float = 0.8, + mask_threshold: float = 0.5, + overlap_mask_area_threshold: float = 0.8, + stuff_classes: Optional[list[int]] = None, + size: Optional[dict[str, int]] = None, + ): + """Post-processes model outputs into final panoptic segmentation prediction.""" + + size = size if size is not None else self.size + + masks_queries_logits = outputs.masks_queries_logits # [batch_size, num_queries, height, width] + class_queries_logits = outputs.class_queries_logits # [batch_size, num_queries, num_classes+1] + + batch_size = class_queries_logits.shape[0] + num_labels = class_queries_logits.shape[-1] - 1 + + output_size = get_target_size(size) + masks_queries_logits = mint.nn.functional.interpolate( + masks_queries_logits, + size=output_size, + mode="bilinear", + ) + + mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size) + pred_scores_batch, pred_labels_batch = class_queries_logits.softmax(axis=-1).max(-1) + + results: list = [] + + for i in range(batch_size): + mask_probs, pred_scores, pred_labels = remove_low_and_no_objects( + mask_probs_batch[i], pred_scores_batch[i], pred_labels_batch[i], threshold, num_labels + ) + + # No mask found + if mask_probs.shape[0] <= 0: + height, width = target_sizes[i] if target_sizes is not None else mask_probs.shape[1:] + segmentation = mint.zeros((height, width)) - 1 + results.append({"segmentation": segmentation, "segments_info": []}) + continue + + segmentation, segments = compute_segments( + mask_probs=mask_probs, + pred_scores=pred_scores, + pred_labels=pred_labels, + stuff_classes=stuff_classes, + mask_threshold=mask_threshold, + overlap_mask_area_threshold=overlap_mask_area_threshold, + target_size=target_sizes[i] if target_sizes is not None else None, + ) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + @filter_out_non_signature_kwargs() + def post_process_instance_segmentation( + self, + outputs, + target_sizes: list[tuple[int, int]], + threshold: float = 0.8, + size: Optional[dict[str, int]] = None, + ): + """Post-processes model outputs into Instance Segmentation Predictions.""" + + size = size if size is not None else self.size + + masks_queries_logits = outputs.masks_queries_logits + class_queries_logits = outputs.class_queries_logits + + output_size = get_target_size(size) + masks_queries_logits = mint.nn.functional.interpolate( + masks_queries_logits, + size=output_size, + mode="bilinear", + ) + + mask_probs_batch = self.unpad_image(masks_queries_logits, target_sizes, size) + + batch_size = class_queries_logits.shape[0] + num_queries = class_queries_logits.shape[-2] + + results = [] + + for i in range(batch_size): + mask_pred = mask_probs_batch[i] + mask_class = class_queries_logits[i] + + # Remove the null class `[..., :-1]` + scores, pred_classes = mask_class.softmax(axis=-1)[..., :-1].max(-1) + pred_masks = (mask_pred > 0).float() + + # Calculate average mask prob + mask_scores = (mask_pred.sigmoid().flatten(1) * pred_masks.flatten(1)).sum(1) / ( + pred_masks.flatten(1).sum(1) + 1e-6 + ) + pred_scores = scores * mask_scores + + segmentation = mint.zeros(target_sizes[i]) - 1 + + instance_maps, segments = [], [] + current_segment_id = 0 + for j in range(num_queries): + score = pred_scores[j].item() + + if not mint.all(pred_masks[j] == 0) and score >= threshold: + segmentation[pred_masks[j] == 1] = current_segment_id + segments.append( + { + "id": current_segment_id, + "label_id": pred_classes[j].item(), + "score": round(score, 6), + } + ) + current_segment_id += 1 + instance_maps.append(pred_masks[j]) + + results.append({"segmentation": segmentation, "segments_info": segments}) + return results + + +__all__ = ["EomtImageProcessorFast"] diff --git a/mindone/transformers/models/eomt/modeling_eomt.py b/mindone/transformers/models/eomt/modeling_eomt.py new file mode 100644 index 0000000000..6dc4ef8b25 --- /dev/null +++ b/mindone/transformers/models/eomt/modeling_eomt.py @@ -0,0 +1,1193 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/eomt/modular_eomt.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_eomt.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Mobile Perception Systems Lab at TU/e and The HuggingFace Inc. team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections.abc +from dataclasses import dataclass +from typing import Callable, Optional, Union + +import numpy as np +from transformers.models.eomt.configuration_eomt import EomtConfig + +import mindspore as ms +import mindspore.mint.nn.functional as F +from mindspore import Tensor, mint, nn + +from ...activations import ACT2FN +from ...modeling_layers import GradientCheckpointingLayer +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...utils import ModelOutput, can_return_tuple, is_scipy_available, requires_backends + +if is_scipy_available(): + from scipy.optimize import linear_sum_assignment + + +@dataclass +class EomtForUniversalSegmentationOutput(ModelOutput): + r""" + loss (`torch.Tensor`, *optional*): + The computed loss, returned when labels are present. + class_queries_logits (`torch.FloatTensor`): + A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each + query. Note the `+ 1` is needed because we incorporate the null class. + masks_queries_logits (`torch.FloatTensor`): + A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each + query. + last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + Last hidden states (final feature map) of the last layer. + hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`): + Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of + shape `(batch_size, sequence_length, hidden_size)`. Hidden-states all layers of the model. + attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`): + Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length, + sequence_length)`. Self and Cross Attentions weights from transformer decoder. + patch_offsets (`list[torch.Tensor]`, *optional*): + list of tuples indicating the image index and start and end positions of patches for semantic segementation. + """ + + loss: Optional[ms.Tensor] = None + class_queries_logits: Optional[ms.Tensor] = None + masks_queries_logits: Optional[ms.Tensor] = None + last_hidden_state: Optional[ms.Tensor] = None + hidden_states: Optional[tuple[ms.Tensor]] = None + attentions: Optional[tuple[ms.Tensor]] = None + patch_offsets: Optional[list[ms.Tensor]] = None + + +# Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py +def sample_point(input_features: ms.Tensor, point_coordinates: ms.Tensor, add_dim=False, **kwargs) -> ms.Tensor: + """ + A wrapper around `torch.nn.functional.grid_sample` to support 3D point_coordinates tensors. + + Args: + input_features (`torch.Tensor` of shape (batch_size, channels, height, width)): + A tensor that contains features map on a height * width grid + point_coordinates (`torch.Tensor` of shape (batch_size, num_points, 2) or (batch_size, grid_height, grid_width,: + 2)): + A tensor that contains [0, 1] * [0, 1] normalized point coordinates + add_dim (`bool`): + boolean value to keep track of added dimension + + Returns: + point_features (`torch.Tensor` of shape (batch_size, channels, num_points) or (batch_size, channels, + height_grid, width_grid): + A tensor that contains features for points in `point_coordinates`. + """ + if point_coordinates.dim() == 3: + add_dim = True + point_coordinates = point_coordinates.unsqueeze(2) + + # use mint.nn.functional.grid_sample to get features for points in `point_coordinates` via bilinear interpolation + point_features = mint.nn.functional.grid_sample(input_features.float(), 2.0 * point_coordinates - 1.0, **kwargs).to( + input_features.dtype + ) + if add_dim: + point_features = point_features.squeeze(3) + + return point_features + + +def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor: + """ + A pair wise version of the dice loss, see `dice_loss` for usage. + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + + Returns: + `torch.Tensor`: The computed loss between each pairs. + """ + inputs = inputs.sigmoid().flatten(1) + numerator = 2 * mint.matmul(inputs, labels.T) + # using broadcasting to get a [num_queries, NUM_CLASSES] matrix + denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :] + loss = 1 - (numerator + 1) / (denominator + 1) + return loss + + +def pair_wise_sigmoid_cross_entropy_loss(inputs: ms.Tensor, labels: ms.Tensor) -> ms.Tensor: + r""" + A pair wise version of the cross entropy loss, see `sigmoid_cross_entropy_loss` for usage. + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + + Returns: + loss (`torch.Tensor`): The computed loss between each pairs. + """ + + height_and_width = inputs.shape[1] + + criterion = mint.nn.BCEWithLogitsLoss(reduction="none") + cross_entropy_loss_pos = criterion(inputs, mint.ones_like(inputs)) + cross_entropy_loss_neg = criterion(inputs, mint.zeros_like(inputs)) + + loss_pos = mint.matmul(cross_entropy_loss_pos / height_and_width, labels.T) + loss_neg = mint.matmul(cross_entropy_loss_neg / height_and_width, (1 - labels).T) + loss = loss_pos + loss_neg + return loss + + +# Adapted from https://github.com/facebookresearch/Eomt/blob/main/eomt/modeling/matcher.py +class EomtHungarianMatcher(nn.Cell): + """This class computes an assignment between the labels and the predictions of the network. + + For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more + predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are + un-matched (and thus treated as non-objects). + """ + + def __init__( + self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0, num_points: int = 12544 + ): + """Creates the matcher + + Params: + cost_class (`float`, *optional*, defaults to 1.0): + Relative weight of the classification error in the matching cost. + cost_mask (`float`, *optional*, defaults to 1.0): + This is the relative weight of the focal loss of the binary mask in the matching cost. + cost_dice (`float`, *optional*, defaults to 1.0): + This is the relative weight of the dice loss of the binary mask in the matching cost. + num_points (`int`, *optional*, defaults to 12544): + No. of points to sample on which the mask loss will be calculated. The same set of K points are + uniformly sampled for all prediction and ground truth masks to construct the cost matrix for bipartite + matching. + """ + super().__init__() + if cost_class == 0 and cost_mask == 0 and cost_dice == 0: + raise ValueError("All costs can't be 0") + + self.num_points = num_points + self.cost_class = cost_class + self.cost_mask = cost_mask + self.cost_dice = cost_dice + + def construct( + self, + masks_queries_logits: ms.Tensor, + class_queries_logits: ms.Tensor, + mask_labels: ms.Tensor, + class_labels: ms.Tensor, + ) -> list[tuple[Tensor]]: + """ + Params: + masks_queries_logits (`torch.Tensor`): + A tensor of dim `batch_size, num_queries, num_labels` with the classification logits. + class_queries_logits (`torch.Tensor`): + A tensor of dim `batch_size, num_queries, height, width` with the predicted masks. + class_labels (`torch.Tensor`): + A tensor of dim `num_target_boxes` (where num_target_boxes is the number of ground-truth objects in the + target) containing the class labels. + mask_labels (`torch.Tensor`): + A tensor of dim `num_target_boxes, height, width` containing the target masks. + + Returns: + matched_indices (`list[tuple[Tensor]]`): A list of size batch_size, containing tuples of (index_i, index_j) + where: + - index_i is the indices of the selected predictions (in order) + - index_j is the indices of the corresponding selected labels (in order) + For each batch element, it holds: + len(index_i) = len(index_j) = min(num_queries, num_target_boxes). + """ + indices: list[tuple[np.array]] = [] + + # iterate through batch size + batch_size = masks_queries_logits.shape[0] + for i in range(batch_size): + pred_probs = class_queries_logits[i].softmax(-1) + pred_mask = masks_queries_logits[i] + + # Compute the classification cost. Contrary to the loss, we don't use the NLL, but approximate it in 1 - proba[target class]. The 1 is a constant that doesn't change the matching, it can be omitted. # noqa: E501 + cost_class = -pred_probs[:, class_labels[i]] + target_mask = mask_labels[i].to(pred_mask.dtype) + target_mask = target_mask[:, None] + pred_mask = pred_mask[:, None] + + # Sample ground truth and predicted masks + point_coordinates = mint.rand(1, self.num_points, 2) + + target_coordinates = point_coordinates.tile((target_mask.shape[0], 1, 1)) + target_mask = sample_point(target_mask, target_coordinates, align_corners=False).squeeze(1) + + pred_coordinates = point_coordinates.tile((pred_mask.shape[0], 1, 1)) + pred_mask = sample_point(pred_mask, pred_coordinates, align_corners=False).squeeze(1) + + # compute the cross entropy loss between each mask pairs -> shape (num_queries, num_labels) + cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask) + # Compute the dice loss between each mask pairs -> shape (num_queries, num_labels) + cost_dice = pair_wise_dice_loss(pred_mask, target_mask) + # final cost matrix + cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice + # eliminate infinite values in cost_matrix to avoid the error ``ValueError: cost matrix is infeasible`` + cost_matrix = mint.minimum(cost_matrix, ms.tensor(1e10)) + cost_matrix = mint.maximum(cost_matrix, ms.tensor(-1e10)) + cost_matrix = mint.nan_to_num(cost_matrix, 0) + # do the assignment using the hungarian algorithm in scipy + assigned_indices: tuple[np.array] = linear_sum_assignment(cost_matrix) + indices.append(assigned_indices) + + # It could be stacked in one tensor + matched_indices = [(ms.tensor(i, dtype=ms.int64), ms.tensor(j, dtype=ms.int64)) for i, j in indices] + return matched_indices + + +def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor: + r""" + Compute the DICE loss, similar to generalized IOU for masks as follows: + + $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$ + + In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow + + $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$ + + Args: + inputs (`torch.Tensor`): + A tensor representing a mask. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + num_masks (`int`): + The number of masks present in the current batch, used for normalization. + + Returns: + `torch.Tensor`: The computed loss. + """ + probs = inputs.sigmoid().flatten(1) + numerator = 2 * (probs * labels).sum(-1) + denominator = probs.sum(-1) + labels.sum(-1) + loss = 1 - (numerator + 1) / (denominator + 1) + loss = loss.sum() / num_masks + return loss + + +def sigmoid_cross_entropy_loss(inputs: ms.Tensor, labels: ms.Tensor, num_masks: int) -> ms.Tensor: + r""" + Args: + inputs (`torch.Tensor`): + A float tensor of arbitrary shape. + labels (`torch.Tensor`): + A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs + (0 for the negative class and 1 for the positive class). + + Returns: + loss (`torch.Tensor`): The computed loss. + """ + criterion = mint.nn.BCEWithLogitsLoss(reduction="none") + cross_entropy_loss = criterion(inputs, labels) + + loss = cross_entropy_loss.mean(1).sum() / num_masks + return loss + + +# Adapted from https://github.com/facebookresearch/Eomt/blob/main/eomt/modeling/criterion.py +class EomtLoss(nn.Cell): + def __init__(self, config: EomtConfig, weight_dict: dict[str, float]): + """ + The Eomt Loss. The loss is computed very similar to DETR. The process happens in two steps: 1) we + compute hungarian assignment between ground truth masks and the outputs of the model 2) we supervise each pair + of matched ground-truth / prediction (supervise class and mask) + + Args: + config (`EomtConfig`): + The configuration for Eomt model also containing loss calculation specific parameters. + weight_dict (`dict[str, float]`): + A dictionary of weights to be applied to the different losses. + """ + super().__init__() + requires_backends(self, ["scipy"]) + self.num_labels = config.num_labels + self.weight_dict = weight_dict + + # Weight to apply to the null class + self.eos_coef = config.no_object_weight + empty_weight = mint.ones(self.num_labels + 1) + empty_weight[-1] = self.eos_coef + self.empty_weight = ms.Parameter(empty_weight, requires_grad=False, name="empty_weight") + + # pointwise mask loss parameters + self.num_points = config.train_num_points + self.oversample_ratio = config.oversample_ratio + self.importance_sample_ratio = config.importance_sample_ratio + + self.matcher = EomtHungarianMatcher( + cost_class=config.class_weight, + cost_dice=config.dice_weight, + cost_mask=config.mask_weight, + num_points=self.num_points, + ) + + def _max_by_axis(self, sizes: list[list[int]]) -> list[int]: + maxes = sizes[0] + for sublist in sizes[1:]: + for index, item in enumerate(sublist): + maxes[index] = max(maxes[index], item) + return maxes + + # Adapted from nested_tensor_from_tensor_list() in original implementation + def _pad_images_to_max_in_batch(self, tensors: list[Tensor]) -> tuple[Tensor, Tensor]: + # get the maximum size in the batch + max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors]) + # compute final size + batch_shape = [len(tensors)] + max_size + batch_size, _, height, width = batch_shape + dtype = tensors[0].dtype + padded_tensors = mint.zeros(batch_shape, dtype=dtype) + padding_masks = mint.ones((batch_size, height, width), dtype=ms.bool_) + # pad the tensors to the size of the biggest one + for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks): + padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor) + padding_mask[: tensor.shape[1], : tensor.shape[2]] = False + + return padded_tensors, padding_masks + + def loss_labels( + self, class_queries_logits: Tensor, class_labels: list[Tensor], indices: tuple[np.array] + ) -> dict[str, Tensor]: + """Compute the losses related to the labels using cross entropy. + + Args: + class_queries_logits (`torch.Tensor`): + A tensor of shape `batch_size, num_queries, num_labels` + class_labels (`list[torch.Tensor]`): + List of class labels of shape `(labels)`. + indices (`tuple[np.array])`: + The indices computed by the Hungarian matcher. + + Returns: + `dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key: + - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels. + """ + pred_logits = class_queries_logits + batch_size, num_queries, _ = pred_logits.shape + criterion = mint.nn.CrossEntropyLoss(weight=self.empty_weight) + idx = self._get_predictions_permutation_indices(indices) # shape of (batch_size, num_queries) + target_classes_o = mint.cat( + [target[j] for target, (_, j) in zip(class_labels, indices)] + ) # shape of (batch_size, num_queries) + target_classes = mint.full((batch_size, num_queries), fill_value=self.num_labels, dtype=ms.int64) + target_classes[idx] = target_classes_o + # Permute target_classes (batch_size, num_queries, num_labels) -> (batch_size, num_labels, num_queries) + pred_logits_transposed = pred_logits.swapaxes(1, 2) + loss_ce = criterion(pred_logits_transposed, target_classes) + losses = {"loss_cross_entropy": loss_ce} + return losses + + def loss_masks( + self, + masks_queries_logits: ms.Tensor, + mask_labels: list[ms.Tensor], + indices: tuple[np.array], + num_masks: int, + ) -> dict[str, ms.Tensor]: + """Compute the losses related to the masks using sigmoid_cross_entropy_loss and dice loss. + + Args: + masks_queries_logits (`torch.Tensor`): + A tensor of shape `(batch_size, num_queries, height, width)`. + mask_labels (`torch.Tensor`): + List of mask labels of shape `(labels, height, width)`. + indices (`tuple[np.array])`: + The indices computed by the Hungarian matcher. + num_masks (`int)`: + The number of masks, used for normalization. + + Returns: + losses (`dict[str, Tensor]`): A dict of `torch.Tensor` containing two keys: + - **loss_mask** -- The loss computed using sigmoid cross entropy loss on the predicted and ground truth. + masks. + - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth, + masks. + """ + src_idx = self._get_predictions_permutation_indices(indices) + tgt_idx = self._get_targets_permutation_indices(indices) + # shape (batch_size * num_queries, height, width) + pred_masks = masks_queries_logits[src_idx] + # shape (batch_size, num_queries, height, width) + # pad all and stack the targets to the num_labels dimension + target_masks, _ = self._pad_images_to_max_in_batch(mask_labels) + target_masks = target_masks[tgt_idx] + + # No need to upsample predictions as we are using normalized coordinates + pred_masks = pred_masks[:, None] + target_masks = target_masks[:, None] + + # Sample point coordinates + point_coordinates = self.sample_points_using_uncertainty( + pred_masks, + lambda logits: self.calculate_uncertainty(logits), + self.num_points, + self.oversample_ratio, + self.importance_sample_ratio, + ) + + point_labels = sample_point(target_masks, point_coordinates, align_corners=False).squeeze(1) + + point_logits = sample_point(pred_masks, point_coordinates, align_corners=False).squeeze(1) + + losses = { + "loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks), + "loss_dice": dice_loss(point_logits, point_labels, num_masks), + } + + del pred_masks + del target_masks + return losses + + def _get_predictions_permutation_indices(self, indices): + # Permute predictions following indices + batch_indices = mint.cat([mint.full_like(src, i) for i, (src, _) in enumerate(indices)]) + predictions_indices = mint.cat([src for (src, _) in indices]) + return batch_indices, predictions_indices + + def _get_targets_permutation_indices(self, indices): + # Permute labels following indices + batch_indices = mint.cat([mint.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)]) + target_indices = mint.cat([tgt for (_, tgt) in indices]) + return batch_indices, target_indices + + def calculate_uncertainty(self, logits: ms.Tensor) -> ms.Tensor: + """ + In Eomt paper, uncertainty is estimated as L1 distance between 0.0 and the logit prediction in 'logits' + for the foreground class in `classes`. + + Args: + logits (`torch.Tensor`): + A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is: + the number of foreground classes. The values are logits. + + Returns: + scores (`torch.Tensor`): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most + uncertain locations having the highest uncertainty score. + """ + uncertainty_scores = -(mint.abs(logits)) + return uncertainty_scores + + def sample_points_using_uncertainty( + self, + logits: ms.Tensor, + uncertainty_function, + num_points: int, + oversample_ratio: int, + importance_sample_ratio: float, + ) -> ms.Tensor: + """ + This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. The + uncertainty is calculated for each point using the passed `uncertainty function` that takes points logit + prediction as input. + + Args: + logits (`float`): + Logit predictions for P points. + uncertainty_function: + A function that takes logit predictions for P points and returns their uncertainties. + num_points (`int`): + The number of points P to sample. + oversample_ratio (`int`): + Oversampling parameter. + importance_sample_ratio (`float`): + Ratio of points that are sampled via importance sampling. + + Returns: + point_coordinates (`torch.Tensor`): + Coordinates for P sampled points. + """ + + num_boxes = logits.shape[0] + num_points_sampled = int(num_points * oversample_ratio) + + # Get random point coordinates + point_coordinates = mint.rand(num_boxes, num_points_sampled, 2) + # Get sampled prediction value for the point coordinates + point_logits = sample_point(logits, point_coordinates, align_corners=False) + # Calculate the uncertainties based on the sampled prediction values of the points + point_uncertainties = uncertainty_function(point_logits) + + num_uncertain_points = int(importance_sample_ratio * num_points) + num_random_points = num_points - num_uncertain_points + + idx = mint.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1] + shift = num_points_sampled * mint.arange(num_boxes, dtype=ms.int64) + idx += shift[:, None] + point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2) + + if num_random_points > 0: + point_coordinates = mint.cat([point_coordinates, mint.rand(num_boxes, num_random_points, 2)], dim=1) + return point_coordinates + + def construct( + self, + masks_queries_logits: ms.Tensor, + class_queries_logits: ms.Tensor, + mask_labels: list[ms.Tensor], + class_labels: list[ms.Tensor], + auxiliary_predictions: Optional[dict[str, ms.Tensor]] = None, + ) -> dict[str, ms.Tensor]: + """ + This performs the loss computation. + + Args: + masks_queries_logits (`torch.Tensor`): + A tensor of shape `(batch_size, num_queries, height, width)`. + class_queries_logits (`torch.Tensor`): + A tensor of shape `(batch_size, num_queries, num_labels)`. + mask_labels (`torch.Tensor`): + List of mask labels of shape `(labels, height, width)`. + class_labels (`list[torch.Tensor]`): + List of class labels of shape `(labels)`. + auxiliary_predictions (`dict[str, torch.Tensor]`, *optional*): + if `use_auxiliary_loss` was set to `true` in [`EomtConfig`], then it contains the logits from + the inner layers of the EomtMaskedAttentionDecoder. + + Returns: + losses (`dict[str, Tensor]`): A dict of `torch.Tensor` containing three keys: + - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels. + - **loss_mask** -- The loss computed using sigmoid cross_entropy loss on the predicted and ground truth + masks. + - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth + masks. + if `use_auxiliary_loss` was set to `true` in [`EomtConfig`], the dictionary contains additional + losses for each auxiliary predictions. + """ + + # retrieve the matching between the outputs of the last layer and the labels + indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels) + # compute the average number of target masks for normalization purposes + num_masks = self.get_num_masks(class_labels) + # get all the losses + losses: dict[str, Tensor] = { + **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks), + **self.loss_labels(class_queries_logits, class_labels, indices), + } + # in case of auxiliary losses, we repeat this process with the output of each intermediate layer. + if auxiliary_predictions is not None: + for idx, aux_outputs in enumerate(auxiliary_predictions): + masks_queries_logits = aux_outputs["masks_queries_logits"] + class_queries_logits = aux_outputs["class_queries_logits"] + loss_dict = self.construct(masks_queries_logits, class_queries_logits, mask_labels, class_labels) + loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()} + losses.update(loss_dict) + + return losses + + def get_num_masks(self, class_labels: ms.Tensor) -> ms.Tensor: + """ + Computes the average number of target masks across the batch, for normalization purposes. + """ + num_masks = sum([len(classes) for classes in class_labels]) + num_masks = ms.tensor(num_masks, dtype=ms.float32) + world_size = 1 + + num_masks = mint.clamp(num_masks / world_size, min=1) + return num_masks + + +class EomtPatchEmbeddings(nn.Cell): + """ + This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial + `hidden_states` (patch embeddings) of shape `(batch_size, seq_length, hidden_size)` to be consumed by a + Transformer. + """ + + def __init__(self, config): + super().__init__() + image_size, patch_size = config.image_size, config.patch_size + num_channels, hidden_size = config.num_channels, config.hidden_size + + image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size) + patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size) + num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0]) + self.image_size = image_size + self.patch_size = patch_size + self.num_channels = num_channels + self.num_patches = num_patches + + self.projection = mint.nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size) + + def construct(self, pixel_values: ms.Tensor) -> ms.Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + f" Expected {self.num_channels} but got {num_channels}." + ) + embeddings = self.projection(pixel_values).flatten(2).swapaxes(1, 2) + return embeddings + + +class EomtEmbeddings(nn.Cell): + """ + Construct the CLS token, mask token, position and patch embeddings. + """ + + def __init__(self, config: EomtConfig) -> None: + super().__init__() + + self.config = config + self.patch_size = config.patch_size + + self.cls_token = ms.Parameter(mint.randn(1, 1, config.hidden_size), name="cls_token") + self.register_tokens = ms.Parameter( + mint.zeros((1, config.num_register_tokens, config.hidden_size)), name="register_tokens" + ) + + self.patch_embeddings = EomtPatchEmbeddings(config) + num_patches = self.patch_embeddings.num_patches + self.dropout = mint.nn.Dropout(config.hidden_dropout_prob) + self.num_prefix_tokens = 1 + config.num_register_tokens # 1 for [CLS] + self.position_embeddings = mint.nn.Embedding(num_patches, config.hidden_size) + self.position_ids = mint.arange(num_patches).broadcast_to((1, -1)) + + def construct(self, pixel_values: ms.Tensor) -> ms.Tensor: + batch_size, _, _, _ = pixel_values.shape + target_dtype = self.patch_embeddings.projection.weight.dtype + embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype)) + + cls_tokens = self.cls_token.broadcast_to((batch_size, -1, -1)) + register_tokens = self.register_tokens.broadcast_to((batch_size, -1, -1)) + + embeddings = embeddings + self.position_embeddings(self.position_ids) + embeddings = mint.cat([cls_tokens, register_tokens, embeddings], dim=1) + + embeddings = self.dropout(embeddings) + + return embeddings + + +def eager_attention_forward( + module: nn.Cell, + query: ms.Tensor, + key: ms.Tensor, + value: ms.Tensor, + attention_mask: Optional[ms.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs, +): + attn_weights = mint.matmul(query, key.swapaxes(-1, -2)) * scaling + if attention_mask is not None: + attn_weights = attn_weights + attention_mask + + attn_weights = mint.nn.functional.softmax(attn_weights, dim=-1, dtype=ms.float32).to(query.dtype) + attn_weights = mint.nn.functional.dropout(attn_weights, p=dropout, training=module.training) + + attn_output = mint.matmul(attn_weights, value) + attn_output = attn_output.swapaxes(1, 2).contiguous() + + return attn_output, attn_weights + + +class EomtAttention(nn.Cell): + """Multi-headed attention from 'Attention Is All You Need' paper""" + + def __init__(self, config): + super().__init__() + self.config = config + self.embed_dim = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.embed_dim // self.num_heads + if self.head_dim * self.num_heads != self.embed_dim: + raise ValueError( + f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim} and `num_heads`:" + f" {self.num_heads})." + ) + self.scale = self.head_dim**-0.5 + self.dropout = config.attention_dropout + self.is_causal = False + + self.k_proj = mint.nn.Linear(self.embed_dim, self.embed_dim) + self.v_proj = mint.nn.Linear(self.embed_dim, self.embed_dim) + self.q_proj = mint.nn.Linear(self.embed_dim, self.embed_dim) + self.out_proj = mint.nn.Linear(self.embed_dim, self.embed_dim) + + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + **kwargs, + ) -> tuple[ms.Tensor, Optional[ms.Tensor]]: + """Input shape: Batch x Time x Channel""" + + batch_size, seq_length, embed_dim = hidden_states.shape + + queries = self.q_proj(hidden_states) + keys = self.k_proj(hidden_states) + values = self.v_proj(hidden_states) + + queries = queries.view(batch_size, seq_length, self.num_heads, self.head_dim).swapaxes(1, 2) + keys = keys.view(batch_size, seq_length, self.num_heads, self.head_dim).swapaxes(1, 2) + values = values.view(batch_size, seq_length, self.num_heads, self.head_dim).swapaxes(1, 2) + + attention_interface: Callable = eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + queries, + keys, + values, + attention_mask, + is_causal=self.is_causal, + scaling=self.scale, + dropout=0.0 if not self.training else self.dropout, + ) + + attn_output = attn_output.reshape(batch_size, seq_length, embed_dim).contiguous() + attn_output = self.out_proj(attn_output) + + return attn_output, attn_weights + + +class EomtLayerScale(nn.Cell): + def __init__(self, config) -> None: + super().__init__() + self.lambda1 = ms.Parameter(config.layerscale_value * mint.ones(config.hidden_size), name="lambda1") + + def construct(self, hidden_state: ms.Tensor) -> ms.Tensor: + return hidden_state * self.lambda1 + + +def drop_path(input: ms.Tensor, drop_prob: float = 0.0, training: bool = False) -> ms.Tensor: + """ + Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, + however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the + layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the + argument. + """ + if drop_prob == 0.0 or not training: + return input + keep_prob = 1 - drop_prob + shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + mint.rand(shape, dtype=input.dtype) + random_tensor.floor_() # binarize + output = input.div(keep_prob) * random_tensor + return output + + +class EomtDropPath(nn.Cell): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob: Optional[float] = None) -> None: + super().__init__() + self.drop_prob = drop_prob + + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: + return drop_path(hidden_states, self.drop_prob, self.training) + + def extra_repr(self) -> str: + return f"p={self.drop_prob}" + + +class EomtMLP(nn.Cell): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + self.fc1 = mint.nn.Linear(in_features, hidden_features, bias=True) + if isinstance(config.hidden_act, str): + self.activation = ACT2FN[config.hidden_act] + else: + self.activation = config.hidden_act + self.fc2 = mint.nn.Linear(hidden_features, out_features, bias=True) + + def construct(self, hidden_state: ms.Tensor) -> ms.Tensor: + hidden_state = self.fc1(hidden_state) + hidden_state = self.activation(hidden_state) + hidden_state = self.fc2(hidden_state) + return hidden_state + + +class EomtSwiGLUFFN(nn.Cell): + def __init__(self, config) -> None: + super().__init__() + in_features = out_features = config.hidden_size + hidden_features = int(config.hidden_size * config.mlp_ratio) + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + + self.weights_in = mint.nn.Linear(in_features, 2 * hidden_features, bias=True) + self.weights_out = mint.nn.Linear(hidden_features, out_features, bias=True) + + def construct(self, hidden_state: ms.Tensor) -> ms.Tensor: + hidden_state = self.weights_in(hidden_state) + x1, x2 = hidden_state.chunk(2, dim=-1) + hidden = mint.nn.functional.silu(x1) * x2 + return self.weights_out(hidden) + + +class EomtLayer(GradientCheckpointingLayer): + """This corresponds to the Block class in the original implementation.""" + + def __init__(self, config: EomtConfig) -> None: + super().__init__() + + self.norm1 = mint.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + self.attention = EomtAttention(config) + self.layer_scale1 = EomtLayerScale(config) + self.drop_path = EomtDropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else mint.nn.Identity() + + self.norm2 = mint.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + if config.use_swiglu_ffn: + self.mlp = EomtSwiGLUFFN(config) + else: + self.mlp = EomtMLP(config) + self.layer_scale2 = EomtLayerScale(config) + + def construct( + self, + hidden_states: ms.Tensor, + head_mask: Optional[ms.Tensor] = None, + output_attentions: bool = False, + ) -> Union[tuple[ms.Tensor, ms.Tensor], tuple[ms.Tensor]]: + self_attention_outputs = self.attention( + self.norm1(hidden_states), # in Eomt, layernorm is applied before self-attention + head_mask, + output_attentions=output_attentions, + ) + attention_output = self_attention_outputs[0] + + attention_output = self.layer_scale1(attention_output) + outputs = self_attention_outputs[1:] # add self attentions if we output attention weights + + # first residual connection + hidden_states = self.drop_path(attention_output) + hidden_states + + # in Eomt, layernorm is also applied after self-attention + layer_output = self.norm2(hidden_states) + layer_output = self.mlp(layer_output) + layer_output = self.layer_scale2(layer_output) + + # second residual connection + layer_output = self.drop_path(layer_output) + hidden_states + + outputs = (layer_output,) + outputs + + return outputs + + +class EomtLayerNorm2d(mint.nn.LayerNorm): + def __init__(self, num_channels, eps=1e-6, affine=True): + super().__init__(num_channels, eps=eps, elementwise_affine=affine) + + def construct(self, hidden_state: ms.Tensor) -> ms.Tensor: + hidden_state = hidden_state.permute(0, 2, 3, 1) + hidden_state = F.layer_norm(hidden_state, self.normalized_shape, self.weight, self.bias, self.eps) + hidden_state = hidden_state.permute(0, 3, 1, 2) + return hidden_state + + +class EomtScaleLayer(nn.Cell): + def __init__(self, config: EomtConfig): + super().__init__() + hidden_size = config.hidden_size + self.conv1 = mint.nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2) + self.activation = ACT2FN[config.hidden_act] + self.conv2 = mint.nn.Conv2d( + hidden_size, + hidden_size, + kernel_size=3, + padding=1, + groups=hidden_size, + bias=False, + ) + + self.layernorm2d = EomtLayerNorm2d(hidden_size) + + def construct(self, hidden_states: ms.tensor) -> ms.Tensor: + hidden_states = self.conv1(hidden_states) + hidden_states = self.activation(hidden_states) + hidden_states = self.conv2(hidden_states) + hidden_states = self.layernorm2d(hidden_states) + return hidden_states + + +class EomtScaleBlock(nn.Cell): + def __init__(self, config: EomtConfig): + super().__init__() + self.num_blocks = config.num_upscale_blocks + self.block = nn.CellList([EomtScaleLayer(config) for _ in range(self.num_blocks)]) + + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: + for block in self.block: + hidden_states = block(hidden_states) + return hidden_states + + +class EomtMaskHead(nn.Cell): + def __init__(self, config: EomtConfig): + super().__init__() + + hidden_size = config.hidden_size + self.fc1 = mint.nn.Linear(hidden_size, hidden_size) + self.fc2 = mint.nn.Linear(hidden_size, hidden_size) + self.fc3 = mint.nn.Linear(hidden_size, hidden_size) + self.activation = ACT2FN[config.hidden_act] + + def construct(self, hidden_states: ms.Tensor) -> ms.Tensor: + hidden_states = self.activation(self.fc1(hidden_states)) + hidden_states = self.activation(self.fc2(hidden_states)) + hidden_states = self.fc3(hidden_states) + return hidden_states + + +class EomtPreTrainedModel(PreTrainedModel): + """ + An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained + models. + """ + + config: EomtConfig + base_model_prefix = "eomt" + main_input_name = "pixel_values" + supports_gradient_checkpointing = False + _no_split_modules = ["EomtLayer"] + _supports_sdpa = True + _supports_flash_attn = True + + def _init_weights(self, module: nn.Cell) -> None: + pass + + +class EomtForUniversalSegmentation(EomtPreTrainedModel): + main_input_name = "pixel_values" + + def __init__(self, config: EomtConfig) -> None: + super().__init__(config) + self.config = config + self.num_hidden_layers = config.num_hidden_layers + self.embeddings = EomtEmbeddings(config) + self.layernorm = mint.nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) + + self.query = mint.nn.Embedding(config.num_queries, config.hidden_size) + self.layers = nn.CellList([EomtLayer(config) for _ in range(config.num_hidden_layers)]) + + self.upscale_block = EomtScaleBlock(config) + self.mask_head = EomtMaskHead(config) + + self.class_predictor = mint.nn.Linear(config.hidden_size, config.num_labels + 1) + + self.grid_size = (config.image_size // config.patch_size, config.image_size // config.patch_size) + self.weight_dict: dict[str, float] = { + "loss_cross_entropy": config.class_weight, + "loss_mask": config.mask_weight, + "loss_dice": config.dice_weight, + } + + self.criterion = EomtLoss(config=config, weight_dict=self.weight_dict) + + self.attn_mask_probs = ms.Parameter(mint.ones(config.num_blocks), requires_grad=False, name="attn_mask_probs") + + self.post_init() + + def get_loss_dict( + self, + masks_queries_logits: Tensor, + class_queries_logits: Tensor, + mask_labels: Tensor, + class_labels: Tensor, + auxiliary_predictions: dict[str, Tensor], + ) -> dict[str, Tensor]: + loss_dict: dict[str, Tensor] = self.criterion( + masks_queries_logits=masks_queries_logits, + class_queries_logits=class_queries_logits, + mask_labels=mask_labels, + class_labels=class_labels, + auxiliary_predictions=auxiliary_predictions, + ) + + # weight each loss by `self.weight_dict[]` including auxiliary losses + for key, weight in self.weight_dict.items(): + for loss_key, loss in loss_dict.items(): + if key in loss_key: + loss *= weight + + return loss_dict + + def get_loss(self, loss_dict: dict[str, Tensor]) -> Tensor: + return sum(loss_dict.values()) + + @can_return_tuple + def construct( + self, + pixel_values: Tensor, + mask_labels: Optional[list[Tensor]] = None, + class_labels: Optional[list[Tensor]] = None, + output_hidden_states: Optional[bool] = None, + output_attentions: Optional[bool] = None, + patch_offsets: Optional[list[Tensor]] = None, + ) -> EomtForUniversalSegmentationOutput: + r""" + mask_labels (`list[torch.Tensor]`, *optional*): + list of mask labels of shape `(num_labels, height, width)` to be fed to a model + class_labels (`list[torch.LongTensor]`, *optional*): + list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the + labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`. + patch_offsets (`list[torch.Tensor]`, *optional*): + list of tuples indicating the image index and start and end positions of patches for semantic segementation. + """ + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + + all_hidden_states = () if output_hidden_states else None + all_attentions = () if output_attentions else None + + masks_queries_logits_per_layer, class_queries_logits_per_layer = (), () + attention_mask = None + + if pixel_values is None: + raise ValueError("You have to specify pixel_values") + + hidden_states = self.embeddings(pixel_values) + + for idx, layer_module in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + if idx == self.num_hidden_layers - self.config.num_blocks: + query = self.query.weight[None, :, :].broadcast_to((hidden_states.shape[0], -1, -1)) + hidden_states = mint.cat((query, hidden_states), dim=1) + + if idx >= self.num_hidden_layers - self.config.num_blocks and ( + self.training or self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks] > 0 + ): + norm_hidden_states = self.layernorm(hidden_states) + masks_queries_logits, class_queries_logits = self.predict(norm_hidden_states) + + masks_queries_logits_per_layer += (masks_queries_logits,) + class_queries_logits_per_layer += (class_queries_logits,) + + attention_mask = mint.ones( + (hidden_states.shape[0], hidden_states.shape[1], hidden_states.shape[1]), dtype=ms.bool_ + ) + + interpolated_logits = F.interpolate(masks_queries_logits, size=self.grid_size, mode="bilinear") + interpolated_logits = interpolated_logits.view( + interpolated_logits.shape[0], interpolated_logits.shape[1], -1 + ) + + num_query_tokens = self.config.num_queries + encoder_start_tokens = num_query_tokens + self.embeddings.num_prefix_tokens + + # Set attention mask for queries to focus on encoder tokens based on interpolated logits + attention_mask[:, :num_query_tokens, encoder_start_tokens:] = interpolated_logits > 0 + + # Disable attention mask for random query tokens. + attention_mask = self._disable_attention_mask( + attention_mask, + prob=self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks], + num_query_tokens=num_query_tokens, + encoder_start_tokens=encoder_start_tokens, + ) + + # Expand attention mask to 4d mask. + attention_mask = attention_mask[:, None, ...].broadcast_to( + (-1, self.config.num_attention_heads, -1, -1) + ) + attention_mask = attention_mask.float().masked_fill(~attention_mask, -1e9) + + layer_outputs = layer_module(hidden_states, attention_mask, output_attentions) + hidden_states = layer_outputs[0] + + if output_attentions: + all_attentions += (layer_outputs[1],) + + sequence_output = self.layernorm(hidden_states) + if output_hidden_states: + all_hidden_states += (sequence_output,) + + masks_queries_logits, class_queries_logits = self.predict(sequence_output) + masks_queries_logits_per_layer += (masks_queries_logits,) + class_queries_logits_per_layer += (class_queries_logits,) + + loss = None + if mask_labels is not None and class_labels is not None: + loss = 0.0 + for masks_queries_logits, class_queries_logits in zip( + masks_queries_logits_per_layer, class_queries_logits_per_layer + ): + loss_dict = self.get_loss_dict( + masks_queries_logits=masks_queries_logits, + class_queries_logits=class_queries_logits, + mask_labels=mask_labels, + class_labels=class_labels, + auxiliary_predictions=None, + ) + loss += self.get_loss(loss_dict) + + return EomtForUniversalSegmentationOutput( + loss=loss, + masks_queries_logits=masks_queries_logits, + class_queries_logits=class_queries_logits, + last_hidden_state=sequence_output, + hidden_states=all_hidden_states, + attentions=all_attentions, + patch_offsets=patch_offsets, + ) + + def get_input_embeddings(self): + return self.embeddings.patch_embeddings + + def predict(self, logits: ms.Tensor): + query_tokens = logits[:, : self.config.num_queries, :] + class_logits = self.class_predictor(query_tokens) + + prefix_tokens = logits[:, self.config.num_queries + self.embeddings.num_prefix_tokens :, :] + prefix_tokens = prefix_tokens.swapaxes(1, 2) + + prefix_tokens = prefix_tokens.reshape(prefix_tokens.shape[0], -1, *self.grid_size) + + query_tokens = self.mask_head(query_tokens) + prefix_tokens = self.upscale_block(prefix_tokens) + + mask_logits = mint.einsum("bqc, bchw -> bqhw", query_tokens, prefix_tokens) + + return mask_logits, class_logits + + @staticmethod + def _disable_attention_mask(attn_mask, prob, num_query_tokens, encoder_start_tokens): + if prob < 1: + # Generate random queries to disable based on the probs + random_queries = mint.rand(attn_mask.shape[0], num_query_tokens) > prob + + # Disable attention to the query tokens, considering the prefix tokens + attn_mask[:, :num_query_tokens, encoder_start_tokens:][random_queries] = 1 + + return attn_mask + + +__all__ = ["EomtPreTrainedModel", "EomtForUniversalSegmentation"] diff --git a/mindone/transformers/models/timesfm/__init__.py b/mindone/transformers/models/timesfm/__init__.py new file mode 100644 index 0000000000..0bac1b9cbb --- /dev/null +++ b/mindone/transformers/models/timesfm/__init__.py @@ -0,0 +1,17 @@ +# Copyright 2025 The HuggingFace Team. All rights reserved. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .modeling_timesfm import * diff --git a/mindone/transformers/models/timesfm/modeling_timesfm.py b/mindone/transformers/models/timesfm/modeling_timesfm.py new file mode 100644 index 0000000000..8ba19d187d --- /dev/null +++ b/mindone/transformers/models/timesfm/modeling_timesfm.py @@ -0,0 +1,819 @@ +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# This file was automatically generated from src/transformers/models/timesfm/modular_timesfm.py. +# Do NOT edit this file manually as any edits will be overwritten by the generation of +# the file from the modular. If any change should be done, please apply the change to the +# modular_timesfm.py file directly. One of our CI enforces this. +# 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨 +# coding=utf-8 +# Copyright 2025 Google LLC and HuggingFace Inc. team. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from collections.abc import Sequence +from dataclasses import dataclass +from typing import Callable, Optional, Union + +from transformers.models.timesfm.configuration_timesfm import TimesFmConfig + +import mindspore as ms +import mindspore.mint.nn.functional as F +from mindspore import mint, nn, ops + +from ...mindspore_adapter import dtype_to_min +from ...modeling_flash_attention_utils import FlashAttentionKwargs +from ...modeling_outputs import BaseModelOutput +from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel +from ...processing_utils import Unpack +from ...utils import TransformersKwargs, can_return_tuple, logging + + +logger = logging.get_logger(__name__) + + +@dataclass +class TimesFmOutput(BaseModelOutput): + r""" + loc (`ms.Tensor` of shape `(batch_size, )`): + The mean of the time series inputs. + scale (`ms.Tensor` of shape `(batch_size,)`): + The scale of the time series inputs. + """ + + loc: Optional[ms.Tensor] = None + scale: Optional[ms.Tensor] = None + + +@dataclass +class TimesFmOutputForPrediction(BaseModelOutput): + r""" + mean_predictions (`ms.Tensor` of shape `(batch_size, sequence_length)`): + The mean predictions of the time series. + full_predictions (`ms.Tensor` of shape `(batch_size, sequence_length)`): + The full predictions of the time series including the mean and the quantiles. + loss (`ms.Tensor` of shape `(1,)`, *optional*, returned when `future_values` is provided): + The loss of the TimesFM model. + """ + + mean_predictions: Optional[ms.Tensor] = None + full_predictions: Optional[ms.Tensor] = None + loss: Optional[Union[ms.Tensor, float]] = None + + +class TimesFmMLP(nn.Cell): + """Pax MLP in mindspore.""" + + def __init__(self, config: TimesFmConfig): + super().__init__() + hidden_size = config.hidden_size + intermediate_size = config.intermediate_size + + self.gate_proj = mint.nn.Linear(hidden_size, intermediate_size) + self.down_proj = mint.nn.Linear(intermediate_size, hidden_size) + self.layer_norm = mint.nn.LayerNorm(normalized_shape=hidden_size, eps=1e-6) + + def construct(self, x, paddings=None): + gate_inp = self.layer_norm(x) + gate = self.gate_proj(gate_inp) + gate = F.relu(gate) + outputs = self.down_proj(gate) + if paddings is not None: + outputs = outputs * (1.0 - paddings[:, :, None]) + return outputs + x + + +class TimesFmResidualBlock(nn.Cell): + """TimesFM residual block.""" + + def __init__(self, input_dims, hidden_dims, output_dims): + super().__init__() + self.input_dims = input_dims + self.hidden_dims = hidden_dims + self.output_dims = output_dims + + self.input_layer = mint.nn.Linear(input_dims, hidden_dims) + self.activation = mint.nn.SiLU() + self.output_layer = mint.nn.Linear(hidden_dims, output_dims) + self.residual_layer = mint.nn.Linear(input_dims, output_dims) + + def construct(self, x): + hidden = self.input_layer(x) + hidden = self.activation(hidden) + output = self.output_layer(hidden) + residual = self.residual_layer(x) + return output + residual + + +class TimesFmRMSNorm(nn.Cell): + def __init__(self, hidden_size, eps=1e-6): + """ + TimesFmRMSNorm is equivalent to T5LayerNorm + """ + super().__init__() + self.weight = ms.Parameter(mint.ones(hidden_size), name="weight") + self.variance_epsilon = eps + + def construct(self, hidden_states): + input_dtype = hidden_states.dtype + hidden_states = hidden_states.to(ms.float32) + variance = hidden_states.pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * mint.rsqrt(variance + self.variance_epsilon) + return self.weight * hidden_states.to(input_dtype) + + def extra_repr(self): + return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}" + + +class TimesFmPositionalEmbedding(nn.Cell): + """Generates position embedding for a given 1-d sequence.""" + + def __init__(self, config: TimesFmConfig): + super().__init__() + min_timescale = config.min_timescale + max_timescale = config.max_timescale + self.embedding_dims = config.hidden_size + + num_timescales = self.embedding_dims // 2 + log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / max(num_timescales - 1, 1) + self.inv_timescales = ms.Parameter( + min_timescale * mint.exp(mint.arange(num_timescales, dtype=ms.float32) * -log_timescale_increment), + requires_grad=False, + name="inv_timescales", + ) + + def construct(self, seq_length=None, position=None): + """Generates a Tensor of sinusoids with different frequencies. + + Args: + seq_length: an optional Python int defining the output sequence length. + if the `position` argument is specified. + position: [B, seq_length], optional position for each token in the + sequence, only required when the sequence is packed. + + Returns: + [B, seqlen, D] if `position` is specified, else [1, seqlen, D] + """ + if position is None and seq_length is None: + raise ValueError("Either position or seq_length must be provided") + + if position is None: + # [1, seqlen] + position = mint.arange(seq_length, dtype=ms.float32).unsqueeze(0) + elif position.ndim != 2: + raise ValueError(f"position must be 2-dimensional, got shape {position.shape}") + + scaled_time = position.view(*position.shape, 1) * self.inv_timescales.view(1, 1, -1) + signal = mint.cat([mint.sin(scaled_time), mint.cos(scaled_time)], dim=2) + + # Padding to ensure correct embedding dimension + signal = F.pad(signal, (0, 0, 0, self.embedding_dims % 2)) + return signal + + +def simple_eager_attention_forward( + module: nn.Cell, + query_states: ms.Tensor, + key_states: ms.Tensor, + value_states: ms.Tensor, + attention_mask: Optional[ms.Tensor], + scaling: float, + dropout: float = 0.0, + **kwargs: Unpack[TransformersKwargs], +): + attn_weights = mint.matmul(query_states, key_states.swapaxes(2, 3)) * scaling + if attention_mask is not None: + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + attn_weights = mint.nn.functional.softmax(attn_weights, dim=-1, dtype=ms.float32).to(query_states.dtype) + attn_weights = mint.nn.functional.dropout(attn_weights, p=dropout, training=module.training) + attn_output = mint.matmul(attn_weights, value_states) + attn_output = attn_output.swapaxes(1, 2).contiguous() + + return attn_output, attn_weights + + +class TimesFmAttention(nn.Cell): + """Implements the attention used in TimesFM. One key difference is that there is _per_dim_scaling of the query.""" + + def __init__(self, config: TimesFmConfig, layer_idx: int): + super().__init__() + self.config = config + self.is_causal = True + self.attention_dropout = config.attention_dropout + self.layer_idx = layer_idx + + self.num_heads = config.num_attention_heads + self.hidden_size = config.hidden_size + self.head_dim = config.head_dim + + self.q_size = self.num_heads * self.head_dim + self.kv_size = self.num_heads * self.head_dim + self.scaling = ms.Parameter(mint.empty((self.head_dim,)), name="scaling") + + self.q_proj = mint.nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.k_proj = mint.nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.v_proj = mint.nn.Linear(self.hidden_size, self.num_heads * self.head_dim) + self.o_proj = mint.nn.Linear(self.num_heads * self.head_dim, self.hidden_size) + + def _scale_query(self, query: ms.Tensor) -> ms.Tensor: + scale = F.softplus(self.scaling).mul(1.442695041 / math.sqrt(self.head_dim)) + return query * scale[None, None, None, :] + + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: Optional[ms.Tensor] = None, + **kwargs: Unpack[FlashAttentionKwargs], + ) -> tuple[ms.Tensor, Optional[ms.Tensor]]: + input_shape = hidden_states.shape[:-1] + hidden_shape = (*input_shape, -1, self.head_dim) + + query_states = self.q_proj(hidden_states).view(hidden_shape).swapaxes(1, 2) + query_states = self._scale_query(query_states) + key_states = self.k_proj(hidden_states).view(hidden_shape).swapaxes(1, 2) + value_states = self.v_proj(hidden_states).view(hidden_shape).swapaxes(1, 2) + + attention_interface: Callable = simple_eager_attention_forward + if self.config._attn_implementation != "eager": + attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation] + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + attention_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=1.0, + **kwargs, + ) + attn_output = attn_output.reshape(*input_shape, -1).contiguous() + attn_output = self.o_proj(attn_output) + return attn_output, attn_weights + + +class TimesFmDecoderLayer(nn.Cell): + """Transformer layer.""" + + def __init__(self, config: TimesFmConfig, layer_idx: int): + super().__init__() + + self.self_attn = TimesFmAttention(config, layer_idx=layer_idx) + self.mlp = TimesFmMLP(config) + self.input_layernorm = TimesFmRMSNorm(config.hidden_size, eps=config.rms_norm_eps) + + def construct( + self, + hidden_states: ms.Tensor, + attention_mask: ms.Tensor, + paddings: ms.Tensor, + output_attentions: bool = False, + ) -> tuple[Optional[ms.Tensor], ms.Tensor]: + # Self Attention + residual = hidden_states + hidden_states = self.input_layernorm(hidden_states) + hidden_states, scores = self.self_attn( + hidden_states=hidden_states, + attention_mask=attention_mask, + output_attentions=output_attentions, + ) + hidden_states = residual + hidden_states + + # MLP + hidden_states = self.mlp(hidden_states, paddings=paddings) + + return scores, hidden_states + + +class TimesFmPreTrainedModel(PreTrainedModel): + config: TimesFmConfig + base_model_prefix = "timesfm" + _no_split_modules = ["TimesFmDecoderLayer"] + main_input_name = "past_values" + _supports_sdpa = True + + def _init_weights(self, module): + pass + + +class TimesFmModel(TimesFmPreTrainedModel): + def __init__(self, config: TimesFmConfig): + super().__init__(config) + + self.config = config + self.input_ff_layer = TimesFmResidualBlock( + input_dims=2 * config.patch_length, + output_dims=config.hidden_size, + hidden_dims=config.intermediate_size, + ) + self.freq_emb = mint.nn.Embedding(num_embeddings=config.freq_size, embedding_dim=config.hidden_size) + self.layers = nn.CellList( + [TimesFmDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)] + ) + if self.config.use_positional_embedding: + self.position_emb = TimesFmPositionalEmbedding(config=config) + + # Initialize weights and apply final processing + self.post_init() + + def _forward_transform( + self, inputs: ms.Tensor, patched_pads: ms.Tensor + ) -> tuple[ms.Tensor, tuple[ms.Tensor, ms.Tensor]]: + """Input is of shape [B, N, P].""" + mu, sigma = self._timesfm_masked_mean_std(inputs, patched_pads) + sigma = mint.where( + sigma < self.config.tolerance, + ms.tensor(1.0, dtype=sigma.dtype), + sigma, + ) + + # Normalize each patch + outputs = (inputs - mu[:, None, None]) / sigma[:, None, None] + outputs = mint.where( + mint.abs(inputs - self.config.pad_val) < self.config.tolerance, + ms.tensor(self.config.pad_val, dtype=outputs.dtype), + outputs, + ) + return outputs, (mu, sigma) + + @can_return_tuple + def construct( + self, + past_values: ms.Tensor, + past_values_padding: ms.Tensor, + freq: ms.Tensor, + output_attentions: bool = False, + output_hidden_states: bool = False, + ) -> TimesFmOutput: + r""" + past_values (`ms.Tensor` of shape `(batch_size, sequence_length)`): + Past values of the time series that serves as input to the model. + past_values_padding (`ms.Tensor` of shape `(batch_size, sequence_length)`): + The padding indicator of the time series. + freq (`ms.Tensor` of shape `(batch_size,)`): + Frequency indices for the time series data. + """ + # Reshape into patches (using view for efficiency) + bsize = past_values.shape[0] + patched_inputs = past_values.view(bsize, -1, self.config.patch_length) + patched_pads = past_values_padding.view(bsize, -1, self.config.patch_length) + + patched_inputs = mint.where( + mint.abs(patched_pads - 1.0) < self.config.tolerance, + ms.tensor(0.0, dtype=patched_inputs.dtype), + patched_inputs, + ) + patched_pads = mint.where( + mint.abs(patched_inputs - self.config.pad_val) < self.config.tolerance, + ms.tensor(1.0, dtype=patched_pads.dtype), + patched_pads, + ) + patched_inputs, stats = self._forward_transform(patched_inputs, patched_pads) + + # B x N x D + patched_inputs = patched_inputs * (1.0 - patched_pads) + concat_inputs = mint.cat([patched_inputs, patched_pads], dim=-1) + model_input = self.input_ff_layer(concat_inputs) + + # A patch should not be padded even if there is at least one zero. + patched_padding = mint.min(patched_pads, dim=-1)[0] # Get the values from the min result + if self.config.use_positional_embedding: + pos_emb = self.position_emb(model_input.shape[1]) + pos_emb = mint.concat([pos_emb] * model_input.shape[0], dim=0) + pos_emb = self._timesfm_shift_padded_seq(patched_padding, pos_emb) + model_input += pos_emb + + f_emb = self.freq_emb(freq) # B x 1 x D + model_input += f_emb + + # Convert paddings to attention mask and combine with causal mask + hidden_states = model_input + attention_mask = self._prepare_4d_attention_mask( + attention_mask=patched_padding, + sequence_length=hidden_states.shape[1], + dtype=hidden_states.dtype, + is_causal=True, + ) + + all_attentions = [] + all_hidden_states = [] + + for layer in self.layers[: self.config.num_hidden_layers]: + scores, hidden_states = layer( + hidden_states=hidden_states, + attention_mask=attention_mask, + paddings=patched_padding, + output_attentions=output_attentions, + ) + if output_attentions: + all_attentions.append(scores) + if output_hidden_states: + all_hidden_states.append(hidden_states) + + if output_hidden_states: + all_hidden_states = [model_input] + all_hidden_states + else: + all_hidden_states = None + + return TimesFmOutput( + last_hidden_state=hidden_states, + hidden_states=all_hidden_states, + attentions=all_attentions if output_attentions else None, + loc=stats[0], + scale=stats[1], + ) + + @staticmethod + def _prepare_4d_attention_mask( + attention_mask: Optional[ms.Tensor], + sequence_length: int, + dtype: ms.Type, + is_causal: bool = True, + ) -> Optional[ms.Tensor]: + """ + Creates 4D attention mask and combines causal and padding masks if needed. + + Args: + attention_mask: Optional tensor of shape (batch_size, seq_length) containing padding mask + sequence_length: Length of the sequence + dtype: Data type of the mask + device: Device of the mask + is_causal: Whether to apply causal masking + + Returns: + 4D attention mask of shape (batch_size, 1, seq_length, seq_length) + """ + # Get minimum value for the dtype + min_value = dtype_to_min(dtype) + + # Handle padding mask + if attention_mask is not None: + # Convert 2D padding mask to 4D attention mask + attention_mask = attention_mask.view(attention_mask.shape[0], 1, 1, -1) + attention_mask = attention_mask * min_value + + # Create causal mask if needed + if is_causal: + causal_mask = mint.triu(mint.ones((sequence_length, sequence_length), dtype=dtype) * min_value, diagonal=1) + causal_mask = causal_mask.view(1, 1, sequence_length, sequence_length) + + # Combine with padding mask if it exists + if attention_mask is not None: + attention_mask = mint.minimum(attention_mask, causal_mask) + else: + attention_mask = causal_mask + + return attention_mask + + @staticmethod + def _timesfm_masked_mean_std(inputs: ms.Tensor, padding: ms.Tensor) -> tuple[ms.Tensor, ms.Tensor]: + """Calculates mean and standard deviation of `inputs` across axis 1. + + It excludes values where `padding` is 1. + + Args: + inputs: A MindSpore tensor of shape [b, n, p]. + padding: A MindSpore tensor of shape [b, n, p] with values 0 or 1. + + Returns: + A tuple containing the mean and standard deviation. + We return the statistics of the first patch with more than three non-padded values. + """ + + # Selecting the first patch with more than 3 unpadded values. + def _get_patch_index(arr: ms.Tensor): + indices = mint.argmax((arr >= 3).to(ms.int32), dim=1) + row_sum = (arr >= 3).to(ms.int32).sum(dim=1) + return mint.where(row_sum == 0, arr.shape[1] - 1, indices) + + pad_sum = mint.sum(1 - padding, dim=2) + patch_indices = _get_patch_index(pad_sum) + bidxs = mint.arange(inputs.shape[0]) + + arr = inputs[bidxs, patch_indices, :] + pad = padding[bidxs, patch_indices, :] + + # Create a mask where padding is 0 + mask = 1 - pad + + # Calculate the number of valid elements + num_valid_elements = mint.sum(mask, dim=1) + num_valid_elements = mint.where( + num_valid_elements == 0, ms.tensor(1, dtype=num_valid_elements.dtype), num_valid_elements + ) + + # Calculate the masked sum and squared sum + masked_sum = mint.sum(arr * mask, dim=1) + masked_squared_sum = mint.sum((arr * mask) ** 2, dim=1) + + # Calculate the masked mean and standard deviation + masked_mean = masked_sum / num_valid_elements + masked_var = masked_squared_sum / num_valid_elements - masked_mean**2 + masked_var = mint.where(masked_var < 0.0, ms.tensor(0.0, dtype=masked_var.dtype), masked_var) + masked_std = mint.sqrt(masked_var) + + return masked_mean, masked_std + + @staticmethod + def _timesfm_shift_padded_seq(mask: ms.Tensor, seq: ms.Tensor) -> ms.Tensor: + """Shifts rows of seq based on the first 0 in each row of the mask. + + Args: + mask: mask tensor of shape [B, N] + seq: seq tensor of shape [B, N, P] + + Returns: + The shifted sequence. + """ + batch_size, num_seq, feature_dim = seq.shape + + new_mask: ms.Tensor = mask == 0 + + # Use argmax to find the first True value in each row + indices = new_mask.to(ms.int32).argmax(dim=1) + + # Handle rows with all zeros + indices[~new_mask.any(dim=1)] = -1 + + # Create index ranges for each sequence in the batch + idx_range = mint.arange(num_seq).view(1, -1, 1).broadcast_to((batch_size, -1, feature_dim)) + + # Calculate shifted indices for each element in each sequence + shifted_idx = (idx_range - indices[:, None, None]) % num_seq + + # Gather values from seq using shifted indices + shifted_seq = seq.gather(1, shifted_idx) + + return shifted_seq + + +class TimesFmModelForPrediction(TimesFmPreTrainedModel): + """TimesFM model for quantile and mean prediction.""" + + def __init__(self, config: TimesFmConfig): + super().__init__(config) + + self.config = config + self.context_len = config.context_length + self.horizon_len = config.horizon_length + + self.decoder = TimesFmModel(config) + + # quantile and mean output + self.horizon_ff_layer = TimesFmResidualBlock( + input_dims=config.hidden_size, + output_dims=config.horizon_length * (1 + len(config.quantiles)), + hidden_dims=config.intermediate_size, + ) + + # Initialize weights and apply final processing + self.post_init() + + def _preprocess(self, inputs: Sequence[ms.Tensor], freq: Sequence[int]) -> tuple[ms.Tensor, ms.Tensor, ms.Tensor]: + """Formats and pads raw inputs to feed into the model. + + This function both pads each time series to match the context length, and + pads the inputs to meet the SPMD shape requirement. + + Args: + inputs: A list of 1d Tensors. Each Tensor is the context time series of + a single forecast task. + freq: list of frequencies + + Returns: + A tuple of: + - the padded input time series to meet the model required context. + - the padding indicator. + - the number of padded examples for SPMD so that each core has the same + number (a multiple of `batch_size`) of examples. + """ + input_ts, input_padding, inp_freq = [], [], [] + + for i, ts in enumerate(inputs): + input_len = ts.shape[0] + padding = mint.zeros(input_len + self.horizon_len, dtype=ts.dtype) + if input_len < self.context_len: + num_front_pad = self.context_len - input_len + ts = mint.cat([mint.zeros(num_front_pad, dtype=ts.dtype), ts], dim=0) + padding = mint.cat([mint.ones(num_front_pad, dtype=ts.dtype), padding], dim=0) + elif input_len > self.context_len: + ts = ts[-self.context_len :] + padding = padding[-(self.context_len + self.horizon_len) :] + + input_ts.append(ts) + input_padding.append(padding) + inp_freq.append(freq[i]) + + return ( + mint.stack(input_ts, dim=0), + mint.stack(input_padding, dim=0), + ms.tensor(inp_freq, dtype=ms.int32).reshape(-1, 1), + ) + + def _postprocess_output(self, model_output: ms.Tensor, stats: tuple[ms.Tensor, ms.Tensor]) -> ms.Tensor: + """Postprocess output of stacked transformer.""" + + # B x N x (H.Q) + output_ts = self.horizon_ff_layer(model_output) + + # Reshape using view + b, n, _ = output_ts.shape + output_ts = output_ts.view(b, n, self.config.horizon_length, len(self.config.quantiles) + 1) + + mu, sigma = stats + return output_ts * sigma[:, None, None, None] + mu[:, None, None, None] + + def _quantile_loss(self, predictions: ms.Tensor, targets: ms.Tensor) -> ms.Tensor: + losses = [] + for i, q in enumerate(self.config.quantiles): + errors = targets - predictions[..., i] + loss = mint.max((q - 1) * errors, q * errors) + losses.append(loss.mean()) + return mint.stack(losses).mean() + + @can_return_tuple + def construct( + self, + past_values: Sequence[ms.Tensor], + freq: Optional[Sequence[Union[ms.Tensor, int]]] = None, + window_size: Optional[int] = None, + future_values: Optional[ms.Tensor] = None, + forecast_context_len: Optional[int] = None, + return_forecast_on_context: bool = False, + truncate_negative: bool = False, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + ) -> TimesFmOutputForPrediction: + r""" + past_values (`ms.Tensor` of shape `(batch_size, sequence_length)`): + Past values of the time series that serves as input to the model. + freq (`ms.Tensor` of shape `(batch_size,)`): + Frequency indices for the time series data. + window_size (`int`, *optional*): + Window size of trend + residual decomposition. If None then we do not do decomposition. + future_values (`ms.Tensor`, *optional*): + Optional future time series values to be used for loss computation. + forecast_context_len (`int`, *optional*): + Optional max context length. + return_forecast_on_context (`bool`, *optional*): + True to return the forecast on the context when available, i.e. after the first input patch. + truncate_negative (`bool`, *optional*): + Truncate to only non-negative values if any of the contexts have non-negative values, + otherwise do nothing. + output_attentions (`bool`, *optional*): + Whether to output the attentions. + output_hidden_states (`bool`, *optional*): + Whether to output the hidden states. + + Example: + + ```python + >>> from mindone.transformers import TimesFmModelForPrediction + >>> import mindspore as ms + >>> from mindspore import mint + + >>> model = TimesFmModelForPrediction.from_pretrained("google/timesfm-2.0-500m-pytorch") + + >>> forecast_input = [mint.linspace(0, 20, 100).sin(), mint.linspace(0, 20, 200).sin(), mint.linspace(0, 20, 400).sin()] + >>> frequency_input = ms.tensor([0, 1, 2], dtype=ms.int64) + + >>> # Generate + >>> outputs = model(past_values=forecast_input, freq=frequency_input, return_dict=True) + >>> point_forecast_conv = outputs.mean_predictions + >>> quantile_forecast_conv = outputs.full_predictions + ``` + """ + if forecast_context_len is None: + fcontext_len = self.context_len + else: + fcontext_len = forecast_context_len + + # Truncate inputs to forecast_context_len + inputs = [ts[-fcontext_len:] for ts in past_values] + inp_min = mint.min(mint.stack([mint.min(ts) for ts in inputs])) + + if window_size is not None: + new_inputs = [] + new_freqs = [] + for i, ts in enumerate(inputs): + new_inputs.extend(self._timesfm_moving_average(ts, window_size)) + if freq is not None: + new_freqs.extend([freq[i]] * 2) + inputs = new_inputs + if freq is not None: + freq = new_freqs + + if freq is None: + logger.info("No frequency provided via `freq`. Default to high (0).") + freq = [0] * len(inputs) + + if output_attentions is None: + output_attentions = self.config.output_attentions + if output_hidden_states is None: + output_hidden_states = self.config.output_hidden_states + + input_ts, input_padding, inp_freq = self._preprocess(inputs, freq) + + final_out = input_ts + context_len = final_out.shape[1] + full_outputs = [] + + if input_padding.shape[1] != final_out.shape[1] + self.horizon_len: + raise ValueError( + "Length of paddings must match length of input + horizon_len:" + f" {input_padding.shape[1]} != {final_out.shape[1]} + {self.horizon_len}" + ) + output_patch_len = self.config.horizon_length + + num_decode_patches = (self.horizon_len + output_patch_len - 1) // output_patch_len + for step_index in range(num_decode_patches): + current_padding = input_padding[:, 0 : final_out.shape[1]] + input_ts = final_out[:, -fcontext_len:] + input_padding = current_padding[:, -fcontext_len:] + decoder_output = self.decoder( + past_values=input_ts, + past_values_padding=input_padding, + freq=inp_freq, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + ) + fprop_outputs = self._postprocess_output( + decoder_output.last_hidden_state, + (decoder_output.loc, decoder_output.scale), + ) + + if return_forecast_on_context and step_index == 0: + # For the first decodings step, collect the model forecast on the + # context except the unavailable first input batch forecast. + new_full_ts = fprop_outputs[:, :-1, : self.config.patch_length, :] + # We have to use reshape and not view for non-contiguous memory + new_full_ts = new_full_ts.reshape(new_full_ts.shape[0], -1, new_full_ts.shape[3]) + + full_outputs.append(new_full_ts) + + # (full batch, last patch, output_patch_len, index of mean forecast = 0) + new_ts = fprop_outputs[:, -1, :output_patch_len, 0] + new_full_ts = fprop_outputs[:, -1, :output_patch_len, :] + # (full batch, last patch, output_patch_len, all output indices) + full_outputs.append(new_full_ts) + final_out = mint.cat([final_out, new_ts], dim=-1) + + if return_forecast_on_context: + # `full_outputs` indexing starts at after the first input patch. + full_outputs = mint.cat(full_outputs, dim=1)[ + :, : (context_len - self.config.patch_length + self.horizon_len), : + ] + else: + # `full_outputs` indexing starts at the forecast horizon. + full_outputs = mint.cat(full_outputs, dim=1)[:, 0 : self.horizon_len, :] + + mean_outputs = full_outputs[:, :, 0] + if window_size is not None: + mean_outputs = mean_outputs[0::2, ...] + mean_outputs[1::2, ...] + full_outputs = full_outputs[0::2, ...] + full_outputs[1::2, ...] + if inp_min >= 0 and truncate_negative: + mean_outputs = mint.maximum(mean_outputs, 0.0) + full_outputs = mint.maximum(full_outputs, 0.0) + + loss = None + if future_values is not None: + mse_loss = F.mse_loss(mean_outputs, future_values) + quantile_loss = self._quantile_loss(full_outputs[:, :, 1:], future_values) + loss = mse_loss + quantile_loss + + return TimesFmOutputForPrediction( + last_hidden_state=decoder_output.last_hidden_state, + attentions=decoder_output.attentions if output_attentions else None, + hidden_states=decoder_output.hidden_states if output_hidden_states else None, + mean_predictions=mean_outputs, + full_predictions=full_outputs, + loss=loss, + ) + + @staticmethod + def _timesfm_moving_average(arr: ms.Tensor, window_size: int) -> list[ms.Tensor]: + """Calculates the moving average using MindSpore's convolution function.""" + # Pad with zeros to handle initial window positions + arr_padded = F.pad(arr, (window_size - 1, 0), "constant", 0) + # Create a convolution kernel + kernel = mint.ones(window_size, dtype=arr.dtype) / window_size + # Apply convolution to calculate the moving average + smoothed_arr = ops.conv1d(arr_padded.view(1, 1, -1), kernel.view(1, 1, -1)).squeeze() + return [smoothed_arr, arr - smoothed_arr] + + +__all__ = ["TimesFmModelForPrediction", "TimesFmPreTrainedModel", "TimesFmModel"] diff --git a/tests/transformers_tests/models/eomt/__init__.py b/tests/transformers_tests/models/eomt/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/transformers_tests/models/eomt/test_modeling_eomt.py b/tests/transformers_tests/models/eomt/test_modeling_eomt.py new file mode 100644 index 0000000000..625541eb36 --- /dev/null +++ b/tests/transformers_tests/models/eomt/test_modeling_eomt.py @@ -0,0 +1,187 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Testing suite for the MindSpore EoMT model.""" + +import inspect + +import numpy as np +import pytest +import torch +from transformers import EomtConfig + +import mindspore as ms + +from tests.modeling_test_utils import ( + MS_DTYPE_MAPPING, + PT_DTYPE_MAPPING, + compute_diffs, + generalized_parse_args, + get_modules, +) +from tests.transformers_tests.models.modeling_common import floats_numpy + +DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3, "bf16": 5e-2} +MODES = [1] + + +class EomtForUniversalSegmentationTester: + def __init__( + self, + batch_size=2, + is_training=True, + image_size=40, + patch_size=2, + num_queries=5, + num_register_tokens=19, + num_labels=4, + hidden_size=8, + num_attention_heads=2, + num_hidden_layers=4, + ): + self.batch_size = batch_size + self.is_training = is_training + self.num_queries = num_queries + self.image_size = image_size + self.patch_size = patch_size + self.num_labels = num_labels + self.hidden_size = hidden_size + self.num_attention_heads = num_attention_heads + self.num_hidden_layers = num_hidden_layers + self.num_register_tokens = num_register_tokens + + num_patches = (image_size // patch_size) ** 2 + self.seq_length = num_patches + 1 + + def get_config(self): + config = { + "image_size": self.image_size, + "patch_size": self.patch_size, + "num_labels": self.num_labels, + "hidden_size": self.hidden_size, + "num_attention_heads": self.num_attention_heads, + "num_hidden_layers": self.num_hidden_layers, + "num_register_tokens": self.num_register_tokens, + "num_queries": self.num_queries, + "num_blocks": 1, + } + return EomtConfig(**config) + + def prepare_config_and_inputs(self): + pixel_values = floats_numpy([self.batch_size, 3, self.image_size, self.image_size]) + + mask_labels = (np.random.rand(self.batch_size, self.num_labels, self.image_size, self.image_size) > 0.5).astype( + np.float32 + ) + class_labels = (np.random.rand(self.batch_size, self.num_labels) > 0.5).astype(np.int64) + + config = self.get_config() + return config, pixel_values, mask_labels, class_labels + + +model_tester = EomtForUniversalSegmentationTester() +config, pixel_values, mask_labels, class_labels = model_tester.prepare_config_and_inputs() +EOMT_CASES = [ + [ + "EomtForUniversalSegmentation", + "transformers.EomtForUniversalSegmentation", + "mindone.transformers.EomtForUniversalSegmentation", + (config,), + {}, + (pixel_values,), + { + "mask_labels": mask_labels, + "class_labels": class_labels, + }, + { + "last_hidden_state": 3, + }, + ], +] + + +@pytest.mark.parametrize( + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode", + [ + case + + [ + dtype, + ] + + [ + mode, + ] + for case in EOMT_CASES + for dtype in DTYPE_AND_THRESHOLDS.keys() + for mode in MODES + ], +) +def test_named_modules( + name, + pt_module, + ms_module, + init_args, + init_kwargs, + inputs_args, + inputs_kwargs, + outputs_map, + dtype, + mode, +): + ms.set_context(mode=mode) + + ( + pt_model, + ms_model, + pt_dtype, + ms_dtype, + ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + + # set `hidden_dtype` if requiring, for some modules always compute in float + # precision and require specific `hidden_dtype` to cast before return + if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: + pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) + ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) + + with torch.no_grad(): + pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) + ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs) + # print("ms:", ms_outputs) + # print("pt:", pt_outputs) + if outputs_map: + pt_outputs_n = [] + ms_outputs_n = [] + for pt_key, ms_idx in outputs_map.items(): + # print("===map", pt_key, ms_idx) + pt_output = getattr(pt_outputs, pt_key) + ms_output = ms_outputs[ms_idx] + if isinstance(pt_output, (list, tuple)): + pt_outputs_n += list(pt_output) + ms_outputs_n += list(ms_output) + else: + pt_outputs_n.append(pt_output) + ms_outputs_n.append(ms_output) + diffs = compute_diffs(pt_outputs_n, ms_outputs_n) + else: + diffs = compute_diffs(pt_outputs, ms_outputs) + + THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype] + assert (np.array(diffs) < THRESHOLD).all(), ( + f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, " + f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}" + ) diff --git a/tests/transformers_tests/models/timesfm/__init__.py b/tests/transformers_tests/models/timesfm/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/tests/transformers_tests/models/timesfm/test_modeling_timesfm.py b/tests/transformers_tests/models/timesfm/test_modeling_timesfm.py new file mode 100644 index 0000000000..3c890817b5 --- /dev/null +++ b/tests/transformers_tests/models/timesfm/test_modeling_timesfm.py @@ -0,0 +1,207 @@ +# coding=utf-8 +# Copyright 2025 Google LLC and HuggingFace Inc. team. +# +# This code is adapted from https://github.com/huggingface/transformers +# with modifications to run transformers on mindspore. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect + +import numpy as np +import pytest +import torch +from transformers import TimesFmConfig + +import mindspore as ms + +from tests.modeling_test_utils import ( + MS_DTYPE_MAPPING, + PT_DTYPE_MAPPING, + compute_diffs, + generalized_parse_args, + get_modules, +) + +DTYPE_AND_THRESHOLDS = {"fp32": 5e-4, "fp16": 5e-3, "bf16": 5e-2} +MODES = [1] + + +class TimesFmModelTester: + def __init__( + self, + patch_length: int = 32, + context_length: int = 512, + horizon_length: int = 128, + freq_size: int = 3, + num_hidden_layers: int = 1, + hidden_size: int = 16, + intermediate_size: int = 32, + head_dim: int = 8, + num_heads: int = 2, + tolerance: float = 1e-6, + rms_norm_eps: float = 1e-6, + quantiles: list[float] = [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], + pad_val: float = 1123581321.0, + use_positional_embedding: bool = True, + initializer_factor: float = 0.0, + is_training: bool = False, + batch_size: int = 3, + ): + self.patch_length = patch_length + self.context_length = context_length + self.horizon_length = horizon_length + self.quantiles = quantiles + self.pad_val = pad_val + self.freq_size = freq_size + self.hidden_size = hidden_size + self.intermediate_size = intermediate_size + self.head_dim = head_dim + self.num_hidden_layers = num_hidden_layers + self.num_attention_heads = num_heads + self.tolerance = tolerance + self.rms_norm_eps = rms_norm_eps + self.use_positional_embedding = use_positional_embedding + self.initializer_factor = initializer_factor + self.is_training = is_training + self.batch_size = batch_size + + # The size of test input + self.seq_length = context_length // patch_length + self.hidden_size = hidden_size + + def get_config(self): + return TimesFmConfig( + patch_length=self.patch_length, + context_length=self.context_length, + horizon_length=self.horizon_length, + quantiles=self.quantiles, + pad_val=self.pad_val, + freq_size=self.freq_size, + hidden_size=self.hidden_size, + intermediate_size=self.intermediate_size, + head_dim=self.head_dim, + num_hidden_layers=self.num_hidden_layers, + num_attention_heads=self.num_attention_heads, + tolerance=self.tolerance, + rms_norm_eps=self.rms_norm_eps, + use_positional_embedding=self.use_positional_embedding, + initializer_factor=self.initializer_factor, + ) + + def get_pipeline_config(self): + return self.get_config() + + def prepare_config_and_inputs(self): + forecast_input = [ + np.sin(np.linspace(0, 20, 100)).astype(np.float32), + np.cos(np.linspace(0, 20, 100)).astype(np.float32), + np.tan(np.linspace(0, 20, 100)).astype(np.float32), + ] + frequency_input = np.array([0, 1, 2], dtype=np.int64) + + return (self.get_config(), np.stack(forecast_input, axis=0), frequency_input) + + +model_tester = TimesFmModelTester() +config, past_values, freq = model_tester.prepare_config_and_inputs() +TIMESFM_CASES = [ + [ + "TimesFmModelForPrediction", + "transformers.TimesFmModelForPrediction", + "mindone.transformers.TimesFmModelForPrediction", + (config,), + {}, + (past_values,), + { + "freq": freq, + }, + { + "last_hidden_state": 0, + }, + ], +] + + +@pytest.mark.parametrize( + "name,pt_module,ms_module,init_args,init_kwargs,inputs_args,inputs_kwargs,outputs_map,dtype,mode", + [ + case + + [ + dtype, + ] + + [ + mode, + ] + for case in TIMESFM_CASES + for dtype in DTYPE_AND_THRESHOLDS.keys() + for mode in MODES + ], +) +def test_named_modules( + name, + pt_module, + ms_module, + init_args, + init_kwargs, + inputs_args, + inputs_kwargs, + outputs_map, + dtype, + mode, +): + ms.set_context(mode=mode) + + ( + pt_model, + ms_model, + pt_dtype, + ms_dtype, + ) = get_modules(pt_module, ms_module, dtype, *init_args, **init_kwargs) + pt_inputs_args, pt_inputs_kwargs, ms_inputs_args, ms_inputs_kwargs = generalized_parse_args( + pt_dtype, ms_dtype, *inputs_args, **inputs_kwargs + ) + + # set `hidden_dtype` if requiring, for some modules always compute in float + # precision and require specific `hidden_dtype` to cast before return + if "hidden_dtype" in inspect.signature(pt_model.forward).parameters: + pt_inputs_kwargs.update({"hidden_dtype": PT_DTYPE_MAPPING[pt_dtype]}) + ms_inputs_kwargs.update({"hidden_dtype": MS_DTYPE_MAPPING[ms_dtype]}) + + with torch.no_grad(): + pt_outputs = pt_model(*pt_inputs_args, **pt_inputs_kwargs) + ms_outputs = ms_model(*ms_inputs_args, **ms_inputs_kwargs, return_dict=False) + # print("ms:", ms_outputs) + # print("pt:", pt_outputs) + if outputs_map: + pt_outputs_n = [] + ms_outputs_n = [] + for pt_key, ms_idx in outputs_map.items(): + # print("===map", pt_key, ms_idx) + pt_output = getattr(pt_outputs, pt_key) + ms_output = ms_outputs[ms_idx] + if isinstance(pt_output, (list, tuple)): + pt_outputs_n += list(pt_output) + ms_outputs_n += list(ms_output) + else: + pt_outputs_n.append(pt_output) + ms_outputs_n.append(ms_output) + diffs = compute_diffs(pt_outputs_n, ms_outputs_n) + else: + diffs = compute_diffs(pt_outputs, ms_outputs) + + THRESHOLD = DTYPE_AND_THRESHOLDS[ms_dtype] + assert (np.array(diffs) < THRESHOLD).all(), ( + f"ms_dtype: {ms_dtype}, pt_type:{pt_dtype}, " + f"Outputs({np.array(diffs).tolist()}) has diff bigger than {THRESHOLD}" + )