Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor TorchVision Normalize method like pytorch. #1428

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
62 changes: 50 additions & 12 deletions src/TorchVision/Functional.cs
Original file line number Diff line number Diff line change
@@ -24,6 +24,35 @@ public static partial class transforms
{
public static partial class functional
{

private static bool IsTensorImage(Tensor img)
{
return img.ndim >= 2;
}

private static bool AssertTensorImage(Tensor img)
{
if (!IsTensorImage(img))
throw new ArgumentException("Tensor is not a torch image.");
return true;
}

/// <summary>
/// Returns the number of channels of an image.
/// </summary>
/// <param name="img">(Tensor) – The image to be checked.</param>
/// <returns>The number of channels.</returns>
public static long get_image_num_channels(Tensor img)
{
AssertTensorImage(img);
var ndim_ = img.ndim;
return ndim_ switch {
2 => 1,
> 2 => img.shape[ndim_ - 3],
_ => throw new ArgumentException($"Input ndim should be 2 or more. Got {ndim_}"),
};
}

/// <summary>
/// Get the image dimensions
/// </summary>
@@ -533,20 +562,29 @@ public static Tensor invert(Tensor input)
/// <param name="input">An image tensor.</param>
/// <param name="means">Sequence of means for each channel.</param>
/// <param name="stdevs">Sequence of standard deviations for each channel.</param>
/// <param name="dtype">Bool to make this operation inplace.</param>
/// <param name="inplace">Bool to make this operation inplace.</param>
/// <returns></returns>
public static Tensor normalize(Tensor input, double[] means, double[] stdevs, ScalarType dtype = ScalarType.Float32)
public static Tensor normalize(Tensor input, double[] means, double[] stdevs, bool inplace = false)
{
if (means.Length != stdevs.Length)
throw new ArgumentException("means and stdevs must be the same length in call to Normalize");
if (means.Length != input.shape[1])
throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations");

using var mean = means.ToTensor(new long[] { 1, means.Length, 1, 1 }).to(input.dtype, input.device); // Assumes NxCxHxW
using var stdev = stdevs.ToTensor(new long[] { 1, stdevs.Length, 1, 1 }).to(input.dtype, input.device); // Assumes NxCxHxW
using var t0 = input - mean;

return t0 / stdev;
using var _ = NewDisposeScope();
AssertTensorImage(input);
if (!input.is_floating_point())
throw new ArgumentException($"Input tensor should be a float tensor. Got {input.dtype}.");
if (input.ndim < 3)
throw new ArgumentException($"Expected tensor to be a tensor image of size (..., C, H, W). Got tensor.size() = ({string.Join(", ", input.shape)})");
if (!inplace)
input = input.clone();


var mean = as_tensor(means, dtype: input.dtype, device: input.device);
var stdev = as_tensor(stdevs, dtype: input.dtype, device: input.device);
if ((stdev == 0).any().item<bool>())
throw new ArgumentException($"std evaluated to zero after conversion to {input.dtype}, leading to division by zero.");
if (mean.ndim == 1)
mean = mean.view(-1, 1, 1);
if (stdev.ndim == 1)
stdev = stdev.view(-1, 1, 1);
return input.sub_(mean).div_(stdev).MoveToOuterDisposeScope();
}

private static Tensor _pad(Tensor input, ReadOnlySpan<long> padding, double fill = 0, PaddingModes padding_mode = PaddingModes.Constant)
63 changes: 16 additions & 47 deletions src/TorchVision/Normalize.cs
Original file line number Diff line number Diff line change
@@ -7,62 +7,32 @@ namespace TorchSharp
{
public static partial class torchvision
{
internal class Normalize : ITransform, IDisposable
internal class Normalize : ITransform
{
internal Normalize(double[] means, double[] stdevs, ScalarType dtype = ScalarType.Float32, torch.Device? device = null)
internal Normalize(double[] means, double[] stdevs,bool inplace = false)
{
if (means is null) throw new ArgumentNullException(nameof(means));
if (stdevs is null) throw new ArgumentNullException(nameof(stdevs));
if (means.Length != stdevs.Length)
throw new ArgumentException($"{nameof(means)} and {nameof(stdevs)} must be the same length in call to Normalize");
if (means.Length != 1 && means.Length != 3)
throw new ArgumentException($"Since they correspond to the number of channels in an image, {nameof(means)} and {nameof(stdevs)} must both be either 1 or 3 long");
this.means = means;
this.stdevs = stdevs;
this.inplace = inplace;

this.means = means.ToTensor(new long[] { 1, means.Length, 1, 1 }); // Assumes NxCxHxW
this.stdevs = stdevs.ToTensor(new long[] { 1, stdevs.Length, 1, 1 }); // Assumes NxCxHxW

if (dtype != ScalarType.Float64) {
this.means = this.means.to_type(dtype);
this.stdevs = this.stdevs.to_type(dtype);
}

if (device != null && device.type != DeviceType.CPU) {
this.means = this.means.to(device);
this.stdevs = this.stdevs.to(device);
}
}

public Tensor call(Tensor input)
{
if (means.size(1) != input.size(1)) throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations");
return (input - means) / stdevs;
}

private Tensor means;
private Tensor stdevs;
bool disposedValue;

protected virtual void Dispose(bool disposing)
{
if (!disposedValue) {
means?.Dispose();
stdevs?.Dispose();
disposedValue = true;
}
var expectedChannels = transforms.functional.get_image_num_channels(input);
if (expectedChannels != means.Length)
throw new ArgumentException("The number of channels is not equal to the number of means and standard deviations");
return transforms.functional.normalize(input, means, stdevs, inplace);
}

~Normalize()
{
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
Dispose(disposing: false);
}

public void Dispose()
{
// Do not change this code. Put cleanup code in 'Dispose(bool disposing)' method
Dispose(disposing: true);
GC.SuppressFinalize(this);
}
private readonly double[] means;
private readonly double[] stdevs;
private readonly bool inplace;

}

public static partial class transforms
@@ -72,12 +42,11 @@ public static partial class transforms
/// </summary>
/// <param name="means">Sequence of means for each channel.</param>
/// <param name="stdevs">Sequence of standard deviations for each channel.</param>
/// <param name="dtype">Bool to make this operation inplace.</param>
/// <param name="device">The device to place the output tensor on.</param>
/// <param name="inplace">Bool to make this operation inplace.</param>
/// <returns></returns>
static public ITransform Normalize(double[] means, double[] stdevs, ScalarType dtype = ScalarType.Float32, torch.Device? device = null)
static public ITransform Normalize(double[] means, double[] stdevs, bool inplace = false)
{
return new Normalize(means, stdevs, dtype, device);
return new Normalize(means, stdevs, inplace);
}
}
}
11 changes: 0 additions & 11 deletions test/TorchSharpTest/TestTorchVision.cs
Original file line number Diff line number Diff line change
@@ -845,17 +845,6 @@ public void TestConstructor_ThrowsArgumentException_IfMeansAndStdevsHaveDifferen
Assert.Throws<ArgumentException>(() => Normalize(means, stdevs));
}

[Fact]
public void TestConstructor_ThrowsArgumentException_IfMeansAndStdevsHaveWrongLengths()
{
// Arrange
double[] means = { 0.485, 0.456 };
double[] stdevs = { 0.229, 0.224 }; // Not 1 or 3

// Act & Assert
Assert.Throws<ArgumentException>(() => Normalize(means, stdevs));
}

[Fact]
public void TestConstructor_CreatesNewNormalizeObject_WithValidArguments()
{