|
| 1 | +import numpy as np |
| 2 | +import torch |
| 3 | +from torch.utils.data import dataset |
| 4 | +from torch import utils |
| 5 | +from random import randint |
| 6 | + |
| 7 | +def one_hot_encode(labels, num_classes): |
| 8 | + """ |
| 9 | + One hot encode labels and convert to tensor. |
| 10 | + """ |
| 11 | + return torch.tensor((np.arange(num_classes) == labels[..., None]).astype(float),dtype=torch.float32) |
| 12 | + |
| 13 | +class DeterministicCSWDataset(dataset.Dataset): |
| 14 | + def __init__(self, n_samples_per_context, contexts_to_load) -> None: |
| 15 | + super().__init__() |
| 16 | + raw_xs = np.array([ |
| 17 | + [[9,1,3,5,7],[9,2,4,6,8]], |
| 18 | + [[10,1,4,5,8],[10,2,3,6,7]] |
| 19 | + ]) |
| 20 | + |
| 21 | + item_indices = np.random.choice(raw_xs.shape[1],sum(n_samples_per_context),replace=True) |
| 22 | + task_names = [0,1] # Flexible so these can be renamed later |
| 23 | + task_indices = [task_names.index(name) for name in contexts_to_load] |
| 24 | + |
| 25 | + context_indices = np.repeat(np.array(task_indices),n_samples_per_context) |
| 26 | + self.xs = one_hot_encode(raw_xs[context_indices,item_indices],11) |
| 27 | + |
| 28 | + self.xs = self.xs.reshape((-1,11)) |
| 29 | + self.ys = torch.cat([self.xs[1:],one_hot_encode(np.array([0]),11)],dim=0) |
| 30 | + context_indices = np.repeat(np.array(task_indices),[x*5 for x in n_samples_per_context]) |
| 31 | + self.contexts = one_hot_encode(context_indices, len(task_names)) |
| 32 | + |
| 33 | + # Remove the last transition since there's no next state available |
| 34 | + self.xs = self.xs[:-1] |
| 35 | + self.ys = self.ys[:-1] |
| 36 | + self.contexts = self.contexts[:-1] |
| 37 | + |
| 38 | + def __len__(self): |
| 39 | + return len(self.xs) |
| 40 | + |
| 41 | + def __getitem__(self, idx): |
| 42 | + return self.xs[idx], self.contexts[idx], self.ys[idx] |
| 43 | + |
| 44 | +def generate_dataset(condition='Blocked'): |
| 45 | + # Generate the dataset for either the blocked or interleaved condition |
| 46 | + if condition=='Blocked': |
| 47 | + contexts_to_load = [0,1,0,1] + [randint(0,1) for _ in range(40)] |
| 48 | + n_samples_per_context = [40,40,40,40] + [1]*40 |
| 49 | + elif condition == 'Interleaved': |
| 50 | + contexts_to_load = [0,1]*80 + [randint(0,1) for _ in range(40)] |
| 51 | + n_samples_per_context = [1]*160 + [1]*40 |
| 52 | + else: |
| 53 | + raise ValueError(f'Unknown dataset condition: {condition}') |
| 54 | + |
| 55 | + return DeterministicCSWDataset(n_samples_per_context, contexts_to_load) |
0 commit comments