Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 123 additions & 7 deletions lightning_pose/data/datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
import torch
from lightning.pytorch.utilities import CombinedLoader
from omegaconf import DictConfig
from torch.utils.data import DataLoader, Subset, random_split
from torch.utils.data import DataLoader, Subset, random_split, WeightedRandomSampler, RandomSampler

from lightning_pose.data.dali import PrepareDALI
from lightning_pose.data.datatypes import SemiSupervisedDataLoaderDict
Expand Down Expand Up @@ -38,6 +38,7 @@ def __init__(
test_probability: float | None = None,
train_frames: float | int | None = None,
torch_seed: int = 42,
enable_weighted_sampler: bool = True,
) -> None:
"""Data module splits a dataset into train, val, and test data loaders.

Expand All @@ -56,7 +57,8 @@ def __init__(
(exclusive) and defines the fraction of the initially selected
train frames
torch_seed: control data splits

enable_weighted_sampler: If True, use a WeightedRandomSampler
for the training dataloader to oversample examples with rarer keypoints.
"""
super().__init__()
self.dataset = dataset
Expand All @@ -80,8 +82,83 @@ def __init__(
self.val_dataset = None # populated by self.setup()
self.test_dataset = None # populated by self.setup()
self.torch_seed = torch_seed
self.enable_weighted_sampler = enable_weighted_sampler
self.train_sampler = None
self._setup()


def _calculate_train_sampler_weights(self, epsilon=1e-6):
"""Calculates weights for WeightedRandomSampler based on keypoint presence."""

if not isinstance(self.train_dataset, Subset):
print("Warning: Sampler weight calculation expects self.train_dataset to be a Subset. Skipping.")
self.train_sampler = None
return

# Determine how to access keypoints based on dataset type
underlying_dataset = self.train_dataset.dataset
if hasattr(underlying_dataset, 'keypoints'): # BaseTrackingDataset or HeatmapDataset
all_keypoints = underlying_dataset.keypoints
elif hasattr(underlying_dataset, 'dataset') and isinstance(underlying_dataset.dataset, dict): # MultiviewHeatmapDataset
# Using the first view as reference
try:
first_view_key = list(underlying_dataset.dataset.keys())[0]
all_keypoints = underlying_dataset.dataset[first_view_key].keypoints
print(f"Calculating sampler weights based on '{first_view_key}' view's keypoints for multiview.")
except (IndexError, AttributeError):
print("Warning: Could not access keypoints from the first view of Multiview dataset. Skipping sampler.")
self.train_sampler = None
return
else:
print("Warning: Could not find keypoints attribute for sampler weight calculation. Skipping.")
self.train_sampler = None
return

try:
train_indices = self.train_dataset.indices
# Ensure indices are valid for the keypoints tensor
if max(train_indices) >= len(all_keypoints):
print(f"Warning: train_indices ({max(train_indices)}) out of bounds for all_keypoints ({len(all_keypoints)}). Skipping sampler.")
self.train_sampler = None
return
train_keypoints = all_keypoints[train_indices]
except IndexError as e:
print(f"Error indexing keypoints with train_indices: {e}. Skipping sampler.")
self.train_sampler = None
return
except Exception as e:
print(f"Unexpected error accessing train keypoints: {e}. Skipping sampler.")
self.train_sampler = None
return

# Check for NaNs (use x-coordinate)
is_present = ~torch.isnan(train_keypoints[:, :, 0]) # Shape: [num_train_samples, num_keypoints]

# Calculate frequency
keypoint_counts = torch.sum(is_present, dim=0).float()
num_train_samples = len(train_indices)
if num_train_samples == 0:
print("Warning: Zero samples in training set. Skipping sampler.")
self.train_sampler = None
return
keypoint_frequencies = keypoint_counts / num_train_samples

# Inverse frequency weights for keypoints
inverse_frequencies = 1.0 / (keypoint_frequencies + epsilon)

# Assign weight to each sample
sample_weights = torch.sum(is_present * inverse_frequencies.unsqueeze(0), dim=1)

# Handle cases where all keypoints might be NaN for a sample
sample_weights[sample_weights == 0] = epsilon # Assign a tiny weight instead of zero

# Create the sampler
self.train_sampler = WeightedRandomSampler(
weights=sample_weights,
num_samples=len(sample_weights),
replacement=True
)
print("Created WeightedRandomSampler for training data.")

def _setup(self) -> None:

datalen = self.dataset.__len__()
Expand Down Expand Up @@ -143,6 +220,17 @@ def _setup(self) -> None:
# train_frames
self.train_dataset.indices = self.train_dataset.indices[:n_frames]

if self.enable_weighted_sampler:
self._calculate_train_sampler_weights()
else:
self.train_sampler = None

# print sampler status
if self.train_sampler:
print("Training sampler: WeightedRandomSampler enabled.")
else:
print("Training sampler: Standard shuffling enabled.")

print(
f"Dataset splits -- "
f"train: {len(self.train_dataset)}, "
Expand All @@ -151,15 +239,43 @@ def _setup(self) -> None:
)

def train_dataloader(self) -> torch.utils.data.DataLoader:
return DataLoader(
self.train_dataset,

if self.train_sampler is not None:
sampler_arg = self.train_sampler
shuffle_arg = None
generator_arg = None
print(f"DEBUG train_dataloader: Using sampler={type(sampler_arg)}")
else:
sampler_arg = None
shuffle_arg = True
generator_arg = torch.Generator().manual_seed(self.torch_seed)
print(f"DEBUG train_dataloader: Using shuffle={shuffle_arg}, sampler=None")

loader = DataLoader(
dataset=self.train_dataset,
batch_size=self.train_batch_size,
num_workers=self.num_workers,
persistent_workers=True if self.num_workers > 0 else False,
shuffle=True,
generator=torch.Generator().manual_seed(self.torch_seed),
sampler=sampler_arg,
shuffle=shuffle_arg,
generator=generator_arg,
)

# (Optional debug prints after creation)
print(f"DEBUG train_dataloader: DataLoader created.")
if hasattr(loader, 'batch_sampler') and loader.batch_sampler is not None:
print(f" -> batch_sampler type: {type(loader.batch_sampler)}")
if hasattr(loader.batch_sampler, 'sampler'):
print(f" -> underlying sampler type: {type(loader.batch_sampler.sampler)}")
else:
print(" -> batch_sampler has no 'sampler' attribute")
else:
print(f" -> No batch_sampler found on DataLoader.")
print(f" -> shuffle attribute (post-init): {getattr(loader, 'shuffle', 'N/A')}")

return loader


def val_dataloader(self) -> torch.utils.data.DataLoader:
return DataLoader(
self.val_dataset,
Expand Down
81 changes: 80 additions & 1 deletion tests/data/test_datamodules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
import numpy as np
import pytest
from lightning.pytorch.utilities import CombinedLoader
from torch.utils.data import RandomSampler
from torch.utils.data import WeightedRandomSampler, RandomSampler, DataLoader, BatchSampler

from lightning_pose.data.datamodules import BaseDataModule

def test_base_datamodule(cfg, base_data_module):

Expand Down Expand Up @@ -404,3 +405,81 @@ def test_multiview_heatmap_data_module_combined_context(
)
assert batch["unlabeled"]["transforms"].shape == (num_views, 1, 2, 3)
assert batch["unlabeled"]["bbox"].shape == (train_size_unlabeled, num_views * 4)

@pytest.mark.parametrize("enable_weighted_sampler", [True, False])
def test_weighted_sampler_optionality(base_dataset, enable_weighted_sampler, capsys): # Add capsys to capture prints
"""Verify that WeightedRandomSampler is used only when configured."""

print(f"\nInstantiating BaseDataModule with enable_weighted_sampler={enable_weighted_sampler}")
data_module = BaseDataModule(
dataset=base_dataset,
train_batch_size=4,
val_batch_size=4,
test_batch_size=4,
num_workers=0,
train_probability=0.8,
enable_weighted_sampler=enable_weighted_sampler,
torch_seed=123,
)
# Capture print output during setup
captured_setup = capsys.readouterr()
print("--- Output during BaseDataModule setup ---")
print(captured_setup.out)
print(captured_setup.err)
print("-----------------------------------------")

train_loader = data_module.train_dataloader()

# Debugging prints
print(f"\n--- Testing with enable_weighted_sampler={enable_weighted_sampler} ---")
print(f"data_module.train_sampler type: {type(data_module.train_sampler)}")
print(f"train_loader type: {type(train_loader)}")
# DataLoader usually wraps the sampler in a BatchSampler. Accessing .sampler might be tricky.
# Let's inspect the batch_sampler attribute
print(f"train_loader.batch_sampler type: {type(train_loader.batch_sampler)}")
if hasattr(train_loader.batch_sampler, 'sampler'):
print(f"train_loader.batch_sampler.sampler type: {type(train_loader.batch_sampler.sampler)}")
else:
print("train_loader.batch_sampler has no 'sampler' attribute")
print(f"train_loader.shuffle: {getattr(train_loader, 'shuffle', 'N/A')}") # shuffle attr might not exist depending on init
print("-----------------------------------------")


if enable_weighted_sampler:
# Case 1: Sampler was enabled AND successfully created
if data_module.train_sampler is not None:
print("Checking case: Sampler enabled and created.")
assert isinstance(data_module.train_sampler, WeightedRandomSampler), \
"data_module.train_sampler should be WeightedRandomSampler instance"
# When a sampler is provided, shuffle=False is passed to DataLoader init.
# The DataLoader instance itself might not have a .shuffle attribute after init.
# The key is that the batch_sampler's underlying sampler is ours.
assert isinstance(train_loader.batch_sampler, BatchSampler), \
"DataLoader should use a BatchSampler"
assert isinstance(train_loader.batch_sampler.sampler, WeightedRandomSampler), \
"BatchSampler's underlying sampler should be WeightedRandomSampler"

# Case 2: Sampler was enabled BUT creation failed (fallback)
else:
print("Checking case: Sampler enabled but creation failed (fallback).")
# Check our internal state
assert data_module.train_sampler is None, \
"data_module.train_sampler should be None if creation failed"
# Check the DataLoader's resulting state (should be like sampler disabled)
assert isinstance(train_loader.batch_sampler, BatchSampler), \
"DataLoader should use a BatchSampler"
# When shuffle=True, DataLoader uses RandomSampler internally
assert isinstance(train_loader.batch_sampler.sampler, RandomSampler), \
"BatchSampler's underlying sampler should be RandomSampler in fallback"

else: # enable_weighted_sampler is False
print("Checking case: Sampler disabled.")
# Check our internal state
assert data_module.train_sampler is None, \
"data_module.train_sampler should be None when disabled"
# Check the DataLoader's resulting state
assert isinstance(train_loader.batch_sampler, BatchSampler), \
"DataLoader should use a BatchSampler"
# When shuffle=True, DataLoader uses RandomSampler internally
assert isinstance(train_loader.batch_sampler.sampler, RandomSampler), \
"BatchSampler's underlying sampler should be RandomSampler when disabled"
Loading