Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add missing accumulate argument for index_put_ method #1460

Merged
merged 14 commits into from
Mar 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions RELEASENOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,10 @@ __Bug Fixes__:

#1426 Sequential.eval() does not put model into eval mode<br/>
`torch.optim.lr_scheduler.LinearLR` `end_factor` default has been corrected, is now 1.0.<br/>

__API Changes__:

#1374 Add accumulate to index_put_<br/>
`torch.optim.lr_scheduler.PolynomialLR` `power` type has been corrected, is now double.<br/>

# NuGet Version 0.105.0
Expand Down
25 changes: 25 additions & 0 deletions src/Native/LibTorchSharp/THSTensor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -837,6 +837,31 @@ void THSTensor_index_put_(Tensor tensor,
CATCH(tensor->index_put_(indices, *value););
}

void THSTensor_index_put_(Tensor tensor,
const int64_t* indexStarts,
const int64_t* indexEnds,
const int64_t* indexSteps,
const Tensor* indexTensors,
const int indicesLength,
const Tensor value,
const bool accumulate)
{
at::indexing::TensorIndex* indicesArray = (at::indexing::TensorIndex*)alloca(indicesLength * sizeof(at::indexing::TensorIndex));
memset(indicesArray, 0, indicesLength * sizeof(at::indexing::TensorIndex));
completeTensorIndices(indexStarts, indexEnds, indexSteps, indexTensors, indicesArray, indicesLength);
auto indices = at::ArrayRef<at::indexing::TensorIndex>(indicesArray, indicesLength);
if (accumulate) {
c10::List<std::optional<at::Tensor>> indicesList = c10::List<std::optional<at::Tensor>>();
for (int i = 0; i < indicesLength; i++) {
indicesList.push_back(c10::optional<at::Tensor>(*indexTensors[i]));
}
CATCH(tensor->index_put_(indicesList, *value, accumulate););
}
else {
CATCH(tensor->index_put_(indices, *value););
}
}

void THSTensor_index_put_scalar_(Tensor tensor,
const int64_t* indexStarts,
const int64_t* indexEnds,
Expand Down
3 changes: 2 additions & 1 deletion src/Native/LibTorchSharp/THSTensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -683,7 +683,8 @@ EXPORT_API(void) THSTensor_index_put_(Tensor tensor,
const int64_t* indexSteps,
const Tensor* indexTensors,
const int indicesLength,
const Tensor value);
const Tensor value,
const bool accumulate = false);

EXPORT_API(Tensor) THSTensor_index_select(Tensor tensor, int64_t dim, Tensor index);

Expand Down
2 changes: 1 addition & 1 deletion src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
internal static extern void THSTensor_index_put_scalar_(IntPtr tensor, IntPtr indexStarts, IntPtr indexEnds, IntPtr indexSteps, IntPtr indexTensors, int indicesLength, IntPtr value);

[DllImport("LibTorchSharp")]
internal static extern void THSTensor_index_put_(IntPtr tensor, IntPtr indexStarts, IntPtr indexEnds, IntPtr indexSteps, IntPtr indexTensors, int indicesLength, IntPtr value);
internal static extern void THSTensor_index_put_(IntPtr tensor, IntPtr indexStarts, IntPtr indexEnds, IntPtr indexSteps, IntPtr indexTensors, int indicesLength, IntPtr value, [MarshalAs(UnmanagedType.U1)] bool accumulate);

[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_get1(IntPtr handle, long i1);
Expand Down
25 changes: 24 additions & 1 deletion src/TorchSharp/Tensor/Tensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1604,7 +1604,25 @@ public Tensor index_put_(Tensor value, params TensorIndex[] indices)
unsafe {
fixed (long* ptrKindAndStarts = arrKindAndStarts, ptrStops = arrStops, ptrSteps = arrSteps) {
fixed (IntPtr* ptrTensors = arrTensors) {
NativeMethods.THSTensor_index_put_(Handle, (IntPtr)ptrKindAndStarts, (IntPtr)ptrStops, (IntPtr)ptrSteps, (IntPtr)ptrTensors, indices.Length, value.Handle);
NativeMethods.THSTensor_index_put_(Handle, (IntPtr)ptrKindAndStarts, (IntPtr)ptrStops, (IntPtr)ptrSteps, (IntPtr)ptrTensors, indices.Length, value.Handle, false);
CheckForErrors();
GC.KeepAlive(indices); // don't release or finalize Tensor indices whose handles have been put into ptrTensors
GC.KeepAlive(value);
return this;
}
}
}
}

public Tensor index_put_(Tensor value, TensorIndex[] indices, bool accumulate = false)
{
EncodeIndices(indices, out var arrKindAndStarts, out var arrStops, out var arrSteps, out var arrTensors);
if (accumulate && arrTensors == null)
throw new Exception("Invalid 'indices' parameter. Must be an array of TensorIndex objects containing tensors with indices that match the shape of the tensor to update");
unsafe {
fixed (long* ptrKindAndStarts = arrKindAndStarts, ptrStops = arrStops, ptrSteps = arrSteps) {
fixed (IntPtr* ptrTensors = arrTensors) {
NativeMethods.THSTensor_index_put_(Handle, (IntPtr)ptrKindAndStarts, (IntPtr)ptrStops, (IntPtr)ptrSteps, (IntPtr)ptrTensors, indices.Length, value.Handle, accumulate);
CheckForErrors();
GC.KeepAlive(indices); // don't release or finalize Tensor indices whose handles have been put into ptrTensors
GC.KeepAlive(value);
Expand All @@ -1622,6 +1640,11 @@ public Tensor index_put_(Tensor value, params Tensor[] indices)
return index_put_(value, indices.Select(t => TensorIndex.Tensor(t)).ToArray());
}

public Tensor index_put_(Tensor value, Tensor[] indices, bool accumulate = false)
{
return index_put_(value, indices.Select(t => TensorIndex.Tensor(t)).ToArray(), accumulate);
}


/// <summary>
/// Index into the tensor using Python-like indexing expressions and place a scalar tensor at the index.
Expand Down
150 changes: 112 additions & 38 deletions test/TorchSharpTest/TestTorchTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -290,17 +290,13 @@ public void TestTensorDefaultPrint()
Tensor t = torch.zeros(2, 2);
string expectedOutput = t.ToString(TensorStringStyle.Default) + Environment.NewLine;
var originalOut = Console.Out;
using (var sw = new StringWriter())
{
try
{
using (var sw = new StringWriter()) {
try {
Console.SetOut(sw);
t.print();
var result = sw.ToString();
Assert.Equal(expectedOutput, result);
}
finally
{
} finally {
Console.SetOut(originalOut);
}
}
Expand Down Expand Up @@ -807,7 +803,7 @@ public void FromArrayFactory()
() => Assert.Equal(1, t.ndim),
() => Assert.Equal(ScalarType.Byte, t.dtype));
}

{
var array = new Memory<long>(new long[8]);
using var t = torch.tensor(array, new long[] { 8 }, device: device);
Expand All @@ -816,11 +812,11 @@ public void FromArrayFactory()
() => Assert.Equal(1, t.ndim),
() => Assert.Equal(ScalarType.Int64, t.dtype));
}

{
var array = new long[18];
array[5] = 17;
var mem = new Memory<long>(array,4,10);
var mem = new Memory<long>(array, 4, 10);
using var t = torch.tensor(mem, new long[] { 8 }, device: device);
Assert.Multiple(
() => Assert.Equal(device.type, t.device_type),
Expand Down Expand Up @@ -3165,6 +3161,86 @@ public void IndexFill2()
() => Assert.Equal(1.0, x[2, 2].ToSingle()));
}

[Fact]
[TestOf(nameof(Tensor.index_put_))]
public void IndexPutOneValueOneIndex()
{
using var _ = NewDisposeScope();

var tensor = ones(5);
var indices = new TensorIndex[] { TensorIndex.Tensor(1) };
var values = torch.tensor(5.0f);

// default accumulate value is false, should only replace value at index 1 with 5
tensor.index_put_(values, indices);
Assert.True(tensor.Equals(torch.tensor(new float[] { 1.0f, 5.0f, 1.0f, 1.0f, 1.0f })));

tensor = ones(5);
// accumulate value is false, explicitly set, should only replace value at index 1 with 5
tensor.index_put_(values, indices, accumulate: false);
Assert.True(tensor.Equals(torch.tensor(new float[] { 1.0f, 5.0f, 1.0f, 1.0f, 1.0f })));

tensor = ones(5);
// accumulate value is true, should add value to index 1, 1 + 5 = 6
tensor.index_put_(values, indices, accumulate: true);
Assert.True(tensor.Equals(torch.tensor(new float[] { 1.0f, 6.0f, 1.0f, 1.0f, 1.0f })));
}

[Fact]
[TestOf(nameof(Tensor.index_put_))]
public void IndexPutOneValueMultipleIndexes()
{
using var _ = NewDisposeScope();

var tensor = ones(5);
var indices = new TensorIndex[] { TensorIndex.Tensor(new long[] {1, 2}) };
var values = torch.tensor(10.0f);

// default accumulate value is false, should only replace value at given indexes
tensor.index_put_(values, indices);
Assert.True(tensor.Equals(torch.tensor(new float[] { 1.0f, 10.0f, 10.0f, 1.0f, 1.0f })));

tensor = ones(5);
// accumulate value is true, should add value to given indexes
tensor.index_put_(values, indices, true);
Assert.True(tensor.Equals(torch.tensor(new float[] { 1.0f, 11.0f, 11.0f, 1.0f, 1.0f })));

// accumulate value is false, explicitly set, should replace value at given indexes
tensor.index_put_(values, indices, false);
Assert.True(tensor.Equals(torch.tensor(new float[] { 1.0f, 10.0f, 10.0f, 1.0f, 1.0f })));
}

[Fact]
[TestOf(nameof(Tensor.index_put_))]
public void IndexPutMultipleValuesMultipleIndexes()
{
using var _ = NewDisposeScope();

var tensor = ones(5, 2);
var indices = new TensorIndex[]
{
TensorIndex.Tensor(new long[] { 1, 2, 0, 3 }), // for first tensor dimension (row)
TensorIndex.Tensor(new long[] { 0, 1, 0, 0 }) // for second tensor dimension (column)
};
var values = torch.tensor(new float[] { 3.0f, 4.0f, 5.0f, 10f });

// default accumulate value is false, should only replace values at given indices with 3, 4, 5, 10
// Indexes to be replaced: (1, 0) -> 3.0, (2, 1) -> 4.0, (0, 0) -> 5.0, (3, 0) -> 10.0
tensor.index_put_(values, indices);
Assert.True(tensor.Equals(torch.tensor(new float[,] { { 5.0f, 1.0f }, { 3.0f, 1.0f }, { 1.0f, 4.0f }, { 10.0f, 1.0f }, { 1.0f, 1.0f } })));

tensor = ones(5, 2);
// accumulate value is true, should perform addition at given indices, 1 + 3 = 4, 1 + 4 = 5, 1 + 5 = 6, 1 + 10 = 11
// Indexes to be replaced: (1, 0) -> 4.0, (2, 1) -> 5.0, (0, 0) -> 6.0, (3, 0) -> 11.0
tensor.index_put_(values, indices, true);
Assert.True(tensor.Equals(torch.tensor(new float[,] { { 6.0f, 1.0f }, { 4.0f, 1.0f }, { 1.0f, 5.0f }, { 11.0f, 1.0f }, { 1.0f, 1.0f } })));

// accumulate value is false, explicitly set, should only replace values at given indices with 3, 4, 5, 10
// Indexes to be replaced: (1, 0) -> 3.0, (2, 1) -> 4.0, (0, 0) -> 5.0, (3, 0) -> 10.0
tensor.index_put_(values, indices, false);
Assert.True(tensor.Equals(torch.tensor(new float[,] { { 5.0f, 1.0f }, { 3.0f, 1.0f }, { 1.0f, 4.0f }, { 10.0f, 1.0f }, { 1.0f, 1.0f } })));
}

[Fact]
[TestOf(nameof(TensorExtensionMethods.ToTensor))]
public void ScalarToTensor()
Expand Down Expand Up @@ -3257,7 +3333,7 @@ public void ScalarToTensor3()
[TestOf(nameof(Tensor))]
public void ScalarToTensorDoesNotLeakMemory()
{
AssertTensorDoesNotLeak(()=>{
AssertTensorDoesNotLeak(() => {
Tensor tensor = 1;
return tensor;
});
Expand All @@ -3273,20 +3349,20 @@ public void ScalarToTensorDoesNotLeakMemory()
[TestOf(nameof(Tensor))]
public void ScalarArrayToTensorDoesNotLeakMemory()
{
AssertTensorDoesNotLeak(() => (new byte[]{1}).ToTensor(new long[]{1}));
AssertTensorDoesNotLeak(() => (new sbyte[]{-1}).ToTensor(new long[]{1}));
AssertTensorDoesNotLeak(() => (new short[]{-1}).ToTensor(new long[]{1}));
AssertTensorDoesNotLeak(() => (new long[]{-1}).ToTensor(new long[]{1}));
AssertTensorDoesNotLeak(() => (new float[]{-1}).ToTensor(new long[]{1}));
AssertTensorDoesNotLeak(() => (new double[]{-1}).ToTensor(new long[]{1}));
AssertTensorDoesNotLeak(() => (new byte[] { 1 }).ToTensor(new long[] { 1 }));
AssertTensorDoesNotLeak(() => (new sbyte[] { -1 }).ToTensor(new long[] { 1 }));
AssertTensorDoesNotLeak(() => (new short[] { -1 }).ToTensor(new long[] { 1 }));
AssertTensorDoesNotLeak(() => (new long[] { -1 }).ToTensor(new long[] { 1 }));
AssertTensorDoesNotLeak(() => (new float[] { -1 }).ToTensor(new long[] { 1 }));
AssertTensorDoesNotLeak(() => (new double[] { -1 }).ToTensor(new long[] { 1 }));
}

[Fact]
[TestOf(nameof(Tensor))]
public void ComplexNumberOfDoubleDoesNotLeakMemory()
{
AssertTensorDoesNotLeak(() => ( torch.tensor((double)-1, (double)-2)));
AssertTensorDoesNotLeak(() => ( torch.tensor(((double)-1, (double)-2))));
AssertTensorDoesNotLeak(() => (torch.tensor((double)-1, (double)-2)));
AssertTensorDoesNotLeak(() => (torch.tensor(((double)-1, (double)-2))));
}

[Fact]
Expand Down Expand Up @@ -4106,7 +4182,7 @@ public void CastMoveAndDisposeAfter()
Assert.True(input.IsInvalid);
Assert.False(cast.IsInvalid);
// make sure we can access the values
Assert.Equal(1, cast[0].ToInt32());
Assert.Equal(1, cast[0].ToInt32());
}
if (torch.cuda.is_available()) {
{
Expand Down Expand Up @@ -8517,28 +8593,27 @@ public void DefaultDTypeCreation()
{
var dt = torch.get_default_dtype();

var t = torch.zeros(5,5);
var t = torch.zeros(5, 5);
Assert.Equal(torch.float32, t.dtype);

try {
torch.set_default_dtype(torch.float64);
t = torch.zeros(5,5);
torch.set_default_dtype(torch.float64);

t = torch.zeros(5, 5);
Assert.Equal(torch.float64, t.dtype);

t = torch.ones(5,5);
t = torch.ones(5, 5);
Assert.Equal(torch.float64, t.dtype);

t = torch.rand(5,5);
t = torch.rand(5, 5);
Assert.Equal(torch.float64, t.dtype);

t = torch.randn(5,5);
t = torch.randn(5, 5);
Assert.Equal(torch.float64, t.dtype);

t = torch.logspace(5, 15, 20);
Assert.Equal(torch.float64, t.dtype);
}
finally {
} finally {
torch.set_default_dtype(dt);
}
}
Expand All @@ -8548,28 +8623,27 @@ public void DefaultDeviceCreation()
{
var dt = torch.get_default_device();

var t = torch.zeros(5,5);
var t = torch.zeros(5, 5);
Assert.Equal(DeviceType.CPU, t.device_type);

try {
torch.set_default_device(torch.META);
t = torch.zeros(5,5);
torch.set_default_device(torch.META);

t = torch.zeros(5, 5);
Assert.Equal(DeviceType.META, t.device_type);

t = torch.ones(5,5);
t = torch.ones(5, 5);
Assert.Equal(DeviceType.META, t.device_type);

t = torch.rand(5,5);
t = torch.rand(5, 5);
Assert.Equal(DeviceType.META, t.device_type);

t = torch.randn(5,5);
t = torch.randn(5, 5);
Assert.Equal(DeviceType.META, t.device_type);

t = torch.logspace(5, 15, 20);
Assert.Equal(DeviceType.META, t.device_type);
}
finally {
} finally {
torch.set_default_device(dt);
}
}
Expand Down