diff --git a/RELEASENOTES.md b/RELEASENOTES.md index 890a1168e..9a474c034 100644 --- a/RELEASENOTES.md +++ b/RELEASENOTES.md @@ -2,6 +2,15 @@ Releases, starting with 9/2/2021, are listed with the most recent release at the top. + +# Next Version + +- Add support for torch.utils.data.ConcatDataset. +- A new interface `IDataset` has been added, and now `Dataset` implements it. +- More overloads of DataLoader() has been added, to accept `IDataset`. +- Type of `DataLoaders.dataset` has been changed to `IDataset`. +- Type of `DataLoaders.collate_fn` has been changed to `Func, Device, S>`. + # NuGet Version 0.105.0 Move to libtorch 2.5.1. As with the 2.4.0 release, MacOS / Intel is no longer supported by libtorch, so TorchSharp doesn, either. @@ -30,7 +39,7 @@ __Issues fixed__: __API Changes__: - #1382: Add support for torch.nn.functional.normalize
+#1382: Add support for torch.nn.functional.normalize
# NuGet Version 0.103.1 diff --git a/src/TorchSharp/ConcatDataset.cs b/src/TorchSharp/ConcatDataset.cs new file mode 100644 index 000000000..54036b657 --- /dev/null +++ b/src/TorchSharp/ConcatDataset.cs @@ -0,0 +1,108 @@ +// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. +using System; +using System.Collections.Generic; +using System.Linq; +using TorchSharp.Modules; + +namespace TorchSharp +{ + public static partial class torch + { + public static partial class utils + { + public static partial class data + { + public static ConcatDataset ConcatDataset(IEnumerable> datasets) + { + return new ConcatDataset(datasets); + } + } + } + } + + namespace Modules + { + public sealed class ConcatDataset : IDataset + { + private static IEnumerable Cumsum(IEnumerable> datasets) + { + var s = 0L; + foreach (var e in datasets) { + s += e.Count; + yield return s; + } + } + + private static long bisectRight(long[] a, long x) + { + var lo = 0; + var hi = a.Length; + while (lo < hi) { + var mid = (lo + hi) / 2; + if (x < a[mid]) + hi = mid; + else + lo = mid + 1; + } + return lo; + } + + // Here we have to use arrays since the given index is in Int64... + private readonly IDataset[] _datasets; + public IReadOnlyList> datasets => _datasets; + + private readonly long[] _cumulativeSizes; + public IReadOnlyList cumulative_sizes => _cumulativeSizes; + + private readonly bool leaveOpen; + + public ConcatDataset( + IEnumerable> datasets, + bool leaveOpen = false) + { + this._datasets = datasets.ToArray(); + if (this._datasets.Length is 0) + throw new ArgumentException( + "datasets should not be an empty iterable", nameof(datasets)); + + // PyTorch also says 'ConcatDataset does not support IterableDataset'. + // But it's not our torch.utils.data.IterableDataset in TorchSharp. + this._cumulativeSizes = Cumsum(datasets).ToArray(); + + this.leaveOpen = leaveOpen; + } + + public long Count => this._cumulativeSizes.Last(); + + public T this[long index] + { + get { + if (index < 0) { + if (-index > this.Count) { + throw new ArgumentException( + "absolute value of index should not exceed dataset length", + nameof(index)); + } + index = this.Count + index; + } + + var datasetIdx = bisectRight(this._cumulativeSizes, index); + long sampleIdx; + if (datasetIdx == 0) + sampleIdx = index; + else + sampleIdx = index - this._cumulativeSizes[datasetIdx - 1]; + return this._datasets[datasetIdx][sampleIdx]; + } + } + + public void Dispose() + { + if (!leaveOpen) { + foreach (var dataset in this._datasets) + dataset.Dispose(); + } + } + } + } +} \ No newline at end of file diff --git a/src/TorchSharp/DataLoader.cs b/src/TorchSharp/DataLoader.cs index b368451fb..96b9e08b5 100644 --- a/src/TorchSharp/DataLoader.cs +++ b/src/TorchSharp/DataLoader.cs @@ -7,6 +7,7 @@ using System.Linq; using System.Threading; using System.Threading.Tasks; +using TorchSharp.Modules; using TorchSharp.Utils; namespace TorchSharp @@ -77,6 +78,116 @@ public static Modules.IterableDataLoader DataLoader( num_worker, drop_last, disposeBatch, disposeDataset); } + + static IReadOnlyDictionary DictionaryDataLoaderCollate( + IEnumerable> dic, + torch.Device device) + { + using (torch.NewDisposeScope()) { + Dictionary batch = new(); + foreach (var x in dic.First().Keys) { + var t = cat(dic.Select(k => k[x].unsqueeze(0)).ToArray(), 0); + if (t.device_type != device.type || t.device_index != device.index) + t = t.to(device); + batch[x] = t.MoveToOuterDisposeScope(); + } + return batch; + } + } + + public static Modules.DataLoader< + IReadOnlyDictionary, + IReadOnlyDictionary + > + DataLoader( + IDataset> dataset, + int batchSize, IEnumerable shuffler, + Device device = null, + int num_worker = 1, bool drop_last = false, + bool disposeBatch = true, bool disposeDataset = true) + { + + return new( + dataset, + batchSize, DictionaryDataLoaderCollate, + shuffler, + device, + num_worker, drop_last, + disposeBatch, disposeDataset); + } + + public static Modules.DataLoader< + IReadOnlyDictionary, + IReadOnlyDictionary + > + DataLoader( + IDataset> dataset, + int batchSize, bool shuffle = false, + Device device = null, int? seed = null, + int num_worker = 1, bool drop_last = false, + bool disposeBatch = true, bool disposeDataset = true) + { + return new( + dataset, + batchSize, DictionaryDataLoaderCollate, shuffle, + device, seed, + num_worker, drop_last, + disposeBatch, disposeDataset); + } + + static IReadOnlyList ListDataLoaderCollate( + IReadOnlyList> dic, torch.Device device) + { + using (torch.NewDisposeScope()) { + List batch = new(); + for (var x = 0; x < dic[0].Count; x++) { + var t = cat(dic.Select(k => k[x].unsqueeze(0)).ToArray(), 0); + if (t.device_type != device.type || t.device_index != device.index) + t = t.to(device); + batch.Add(t.MoveToOuterDisposeScope()); + } + return batch; + } + } + public static Modules.DataLoader< + IReadOnlyList, + IReadOnlyList + > + DataLoader( + IDataset> dataset, + int batchSize, IEnumerable shuffler, + Device device = null, + int num_worker = 1, bool drop_last = false, + bool disposeBatch = true, bool disposeDataset = true) + { + return new( + dataset, + batchSize, ListDataLoaderCollate, + shuffler, + device, + num_worker, drop_last, + disposeBatch, disposeDataset); + } + + public static Modules.DataLoader< + IReadOnlyList, + IReadOnlyList + > + DataLoader( + IDataset> dataset, + int batchSize, bool shuffle = false, + Device device = null, int? seed = null, + int num_worker = 1, bool drop_last = false, + bool disposeBatch = true, bool disposeDataset = true) + { + return new( + dataset, + batchSize, ListDataLoaderCollate, + shuffle, + device, seed, + num_worker, drop_last, + disposeBatch, disposeDataset); + } } } } @@ -264,12 +375,12 @@ public IterableDataLoader( /// public class DataLoader : IEnumerable, IDisposable { - public Dataset dataset { get; } + public IDataset dataset { get; } public int batch_size { get; } public bool drop_last { get; } public IEnumerable sampler { get; } public int num_workers { get; } - public Func, Device, S> collate_fn { get; } + public Func, Device, S> collate_fn { get; } public Device Device { get; } public bool DisposeBatch { get; } @@ -295,7 +406,84 @@ public class DataLoader : IEnumerable, IDisposable /// Indicates whether to dispose the dataset when being disposed. /// public DataLoader( - Dataset dataset, + IDataset dataset, + int batchSize, + Func, torch.Device, S> collate_fn, + IEnumerable shuffler, + Device? device = null, + int num_worker = 1, + bool drop_last = false, + bool disposeBatch = true, + bool disposeDataset = true) + { + this.dataset = dataset; + this.batch_size = batchSize; + this.drop_last = drop_last; + this.Device = device ?? CPU; + this.sampler = shuffler; + this.num_workers = Math.Max(num_worker, 1); + this.collate_fn = collate_fn; + this.DisposeBatch = disposeBatch; + this.DisposeDataset = disposeDataset; + } + + /// + /// Pytorch style dataloader + /// + /// Dataset for create batch + /// Size of batch + /// Callback to merge items to make a batch + /// true if shuffle dataset, false for not + /// device for output tensor + /// Seed for generating shuffle + /// Count of worker + /// + /// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size. + /// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller. + /// + /// + /// Indicates whether to automatically dispose the collated tensors (a batch) after an iteration. + /// + /// + /// Indicates whether to dispose the dataset when being disposed. + /// + public DataLoader( + IDataset dataset, + int batchSize, + Func, torch.Device, S> collate_fn, + bool shuffle = false, + Device? device = null, + int? seed = null, + int num_worker = 1, + bool drop_last = false, + bool disposeBatch = true, + bool disposeDataset = true) : + this(dataset, batchSize, collate_fn, + shuffle ? new FisherYatesShuffler(dataset.Count, seed) : LongRange(dataset.Count), + device, num_worker, drop_last, disposeBatch, disposeDataset) + { } + + /// + /// Pytorch style dataloader + /// + /// Dataset for create batch + /// Size of batch + /// Callback to merge items make to a batch + /// device for output tensor + /// Shuffler for dataloader + /// Count of worker + /// + /// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size. + /// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller. + /// + /// + /// Indicates whether to automatically dispose the collated tensors after an iteration. + /// + /// + /// Indicates whether to dispose the dataset when being disposed. + /// + public DataLoader( + IDataset dataset, int batchSize, Func, torch.Device, S> collate_fn, IEnumerable shuffler, @@ -337,7 +525,7 @@ public DataLoader( /// Indicates whether to dispose the dataset when being disposed. /// public DataLoader( - Dataset dataset, + IDataset dataset, int batchSize, Func, torch.Device, S> collate_fn, bool shuffle = false, @@ -432,7 +620,7 @@ public bool MoveNext() .WithDegreeOfParallelism(loader.num_workers) .ForAll((i) => { using var getTensorScope = torch.NewDisposeScope(); - tensors[i] = loader.dataset.GetTensor(indices[i]); + tensors[i] = loader.dataset[indices[i]]; getTensorDisposables[i] = getTensorScope.DetachAllAndDispose(); }); diff --git a/src/TorchSharp/Dataset.cs b/src/TorchSharp/Dataset.cs index 8af054f3e..5ff5fe6fe 100644 --- a/src/TorchSharp/Dataset.cs +++ b/src/TorchSharp/Dataset.cs @@ -27,7 +27,7 @@ public abstract class IterableDataset : Dataset> /// /// The base nterface for all Datasets. /// - public abstract class Dataset : IDisposable + public abstract class Dataset : Modules.IDataset, IDisposable { public void Dispose() { @@ -35,9 +35,7 @@ public void Dispose() GC.SuppressFinalize(this); } - /// - /// Size of dataset - /// + /// public abstract long Count { get; } /// @@ -47,6 +45,9 @@ public void Dispose() /// Tensors of index. DataLoader will catenate these tensors into batches. public abstract T GetTensor(long index); + /// + public T this[long index] => GetTensor(index); + protected virtual void Dispose(bool disposing) { } @@ -54,4 +55,25 @@ protected virtual void Dispose(bool disposing) } } } + + namespace Modules + { + /// + /// The base interface for all Datasets. + /// + public interface IDataset : IDisposable + { + /// + /// Size of dataset + /// + long Count { get; } + + /// + /// Get tensor according to index + /// + /// Index for tensor + /// Tensors of index. DataLoader will catenate these tensors into batches. + T this[long index] { get; } + } + } } \ No newline at end of file diff --git a/src/TorchSharp/Utils/TensorDataset.cs b/src/TorchSharp/Utils/TensorDataset.cs index e3dc4161a..883ce7c39 100644 --- a/src/TorchSharp/Utils/TensorDataset.cs +++ b/src/TorchSharp/Utils/TensorDataset.cs @@ -40,16 +40,6 @@ internal TensorDataset(torch.Tensor[] tensors) _tensors = tensors.Select(x => x.alias().DetachFromDisposeScope()).ToArray(); } - /// - /// Indexer - /// - public IList this[long index] { - - get { - return _tensors.Select(t => t[index]).ToList(); - } - } - /// /// Length of the dataset, i.e. the size of the first dimension. /// @@ -59,7 +49,7 @@ public override long Count { public override IList GetTensor(long index) { - return this[index]; + return _tensors.Select(t => t[index]).ToList(); } readonly torch.Tensor[] _tensors; diff --git a/test/TorchSharpTest/TestDataLoader.cs b/test/TorchSharpTest/TestDataLoader.cs index 2ffce70ec..79f13c1ae 100644 --- a/test/TorchSharpTest/TestDataLoader.cs +++ b/test/TorchSharpTest/TestDataLoader.cs @@ -1,7 +1,10 @@ // Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information. using System.Collections.Generic; +using System.Linq; +using TorchSharp.Modules; using Xunit; +using static TorchSharp.torch; namespace TorchSharp @@ -27,6 +30,23 @@ private class TestIterableDataset : torch.utils.data.IterableDataset } } + private class TestDatasetFromEnumerable : IDataset + { + private readonly T[] values; + public TestDatasetFromEnumerable(IEnumerable values) + { + this.values = values.ToArray(); + this.Disposed = false; + } + public bool Disposed { get; set; } + public T this[long index] => values[index]; + public long Count => values.LongLength; + public void Dispose() + { + this.Disposed = true; + } + } + [Fact] public void DatasetTest() { @@ -230,5 +250,71 @@ public void CustomSeedTest() iterator.Dispose(); iterator2.Dispose(); } + + [Fact] + public void ConcatDatasetTest() + { + using var dataset1 = new TestDatasetFromEnumerable<(int, int)>(new[] { + (1, 1), // dataset 1 value 1 + (1, 2), // dataset 1 value 2 + (1, 3), + }); + using var dataset2 = new TestDatasetFromEnumerable<(int, int)>(new[] { + (2, 1), + (2, 2), + }); + using var dataset3 = new TestDatasetFromEnumerable<(int, int)>(new[] { + (3, 1), + (3, 2), + (3, 3), + (3, 4), + }); + + using var dataset = new ConcatDataset<(int, int)>(new[] { + dataset1, dataset2, dataset3 + }); + + Assert.Equal(3 + 2 + 4, dataset.Count); + + Assert.Equal((1, 1), dataset[0]); + Assert.Equal((1, 2), dataset[1]); + Assert.Equal((1, 3), dataset[2]); + Assert.Equal((2, 1), dataset[3]); + Assert.Equal((2, 2), dataset[4]); + Assert.Equal((3, 1), dataset[5]); + Assert.Equal((3, 2), dataset[6]); + Assert.Equal((3, 3), dataset[7]); + Assert.Equal((3, 4), dataset[8]); + + Assert.Equal((1, 1), dataset[-9]); + Assert.Equal((1, 2), dataset[-8]); + Assert.Equal((1, 3), dataset[-7]); + Assert.Equal((2, 1), dataset[-6]); + Assert.Equal((2, 2), dataset[-5]); + Assert.Equal((3, 1), dataset[-4]); + Assert.Equal((3, 2), dataset[-3]); + Assert.Equal((3, 3), dataset[-2]); + Assert.Equal((3, 4), dataset[-1]); + + Assert.False(dataset1.Disposed); + Assert.False(dataset2.Disposed); + Assert.False(dataset3.Disposed); + dataset.Dispose(); + Assert.True(dataset1.Disposed); + Assert.True(dataset2.Disposed); + Assert.True(dataset3.Disposed); + } + + [Fact] + public void ConcatDatasetDataLoader() + { + using var dataset = torch.utils.data.ConcatDataset( + new[] { + new TestDataset(), + new TestDataset(), + }); + var dataloader = torch.utils.data.DataLoader(dataset, 10, false); + Assert.Equal(2, dataloader.Count); + } } } \ No newline at end of file