From 9e9446aafe91d3f8b9b33befb68156942d163f1f Mon Sep 17 00:00:00 2001 From: Masaru Kimura Date: Mon, 3 Feb 2025 19:20:02 +0900 Subject: [PATCH 1/6] Add torchvision.transforms.Resize interpolation and antialias. torchvision.transforms.Resize forced nearest interpolation and no antialias, but shouln't. Based on my understanding, original torchvision.transforms.Resize calls like; - torchvision.transforms.Resize - torchvision.transforms.functional.resize - torchvision.transforms._functional_pil.resize - PIL.Image.Image.resize - torchvision.transforms._functional_tensor.resize - torch.nn.functional.interpolate Note, this PR still keeps nearest interpolation and no antialias by default for torchvision.transforms.Resize to maximize compatibility for existing code using TorchSharp and make it being incompatible to original torchvision.transforms.Resize default, however, it would be up to the upstream decision. See also; * https://pytorch.org/vision/main/generated/torchvision.transforms.Resize.html * https://pytorch.org/vision/main/generated/torchvision.transforms.functional.resize.html * https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html --- src/Native/LibTorchSharp/THSNN.cpp | 5 ++- src/Native/LibTorchSharp/THSNN.h | 2 +- src/TorchSharp/NN/Vision.cs | 14 ++++-- src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs | 2 +- src/TorchVision/Functional.cs | 22 +++++++--- src/TorchVision/Resize.cs | 43 ++++++++++++++++--- test/TorchSharpTest/NN.cs | 12 ++++++ test/TorchSharpTest/TestTorchVision.cs | 6 +-- 8 files changed, 84 insertions(+), 22 deletions(-) diff --git a/src/Native/LibTorchSharp/THSNN.cpp b/src/Native/LibTorchSharp/THSNN.cpp index f5e9643e7..941399e62 100644 --- a/src/Native/LibTorchSharp/THSNN.cpp +++ b/src/Native/LibTorchSharp/THSNN.cpp @@ -109,6 +109,8 @@ void ApplyInterpolateMode(T& opts, const int8_t mode) opts = opts.mode(torch::kTrilinear); if (mode == 5) opts = opts.mode(torch::kArea); + if (mode == 6) + opts = opts.mode(torch::kNearestExact); } template @@ -176,13 +178,14 @@ Tensor THSNN_affine_grid(const Tensor theta, const int64_t* size, const int size } -EXPORT_API(Tensor) THSNN_interpolate(const Tensor input, const int64_t* size, const int size_len, const double* scale_factor, const int scale_factor_len, const int8_t mode, const int8_t align_corners, const bool recompute_scale_factor, NNAnyModule* outAsAnyModule) +EXPORT_API(Tensor) THSNN_interpolate(const Tensor input, const int64_t* size, const int size_len, const double* scale_factor, const int scale_factor_len, const int8_t mode, const int8_t align_corners, const bool recompute_scale_factor, const bool antialias, NNAnyModule* outAsAnyModule) { auto opts = torch::nn::functional::InterpolateFuncOptions().recompute_scale_factor(recompute_scale_factor); // align_corners -- 0=None, 1=true, 2=false if (align_corners != 0) opts.align_corners(align_corners == 1); ApplyInterpolateMode(opts, mode); + opts.antialias(antialias); if (size_len > 0) { std::vector sizes; diff --git a/src/Native/LibTorchSharp/THSNN.h b/src/Native/LibTorchSharp/THSNN.h index f7af6bd1f..3dab43f90 100644 --- a/src/Native/LibTorchSharp/THSNN.h +++ b/src/Native/LibTorchSharp/THSNN.h @@ -71,7 +71,7 @@ EXPORT_API(Tensor) THSNN_pixel_unshuffle(const Tensor tensor, const int64_t do // Vision -- Functions EXPORT_API(Tensor) THSNN_pad(const Tensor input, const int64_t* pad, const int pad_length, const int8_t mode, const double value); -EXPORT_API(Tensor) THSNN_interpolate(const Tensor input, const int64_t* size, const int size_len, const double* scale_factor, const int scale_factor_len, const int8_t mode, const int8_t align_corners, const bool recompute_scale_factor, NNAnyModule* outAsAnyModule); +EXPORT_API(Tensor) THSNN_interpolate(const Tensor input, const int64_t* size, const int size_len, const double* scale_factor, const int scale_factor_len, const int8_t mode, const int8_t align_corners, const bool recompute_scale_factor, const bool antialias, NNAnyModule* outAsAnyModule); EXPORT_API(Tensor) THSNN_grid_sample(const Tensor input, const Tensor grid, const int8_t mode, const int8_t padding_mode, const int8_t align_corners); EXPORT_API(Tensor) THSNN_affine_grid(const Tensor theta, const int64_t* size, const int size_len, const bool align_corners); diff --git a/src/TorchSharp/NN/Vision.cs b/src/TorchSharp/NN/Vision.cs index 5dd5fe6e2..db751a7ae 100644 --- a/src/TorchSharp/NN/Vision.cs +++ b/src/TorchSharp/NN/Vision.cs @@ -23,7 +23,8 @@ public enum InterpolationMode Bilinear = 2, Bicubic = 3, Trilinear = 4, - Area = 5 + Area = 5, + NearestExact = 6 } public enum GridSampleMode @@ -194,7 +195,7 @@ public static Tensor affine_grid(Tensor theta, long[]? size = null, bool align_c /// The input tensor /// Output spatial size /// Multiplier for spatial size. Has to match input size if it is a tuple. - /// The algorithm used for upsampling: 'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area' + /// The algorithm used for upsampling: 'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area' | 'nearest-exact' /// Geometrically, we consider the pixels of the input and output as squares rather than points. /// If set to true, the input and output tensors are aligned by the center points of their corner pixels, preserving the values at the corner pixels. /// If set to false, the input and output tensors are aligned by the corner points of their corner pixels, and the interpolation uses edge value padding for out-of-boundary values, making this operation independent of input size when scale_factor is kept the same. @@ -205,14 +206,19 @@ public static Tensor affine_grid(Tensor theta, long[]? size = null, bool align_c /// Otherwise, a new scale_factor will be computed based on the output and input sizes for use in the interpolation computation /// (i.e. the computation will be identical to if the computed output_size were passed-in explicitly). /// + /// + /// Flag to apply anti-aliasing. Using anti-alias + /// option together with align_corners = false, interpolation result would match Pillow + /// result for downsampling operation. Supported modes: 'bilinear', 'bicubic'. + /// /// - public static Tensor interpolate(Tensor x, long[]? size = null, double[]? scale_factor = null, InterpolationMode mode = InterpolationMode.Nearest, bool? align_corners = null, bool recompute_scale_factor = false) + public static Tensor interpolate(Tensor x, long[]? size = null, double[]? scale_factor = null, InterpolationMode mode = InterpolationMode.Nearest, bool? align_corners = null, bool recompute_scale_factor = false, bool antialias = false) { unsafe { fixed (long* psize = size) { fixed (double* pSF = scale_factor) { byte ac = (byte)((align_corners.HasValue) ? (align_corners.Value ? 1 : 2) : 0); - var res = THSNN_interpolate(x.Handle, (IntPtr)psize, size is null ? 0 : size.Length, (IntPtr)pSF, scale_factor is null ? 0 : scale_factor.Length, (byte)mode, ac, recompute_scale_factor); + var res = THSNN_interpolate(x.Handle, (IntPtr)psize, size is null ? 0 : size.Length, (IntPtr)pSF, scale_factor is null ? 0 : scale_factor.Length, (byte)mode, ac, recompute_scale_factor, antialias); if (res == IntPtr.Zero) { torch.CheckForErrors(); } return new Tensor(res); } diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs index fd24a26c4..2b97b61eb 100644 --- a/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs +++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs @@ -44,7 +44,7 @@ internal static extern IntPtr THSNN_custom_module( [DllImport("LibTorchSharp")] // align_corners -- 0=None, 1=true, 2=false - internal static extern IntPtr THSNN_interpolate(IntPtr input, IntPtr size, int size_len, IntPtr scale_factor, int scale_factor_len, byte mode, byte align_corners, [MarshalAs(UnmanagedType.U1)] bool recompute_scale_factor); + internal static extern IntPtr THSNN_interpolate(IntPtr input, IntPtr size, int size_len, IntPtr scale_factor, int scale_factor_len, byte mode, byte align_corners, [MarshalAs(UnmanagedType.U1)] bool recompute_scale_factor, [MarshalAs(UnmanagedType.U1)] bool antialias); [DllImport("LibTorchSharp")] // align_corners -- 0=None, 1=true, 2=false diff --git a/src/TorchVision/Functional.cs b/src/TorchVision/Functional.cs index 0f8b00259..6495bd60f 100644 --- a/src/TorchVision/Functional.cs +++ b/src/TorchVision/Functional.cs @@ -694,13 +694,22 @@ public static Tensor posterize(Tensor input, int bits) /// An image tensor. /// The height of the resized image. Must be > 0. /// The width of the resized image. Must be > 0. + /// + /// Desired interpolation enum defined by TorchSharp.torch.InterpolationMode. + /// Default is InterpolationMode.Nearest; not InterpolationMode.Bilinear (incompatible to Python's torchvision v0.17 or later for historical reasons). + /// Only InterpolationMode.Nearest, InterpolationMode.NearestExact, InterpolationMode.Bilinear and InterpolationMode.Bicubic are supported. + /// /// The maximum allowed for the longer edge of the resized image. + /// + /// Whether to apply antialiasing. + /// It only affects bilinear or bicubic modes and it is ignored otherwise. + /// Possible values are: + /// * true: will apply antialiasing for bilinear or bicubic modes. Other mode aren't affected. This is probably what you want to use. + /// * false (default, incompatible to Python's torchvision v0.17 or later for historical reasons): will not apply antialiasing on any mode. + /// /// - public static Tensor resize(Tensor input, int height, int width, int? maxSize = null) + public static Tensor resize(Tensor input, int height, int width, InterpolationMode interpolation = InterpolationMode.Nearest, int ? maxSize = null, bool antialias = false) { - // For now, we don't allow any other modes. - const InterpolationMode interpolation = InterpolationMode.Nearest; - var hoffset = input.Dimensions - 2; var iHeight = input.shape[hoffset]; var iWidth = input.shape[hoffset + 1]; @@ -727,9 +736,12 @@ public static Tensor resize(Tensor input, int height, int width, int? maxSize = } } + if (antialias && interpolation != InterpolationMode.Bilinear && interpolation != InterpolationMode.Bicubic) + antialias = false; + using var img0 = SqueezeIn(input, new ScalarType[] { ScalarType.Float32, ScalarType.Float64 }, out var needCast, out var needSqueeze, out var dtype); - using var img1 = torch.nn.functional.interpolate(img0, new long[] { h, w }, mode: interpolation, align_corners: null); + using var img1 = torch.nn.functional.interpolate(img0, new long[] { h, w }, mode: interpolation, align_corners: null, antialias: antialias); return SqueezeOut(img1, needCast, needSqueeze, dtype); } diff --git a/src/TorchVision/Resize.cs b/src/TorchVision/Resize.cs index 6d5b77751..c0798d75c 100644 --- a/src/TorchVision/Resize.cs +++ b/src/TorchVision/Resize.cs @@ -8,20 +8,24 @@ public static partial class torchvision { internal class Resize : ITransform { - internal Resize(int height, int width, int? maxSize) + internal Resize(int height, int width, InterpolationMode interpolation, int? maxSize, bool antialias) { this.height = height; this.width = width; + this.interpolation = interpolation; this.maxSize = maxSize; + this.antialias = antialias; } public Tensor call(Tensor input) { - return transforms.functional.resize(input, height, width, maxSize); + return transforms.functional.resize(input, height, width, interpolation, maxSize, antialias); } private int height, width; + private InterpolationMode interpolation; private int? maxSize; + private bool antialias; } public static partial class transforms @@ -31,20 +35,45 @@ public static partial class transforms /// /// Desired output height /// Desired output width + /// + /// Desired interpolation enum defined by TorchSharp.torch.InterpolationMode. + /// Default is InterpolationMode.Nearest; not InterpolationMode.Bilinear (incompatible to Python's torchvision v0.17 or later for historical reasons). + /// Only InterpolationMode.Nearest, InterpolationMode.NearestExact, InterpolationMode.Bilinear and InterpolationMode.Bicubic are supported. + /// + /// The maximum allowed for the longer edge of the resized image. + /// + /// Whether to apply antialiasing. + /// It only affects bilinear or bicubic modes and it is ignored otherwise. + /// Possible values are: + /// * true: will apply antialiasing for bilinear or bicubic modes. Other mode aren't affected. This is probably what you want to use. + /// * false (default, incompatible to Python's torchvision v0.17 or later for historical reasons): will not apply antialiasing on any mode. + /// /// - static public ITransform Resize(int height, int width) + static public ITransform Resize(int height, int width, InterpolationMode interpolation = InterpolationMode.Nearest, int? maxSize = null, bool antialias = false) { - return new Resize(height, width, null); + return new Resize(height, width, interpolation, maxSize, antialias); } /// /// Resize the input image to the given size. /// /// Desired output size - /// Max size - static public ITransform Resize(int size, int? maxSize = null) + /// + /// Desired interpolation enum defined by TorchSharp.torch.InterpolationMode. + /// Default is InterpolationMode.Nearest; not InterpolationMode.Bilinear (incompatible to Python's torchvision v0.17 or later for historical reasons). + /// Only InterpolationMode.Nearest, InterpolationMode.NearestExact, InterpolationMode.Bilinear and InterpolationMode.Bicubic are supported. + /// + /// The maximum allowed for the longer edge of the resized image. + /// + /// Whether to apply antialiasing. + /// It only affects bilinear or bicubic modes and it is ignored otherwise. + /// Possible values are: + /// * true: will apply antialiasing for bilinear or bicubic modes. Other mode aren't affected. This is probably what you want to use. + /// * false (default, incompatible to Python's torchvision v0.17 or later for historical reasons): will not apply antialiasing on any mode. + /// + static public ITransform Resize(int size, InterpolationMode interpolation = InterpolationMode.Nearest, int? maxSize = null, bool antialias = false) { - return new Resize(size, -1, maxSize); + return new Resize(size, -1, interpolation, maxSize, antialias); } } } diff --git a/test/TorchSharpTest/NN.cs b/test/TorchSharpTest/NN.cs index e94eb83c4..c2c9ada18 100644 --- a/test/TorchSharpTest/NN.cs +++ b/test/TorchSharpTest/NN.cs @@ -6668,6 +6668,18 @@ public void TestInterpolateTrilinear() } } + [Fact] + public void TestInterpolateNearestExact() + { + foreach (var device in TestUtils.AvailableDevices()) { + using (Tensor input = torch.arange(1, 5, float32, device: device).view(1, 1, 2, 2)) + using (var res = interpolate(input, scale_factor: new double[] { 2, 2 }, mode: InterpolationMode.NearestExact)) { + Assert.Equal(device.type, res.device_type); + Assert.Equal(new long[] { 1, 1, 4, 4 }, res.shape); + } + } + } + [Fact] public void TestUpsampleNearest() { diff --git a/test/TorchSharpTest/TestTorchVision.cs b/test/TorchSharpTest/TestTorchVision.cs index c8f1bc341..2c902b2b8 100644 --- a/test/TorchSharpTest/TestTorchVision.cs +++ b/test/TorchSharpTest/TestTorchVision.cs @@ -938,7 +938,7 @@ public void Resize_WithSizeAndMaxSize_ReturnsTensor() int size = 20; int? maxSize = 30; var input = torch.randn(1, 3, 256, 256); - var transform = Resize(size, maxSize); + var transform = Resize(size, maxSize: maxSize); //Act var result = transform.call(input); @@ -1345,7 +1345,7 @@ public void Resize_WhenMaxSizeNotMet_ThrowsArgumentException() int? maxSize = 8; // Act + Assert - Assert.Throws(() => functional.resize(input, height, -1, maxSize)); + Assert.Throws(() => functional.resize(input, height, -1, maxSize: maxSize)); } [Fact] @@ -1357,7 +1357,7 @@ public void Resize_WhenMaxSizeMet_DoesNotThrowException() int? maxSize = 10; // Act + Assert - functional.resize(input, height, -1, maxSize); + functional.resize(input, height, -1, maxSize: maxSize); } From 4ef180ccf6abe33e46f58e842004be9d60630e6e Mon Sep 17 00:00:00 2001 From: Masaru Kimura Date: Tue, 25 Mar 2025 23:16:27 +0900 Subject: [PATCH 2/6] Update RELEASENOTES.md. --- RELEASENOTES.md | 1 + 1 file changed, 1 insertion(+) diff --git a/RELEASENOTES.md b/RELEASENOTES.md index f11f83cdc..331635671 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -13,6 +13,7 @@ __API Changes__: #1374 Add accumulate to index_put_
`torch.optim.lr_scheduler.PolynomialLR` `power` type has been corrected, is now double.
Returning an input tensor has been corrected, is now `alias()`.
+Add `torchvision.transforms.Resize` `interpolation` and `antialias`.
# NuGet Version 0.105.0 From 7cdcdf4f239501d66f607674a9bec0c6d6f1670e Mon Sep 17 00:00:00 2001 From: Masaru Kimura Date: Mon, 31 Mar 2025 16:53:54 +0900 Subject: [PATCH 3/6] Consider maxSize compatibility. --- src/TorchVision/Functional.cs | 15 ++++++++++++++- src/TorchVision/Resize.cs | 22 ++++++++++++++++++++++ test/TorchSharpTest/TestTorchVision.cs | 6 +++--- 3 files changed, 39 insertions(+), 4 deletions(-) diff --git a/src/TorchVision/Functional.cs b/src/TorchVision/Functional.cs index 6495bd60f..ce54735bb 100644 --- a/src/TorchVision/Functional.cs +++ b/src/TorchVision/Functional.cs @@ -708,7 +708,7 @@ public static Tensor posterize(Tensor input, int bits) /// * false (default, incompatible to Python's torchvision v0.17 or later for historical reasons): will not apply antialiasing on any mode. /// /// - public static Tensor resize(Tensor input, int height, int width, InterpolationMode interpolation = InterpolationMode.Nearest, int ? maxSize = null, bool antialias = false) + public static Tensor resize(Tensor input, int height, int width, InterpolationMode interpolation = InterpolationMode.Nearest, int? maxSize = null, bool antialias = false) { var hoffset = input.Dimensions - 2; var iHeight = input.shape[hoffset]; @@ -746,6 +746,19 @@ public static Tensor resize(Tensor input, int height, int width, InterpolationMo return SqueezeOut(img1, needCast, needSqueeze, dtype); } + /// + /// Resize the input image to the given size. + /// + /// An image tensor. + /// The height of the resized image. Must be > 0. + /// The width of the resized image. Must be > 0. + /// The maximum allowed for the longer edge of the resized image. + /// + public static Tensor resize(Tensor input, int height, int width, int? maxSize = null) + { + return resize(input, height, width, InterpolationMode.Nearest, maxSize, false); + } + /// /// Crop the given image and resize it to desired size. /// diff --git a/src/TorchVision/Resize.cs b/src/TorchVision/Resize.cs index c0798d75c..85f725f00 100644 --- a/src/TorchVision/Resize.cs +++ b/src/TorchVision/Resize.cs @@ -75,6 +75,28 @@ static public ITransform Resize(int size, InterpolationMode interpolation = Inte { return new Resize(size, -1, interpolation, maxSize, antialias); } + + /// + /// Resize the input image to the given size. + /// + /// Desired output height + /// Desired output width + /// The maximum allowed for the longer edge of the resized image. + /// + static public ITransform Resize(int height, int width, int? maxSize = null) + { + return new Resize(height, width, InterpolationMode.Nearest, maxSize, false); + } + + /// + /// Resize the input image to the given size. + /// + /// Desired output size + /// The maximum allowed for the longer edge of the resized image. + static public ITransform Resize(int size, int? maxSize = null) + { + return new Resize(size, -1, InterpolationMode.Nearest, maxSize, false); + } } } } \ No newline at end of file diff --git a/test/TorchSharpTest/TestTorchVision.cs b/test/TorchSharpTest/TestTorchVision.cs index 2c902b2b8..c8f1bc341 100644 --- a/test/TorchSharpTest/TestTorchVision.cs +++ b/test/TorchSharpTest/TestTorchVision.cs @@ -938,7 +938,7 @@ public void Resize_WithSizeAndMaxSize_ReturnsTensor() int size = 20; int? maxSize = 30; var input = torch.randn(1, 3, 256, 256); - var transform = Resize(size, maxSize: maxSize); + var transform = Resize(size, maxSize); //Act var result = transform.call(input); @@ -1345,7 +1345,7 @@ public void Resize_WhenMaxSizeNotMet_ThrowsArgumentException() int? maxSize = 8; // Act + Assert - Assert.Throws(() => functional.resize(input, height, -1, maxSize: maxSize)); + Assert.Throws(() => functional.resize(input, height, -1, maxSize)); } [Fact] @@ -1357,7 +1357,7 @@ public void Resize_WhenMaxSizeMet_DoesNotThrowException() int? maxSize = 10; // Act + Assert - functional.resize(input, height, -1, maxSize: maxSize); + functional.resize(input, height, -1, maxSize); } From b1b2128b1584899bb24290ae3542bdc6cd76e7c4 Mon Sep 17 00:00:00 2001 From: Masaru Kimura Date: Wed, 9 Apr 2025 11:23:12 +0900 Subject: [PATCH 4/6] Update antialias documentation. --- src/TorchVision/Functional.cs | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/TorchVision/Functional.cs b/src/TorchVision/Functional.cs index ce54735bb..00106aad5 100644 --- a/src/TorchVision/Functional.cs +++ b/src/TorchVision/Functional.cs @@ -706,6 +706,7 @@ public static Tensor posterize(Tensor input, int bits) /// Possible values are: /// * true: will apply antialiasing for bilinear or bicubic modes. Other mode aren't affected. This is probably what you want to use. /// * false (default, incompatible to Python's torchvision v0.17 or later for historical reasons): will not apply antialiasing on any mode. + /// antialias value will be automatically set to false silently in case interpolation is not InterpolationMode.Bilinear or InterpolationMode.Bicubic. /// /// public static Tensor resize(Tensor input, int height, int width, InterpolationMode interpolation = InterpolationMode.Nearest, int? maxSize = null, bool antialias = false) @@ -736,6 +737,11 @@ public static Tensor resize(Tensor input, int height, int width, InterpolationMo } } + // See https://github.com/pytorch/vision/blob/v0.21.0/torchvision/transforms/_functional_tensor.py#L455 + // "We manually set it to False to avoid an error downstream in interpolate() + // This behaviour is documented: the parameter is irrelevant for modes + // that are not bilinear or bicubic. We used to raise an error here, but + // now we don't ..." if (antialias && interpolation != InterpolationMode.Bilinear && interpolation != InterpolationMode.Bicubic) antialias = false; From 1764c73a3208dab5175a265fc3aebdd75168a771 Mon Sep 17 00:00:00 2001 From: Masaru Kimura Date: Wed, 9 Apr 2025 14:56:03 +0900 Subject: [PATCH 5/6] Add more compatibility handling. --- src/TorchVision/Functional.cs | 12 ++++++++++++ src/TorchVision/Resize.cs | 20 ++++++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/src/TorchVision/Functional.cs b/src/TorchVision/Functional.cs index 00106aad5..b26d4715f 100644 --- a/src/TorchVision/Functional.cs +++ b/src/TorchVision/Functional.cs @@ -765,6 +765,18 @@ public static Tensor resize(Tensor input, int height, int width, int? maxSize = return resize(input, height, width, InterpolationMode.Nearest, maxSize, false); } + /// + /// Resize the input image to the given size. + /// + /// An image tensor. + /// The height of the resized image. Must be > 0. + /// The width of the resized image. Must be > 0. + /// + public static Tensor resize(Tensor input, int height, int width) + { + return resize(input, height, width, InterpolationMode.Nearest, null, false); + } + /// /// Crop the given image and resize it to desired size. /// diff --git a/src/TorchVision/Resize.cs b/src/TorchVision/Resize.cs index 85f725f00..733c59aac 100644 --- a/src/TorchVision/Resize.cs +++ b/src/TorchVision/Resize.cs @@ -88,6 +88,17 @@ static public ITransform Resize(int height, int width, int? maxSize = null) return new Resize(height, width, InterpolationMode.Nearest, maxSize, false); } + /// + /// Resize the input image to the given size. + /// + /// Desired output height + /// Desired output width + /// + static public ITransform Resize(int height, int width) + { + return new Resize(height, width, InterpolationMode.Nearest, null, false); + } + /// /// Resize the input image to the given size. /// @@ -97,6 +108,15 @@ static public ITransform Resize(int size, int? maxSize = null) { return new Resize(size, -1, InterpolationMode.Nearest, maxSize, false); } + + /// + /// Resize the input image to the given size. + /// + /// Desired output size + static public ITransform Resize(int size) + { + return new Resize(size, -1, InterpolationMode.Nearest, null, false); + } } } } \ No newline at end of file From 5d0ab742682d72919eec7f4f6ab8f05395edf496 Mon Sep 17 00:00:00 2001 From: Masaru Kimura Date: Wed, 9 Apr 2025 14:56:41 +0900 Subject: [PATCH 6/6] Add test cases for antialias. --- test/TorchSharpTest/NN.cs | 111 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 111 insertions(+) diff --git a/test/TorchSharpTest/NN.cs b/test/TorchSharpTest/NN.cs index c2c9ada18..86e339d7f 100644 --- a/test/TorchSharpTest/NN.cs +++ b/test/TorchSharpTest/NN.cs @@ -6643,6 +6643,117 @@ public void TestInterpolateBilinear2D() } } + [Fact] + public void TestInterpolateBilinear2DNoAntialias() + { + foreach (var device in TestUtils.AvailableDevices()) { + using Tensor input = torch.tensor(rawArray: new float[] { + 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, + 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, + 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f + }, new long[] { 1, 1, 9, 9 }, float32, device: device); + using var res = torch.nn.functional.interpolate(input, new long[] { 6, 6 }, mode: InterpolationMode.Bilinear, antialias: false); + using Tensor expect = torch.tensor(rawArray: new float[] { + 0.7500f, 0.0000f, 0.7500f, 0.0000f, 0.7500f, 0.0000f, + 0.7500f, 0.0000f, 0.7500f, 0.0000f, 0.7500f, 0.0000f, + 0.2500f, 0.2500f, 0.2500f, 0.2500f, 0.2500f, 0.2500f, + 0.2500f, 0.2500f, 0.2500f, 0.2500f, 0.2500f, 0.2500f, + 0.0000f, 0.7500f, 0.0000f, 0.7500f, 0.0000f, 0.7500f, + 0.0000f, 0.7500f, 0.0000f, 0.7500f, 0.0000f, 0.7500f + }, new long[] { 1, 1, 6, 6 }, float32, device: device); + Assert.True(torch.allclose(res, expect, rtol: 0.0, atol: 1E-04)); + } + } + + [Fact] + public void TestInterpolateBilinear2DAntialias() + { + foreach (var device in TestUtils.AvailableDevices()) { + using Tensor input = torch.tensor(rawArray: new float[] { + 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, + 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, + 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f + }, new long[] { 1, 1, 9, 9 }, float32, device: device); + using var res = torch.nn.functional.interpolate(input, new long[] { 6, 6 }, mode: InterpolationMode.Bilinear, antialias: true); + using Tensor expect = torch.tensor(rawArray: new float[] { + 0.6250f, 0.1111f, 0.5556f, 0.1111f, 0.5556f, 0.0000f, + 0.5972f, 0.1358f, 0.5309f, 0.1358f, 0.5309f, 0.0417f, + 0.4028f, 0.3086f, 0.3580f, 0.3086f, 0.3580f, 0.3333f, + 0.3333f, 0.3580f, 0.3086f, 0.3580f, 0.3086f, 0.4028f, + 0.0417f, 0.5309f, 0.1358f, 0.5309f, 0.1358f, 0.5972f, + 0.0000f, 0.5556f, 0.1111f, 0.5556f, 0.1111f, 0.6250f + }, new long[] { 1, 1, 6, 6 }, float32, device: device); + Assert.True(torch.allclose(res, expect, rtol: 0.0, atol: 1E-04)); + } + } + + [Fact] + public void TestInterpolateBicubic2DNoAntialias() + { + foreach (var device in TestUtils.AvailableDevices()) { + using Tensor input = torch.tensor(rawArray: new float[] { + 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, + 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, + 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f + }, new long[] { 1, 1, 9, 9 }, float32, device: device); + using var res = torch.nn.functional.interpolate(input, new long[] { 6, 6 }, mode: InterpolationMode.Bicubic, antialias: true); + using Tensor expect = torch.tensor(rawArray: new float[] { + 0.6493f, 0.0467f, 0.6196f, 0.0471f, 0.6226f, -0.0440f, + 0.6356f, 0.0619f, 0.6042f, 0.0624f, 0.6069f, -0.0205f, + 0.4083f, 0.3155f, 0.3487f, 0.3180f, 0.3464f, 0.3712f, + 0.3712f, 0.3464f, 0.3180f, 0.3487f, 0.3155f, 0.4083f, + -0.0205f, 0.6069f, 0.0624f, 0.6042f, 0.0619f, 0.6356f, + -0.0440f, 0.6226f, 0.0471f, 0.6196f, 0.0467f, 0.6493f + }, new long[] { 1, 1, 6, 6 }, float32, device: device); + Assert.True(torch.allclose(res, expect, rtol: 0.0, atol: 1E-04)); + } + } + + [Fact] + public void TestInterpolateBicubic2DAntialias() + { + foreach (var device in TestUtils.AvailableDevices()) { + using Tensor input = torch.tensor(rawArray: new float[] { + 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, + 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, + 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, + 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f, 0.0f, 0.0f, 1.0f + }, new long[] { 1, 1, 9, 9 }, float32, device: device); + using var res = torch.nn.functional.interpolate(input, new long[] { 6, 6 }, mode: InterpolationMode.Bicubic, antialias: false); + using Tensor expect = torch.tensor(rawArray: new float[] { + 0.7734f, -0.1406f, 0.8789f, -0.1406f, 0.8789f, -0.0352f, + 0.8274f, -0.1831f, 0.9440f, -0.1831f, 0.9440f, -0.0665f, + 0.2077f, 0.3042f, 0.1966f, 0.3042f, 0.1966f, 0.2930f, + 0.2930f, 0.1966f, 0.3042f, 0.1966f, 0.3042f, 0.2077f, + -0.0665f, 0.9440f, -0.1831f, 0.9440f, -0.1831f, 0.8274f, + -0.0352f, 0.8789f, -0.1406f, 0.8789f, -0.1406f, 0.7734f + }, new long[] { 1, 1, 6, 6 }, float32, device: device); + Assert.True(torch.allclose(res, expect, rtol: 0.0, atol: 1E-04)); + } + } [Fact] public void TestInterpolateArea()