diff --git a/RELEASENOTES.md b/RELEASENOTES.md
index 49bb3583d..721dd1a7f 100644
--- a/RELEASENOTES.md
+++ b/RELEASENOTES.md
@@ -7,6 +7,10 @@ __Bug Fixes__:
#1426 Sequential.eval() does not put model into eval mode
`torch.optim.lr_scheduler.LinearLR` `end_factor` default has been corrected, is now 1.0.
+
+__API Changes__:
+
+#1374 Add accumulate to index_put_
`torch.optim.lr_scheduler.PolynomialLR` `power` type has been corrected, is now double.
# NuGet Version 0.105.0
diff --git a/src/Native/LibTorchSharp/THSTensor.cpp b/src/Native/LibTorchSharp/THSTensor.cpp
index 5cff3ab82..13cd5787e 100644
--- a/src/Native/LibTorchSharp/THSTensor.cpp
+++ b/src/Native/LibTorchSharp/THSTensor.cpp
@@ -837,6 +837,31 @@ void THSTensor_index_put_(Tensor tensor,
CATCH(tensor->index_put_(indices, *value););
}
+void THSTensor_index_put_(Tensor tensor,
+ const int64_t* indexStarts,
+ const int64_t* indexEnds,
+ const int64_t* indexSteps,
+ const Tensor* indexTensors,
+ const int indicesLength,
+ const Tensor value,
+ const bool accumulate)
+{
+ at::indexing::TensorIndex* indicesArray = (at::indexing::TensorIndex*)alloca(indicesLength * sizeof(at::indexing::TensorIndex));
+ memset(indicesArray, 0, indicesLength * sizeof(at::indexing::TensorIndex));
+ completeTensorIndices(indexStarts, indexEnds, indexSteps, indexTensors, indicesArray, indicesLength);
+ auto indices = at::ArrayRef(indicesArray, indicesLength);
+ if (accumulate) {
+ c10::List> indicesList = c10::List>();
+ for (int i = 0; i < indicesLength; i++) {
+ indicesList.push_back(c10::optional(*indexTensors[i]));
+ }
+ CATCH(tensor->index_put_(indicesList, *value, accumulate););
+ }
+ else {
+ CATCH(tensor->index_put_(indices, *value););
+ }
+}
+
void THSTensor_index_put_scalar_(Tensor tensor,
const int64_t* indexStarts,
const int64_t* indexEnds,
diff --git a/src/Native/LibTorchSharp/THSTensor.h b/src/Native/LibTorchSharp/THSTensor.h
index ebbdf8302..0925cd4e0 100644
--- a/src/Native/LibTorchSharp/THSTensor.h
+++ b/src/Native/LibTorchSharp/THSTensor.h
@@ -683,7 +683,8 @@ EXPORT_API(void) THSTensor_index_put_(Tensor tensor,
const int64_t* indexSteps,
const Tensor* indexTensors,
const int indicesLength,
- const Tensor value);
+ const Tensor value,
+ const bool accumulate = false);
EXPORT_API(Tensor) THSTensor_index_select(Tensor tensor, int64_t dim, Tensor index);
diff --git a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
index 65018f5a5..bb568ae68 100644
--- a/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
+++ b/src/TorchSharp/PInvoke/LibTorchSharp.THSTensor.cs
@@ -410,7 +410,7 @@ internal static extern IntPtr THSTensor_upsample_nearest3d(IntPtr input,
internal static extern void THSTensor_index_put_scalar_(IntPtr tensor, IntPtr indexStarts, IntPtr indexEnds, IntPtr indexSteps, IntPtr indexTensors, int indicesLength, IntPtr value);
[DllImport("LibTorchSharp")]
- internal static extern void THSTensor_index_put_(IntPtr tensor, IntPtr indexStarts, IntPtr indexEnds, IntPtr indexSteps, IntPtr indexTensors, int indicesLength, IntPtr value);
+ internal static extern void THSTensor_index_put_(IntPtr tensor, IntPtr indexStarts, IntPtr indexEnds, IntPtr indexSteps, IntPtr indexTensors, int indicesLength, IntPtr value, [MarshalAs(UnmanagedType.U1)] bool accumulate);
[DllImport("LibTorchSharp")]
internal static extern IntPtr THSTensor_get1(IntPtr handle, long i1);
diff --git a/src/TorchSharp/Tensor/Tensor.cs b/src/TorchSharp/Tensor/Tensor.cs
index 41b007c9e..c17995a52 100644
--- a/src/TorchSharp/Tensor/Tensor.cs
+++ b/src/TorchSharp/Tensor/Tensor.cs
@@ -1604,7 +1604,25 @@ public Tensor index_put_(Tensor value, params TensorIndex[] indices)
unsafe {
fixed (long* ptrKindAndStarts = arrKindAndStarts, ptrStops = arrStops, ptrSteps = arrSteps) {
fixed (IntPtr* ptrTensors = arrTensors) {
- NativeMethods.THSTensor_index_put_(Handle, (IntPtr)ptrKindAndStarts, (IntPtr)ptrStops, (IntPtr)ptrSteps, (IntPtr)ptrTensors, indices.Length, value.Handle);
+ NativeMethods.THSTensor_index_put_(Handle, (IntPtr)ptrKindAndStarts, (IntPtr)ptrStops, (IntPtr)ptrSteps, (IntPtr)ptrTensors, indices.Length, value.Handle, false);
+ CheckForErrors();
+ GC.KeepAlive(indices); // don't release or finalize Tensor indices whose handles have been put into ptrTensors
+ GC.KeepAlive(value);
+ return this;
+ }
+ }
+ }
+ }
+
+ public Tensor index_put_(Tensor value, TensorIndex[] indices, bool accumulate = false)
+ {
+ EncodeIndices(indices, out var arrKindAndStarts, out var arrStops, out var arrSteps, out var arrTensors);
+ if (accumulate && arrTensors == null)
+ throw new Exception("Invalid 'indices' parameter. Must be an array of TensorIndex objects containing tensors with indices that match the shape of the tensor to update");
+ unsafe {
+ fixed (long* ptrKindAndStarts = arrKindAndStarts, ptrStops = arrStops, ptrSteps = arrSteps) {
+ fixed (IntPtr* ptrTensors = arrTensors) {
+ NativeMethods.THSTensor_index_put_(Handle, (IntPtr)ptrKindAndStarts, (IntPtr)ptrStops, (IntPtr)ptrSteps, (IntPtr)ptrTensors, indices.Length, value.Handle, accumulate);
CheckForErrors();
GC.KeepAlive(indices); // don't release or finalize Tensor indices whose handles have been put into ptrTensors
GC.KeepAlive(value);
@@ -1622,6 +1640,11 @@ public Tensor index_put_(Tensor value, params Tensor[] indices)
return index_put_(value, indices.Select(t => TensorIndex.Tensor(t)).ToArray());
}
+ public Tensor index_put_(Tensor value, Tensor[] indices, bool accumulate = false)
+ {
+ return index_put_(value, indices.Select(t => TensorIndex.Tensor(t)).ToArray(), accumulate);
+ }
+
///
/// Index into the tensor using Python-like indexing expressions and place a scalar tensor at the index.
diff --git a/test/TorchSharpTest/TestTorchTensor.cs b/test/TorchSharpTest/TestTorchTensor.cs
index 63cbdf59d..69ba732f3 100644
--- a/test/TorchSharpTest/TestTorchTensor.cs
+++ b/test/TorchSharpTest/TestTorchTensor.cs
@@ -290,17 +290,13 @@ public void TestTensorDefaultPrint()
Tensor t = torch.zeros(2, 2);
string expectedOutput = t.ToString(TensorStringStyle.Default) + Environment.NewLine;
var originalOut = Console.Out;
- using (var sw = new StringWriter())
- {
- try
- {
+ using (var sw = new StringWriter()) {
+ try {
Console.SetOut(sw);
t.print();
var result = sw.ToString();
Assert.Equal(expectedOutput, result);
- }
- finally
- {
+ } finally {
Console.SetOut(originalOut);
}
}
@@ -807,7 +803,7 @@ public void FromArrayFactory()
() => Assert.Equal(1, t.ndim),
() => Assert.Equal(ScalarType.Byte, t.dtype));
}
-
+
{
var array = new Memory(new long[8]);
using var t = torch.tensor(array, new long[] { 8 }, device: device);
@@ -816,11 +812,11 @@ public void FromArrayFactory()
() => Assert.Equal(1, t.ndim),
() => Assert.Equal(ScalarType.Int64, t.dtype));
}
-
+
{
var array = new long[18];
array[5] = 17;
- var mem = new Memory(array,4,10);
+ var mem = new Memory(array, 4, 10);
using var t = torch.tensor(mem, new long[] { 8 }, device: device);
Assert.Multiple(
() => Assert.Equal(device.type, t.device_type),
@@ -3165,6 +3161,86 @@ public void IndexFill2()
() => Assert.Equal(1.0, x[2, 2].ToSingle()));
}
+ [Fact]
+ [TestOf(nameof(Tensor.index_put_))]
+ public void IndexPutOneValueOneIndex()
+ {
+ using var _ = NewDisposeScope();
+
+ var tensor = ones(5);
+ var indices = new TensorIndex[] { TensorIndex.Tensor(1) };
+ var values = torch.tensor(5.0f);
+
+ // default accumulate value is false, should only replace value at index 1 with 5
+ tensor.index_put_(values, indices);
+ Assert.True(tensor.Equals(torch.tensor(new float[] { 1.0f, 5.0f, 1.0f, 1.0f, 1.0f })));
+
+ tensor = ones(5);
+ // accumulate value is false, explicitly set, should only replace value at index 1 with 5
+ tensor.index_put_(values, indices, accumulate: false);
+ Assert.True(tensor.Equals(torch.tensor(new float[] { 1.0f, 5.0f, 1.0f, 1.0f, 1.0f })));
+
+ tensor = ones(5);
+ // accumulate value is true, should add value to index 1, 1 + 5 = 6
+ tensor.index_put_(values, indices, accumulate: true);
+ Assert.True(tensor.Equals(torch.tensor(new float[] { 1.0f, 6.0f, 1.0f, 1.0f, 1.0f })));
+ }
+
+ [Fact]
+ [TestOf(nameof(Tensor.index_put_))]
+ public void IndexPutOneValueMultipleIndexes()
+ {
+ using var _ = NewDisposeScope();
+
+ var tensor = ones(5);
+ var indices = new TensorIndex[] { TensorIndex.Tensor(new long[] {1, 2}) };
+ var values = torch.tensor(10.0f);
+
+ // default accumulate value is false, should only replace value at given indexes
+ tensor.index_put_(values, indices);
+ Assert.True(tensor.Equals(torch.tensor(new float[] { 1.0f, 10.0f, 10.0f, 1.0f, 1.0f })));
+
+ tensor = ones(5);
+ // accumulate value is true, should add value to given indexes
+ tensor.index_put_(values, indices, true);
+ Assert.True(tensor.Equals(torch.tensor(new float[] { 1.0f, 11.0f, 11.0f, 1.0f, 1.0f })));
+
+ // accumulate value is false, explicitly set, should replace value at given indexes
+ tensor.index_put_(values, indices, false);
+ Assert.True(tensor.Equals(torch.tensor(new float[] { 1.0f, 10.0f, 10.0f, 1.0f, 1.0f })));
+ }
+
+ [Fact]
+ [TestOf(nameof(Tensor.index_put_))]
+ public void IndexPutMultipleValuesMultipleIndexes()
+ {
+ using var _ = NewDisposeScope();
+
+ var tensor = ones(5, 2);
+ var indices = new TensorIndex[]
+ {
+ TensorIndex.Tensor(new long[] { 1, 2, 0, 3 }), // for first tensor dimension (row)
+ TensorIndex.Tensor(new long[] { 0, 1, 0, 0 }) // for second tensor dimension (column)
+ };
+ var values = torch.tensor(new float[] { 3.0f, 4.0f, 5.0f, 10f });
+
+ // default accumulate value is false, should only replace values at given indices with 3, 4, 5, 10
+ // Indexes to be replaced: (1, 0) -> 3.0, (2, 1) -> 4.0, (0, 0) -> 5.0, (3, 0) -> 10.0
+ tensor.index_put_(values, indices);
+ Assert.True(tensor.Equals(torch.tensor(new float[,] { { 5.0f, 1.0f }, { 3.0f, 1.0f }, { 1.0f, 4.0f }, { 10.0f, 1.0f }, { 1.0f, 1.0f } })));
+
+ tensor = ones(5, 2);
+ // accumulate value is true, should perform addition at given indices, 1 + 3 = 4, 1 + 4 = 5, 1 + 5 = 6, 1 + 10 = 11
+ // Indexes to be replaced: (1, 0) -> 4.0, (2, 1) -> 5.0, (0, 0) -> 6.0, (3, 0) -> 11.0
+ tensor.index_put_(values, indices, true);
+ Assert.True(tensor.Equals(torch.tensor(new float[,] { { 6.0f, 1.0f }, { 4.0f, 1.0f }, { 1.0f, 5.0f }, { 11.0f, 1.0f }, { 1.0f, 1.0f } })));
+
+ // accumulate value is false, explicitly set, should only replace values at given indices with 3, 4, 5, 10
+ // Indexes to be replaced: (1, 0) -> 3.0, (2, 1) -> 4.0, (0, 0) -> 5.0, (3, 0) -> 10.0
+ tensor.index_put_(values, indices, false);
+ Assert.True(tensor.Equals(torch.tensor(new float[,] { { 5.0f, 1.0f }, { 3.0f, 1.0f }, { 1.0f, 4.0f }, { 10.0f, 1.0f }, { 1.0f, 1.0f } })));
+ }
+
[Fact]
[TestOf(nameof(TensorExtensionMethods.ToTensor))]
public void ScalarToTensor()
@@ -3257,7 +3333,7 @@ public void ScalarToTensor3()
[TestOf(nameof(Tensor))]
public void ScalarToTensorDoesNotLeakMemory()
{
- AssertTensorDoesNotLeak(()=>{
+ AssertTensorDoesNotLeak(() => {
Tensor tensor = 1;
return tensor;
});
@@ -3273,20 +3349,20 @@ public void ScalarToTensorDoesNotLeakMemory()
[TestOf(nameof(Tensor))]
public void ScalarArrayToTensorDoesNotLeakMemory()
{
- AssertTensorDoesNotLeak(() => (new byte[]{1}).ToTensor(new long[]{1}));
- AssertTensorDoesNotLeak(() => (new sbyte[]{-1}).ToTensor(new long[]{1}));
- AssertTensorDoesNotLeak(() => (new short[]{-1}).ToTensor(new long[]{1}));
- AssertTensorDoesNotLeak(() => (new long[]{-1}).ToTensor(new long[]{1}));
- AssertTensorDoesNotLeak(() => (new float[]{-1}).ToTensor(new long[]{1}));
- AssertTensorDoesNotLeak(() => (new double[]{-1}).ToTensor(new long[]{1}));
+ AssertTensorDoesNotLeak(() => (new byte[] { 1 }).ToTensor(new long[] { 1 }));
+ AssertTensorDoesNotLeak(() => (new sbyte[] { -1 }).ToTensor(new long[] { 1 }));
+ AssertTensorDoesNotLeak(() => (new short[] { -1 }).ToTensor(new long[] { 1 }));
+ AssertTensorDoesNotLeak(() => (new long[] { -1 }).ToTensor(new long[] { 1 }));
+ AssertTensorDoesNotLeak(() => (new float[] { -1 }).ToTensor(new long[] { 1 }));
+ AssertTensorDoesNotLeak(() => (new double[] { -1 }).ToTensor(new long[] { 1 }));
}
[Fact]
[TestOf(nameof(Tensor))]
public void ComplexNumberOfDoubleDoesNotLeakMemory()
{
- AssertTensorDoesNotLeak(() => ( torch.tensor((double)-1, (double)-2)));
- AssertTensorDoesNotLeak(() => ( torch.tensor(((double)-1, (double)-2))));
+ AssertTensorDoesNotLeak(() => (torch.tensor((double)-1, (double)-2)));
+ AssertTensorDoesNotLeak(() => (torch.tensor(((double)-1, (double)-2))));
}
[Fact]
@@ -4106,7 +4182,7 @@ public void CastMoveAndDisposeAfter()
Assert.True(input.IsInvalid);
Assert.False(cast.IsInvalid);
// make sure we can access the values
- Assert.Equal(1, cast[0].ToInt32());
+ Assert.Equal(1, cast[0].ToInt32());
}
if (torch.cuda.is_available()) {
{
@@ -8517,28 +8593,27 @@ public void DefaultDTypeCreation()
{
var dt = torch.get_default_dtype();
- var t = torch.zeros(5,5);
+ var t = torch.zeros(5, 5);
Assert.Equal(torch.float32, t.dtype);
try {
- torch.set_default_dtype(torch.float64);
-
- t = torch.zeros(5,5);
+ torch.set_default_dtype(torch.float64);
+
+ t = torch.zeros(5, 5);
Assert.Equal(torch.float64, t.dtype);
- t = torch.ones(5,5);
+ t = torch.ones(5, 5);
Assert.Equal(torch.float64, t.dtype);
- t = torch.rand(5,5);
+ t = torch.rand(5, 5);
Assert.Equal(torch.float64, t.dtype);
- t = torch.randn(5,5);
+ t = torch.randn(5, 5);
Assert.Equal(torch.float64, t.dtype);
t = torch.logspace(5, 15, 20);
Assert.Equal(torch.float64, t.dtype);
- }
- finally {
+ } finally {
torch.set_default_dtype(dt);
}
}
@@ -8548,28 +8623,27 @@ public void DefaultDeviceCreation()
{
var dt = torch.get_default_device();
- var t = torch.zeros(5,5);
+ var t = torch.zeros(5, 5);
Assert.Equal(DeviceType.CPU, t.device_type);
try {
- torch.set_default_device(torch.META);
-
- t = torch.zeros(5,5);
+ torch.set_default_device(torch.META);
+
+ t = torch.zeros(5, 5);
Assert.Equal(DeviceType.META, t.device_type);
- t = torch.ones(5,5);
+ t = torch.ones(5, 5);
Assert.Equal(DeviceType.META, t.device_type);
- t = torch.rand(5,5);
+ t = torch.rand(5, 5);
Assert.Equal(DeviceType.META, t.device_type);
- t = torch.randn(5,5);
+ t = torch.randn(5, 5);
Assert.Equal(DeviceType.META, t.device_type);
t = torch.logspace(5, 15, 20);
Assert.Equal(DeviceType.META, t.device_type);
- }
- finally {
+ } finally {
torch.set_default_device(dt);
}
}