Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -206,10 +206,11 @@ Contains visualization functions for the datasets and the trained models.

## NeurIPS 2019 Disentanglement Challenge

The library is also used for the [NeurIPS 2019 Disentanglement challenge](https://www.aicrowd.com/challenges/neurips-2019-disentanglement-challenge). The challenge consists of three different datasets.
The library is also used for the [NeurIPS 2019 Disentanglement challenge](https://www.aicrowd.com/challenges/neurips-2019-disentanglement-challenge). The challenge consists of four different datasets.
1. Simplistic rendered images ([mpi3d_toy](https://storage.googleapis.com/disentanglement_dataset/data_npz/sim_toy_64x_ordered_without_heldout_factors.npz))
2. Realistic rendered images (mpi3d_realistic): _not yet published_
3. Real world images (mpi3d_real): _not yet published_
4. Subset of Real world images (mpi3d_real_subset): _not yet published_

Currently, only the simplistic rendered dataset is publicly available and will be automatically downloaded by running the following command.
```
Expand Down
21 changes: 18 additions & 3 deletions disentanglement_lib/data/ground_truth/mpi3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ class MPI3D(ground_truth_data.GroundTruthData):
"""

def __init__(self, mode="mpi3d_toy"):
self.factor_sizes = [4, 4, 2, 3, 3, 40, 40]
self.latent_factor_indices = [0, 1, 2, 3, 4, 5, 6]
self.num_total_factors = 7

if mode == "mpi3d_toy":
mpi3d_path = os.path.join(
os.environ.get("DISENTANGLEMENT_LIB_DATA", "."), "mpi3d_toy",
Expand Down Expand Up @@ -85,13 +89,23 @@ def __init__(self, mode="mpi3d_toy"):
else:
with tf.io.gfile.GFile(mpi3d_path, "rb") as f:
data = np.load(f)
elif mode == "mpi3d_real_subset":
self.factor_sizes = [2, 2, 2, 3, 3, 40, 40]

mpi3d_path = os.path.join(
os.environ.get("DISENTANGLEMENT_LIB_DATA", "."), "mpi3d_real",
"mpi3d_real_subset.npz")
if not tf.io.gfile.exists(mpi3d_path):
raise ValueError(
"Dataset '{}' not found. Make sure the dataset is publicly available and downloaded correctly."
.format(mode))
else:
with tf.io.gfile.GFile(mpi3d_path, "rb") as f:
data = np.load(f)
else:
raise ValueError("Unknown mode provided.")

self.images = data["images"]
self.factor_sizes = [4, 4, 2, 3, 3, 40, 40]
self.latent_factor_indices = [0, 1, 2, 3, 4, 5, 6]
self.num_total_factors = 7
self.state_space = util.SplitDiscreteStateSpace(self.factor_sizes,
self.latent_factor_indices)
self.factor_bases = np.prod(self.factor_sizes) / np.cumprod(
Expand All @@ -117,3 +131,4 @@ def sample_observations_from_factors(self, factors, random_state):
all_factors = self.state_space.sample_all_factors(factors, random_state)
indices = np.array(np.dot(all_factors, self.factor_bases), dtype=np.int64)
return self.images[indices] / 255.

2 changes: 2 additions & 0 deletions disentanglement_lib/data/ground_truth/named_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ def get_named_ground_truth_data(name):
return mpi3d.MPI3D(mode="mpi3d_realistic")
elif name == "mpi3d_real":
return mpi3d.MPI3D(mode="mpi3d_real")
elif name == "mpi3d_real_subset":
return mpi3d.MPI3D(mode="mpi3d_real_subset")
elif name == "shapes3d":
return shapes3d.Shapes3D()
elif name == "dummy_data":
Expand Down