From 9ea59a1c8de3e7bede00165447969b877f6f5d50 Mon Sep 17 00:00:00 2001 From: yueyinqiu Date: Fri, 15 Nov 2024 15:35:22 +0800 Subject: [PATCH 1/6] dataset interface --- src/TorchSharp/DataLoader.cs | 197 +++++++++++++++++++++++++- src/TorchSharp/Dataset.cs | 27 +++- src/TorchSharp/Utils/TensorDataset.cs | 12 +- 3 files changed, 216 insertions(+), 20 deletions(-) diff --git a/src/TorchSharp/DataLoader.cs b/src/TorchSharp/DataLoader.cs index b368451fb..0fc479ffb 100644 --- a/src/TorchSharp/DataLoader.cs +++ b/src/TorchSharp/DataLoader.cs @@ -77,6 +77,116 @@ public static Modules.IterableDataLoader DataLoader( num_worker, drop_last, disposeBatch, disposeDataset); } + + static IReadOnlyDictionary DictionaryDataLoaderCollate( + IEnumerable> dic, + torch.Device device) + { + using (torch.NewDisposeScope()) { + Dictionary batch = new(); + foreach (var x in dic.First().Keys) { + var t = cat(dic.Select(k => k[x].unsqueeze(0)).ToArray(), 0); + if (t.device_type != device.type || t.device_index != device.index) + t = t.to(device); + batch[x] = t.MoveToOuterDisposeScope(); + } + return batch; + } + } + + public static Modules.DataLoader< + IReadOnlyDictionary, + IReadOnlyDictionary + > + DataLoader( + IDataset> dataset, + int batchSize, IEnumerable shuffler, + Device device = null, + int num_worker = 1, bool drop_last = false, + bool disposeBatch = true, bool disposeDataset = true) + { + + return new( + dataset, + batchSize, DictionaryDataLoaderCollate, + shuffler, + device, + num_worker, drop_last, + disposeBatch, disposeDataset); + } + + public static Modules.DataLoader< + IReadOnlyDictionary, + IReadOnlyDictionary + > + DataLoader( + IDataset> dataset, + int batchSize, bool shuffle = false, + Device device = null, int? seed = null, + int num_worker = 1, bool drop_last = false, + bool disposeBatch = true, bool disposeDataset = true) + { + return new( + dataset, + batchSize, DictionaryDataLoaderCollate, shuffle, + device, seed, + num_worker, drop_last, + disposeBatch, disposeDataset); + } + + static IReadOnlyList ListDataLoaderCollate( + IReadOnlyList> dic, torch.Device device) + { + using (torch.NewDisposeScope()) { + List batch = new(); + for (var x = 0; x < dic[0].Count; x++) { + var t = cat(dic.Select(k => k[x].unsqueeze(0)).ToArray(), 0); + if (t.device_type != device.type || t.device_index != device.index) + t = t.to(device); + batch.Add(t.MoveToOuterDisposeScope()); + } + return batch; + } + } + public static Modules.DataLoader< + IReadOnlyList, + IReadOnlyList + > + DataLoader( + IDataset> dataset, + int batchSize, IEnumerable shuffler, + Device device = null, + int num_worker = 1, bool drop_last = false, + bool disposeBatch = true, bool disposeDataset = true) + { + return new( + dataset, + batchSize, ListDataLoaderCollate, + shuffler, + device, + num_worker, drop_last, + disposeBatch, disposeDataset); + } + + public static Modules.DataLoader< + IReadOnlyList, + IReadOnlyList + > + DataLoader( + IDataset> dataset, + int batchSize, bool shuffle = false, + Device device = null, int? seed = null, + int num_worker = 1, bool drop_last = false, + bool disposeBatch = true, bool disposeDataset = true) + { + return new( + dataset, + batchSize, ListDataLoaderCollate, + shuffle, + device, seed, + num_worker, drop_last, + disposeBatch, disposeDataset); + } } } } @@ -264,12 +374,12 @@ public IterableDataLoader( /// public class DataLoader : IEnumerable, IDisposable { - public Dataset dataset { get; } + public IDataset dataset { get; } public int batch_size { get; } public bool drop_last { get; } public IEnumerable sampler { get; } public int num_workers { get; } - public Func, Device, S> collate_fn { get; } + public Func, Device, S> collate_fn { get; } public Device Device { get; } public bool DisposeBatch { get; } @@ -295,7 +405,84 @@ public class DataLoader : IEnumerable, IDisposable /// Indicates whether to dispose the dataset when being disposed. /// public DataLoader( - Dataset dataset, + IDataset dataset, + int batchSize, + Func, torch.Device, S> collate_fn, + IEnumerable shuffler, + Device? device = null, + int num_worker = 1, + bool drop_last = false, + bool disposeBatch = true, + bool disposeDataset = true) + { + this.dataset = dataset; + this.batch_size = batchSize; + this.drop_last = drop_last; + this.Device = device ?? CPU; + this.sampler = shuffler; + this.num_workers = Math.Max(num_worker, 1); + this.collate_fn = collate_fn; + this.DisposeBatch = disposeBatch; + this.DisposeDataset = disposeDataset; + } + + /// + /// Pytorch style dataloader + /// + /// Dataset for create batch + /// Size of batch + /// Callback to merge items to make a batch + /// true if shuffle dataset, false for not + /// device for output tensor + /// Seed for generating shuffle + /// Count of worker + /// + /// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size. + /// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller. + /// + /// + /// Indicates whether to automatically dispose the collated tensors (a batch) after an iteration. + /// + /// + /// Indicates whether to dispose the dataset when being disposed. + /// + public DataLoader( + IDataset dataset, + int batchSize, + Func, torch.Device, S> collate_fn, + bool shuffle = false, + Device? device = null, + int? seed = null, + int num_worker = 1, + bool drop_last = false, + bool disposeBatch = true, + bool disposeDataset = true) : + this(dataset, batchSize, collate_fn, + shuffle ? new FisherYatesShuffler(dataset.Count, seed) : LongRange(dataset.Count), + device, num_worker, drop_last, disposeBatch, disposeDataset) + { } + + /// + /// Pytorch style dataloader + /// + /// Dataset for create batch + /// Size of batch + /// Callback to merge items make to a batch + /// device for output tensor + /// Shuffler for dataloader + /// Count of worker + /// + /// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size. + /// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller. + /// + /// + /// Indicates whether to automatically dispose the collated tensors after an iteration. + /// + /// + /// Indicates whether to dispose the dataset when being disposed. + /// + public DataLoader( + IDataset dataset, int batchSize, Func, torch.Device, S> collate_fn, IEnumerable shuffler, @@ -337,7 +524,7 @@ public DataLoader( /// Indicates whether to dispose the dataset when being disposed. /// public DataLoader( - Dataset dataset, + IDataset dataset, int batchSize, Func, torch.Device, S> collate_fn, bool shuffle = false, @@ -432,7 +619,7 @@ public bool MoveNext() .WithDegreeOfParallelism(loader.num_workers) .ForAll((i) => { using var getTensorScope = torch.NewDisposeScope(); - tensors[i] = loader.dataset.GetTensor(indices[i]); + tensors[i] = loader.dataset[indices[i]]; getTensorDisposables[i] = getTensorScope.DetachAllAndDispose(); }); diff --git a/src/TorchSharp/Dataset.cs b/src/TorchSharp/Dataset.cs index 8af054f3e..22c3ee32b 100644 --- a/src/TorchSharp/Dataset.cs +++ b/src/TorchSharp/Dataset.cs @@ -24,10 +24,28 @@ public abstract class IterableDataset : Dataset> { } + /// + /// The base interface for all Datasets. + /// + public interface IDataset : IDisposable + { + /// + /// Size of dataset + /// + long Count { get; } + + /// + /// Get tensor according to index + /// + /// Index for tensor + /// Tensors of index. DataLoader will catenate these tensors into batches. + T this[long index] { get; } + } + /// /// The base nterface for all Datasets. /// - public abstract class Dataset : IDisposable + public abstract class Dataset : IDataset, IDisposable { public void Dispose() { @@ -35,9 +53,7 @@ public void Dispose() GC.SuppressFinalize(this); } - /// - /// Size of dataset - /// + /// public abstract long Count { get; } /// @@ -47,6 +63,9 @@ public void Dispose() /// Tensors of index. DataLoader will catenate these tensors into batches. public abstract T GetTensor(long index); + /// + public T this[long index] => GetTensor(index); + protected virtual void Dispose(bool disposing) { } diff --git a/src/TorchSharp/Utils/TensorDataset.cs b/src/TorchSharp/Utils/TensorDataset.cs index e3dc4161a..883ce7c39 100644 --- a/src/TorchSharp/Utils/TensorDataset.cs +++ b/src/TorchSharp/Utils/TensorDataset.cs @@ -40,16 +40,6 @@ internal TensorDataset(torch.Tensor[] tensors) _tensors = tensors.Select(x => x.alias().DetachFromDisposeScope()).ToArray(); } - /// - /// Indexer - /// - public IList this[long index] { - - get { - return _tensors.Select(t => t[index]).ToList(); - } - } - /// /// Length of the dataset, i.e. the size of the first dimension. /// @@ -59,7 +49,7 @@ public override long Count { public override IList GetTensor(long index) { - return this[index]; + return _tensors.Select(t => t[index]).ToList(); } readonly torch.Tensor[] _tensors; From 32b08dda5d3a84ba56eeb7a50fb17ed6b65dec85 Mon Sep 17 00:00:00 2001 From: yueyinqiu Date: Fri, 15 Nov 2024 15:51:12 +0800 Subject: [PATCH 2/6] concat dataset --- src/TorchSharp/ConcatDataset.cs | 108 ++++++++++++++++++++++++++++++++ src/TorchSharp/DataLoader.cs | 1 + src/TorchSharp/Dataset.cs | 39 ++++++------ 3 files changed, 130 insertions(+), 18 deletions(-) create mode 100644 src/TorchSharp/ConcatDataset.cs diff --git a/src/TorchSharp/ConcatDataset.cs b/src/TorchSharp/ConcatDataset.cs new file mode 100644 index 000000000..09fb7dedf --- /dev/null +++ b/src/TorchSharp/ConcatDataset.cs @@ -0,0 +1,108 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Collections.Generic; +using System.Linq; +using TorchSharp.Modules; + +namespace TorchSharp +{ + public static partial class torch + { + public static partial class utils + { + public static partial class data + { + public static ConcatDataset ConcatDataset(IEnumerable> datasets) + { + return new ConcatDataset(datasets); + } + } + } + } + + namespace Modules + { + public class ConcatDataset : IDataset + { + private static IEnumerable Cumsum(IEnumerable> datasets) + { + var s = 0L; + foreach (var e in datasets) { + s += e.Count; + yield return s; + } + } + + private static long bisectRight(long[] a, long x) + { + var lo = 0; + var hi = a.Length; + while (lo < hi) { + var mid = (lo + hi) / 2; + if (x < a[mid]) + hi = mid; + else + lo = mid + 1; + } + return lo; + } + + // Here we have to use arrays since the given index is in Int64... + private readonly IDataset[] _datasets; + public IReadOnlyList> datasets => _datasets; + + private readonly long[] _cumulativeSizes; + public IReadOnlyList cumulative_sizes => _cumulativeSizes; + + private readonly bool leaveOpen; + + public ConcatDataset( + IEnumerable> datasets, + bool leaveOpen = false) + { + this._datasets = datasets.ToArray(); + if (this._datasets.Length is 0) + throw new ArgumentException( + "datasets should not be an empty iterable", nameof(datasets)); + + // PyTorch also says 'ConcatDataset does not support IterableDataset'. + // But it's not our torch.utils.data.IterableDataset in TorchSharp. + this._cumulativeSizes = Cumsum(datasets).ToArray(); + + this.leaveOpen = leaveOpen; + } + + public long Count => this._cumulativeSizes.Last(); + + public T this[long index] + { + get { + if (index < 0) { + if (-index > this.Count) { + throw new ArgumentException( + "absolute value of index should not exceed dataset length", + nameof(index)); + } + index = this.Count + index; + } + + var datasetIdx = bisectRight(this._cumulativeSizes, index); + long sampleIdx; + if (datasetIdx == 0) + sampleIdx = index; + else + sampleIdx = index - this._cumulativeSizes[datasetIdx - 1]; + return this._datasets[datasetIdx][sampleIdx]; + } + } + + public void Dispose() + { + if (!leaveOpen) { + foreach (var dataset in this._datasets) + dataset.Dispose(); + } + } + } + } +} \ No newline at end of file diff --git a/src/TorchSharp/DataLoader.cs b/src/TorchSharp/DataLoader.cs index 0fc479ffb..96b9e08b5 100644 --- a/src/TorchSharp/DataLoader.cs +++ b/src/TorchSharp/DataLoader.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; +using TorchSharp.Modules; using TorchSharp.Utils; namespace TorchSharp diff --git a/src/TorchSharp/Dataset.cs b/src/TorchSharp/Dataset.cs index 22c3ee32b..363cd41c3 100644 --- a/src/TorchSharp/Dataset.cs +++ b/src/TorchSharp/Dataset.cs @@ -24,24 +24,6 @@ public abstract class IterableDataset : Dataset> { } - /// - /// The base interface for all Datasets. - /// - public interface IDataset : IDisposable - { - /// - /// Size of dataset - /// - long Count { get; } - - /// - /// Get tensor according to index - /// - /// Index for tensor - /// Tensors of index. DataLoader will catenate these tensors into batches. - T this[long index] { get; } - } - /// /// The base nterface for all Datasets. /// @@ -73,4 +55,25 @@ protected virtual void Dispose(bool disposing) } } } + + namespace Modules + { + /// + /// The base interface for all Datasets. + /// + public interface IDataset : IDisposable + { + /// + /// Size of dataset + /// + long Count { get; } + + /// + /// Get tensor according to index + /// + /// Index for tensor + /// Tensors of index. DataLoader will catenate these tensors into batches. + T this[long index] { get; } + } + } } \ No newline at end of file From f724dcb222ee94372da68b2cc8d1129bd8b5374a Mon Sep 17 00:00:00 2001 From: yueyinqiu Date: Fri, 15 Nov 2024 15:56:48 +0800 Subject: [PATCH 3/6] unit test for concat dataset --- src/TorchSharp/ConcatDataset.cs | 2 +- src/TorchSharp/Dataset.cs | 2 +- test/TorchSharpTest/TestDataLoader.cs | 86 +++++++++++++++++++++++++++ 3 files changed, 88 insertions(+), 2 deletions(-) diff --git a/src/TorchSharp/ConcatDataset.cs b/src/TorchSharp/ConcatDataset.cs index 09fb7dedf..54036b657 100644 --- a/src/TorchSharp/ConcatDataset.cs +++ b/src/TorchSharp/ConcatDataset.cs @@ -22,7 +22,7 @@ public static ConcatDataset ConcatDataset(IEnumerable> dataset namespace Modules { - public class ConcatDataset : IDataset + public sealed class ConcatDataset : IDataset { private static IEnumerable Cumsum(IEnumerable> datasets) { diff --git a/src/TorchSharp/Dataset.cs b/src/TorchSharp/Dataset.cs index 363cd41c3..5ff5fe6fe 100644 --- a/src/TorchSharp/Dataset.cs +++ b/src/TorchSharp/Dataset.cs @@ -27,7 +27,7 @@ public abstract class IterableDataset : Dataset> /// /// The base nterface for all Datasets. /// - public abstract class Dataset : IDataset, IDisposable + public abstract class Dataset : Modules.IDataset, IDisposable { public void Dispose() { diff --git a/test/TorchSharpTest/TestDataLoader.cs b/test/TorchSharpTest/TestDataLoader.cs index 2ffce70ec..79f13c1ae 100644 --- a/test/TorchSharpTest/TestDataLoader.cs +++ b/test/TorchSharpTest/TestDataLoader.cs @@ -1,7 +1,10 @@ // Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. using System.Collections.Generic; +using System.Linq; +using TorchSharp.Modules; using Xunit; +using static TorchSharp.torch; namespace TorchSharp @@ -27,6 +30,23 @@ private class TestIterableDataset : torch.utils.data.IterableDataset } } + private class TestDatasetFromEnumerable : IDataset + { + private readonly T[] values; + public TestDatasetFromEnumerable(IEnumerable values) + { + this.values = values.ToArray(); + this.Disposed = false; + } + public bool Disposed { get; set; } + public T this[long index] => values[index]; + public long Count => values.LongLength; + public void Dispose() + { + this.Disposed = true; + } + } + [Fact] public void DatasetTest() { @@ -230,5 +250,71 @@ public void CustomSeedTest() iterator.Dispose(); iterator2.Dispose(); } + + [Fact] + public void ConcatDatasetTest() + { + using var dataset1 = new TestDatasetFromEnumerable<(int, int)>(new[] { + (1, 1), // dataset 1 value 1 + (1, 2), // dataset 1 value 2 + (1, 3), + }); + using var dataset2 = new TestDatasetFromEnumerable<(int, int)>(new[] { + (2, 1), + (2, 2), + }); + using var dataset3 = new TestDatasetFromEnumerable<(int, int)>(new[] { + (3, 1), + (3, 2), + (3, 3), + (3, 4), + }); + + using var dataset = new ConcatDataset<(int, int)>(new[] { + dataset1, dataset2, dataset3 + }); + + Assert.Equal(3 + 2 + 4, dataset.Count); + + Assert.Equal((1, 1), dataset[0]); + Assert.Equal((1, 2), dataset[1]); + Assert.Equal((1, 3), dataset[2]); + Assert.Equal((2, 1), dataset[3]); + Assert.Equal((2, 2), dataset[4]); + Assert.Equal((3, 1), dataset[5]); + Assert.Equal((3, 2), dataset[6]); + Assert.Equal((3, 3), dataset[7]); + Assert.Equal((3, 4), dataset[8]); + + Assert.Equal((1, 1), dataset[-9]); + Assert.Equal((1, 2), dataset[-8]); + Assert.Equal((1, 3), dataset[-7]); + Assert.Equal((2, 1), dataset[-6]); + Assert.Equal((2, 2), dataset[-5]); + Assert.Equal((3, 1), dataset[-4]); + Assert.Equal((3, 2), dataset[-3]); + Assert.Equal((3, 3), dataset[-2]); + Assert.Equal((3, 4), dataset[-1]); + + Assert.False(dataset1.Disposed); + Assert.False(dataset2.Disposed); + Assert.False(dataset3.Disposed); + dataset.Dispose(); + Assert.True(dataset1.Disposed); + Assert.True(dataset2.Disposed); + Assert.True(dataset3.Disposed); + } + + [Fact] + public void ConcatDatasetDataLoader() + { + using var dataset = torch.utils.data.ConcatDataset( + new[] { + new TestDataset(), + new TestDataset(), + }); + var dataloader = torch.utils.data.DataLoader(dataset, 10, false); + Assert.Equal(2, dataloader.Count); + } } } \ No newline at end of file From ee0b6aabadcf8e62c96816bb1a31bf280d465eec Mon Sep 17 00:00:00 2001 From: yueyinqiu Date: Fri, 15 Nov 2024 16:30:42 +0800 Subject: [PATCH 4/6] Update RELEASENOTES.md --- RELEASENOTES.md | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/RELEASENOTES.md b/RELEASENOTES.md index f475beb79..7d381d837 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -17,6 +17,8 @@ The names of several arguments have been changed to align better with Pytorch na The argument defaults for `torch.diagonal()` and `Tensor.diagonal()` arguments have been corrected. +Type of `DataLoaders.dataset` has been changed to `IDataset`.
+ __Issues fixed__: #1397 Look into whether parameter creation from a tensor leads to incorrect dispose scope statistics. This bug was discovered during testing of the PR.
@@ -26,7 +28,12 @@ __Issues fixed__: __API Changes__: - #1382: Add support for torch.nn.functional.normalize
+#1382: Add support for torch.nn.functional.normalize
+ +Add support for torch.utils.data.ConcatDataset.
+A new interface `IDataset` has been added, and now `Dataset` implements it.
+More overloads of DataLoader() has been added, to accept `IDataset`.
+Type of `DataLoaders.collate_fn` has been changed to `Func, Device, S>`.
# NuGet Version 0.103.1 From 6220336d4de29c83ddc14e6636287edcac4d4ded Mon Sep 17 00:00:00 2001 From: yueyinqiu Date: Wed, 11 Dec 2024 16:05:39 +0800 Subject: [PATCH 5/6] Update RELEASENOTES.md --- RELEASENOTES.md | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/RELEASENOTES.md b/RELEASENOTES.md index 3453635d4..14566cd9f 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -2,6 +2,21 @@ Releases, starting with 9/2/2021, are listed with the most recent release at the top. +# Next Version + +__Breaking Changes__: + +Type of `DataLoaders.dataset` has been changed to `IDataset`.
+ +__Issues fixed__: + +__API Changes__: + +Add support for torch.utils.data.ConcatDataset.
+A new interface `IDataset` has been added, and now `Dataset` implements it.
+More overloads of DataLoader() has been added, to accept `IDataset`.
+Type of `DataLoaders.collate_fn` has been changed to `Func, Device, S>`.
+ # NuGet Version 0.104.0 This is a big change in implementation, but not as big in API surface area. Many of the builtin modules, but not all, were re-implemented in managed code calling into native code via the functional APIs. This has several advantages: @@ -17,8 +32,6 @@ The names of several arguments have been changed to align better with Pytorch na The argument defaults for `torch.diagonal()` and `Tensor.diagonal()` arguments have been corrected.
The default `newLine` for `str`, `jlstr`, `npstr`, `cstr` and `print` have been corrected.
-Type of `DataLoaders.dataset` has been changed to `IDataset`.
- __Issues fixed__: #1397 Look into whether parameter creation from a tensor leads to incorrect dispose scope statistics. This bug was discovered during testing of the PR.
@@ -30,11 +43,6 @@ __API Changes__: #1382: Add support for torch.nn.functional.normalize
-Add support for torch.utils.data.ConcatDataset.
-A new interface `IDataset` has been added, and now `Dataset` implements it.
-More overloads of DataLoader() has been added, to accept `IDataset`.
-Type of `DataLoaders.collate_fn` has been changed to `Func, Device, S>`.
- # NuGet Version 0.103.1 __Breaking Changes__: From abe85e3184b6e6d78992839ce7de6af18bc27895 Mon Sep 17 00:00:00 2001 From: yueyinqiu Date: Wed, 11 Dec 2024 16:05:46 +0800 Subject: [PATCH 6/6] Update RELEASENOTES.md --- RELEASENOTES.md | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/RELEASENOTES.md b/RELEASENOTES.md index 14566cd9f..d91c89382 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -6,8 +6,6 @@ Releases, starting with 9/2/2021, are listed with the most recent release at the __Breaking Changes__: -Type of `DataLoaders.dataset` has been changed to `IDataset`.
- __Issues fixed__: __API Changes__: @@ -15,6 +13,7 @@ __API Changes__: Add support for torch.utils.data.ConcatDataset.
A new interface `IDataset` has been added, and now `Dataset` implements it.
More overloads of DataLoader() has been added, to accept `IDataset`.
+Type of `DataLoaders.dataset` has been changed to `IDataset`.
Type of `DataLoaders.collate_fn` has been changed to `Func, Device, S>`.
# NuGet Version 0.104.0