Skip to content
41 changes: 23 additions & 18 deletions src/lanfactory/trainers/torch_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,22 +74,17 @@ def __init__(

def __len__(self) -> int:
# Number of batches per epoch
return (
len(self.file_ids)
* ((self.file_shape_dict["inputs"][0] // self.batch_size) * self.batch_size)
) // self.batch_size
return len(self.file_ids) * self.batches_per_file

def __getitem__(self, index: int) -> tuple[np.ndarray, np.ndarray]:
# Check if it is time to load the next file
if ((index % self.batches_per_file) == 0) or (self.tmp_data == {}):
self.__load_file(file_index=self.indexes[index // self.batches_per_file])

# Generate and return a batch
batch_ids = np.arange(
((index % self.batches_per_file) * self.batch_size),
((index % self.batches_per_file) + 1) * self.batch_size,
1,
)
start_idx = (index % self.batches_per_file) * self.batch_size
end_idx = start_idx + self.batch_size
batch_ids = np.arange(start_idx, end_idx, 1)
X, y = self.__data_generation(batch_ids)
return X, y

Expand All @@ -114,10 +109,21 @@ def __init_file_shape(self) -> None:
"inputs": init_file[self.features_key].shape,
"labels": init_file[self.label_key].shape,
}
self.batches_per_file = int(self.file_shape_dict["inputs"][0] / self.batch_size)

# Validate that samples_per_file is divisible by batch_size
samples_per_file = self.file_shape_dict["inputs"][0]
if samples_per_file % self.batch_size != 0:
raise ValueError(
f"samples_per_file ({samples_per_file}) must be divisible by "
f"batch_size ({self.batch_size}). Current remainder: "
f"{samples_per_file % self.batch_size}"
)

self.batches_per_file = samples_per_file // self.batch_size

self.input_dim = self.file_shape_dict["inputs"][1]

if "generator_config" in init_file.keys():
if "generator_config" in init_file:
self.data_generator_config = init_file["generator_config"]

if len(self.file_shape_dict["labels"]) > 1:
Expand Down Expand Up @@ -176,7 +182,7 @@ def __init__(
self.input_shape = input_shape
self.network_config = network_config

if "train_output_type" in self.network_config.keys():
if "train_output_type" in self.network_config:
self.train_output_type = self.network_config["train_output_type"]
else:
self.train_output_type = "logprob"
Expand Down Expand Up @@ -393,25 +399,24 @@ def __get_scheduler(self) -> None:
mode="min",
factor=(
self.train_config["lr_scheduler_params"]["factor"]
if "factor" in self.train_config["lr_scheduler_params"].keys()
if "factor" in self.train_config["lr_scheduler_params"]
else 0.1
),
patience=(
self.train_config["lr_scheduler_params"]["patience"]
if "patience" in self.train_config["lr_scheduler_params"].keys()
if "patience" in self.train_config["lr_scheduler_params"]
else 2
),
threshold=(
self.train_config["lr_scheduler_params"]["threshold"]
if "threshold"
in self.train_config["lr_scheduler_params"].keys()
if "threshold" in self.train_config["lr_scheduler_params"]
else 0.001
),
threshold_mode="rel",
cooldown=0,
min_lr=(
self.train_config["lr_scheduler_params"]["min_lr"]
if "min_lr" in self.train_config["lr_scheduler_params"].keys()
if "min_lr" in self.train_config["lr_scheduler_params"]
else 0.00000001
),
)
Expand All @@ -420,7 +425,7 @@ def __get_scheduler(self) -> None:
self.optimizer,
gamma=(
self.train_config["lr_scheduler_params"]["factor"]
if "factor" in self.train_config["lr_scheduler_params"].keys()
if "factor" in self.train_config["lr_scheduler_params"]
else 0.1
),
last_epoch=-1,
Expand Down
4 changes: 2 additions & 2 deletions tests/cli/config_network_training_lan.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
NETWORK_TYPE: "lan" # Test 4
CPU_BATCH_SIZE: 1000
GPU_BATCH_SIZE: 5000
CPU_BATCH_SIZE: 20000
GPU_BATCH_SIZE: 20000
GENERATOR_APPROACH: "lan" # could maybe be deleted
OPTIMIZER_: "adam"
N_EPOCHS: 2
Expand Down
12 changes: 6 additions & 6 deletions tests/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ class TestTrainConstantsLAN:
"""Test training constants."""

N_EPOCHS: int = 2
CPU_BATCH_SIZE: int = 4196
GPU_BATCH_SIZE: int = 4196
CPU_BATCH_SIZE: int = 100000
GPU_BATCH_SIZE: int = 100000
OPTIMIZER: str = "adam"
LEARNING_RATE: float = 2e-06
LR_SCHEDULER: str = "reduce_on_plateau"
Expand All @@ -57,8 +57,8 @@ class TestTrainConstantsCPN:
"""Test training constants."""

N_EPOCHS: int = 2
CPU_BATCH_SIZE: int = 32
GPU_BATCH_SIZE: int = 32
CPU_BATCH_SIZE: int = 250
GPU_BATCH_SIZE: int = 250
OPTIMIZER: str = "adam"
LEARNING_RATE: float = 2e-06
LR_SCHEDULER: str = "reduce_on_plateau"
Expand All @@ -73,8 +73,8 @@ class TestTrainConstantsOPN:
"""Test training constants."""

N_EPOCHS: int = 2
CPU_BATCH_SIZE: int = 32
GPU_BATCH_SIZE: int = 32
CPU_BATCH_SIZE: int = 250
GPU_BATCH_SIZE: int = 250
OPTIMIZER: str = "adam"
LEARNING_RATE: float = 2e-06
LR_SCHEDULER: str = "reduce_on_plateau"
Expand Down
59 changes: 41 additions & 18 deletions tests/test_torch_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,17 @@ def test_dataset_torch_init(create_mock_data_files): # pylint: disable=redefine

dataset = DatasetTorch(
file_ids=file_list,
batch_size=128,
batch_size=100,
label_lower_bound=-16.0,
features_key="lan_data",
label_key="lan_labels",
)

# Verify attributes are set
assert dataset.batch_size == 128
assert dataset.batch_size == 100
assert len(dataset.file_ids) == 2
assert dataset.input_dim == 6
assert dataset.batches_per_file == 1000 // 128 # 7 batches per file
assert dataset.batches_per_file == 1000 // 100 # 10 batches per file


def test_dataset_torch_len(create_mock_data_files): # pylint: disable=redefined-outer-name
Expand Down Expand Up @@ -80,16 +80,16 @@ def test_dataset_torch_getitem_single_batch(

dataset = DatasetTorch(
file_ids=file_list,
batch_size=128,
batch_size=100,
features_key="lan_data",
label_key="lan_labels",
)

# Get first batch
X, y = dataset[0]

assert X.shape == (128, 6) # batch_size x features
assert y.shape == (128, 1) # batch_size x 1 (expanded)
assert X.shape == (100, 6) # batch_size x features
assert y.shape == (100, 1) # batch_size x 1 (expanded)
assert isinstance(X, np.ndarray)
assert isinstance(y, np.ndarray)

Expand Down Expand Up @@ -193,7 +193,7 @@ def test_dataset_torch_getitem_with_label_bounds(

dataset = DatasetTorch(
file_ids=file_list,
batch_size=128,
batch_size=100,
label_lower_bound=-10.0,
label_upper_bound=10.0,
features_key="lan_data",
Expand Down Expand Up @@ -221,16 +221,16 @@ def test_dataset_torch_with_2d_labels(tmp_path):

dataset = DatasetTorch(
file_ids=[str(file_path)],
batch_size=128,
batch_size=100,
features_key="lan_data",
label_key="lan_labels",
)

X, y = dataset[0]

# 2D labels should remain 2D
assert X.shape == (128, 6)
assert y.shape == (128, 3)
assert X.shape == (100, 6)
assert y.shape == (100, 3)


def test_dataset_torch_sequential_access_pattern(
Expand Down Expand Up @@ -286,7 +286,7 @@ def test_dataset_torch_with_jax_output(

dataset = DatasetTorch(
file_ids=file_list,
batch_size=128,
batch_size=100,
features_key="lan_data",
label_key="lan_labels",
out_framework="jax",
Expand Down Expand Up @@ -370,6 +370,29 @@ def test_dataset_torch_3d_labels_raises_error(tmp_path):
X, y = dataset[0]


def test_dataset_torch_batch_size_not_divisible_raises_error(tmp_path):
"""Test DatasetTorch raises ValueError when batch_size doesn't divide samples_per_file."""
file_path = tmp_path / "training_data.pickle"
data = {
"lan_data": np.random.randn(1000, 6).astype(np.float32),
"lan_labels": np.random.randn(1000).astype(np.float32),
}
with open(file_path, "wb") as f:
pickle.dump(data, f)

# batch_size=128 doesn't divide 1000 evenly (remainder 104)
with pytest.raises(
ValueError,
match=r"samples_per_file \(1000\) must be divisible by batch_size \(128\)\. Current remainder: 104",
):
DatasetTorch(
file_ids=[str(file_path)],
batch_size=128,
features_key="lan_data",
label_key="lan_labels",
)


def test_model_trainer_torch_mlp_init_with_dict():
"""Test ModelTrainerTorchMLP initialization with dict train_config."""
train_config = {
Expand Down Expand Up @@ -723,14 +746,14 @@ def test_model_trainer_torch_mlp_with_mse_loss(create_mock_data_files):
# Create datasets
train_dataset = DatasetTorch(
file_ids=file_list,
batch_size=16,
batch_size=20,
label_lower_bound=-16.0,
features_key="lan_data",
label_key="lan_labels",
)
valid_dataset = DatasetTorch(
file_ids=file_list,
batch_size=16,
batch_size=20,
label_lower_bound=-16.0,
features_key="lan_data",
label_key="lan_labels",
Expand Down Expand Up @@ -790,14 +813,14 @@ def test_model_trainer_torch_mlp_with_bce_loss(create_mock_data_files):
# Create datasets
train_dataset = DatasetTorch(
file_ids=file_list,
batch_size=16,
batch_size=20,
label_lower_bound=-16.0,
features_key="lan_data",
label_key="lan_labels",
)
valid_dataset = DatasetTorch(
file_ids=file_list,
batch_size=16,
batch_size=20,
label_lower_bound=-16.0,
features_key="lan_data",
label_key="lan_labels",
Expand Down Expand Up @@ -857,14 +880,14 @@ def test_model_trainer_torch_mlp_with_sgd_optimizer(create_mock_data_files):
# Create datasets
train_dataset = DatasetTorch(
file_ids=file_list,
batch_size=16,
batch_size=20,
label_lower_bound=-16.0,
features_key="lan_data",
label_key="lan_labels",
)
valid_dataset = DatasetTorch(
file_ids=file_list,
batch_size=16,
batch_size=20,
label_lower_bound=-16.0,
features_key="lan_data",
label_key="lan_labels",
Expand Down Expand Up @@ -993,7 +1016,7 @@ def test_model_trainer_torch_mlp_with_none_train_config(create_mock_data_files):

train_dataset = DatasetTorch(
file_ids=file_list,
batch_size=16,
batch_size=20,
label_lower_bound=-16.0,
features_key="lan_data",
label_key="lan_labels",
Expand Down