Skip to content

Commit 8c88fd0

Browse files
hiyuhMasaru Kimura
andauthored
Add torchvision.transforms.Resize interpolation and antialias. (#1441)
* Add torchvision.transforms.Resize interpolation and antialias. torchvision.transforms.Resize forced nearest interpolation and no antialias, but shouln't. Based on my understanding, original torchvision.transforms.Resize calls like; - torchvision.transforms.Resize - torchvision.transforms.functional.resize - torchvision.transforms._functional_pil.resize - PIL.Image.Image.resize - torchvision.transforms._functional_tensor.resize - torch.nn.functional.interpolate Note, this PR still keeps nearest interpolation and no antialias by default for torchvision.transforms.Resize to maximize compatibility for existing code using TorchSharp and make it being incompatible to original torchvision.transforms.Resize default, however, it would be up to the upstream decision. See also; * https://pytorch.org/vision/main/generated/torchvision.transforms.Resize.html * https://pytorch.org/vision/main/generated/torchvision.transforms.functional.resize.html * https://pytorch.org/docs/stable/generated/torch.nn.functional.interpolate.html * Update RELEASENOTES.md. * Consider maxSize compatibility. * Update antialias documentation. * Add more compatibility handling. * Add test cases for antialias. --------- Co-authored-by: Masaru Kimura <[email protected]>
1 parent b056ca5 commit 8c88fd0

File tree

8 files changed

+264
-17
lines changed

8 files changed

+264
-17
lines changed

RELEASENOTES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ __API Changes__:
1313
#1374 Add accumulate to index_put_<br/>
1414
`torch.optim.lr_scheduler.PolynomialLR` `power` type has been corrected, is now double.<br/>
1515
Returning an input tensor has been corrected, is now `alias()`.<br/>
16+
Add `torchvision.transforms.Resize` `interpolation` and `antialias`.<br />
1617

1718
# NuGet Version 0.105.0
1819

src/Native/LibTorchSharp/THSNN.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,8 @@ void ApplyInterpolateMode(T& opts, const int8_t mode)
109109
opts = opts.mode(torch::kTrilinear);
110110
if (mode == 5)
111111
opts = opts.mode(torch::kArea);
112+
if (mode == 6)
113+
opts = opts.mode(torch::kNearestExact);
112114
}
113115

114116
template<typename T>
@@ -176,13 +178,14 @@ Tensor THSNN_affine_grid(const Tensor theta, const int64_t* size, const int size
176178
}
177179

178180

179-
EXPORT_API(Tensor) THSNN_interpolate(const Tensor input, const int64_t* size, const int size_len, const double* scale_factor, const int scale_factor_len, const int8_t mode, const int8_t align_corners, const bool recompute_scale_factor, NNAnyModule* outAsAnyModule)
181+
EXPORT_API(Tensor) THSNN_interpolate(const Tensor input, const int64_t* size, const int size_len, const double* scale_factor, const int scale_factor_len, const int8_t mode, const int8_t align_corners, const bool recompute_scale_factor, const bool antialias, NNAnyModule* outAsAnyModule)
180182
{
181183
auto opts = torch::nn::functional::InterpolateFuncOptions().recompute_scale_factor(recompute_scale_factor);
182184
// align_corners -- 0=None, 1=true, 2=false
183185
if (align_corners != 0)
184186
opts.align_corners(align_corners == 1);
185187
ApplyInterpolateMode(opts, mode);
188+
opts.antialias(antialias);
186189

187190
if (size_len > 0) {
188191
std::vector<int64_t> sizes;

src/Native/LibTorchSharp/THSNN.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ EXPORT_API(Tensor) THSNN_pixel_unshuffle(const Tensor tensor, const int64_t do
7171
// Vision -- Functions
7272

7373
EXPORT_API(Tensor) THSNN_pad(const Tensor input, const int64_t* pad, const int pad_length, const int8_t mode, const double value);
74-
EXPORT_API(Tensor) THSNN_interpolate(const Tensor input, const int64_t* size, const int size_len, const double* scale_factor, const int scale_factor_len, const int8_t mode, const int8_t align_corners, const bool recompute_scale_factor, NNAnyModule* outAsAnyModule);
74+
EXPORT_API(Tensor) THSNN_interpolate(const Tensor input, const int64_t* size, const int size_len, const double* scale_factor, const int scale_factor_len, const int8_t mode, const int8_t align_corners, const bool recompute_scale_factor, const bool antialias, NNAnyModule* outAsAnyModule);
7575
EXPORT_API(Tensor) THSNN_grid_sample(const Tensor input, const Tensor grid, const int8_t mode, const int8_t padding_mode, const int8_t align_corners);
7676
EXPORT_API(Tensor) THSNN_affine_grid(const Tensor theta, const int64_t* size, const int size_len, const bool align_corners);
7777

src/TorchSharp/NN/Vision.cs

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ public enum InterpolationMode
2323
Bilinear = 2,
2424
Bicubic = 3,
2525
Trilinear = 4,
26-
Area = 5
26+
Area = 5,
27+
NearestExact = 6
2728
}
2829

2930
public enum GridSampleMode
@@ -194,7 +195,7 @@ public static Tensor affine_grid(Tensor theta, long[]? size = null, bool align_c
194195
/// <param name="x">The input tensor</param>
195196
/// <param name="size">Output spatial size</param>
196197
/// <param name="scale_factor">Multiplier for spatial size. Has to match input size if it is a tuple.</param>
197-
/// <param name="mode">The algorithm used for upsampling: 'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area'</param>
198+
/// <param name="mode">The algorithm used for upsampling: 'nearest' | 'linear' | 'bilinear' | 'bicubic' | 'trilinear' | 'area' | 'nearest-exact'</param>
198199
/// <param name="align_corners">Geometrically, we consider the pixels of the input and output as squares rather than points.
199200
/// If set to true, the input and output tensors are aligned by the center points of their corner pixels, preserving the values at the corner pixels.
200201
/// If set to false, the input and output tensors are aligned by the corner points of their corner pixels, and the interpolation uses edge value padding for out-of-boundary values, making this operation independent of input size when scale_factor is kept the same.</param>
@@ -205,14 +206,19 @@ public static Tensor affine_grid(Tensor theta, long[]? size = null, bool align_c
205206
/// Otherwise, a new scale_factor will be computed based on the output and input sizes for use in the interpolation computation
206207
/// (i.e. the computation will be identical to if the computed output_size were passed-in explicitly).
207208
/// </param>
209+
/// <param name="antialias">
210+
/// Flag to apply anti-aliasing. Using anti-alias
211+
/// option together with align_corners = false, interpolation result would match Pillow
212+
/// result for downsampling operation. Supported modes: 'bilinear', 'bicubic'.
213+
/// </param>
208214
/// <returns></returns>
209-
public static Tensor interpolate(Tensor x, long[]? size = null, double[]? scale_factor = null, InterpolationMode mode = InterpolationMode.Nearest, bool? align_corners = null, bool recompute_scale_factor = false)
215+
public static Tensor interpolate(Tensor x, long[]? size = null, double[]? scale_factor = null, InterpolationMode mode = InterpolationMode.Nearest, bool? align_corners = null, bool recompute_scale_factor = false, bool antialias = false)
210216
{
211217
unsafe {
212218
fixed (long* psize = size) {
213219
fixed (double* pSF = scale_factor) {
214220
byte ac = (byte)((align_corners.HasValue) ? (align_corners.Value ? 1 : 2) : 0);
215-
var res = THSNN_interpolate(x.Handle, (IntPtr)psize, size is null ? 0 : size.Length, (IntPtr)pSF, scale_factor is null ? 0 : scale_factor.Length, (byte)mode, ac, recompute_scale_factor);
221+
var res = THSNN_interpolate(x.Handle, (IntPtr)psize, size is null ? 0 : size.Length, (IntPtr)pSF, scale_factor is null ? 0 : scale_factor.Length, (byte)mode, ac, recompute_scale_factor, antialias);
216222
if (res == IntPtr.Zero) { torch.CheckForErrors(); }
217223
return new Tensor(res);
218224
}

src/TorchSharp/PInvoke/LibTorchSharp.THSNN.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ internal static extern IntPtr THSNN_custom_module(
4444

4545
[DllImport("LibTorchSharp")]
4646
// align_corners -- 0=None, 1=true, 2=false
47-
internal static extern IntPtr THSNN_interpolate(IntPtr input, IntPtr size, int size_len, IntPtr scale_factor, int scale_factor_len, byte mode, byte align_corners, [MarshalAs(UnmanagedType.U1)] bool recompute_scale_factor);
47+
internal static extern IntPtr THSNN_interpolate(IntPtr input, IntPtr size, int size_len, IntPtr scale_factor, int scale_factor_len, byte mode, byte align_corners, [MarshalAs(UnmanagedType.U1)] bool recompute_scale_factor, [MarshalAs(UnmanagedType.U1)] bool antialias);
4848

4949
[DllImport("LibTorchSharp")]
5050
// align_corners -- 0=None, 1=true, 2=false

src/TorchVision/Functional.cs

Lines changed: 48 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -694,13 +694,23 @@ public static Tensor posterize(Tensor input, int bits)
694694
/// <param name="input">An image tensor.</param>
695695
/// <param name="height">The height of the resized image. Must be > 0.</param>
696696
/// <param name="width">The width of the resized image. Must be > 0.</param>
697+
/// <param name="interpolation">
698+
/// Desired interpolation enum defined by TorchSharp.torch.InterpolationMode.
699+
/// Default is InterpolationMode.Nearest; not InterpolationMode.Bilinear (incompatible to Python's torchvision v0.17 or later for historical reasons).
700+
/// Only InterpolationMode.Nearest, InterpolationMode.NearestExact, InterpolationMode.Bilinear and InterpolationMode.Bicubic are supported.
701+
/// </param>
697702
/// <param name="maxSize">The maximum allowed for the longer edge of the resized image.</param>
703+
/// <param name="antialias">
704+
/// Whether to apply antialiasing.
705+
/// It only affects bilinear or bicubic modes and it is ignored otherwise.
706+
/// Possible values are:
707+
/// * true: will apply antialiasing for bilinear or bicubic modes. Other mode aren't affected. This is probably what you want to use.
708+
/// * false (default, incompatible to Python's torchvision v0.17 or later for historical reasons): will not apply antialiasing on any mode.
709+
/// antialias value will be automatically set to false silently in case interpolation is not InterpolationMode.Bilinear or InterpolationMode.Bicubic.
710+
/// </param>
698711
/// <returns></returns>
699-
public static Tensor resize(Tensor input, int height, int width, int? maxSize = null)
712+
public static Tensor resize(Tensor input, int height, int width, InterpolationMode interpolation = InterpolationMode.Nearest, int? maxSize = null, bool antialias = false)
700713
{
701-
// For now, we don't allow any other modes.
702-
const InterpolationMode interpolation = InterpolationMode.Nearest;
703-
704714
var hoffset = input.Dimensions - 2;
705715
var iHeight = input.shape[hoffset];
706716
var iWidth = input.shape[hoffset + 1];
@@ -727,13 +737,46 @@ public static Tensor resize(Tensor input, int height, int width, int? maxSize =
727737
}
728738
}
729739

740+
// See https://github.com/pytorch/vision/blob/v0.21.0/torchvision/transforms/_functional_tensor.py#L455
741+
// "We manually set it to False to avoid an error downstream in interpolate()
742+
// This behaviour is documented: the parameter is irrelevant for modes
743+
// that are not bilinear or bicubic. We used to raise an error here, but
744+
// now we don't ..."
745+
if (antialias && interpolation != InterpolationMode.Bilinear && interpolation != InterpolationMode.Bicubic)
746+
antialias = false;
747+
730748
using var img0 = SqueezeIn(input, new ScalarType[] { ScalarType.Float32, ScalarType.Float64 }, out var needCast, out var needSqueeze, out var dtype);
731749

732-
using var img1 = torch.nn.functional.interpolate(img0, new long[] { h, w }, mode: interpolation, align_corners: null);
750+
using var img1 = torch.nn.functional.interpolate(img0, new long[] { h, w }, mode: interpolation, align_corners: null, antialias: antialias);
733751

734752
return SqueezeOut(img1, needCast, needSqueeze, dtype);
735753
}
736754

755+
/// <summary>
756+
/// Resize the input image to the given size.
757+
/// </summary>
758+
/// <param name="input">An image tensor.</param>
759+
/// <param name="height">The height of the resized image. Must be > 0.</param>
760+
/// <param name="width">The width of the resized image. Must be > 0.</param>
761+
/// <param name="maxSize">The maximum allowed for the longer edge of the resized image.</param>
762+
/// <returns></returns>
763+
public static Tensor resize(Tensor input, int height, int width, int? maxSize = null)
764+
{
765+
return resize(input, height, width, InterpolationMode.Nearest, maxSize, false);
766+
}
767+
768+
/// <summary>
769+
/// Resize the input image to the given size.
770+
/// </summary>
771+
/// <param name="input">An image tensor.</param>
772+
/// <param name="height">The height of the resized image. Must be > 0.</param>
773+
/// <param name="width">The width of the resized image. Must be > 0.</param>
774+
/// <returns></returns>
775+
public static Tensor resize(Tensor input, int height, int width)
776+
{
777+
return resize(input, height, width, InterpolationMode.Nearest, null, false);
778+
}
779+
737780
/// <summary>
738781
/// Crop the given image and resize it to desired size.
739782
/// </summary>

src/TorchVision/Resize.cs

Lines changed: 76 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -8,24 +8,86 @@ public static partial class torchvision
88
{
99
internal class Resize : ITransform
1010
{
11-
internal Resize(int height, int width, int? maxSize)
11+
internal Resize(int height, int width, InterpolationMode interpolation, int? maxSize, bool antialias)
1212
{
1313
this.height = height;
1414
this.width = width;
15+
this.interpolation = interpolation;
1516
this.maxSize = maxSize;
17+
this.antialias = antialias;
1618
}
1719

1820
public Tensor call(Tensor input)
1921
{
20-
return transforms.functional.resize(input, height, width, maxSize);
22+
return transforms.functional.resize(input, height, width, interpolation, maxSize, antialias);
2123
}
2224

2325
private int height, width;
26+
private InterpolationMode interpolation;
2427
private int? maxSize;
28+
private bool antialias;
2529
}
2630

2731
public static partial class transforms
2832
{
33+
/// <summary>
34+
/// Resize the input image to the given size.
35+
/// </summary>
36+
/// <param name="height">Desired output height</param>
37+
/// <param name="width">Desired output width</param>
38+
/// <param name="interpolation">
39+
/// Desired interpolation enum defined by TorchSharp.torch.InterpolationMode.
40+
/// Default is InterpolationMode.Nearest; not InterpolationMode.Bilinear (incompatible to Python's torchvision v0.17 or later for historical reasons).
41+
/// Only InterpolationMode.Nearest, InterpolationMode.NearestExact, InterpolationMode.Bilinear and InterpolationMode.Bicubic are supported.
42+
/// </param>
43+
/// <param name="maxSize">The maximum allowed for the longer edge of the resized image.</param>
44+
/// <param name="antialias">
45+
/// Whether to apply antialiasing.
46+
/// It only affects bilinear or bicubic modes and it is ignored otherwise.
47+
/// Possible values are:
48+
/// * true: will apply antialiasing for bilinear or bicubic modes. Other mode aren't affected. This is probably what you want to use.
49+
/// * false (default, incompatible to Python's torchvision v0.17 or later for historical reasons): will not apply antialiasing on any mode.
50+
/// </param>
51+
/// <returns></returns>
52+
static public ITransform Resize(int height, int width, InterpolationMode interpolation = InterpolationMode.Nearest, int? maxSize = null, bool antialias = false)
53+
{
54+
return new Resize(height, width, interpolation, maxSize, antialias);
55+
}
56+
57+
/// <summary>
58+
/// Resize the input image to the given size.
59+
/// </summary>
60+
/// <param name="size">Desired output size</param>
61+
/// <param name="interpolation">
62+
/// Desired interpolation enum defined by TorchSharp.torch.InterpolationMode.
63+
/// Default is InterpolationMode.Nearest; not InterpolationMode.Bilinear (incompatible to Python's torchvision v0.17 or later for historical reasons).
64+
/// Only InterpolationMode.Nearest, InterpolationMode.NearestExact, InterpolationMode.Bilinear and InterpolationMode.Bicubic are supported.
65+
/// </param>
66+
/// <param name="maxSize">The maximum allowed for the longer edge of the resized image.</param>
67+
/// <param name="antialias">
68+
/// Whether to apply antialiasing.
69+
/// It only affects bilinear or bicubic modes and it is ignored otherwise.
70+
/// Possible values are:
71+
/// * true: will apply antialiasing for bilinear or bicubic modes. Other mode aren't affected. This is probably what you want to use.
72+
/// * false (default, incompatible to Python's torchvision v0.17 or later for historical reasons): will not apply antialiasing on any mode.
73+
/// </param>
74+
static public ITransform Resize(int size, InterpolationMode interpolation = InterpolationMode.Nearest, int? maxSize = null, bool antialias = false)
75+
{
76+
return new Resize(size, -1, interpolation, maxSize, antialias);
77+
}
78+
79+
/// <summary>
80+
/// Resize the input image to the given size.
81+
/// </summary>
82+
/// <param name="height">Desired output height</param>
83+
/// <param name="width">Desired output width</param>
84+
/// <param name="maxSize">The maximum allowed for the longer edge of the resized image.</param>
85+
/// <returns></returns>
86+
static public ITransform Resize(int height, int width, int? maxSize = null)
87+
{
88+
return new Resize(height, width, InterpolationMode.Nearest, maxSize, false);
89+
}
90+
2991
/// <summary>
3092
/// Resize the input image to the given size.
3193
/// </summary>
@@ -34,17 +96,26 @@ public static partial class transforms
3496
/// <returns></returns>
3597
static public ITransform Resize(int height, int width)
3698
{
37-
return new Resize(height, width, null);
99+
return new Resize(height, width, InterpolationMode.Nearest, null, false);
38100
}
39101

40102
/// <summary>
41103
/// Resize the input image to the given size.
42104
/// </summary>
43105
/// <param name="size">Desired output size</param>
44-
/// <param name="maxSize">Max size</param>
106+
/// <param name="maxSize">The maximum allowed for the longer edge of the resized image.</param>
45107
static public ITransform Resize(int size, int? maxSize = null)
46108
{
47-
return new Resize(size, -1, maxSize);
109+
return new Resize(size, -1, InterpolationMode.Nearest, maxSize, false);
110+
}
111+
112+
/// <summary>
113+
/// Resize the input image to the given size.
114+
/// </summary>
115+
/// <param name="size">Desired output size</param>
116+
static public ITransform Resize(int size)
117+
{
118+
return new Resize(size, -1, InterpolationMode.Nearest, null, false);
48119
}
49120
}
50121
}

0 commit comments

Comments
 (0)