Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

concat dataset #1411

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion RELEASENOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,15 @@

Releases, starting with 9/2/2021, are listed with the most recent release at the top.


# Next Version

- Add support for torch.utils.data.ConcatDataset.
- A new interface `IDataset<out T>` has been added, and now `Dataset<T>` implements it.
- More overloads of DataLoader() has been added, to accept `IDataset`.
- Type of `DataLoaders<T>.dataset` has been changed to `IDataset<T>`.
- Type of `DataLoaders<T>.collate_fn` has been changed to `Func<IReadOnlyList<T>, Device, S>`.

# NuGet Version 0.105.0

Move to libtorch 2.5.1. As with the 2.4.0 release, MacOS / Intel is no longer supported by libtorch, so TorchSharp doesn, either.
Expand Down Expand Up @@ -30,7 +39,7 @@ __Issues fixed__:

__API Changes__:

#1382: Add support for torch.nn.functional.normalize<br/>
#1382: Add support for torch.nn.functional.normalize<br/>

# NuGet Version 0.103.1

Expand Down
108 changes: 108 additions & 0 deletions src/TorchSharp/ConcatDataset.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Copyright (c) .NET Foundation and Contributors. All Rights Reserved. See LICENSE in the project root for license information.
using System;
using System.Collections.Generic;
using System.Linq;
using TorchSharp.Modules;

namespace TorchSharp
{
public static partial class torch
{
public static partial class utils
{
public static partial class data
{
public static ConcatDataset<T> ConcatDataset<T>(IEnumerable<IDataset<T>> datasets)
{
return new ConcatDataset<T>(datasets);
}
}
}
}

namespace Modules
{
public sealed class ConcatDataset<T> : IDataset<T>
{
private static IEnumerable<long> Cumsum(IEnumerable<IDataset<T>> datasets)
{
var s = 0L;
foreach (var e in datasets) {
s += e.Count;
yield return s;
}
}

private static long bisectRight(long[] a, long x)
{
var lo = 0;
var hi = a.Length;
while (lo < hi) {
var mid = (lo + hi) / 2;
if (x < a[mid])
hi = mid;
else
lo = mid + 1;
}
return lo;
}

// Here we have to use arrays since the given index is in Int64...
private readonly IDataset<T>[] _datasets;
public IReadOnlyList<IDataset<T>> datasets => _datasets;

private readonly long[] _cumulativeSizes;
public IReadOnlyList<long> cumulative_sizes => _cumulativeSizes;

private readonly bool leaveOpen;

public ConcatDataset(
IEnumerable<IDataset<T>> datasets,
bool leaveOpen = false)
{
this._datasets = datasets.ToArray();
if (this._datasets.Length is 0)
throw new ArgumentException(
"datasets should not be an empty iterable", nameof(datasets));

// PyTorch also says 'ConcatDataset does not support IterableDataset'.
// But it's not our torch.utils.data.IterableDataset in TorchSharp.
this._cumulativeSizes = Cumsum(datasets).ToArray();

this.leaveOpen = leaveOpen;
}

public long Count => this._cumulativeSizes.Last();

public T this[long index]
{
get {
if (index < 0) {
if (-index > this.Count) {
throw new ArgumentException(
"absolute value of index should not exceed dataset length",
nameof(index));
}
index = this.Count + index;
}

var datasetIdx = bisectRight(this._cumulativeSizes, index);
long sampleIdx;
if (datasetIdx == 0)
sampleIdx = index;
else
sampleIdx = index - this._cumulativeSizes[datasetIdx - 1];
return this._datasets[datasetIdx][sampleIdx];
}
}

public void Dispose()
{
if (!leaveOpen) {
foreach (var dataset in this._datasets)
dataset.Dispose();
}
}
}
}
}
198 changes: 193 additions & 5 deletions src/TorchSharp/DataLoader.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using TorchSharp.Modules;
using TorchSharp.Utils;

namespace TorchSharp
Expand Down Expand Up @@ -77,6 +78,116 @@ public static Modules.IterableDataLoader DataLoader(
num_worker, drop_last,
disposeBatch, disposeDataset);
}

static IReadOnlyDictionary<string, torch.Tensor> DictionaryDataLoaderCollate(
IEnumerable<IReadOnlyDictionary<string, torch.Tensor>> dic,
torch.Device device)
{
using (torch.NewDisposeScope()) {
Dictionary<string, torch.Tensor> batch = new();
foreach (var x in dic.First().Keys) {
var t = cat(dic.Select(k => k[x].unsqueeze(0)).ToArray(), 0);
if (t.device_type != device.type || t.device_index != device.index)
t = t.to(device);
batch[x] = t.MoveToOuterDisposeScope();
}
return batch;
}
}

public static Modules.DataLoader<
IReadOnlyDictionary<string, Tensor>,
IReadOnlyDictionary<string, Tensor>
>
DataLoader(
IDataset<IReadOnlyDictionary<string, Tensor>> dataset,
int batchSize, IEnumerable<long> shuffler,
Device device = null,
int num_worker = 1, bool drop_last = false,
bool disposeBatch = true, bool disposeDataset = true)
{

return new(
dataset,
batchSize, DictionaryDataLoaderCollate,
shuffler,
device,
num_worker, drop_last,
disposeBatch, disposeDataset);
}

public static Modules.DataLoader<
IReadOnlyDictionary<string, Tensor>,
IReadOnlyDictionary<string, Tensor>
>
DataLoader(
IDataset<IReadOnlyDictionary<string, Tensor>> dataset,
int batchSize, bool shuffle = false,
Device device = null, int? seed = null,
int num_worker = 1, bool drop_last = false,
bool disposeBatch = true, bool disposeDataset = true)
{
return new(
dataset,
batchSize, DictionaryDataLoaderCollate, shuffle,
device, seed,
num_worker, drop_last,
disposeBatch, disposeDataset);
}

static IReadOnlyList<torch.Tensor> ListDataLoaderCollate(
IReadOnlyList<IReadOnlyList<torch.Tensor>> dic, torch.Device device)
{
using (torch.NewDisposeScope()) {
List<torch.Tensor> batch = new();
for (var x = 0; x < dic[0].Count; x++) {
var t = cat(dic.Select(k => k[x].unsqueeze(0)).ToArray(), 0);
if (t.device_type != device.type || t.device_index != device.index)
t = t.to(device);
batch.Add(t.MoveToOuterDisposeScope());
}
return batch;
}
}
public static Modules.DataLoader<
IReadOnlyList<Tensor>,
IReadOnlyList<Tensor>
>
DataLoader(
IDataset<IReadOnlyList<Tensor>> dataset,
int batchSize, IEnumerable<long> shuffler,
Device device = null,
int num_worker = 1, bool drop_last = false,
bool disposeBatch = true, bool disposeDataset = true)
{
return new(
dataset,
batchSize, ListDataLoaderCollate,
shuffler,
device,
num_worker, drop_last,
disposeBatch, disposeDataset);
}

public static Modules.DataLoader<
IReadOnlyList<Tensor>,
IReadOnlyList<Tensor>
>
DataLoader(
IDataset<IReadOnlyList<Tensor>> dataset,
int batchSize, bool shuffle = false,
Device device = null, int? seed = null,
int num_worker = 1, bool drop_last = false,
bool disposeBatch = true, bool disposeDataset = true)
{
return new(
dataset,
batchSize, ListDataLoaderCollate,
shuffle,
device, seed,
num_worker, drop_last,
disposeBatch, disposeDataset);
}
}
}
}
Expand Down Expand Up @@ -264,12 +375,12 @@ public IterableDataLoader(
/// </summary>
public class DataLoader<T, S> : IEnumerable<S>, IDisposable
{
public Dataset<T> dataset { get; }
public IDataset<T> dataset { get; }
public int batch_size { get; }
public bool drop_last { get; }
public IEnumerable<long> sampler { get; }
public int num_workers { get; }
public Func<IEnumerable<T>, Device, S> collate_fn { get; }
public Func<IReadOnlyList<T>, Device, S> collate_fn { get; }

public Device Device { get; }
public bool DisposeBatch { get; }
Expand All @@ -295,7 +406,84 @@ public class DataLoader<T, S> : IEnumerable<S>, IDisposable
/// Indicates whether to dispose the dataset when being disposed.
/// </param>
public DataLoader(
Dataset<T> dataset,
IDataset<T> dataset,
int batchSize,
Func<IReadOnlyList<T>, torch.Device, S> collate_fn,
IEnumerable<long> shuffler,
Device? device = null,
int num_worker = 1,
bool drop_last = false,
bool disposeBatch = true,
bool disposeDataset = true)
{
this.dataset = dataset;
this.batch_size = batchSize;
this.drop_last = drop_last;
this.Device = device ?? CPU;
this.sampler = shuffler;
this.num_workers = Math.Max(num_worker, 1);
this.collate_fn = collate_fn;
this.DisposeBatch = disposeBatch;
this.DisposeDataset = disposeDataset;
}

/// <summary>
/// Pytorch style dataloader
/// </summary>
/// <param name="dataset">Dataset for create batch</param>
/// <param name="batchSize">Size of batch</param>
/// <param name="collate_fn">Callback to merge items to make a batch</param>
/// <param name="shuffle">true if shuffle dataset, false for not</param>
/// <param name="device">device for output tensor</param>
/// <param name="seed">Seed for generating shuffle</param>
/// <param name="num_worker">Count of worker</param>
/// <param name="drop_last">
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
/// </param>
/// <param name="disposeBatch">
/// Indicates whether to automatically dispose the collated tensors (a batch) after an iteration.
/// </param>
/// <param name="disposeDataset">
/// Indicates whether to dispose the dataset when being disposed.
/// </param>
public DataLoader(
IDataset<T> dataset,
int batchSize,
Func<IReadOnlyList<T>, torch.Device, S> collate_fn,
bool shuffle = false,
Device? device = null,
int? seed = null,
int num_worker = 1,
bool drop_last = false,
bool disposeBatch = true,
bool disposeDataset = true) :
this(dataset, batchSize, collate_fn,
shuffle ? new FisherYatesShuffler(dataset.Count, seed) : LongRange(dataset.Count),
device, num_worker, drop_last, disposeBatch, disposeDataset)
{ }

/// <summary>
/// Pytorch style dataloader
/// </summary>
/// <param name="dataset">Dataset for create batch</param>
/// <param name="batchSize">Size of batch</param>
/// <param name="collate_fn">Callback to merge items make to a batch</param>
/// <param name="device">device for output tensor</param>
/// <param name="shuffler">Shuffler for dataloader</param>
/// <param name="num_worker">Count of worker</param>
/// <param name="drop_last">
/// Set to true to drop the last incomplete batch, if the dataset size is not divisible by the batch size.
/// If alse and the size of dataset is not divisible by the batch size, then the last batch will be smaller.
/// </param>
/// <param name="disposeBatch">
/// Indicates whether to automatically dispose the collated tensors after an iteration.
/// </param>
/// <param name="disposeDataset">
/// Indicates whether to dispose the dataset when being disposed.
/// </param>
public DataLoader(
IDataset<T> dataset,
int batchSize,
Func<IEnumerable<T>, torch.Device, S> collate_fn,
IEnumerable<long> shuffler,
Expand Down Expand Up @@ -337,7 +525,7 @@ public DataLoader(
/// Indicates whether to dispose the dataset when being disposed.
/// </param>
public DataLoader(
Dataset<T> dataset,
IDataset<T> dataset,
int batchSize,
Func<IEnumerable<T>, torch.Device, S> collate_fn,
bool shuffle = false,
Expand Down Expand Up @@ -432,7 +620,7 @@ public bool MoveNext()
.WithDegreeOfParallelism(loader.num_workers)
.ForAll((i) => {
using var getTensorScope = torch.NewDisposeScope();
tensors[i] = loader.dataset.GetTensor(indices[i]);
tensors[i] = loader.dataset[indices[i]];
getTensorDisposables[i] = getTensorScope.DetachAllAndDispose();
});

Expand Down
Loading