diff --git a/src/lanfactory/trainers/torch_mlp.py b/src/lanfactory/trainers/torch_mlp.py index 0605d85..e252f7d 100755 --- a/src/lanfactory/trainers/torch_mlp.py +++ b/src/lanfactory/trainers/torch_mlp.py @@ -74,10 +74,7 @@ def __init__( def __len__(self) -> int: # Number of batches per epoch - return ( - len(self.file_ids) - * ((self.file_shape_dict["inputs"][0] // self.batch_size) * self.batch_size) - ) // self.batch_size + return len(self.file_ids) * self.batches_per_file def __getitem__(self, index: int) -> tuple[np.ndarray, np.ndarray]: # Check if it is time to load the next file @@ -85,11 +82,9 @@ def __getitem__(self, index: int) -> tuple[np.ndarray, np.ndarray]: self.__load_file(file_index=self.indexes[index // self.batches_per_file]) # Generate and return a batch - batch_ids = np.arange( - ((index % self.batches_per_file) * self.batch_size), - ((index % self.batches_per_file) + 1) * self.batch_size, - 1, - ) + start_idx = (index % self.batches_per_file) * self.batch_size + end_idx = start_idx + self.batch_size + batch_ids = np.arange(start_idx, end_idx, 1) X, y = self.__data_generation(batch_ids) return X, y @@ -114,10 +109,21 @@ def __init_file_shape(self) -> None: "inputs": init_file[self.features_key].shape, "labels": init_file[self.label_key].shape, } - self.batches_per_file = int(self.file_shape_dict["inputs"][0] / self.batch_size) + + # Validate that samples_per_file is divisible by batch_size + samples_per_file = self.file_shape_dict["inputs"][0] + if samples_per_file % self.batch_size != 0: + raise ValueError( + f"samples_per_file ({samples_per_file}) must be divisible by " + f"batch_size ({self.batch_size}). Current remainder: " + f"{samples_per_file % self.batch_size}" + ) + + self.batches_per_file = samples_per_file // self.batch_size + self.input_dim = self.file_shape_dict["inputs"][1] - if "generator_config" in init_file.keys(): + if "generator_config" in init_file: self.data_generator_config = init_file["generator_config"] if len(self.file_shape_dict["labels"]) > 1: @@ -176,7 +182,7 @@ def __init__( self.input_shape = input_shape self.network_config = network_config - if "train_output_type" in self.network_config.keys(): + if "train_output_type" in self.network_config: self.train_output_type = self.network_config["train_output_type"] else: self.train_output_type = "logprob" @@ -393,25 +399,24 @@ def __get_scheduler(self) -> None: mode="min", factor=( self.train_config["lr_scheduler_params"]["factor"] - if "factor" in self.train_config["lr_scheduler_params"].keys() + if "factor" in self.train_config["lr_scheduler_params"] else 0.1 ), patience=( self.train_config["lr_scheduler_params"]["patience"] - if "patience" in self.train_config["lr_scheduler_params"].keys() + if "patience" in self.train_config["lr_scheduler_params"] else 2 ), threshold=( self.train_config["lr_scheduler_params"]["threshold"] - if "threshold" - in self.train_config["lr_scheduler_params"].keys() + if "threshold" in self.train_config["lr_scheduler_params"] else 0.001 ), threshold_mode="rel", cooldown=0, min_lr=( self.train_config["lr_scheduler_params"]["min_lr"] - if "min_lr" in self.train_config["lr_scheduler_params"].keys() + if "min_lr" in self.train_config["lr_scheduler_params"] else 0.00000001 ), ) @@ -420,7 +425,7 @@ def __get_scheduler(self) -> None: self.optimizer, gamma=( self.train_config["lr_scheduler_params"]["factor"] - if "factor" in self.train_config["lr_scheduler_params"].keys() + if "factor" in self.train_config["lr_scheduler_params"] else 0.1 ), last_epoch=-1, diff --git a/tests/cli/config_network_training_lan.yaml b/tests/cli/config_network_training_lan.yaml index 24cbf0a..9d7a397 100644 --- a/tests/cli/config_network_training_lan.yaml +++ b/tests/cli/config_network_training_lan.yaml @@ -1,6 +1,6 @@ NETWORK_TYPE: "lan" # Test 4 -CPU_BATCH_SIZE: 1000 -GPU_BATCH_SIZE: 5000 +CPU_BATCH_SIZE: 20000 +GPU_BATCH_SIZE: 20000 GENERATOR_APPROACH: "lan" # could maybe be deleted OPTIMIZER_: "adam" N_EPOCHS: 2 diff --git a/tests/constants.py b/tests/constants.py index 6932138..dd65947 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -41,8 +41,8 @@ class TestTrainConstantsLAN: """Test training constants.""" N_EPOCHS: int = 2 - CPU_BATCH_SIZE: int = 4196 - GPU_BATCH_SIZE: int = 4196 + CPU_BATCH_SIZE: int = 100000 + GPU_BATCH_SIZE: int = 100000 OPTIMIZER: str = "adam" LEARNING_RATE: float = 2e-06 LR_SCHEDULER: str = "reduce_on_plateau" @@ -57,8 +57,8 @@ class TestTrainConstantsCPN: """Test training constants.""" N_EPOCHS: int = 2 - CPU_BATCH_SIZE: int = 32 - GPU_BATCH_SIZE: int = 32 + CPU_BATCH_SIZE: int = 250 + GPU_BATCH_SIZE: int = 250 OPTIMIZER: str = "adam" LEARNING_RATE: float = 2e-06 LR_SCHEDULER: str = "reduce_on_plateau" @@ -73,8 +73,8 @@ class TestTrainConstantsOPN: """Test training constants.""" N_EPOCHS: int = 2 - CPU_BATCH_SIZE: int = 32 - GPU_BATCH_SIZE: int = 32 + CPU_BATCH_SIZE: int = 250 + GPU_BATCH_SIZE: int = 250 OPTIMIZER: str = "adam" LEARNING_RATE: float = 2e-06 LR_SCHEDULER: str = "reduce_on_plateau" diff --git a/tests/test_torch_mlp.py b/tests/test_torch_mlp.py index f510a80..d230d74 100644 --- a/tests/test_torch_mlp.py +++ b/tests/test_torch_mlp.py @@ -42,17 +42,17 @@ def test_dataset_torch_init(create_mock_data_files): # pylint: disable=redefine dataset = DatasetTorch( file_ids=file_list, - batch_size=128, + batch_size=100, label_lower_bound=-16.0, features_key="lan_data", label_key="lan_labels", ) # Verify attributes are set - assert dataset.batch_size == 128 + assert dataset.batch_size == 100 assert len(dataset.file_ids) == 2 assert dataset.input_dim == 6 - assert dataset.batches_per_file == 1000 // 128 # 7 batches per file + assert dataset.batches_per_file == 1000 // 100 # 10 batches per file def test_dataset_torch_len(create_mock_data_files): # pylint: disable=redefined-outer-name @@ -80,7 +80,7 @@ def test_dataset_torch_getitem_single_batch( dataset = DatasetTorch( file_ids=file_list, - batch_size=128, + batch_size=100, features_key="lan_data", label_key="lan_labels", ) @@ -88,8 +88,8 @@ def test_dataset_torch_getitem_single_batch( # Get first batch X, y = dataset[0] - assert X.shape == (128, 6) # batch_size x features - assert y.shape == (128, 1) # batch_size x 1 (expanded) + assert X.shape == (100, 6) # batch_size x features + assert y.shape == (100, 1) # batch_size x 1 (expanded) assert isinstance(X, np.ndarray) assert isinstance(y, np.ndarray) @@ -193,7 +193,7 @@ def test_dataset_torch_getitem_with_label_bounds( dataset = DatasetTorch( file_ids=file_list, - batch_size=128, + batch_size=100, label_lower_bound=-10.0, label_upper_bound=10.0, features_key="lan_data", @@ -221,7 +221,7 @@ def test_dataset_torch_with_2d_labels(tmp_path): dataset = DatasetTorch( file_ids=[str(file_path)], - batch_size=128, + batch_size=100, features_key="lan_data", label_key="lan_labels", ) @@ -229,8 +229,8 @@ def test_dataset_torch_with_2d_labels(tmp_path): X, y = dataset[0] # 2D labels should remain 2D - assert X.shape == (128, 6) - assert y.shape == (128, 3) + assert X.shape == (100, 6) + assert y.shape == (100, 3) def test_dataset_torch_sequential_access_pattern( @@ -286,7 +286,7 @@ def test_dataset_torch_with_jax_output( dataset = DatasetTorch( file_ids=file_list, - batch_size=128, + batch_size=100, features_key="lan_data", label_key="lan_labels", out_framework="jax", @@ -370,6 +370,29 @@ def test_dataset_torch_3d_labels_raises_error(tmp_path): X, y = dataset[0] +def test_dataset_torch_batch_size_not_divisible_raises_error(tmp_path): + """Test DatasetTorch raises ValueError when batch_size doesn't divide samples_per_file.""" + file_path = tmp_path / "training_data.pickle" + data = { + "lan_data": np.random.randn(1000, 6).astype(np.float32), + "lan_labels": np.random.randn(1000).astype(np.float32), + } + with open(file_path, "wb") as f: + pickle.dump(data, f) + + # batch_size=128 doesn't divide 1000 evenly (remainder 104) + with pytest.raises( + ValueError, + match=r"samples_per_file \(1000\) must be divisible by batch_size \(128\)\. Current remainder: 104", + ): + DatasetTorch( + file_ids=[str(file_path)], + batch_size=128, + features_key="lan_data", + label_key="lan_labels", + ) + + def test_model_trainer_torch_mlp_init_with_dict(): """Test ModelTrainerTorchMLP initialization with dict train_config.""" train_config = { @@ -723,14 +746,14 @@ def test_model_trainer_torch_mlp_with_mse_loss(create_mock_data_files): # Create datasets train_dataset = DatasetTorch( file_ids=file_list, - batch_size=16, + batch_size=20, label_lower_bound=-16.0, features_key="lan_data", label_key="lan_labels", ) valid_dataset = DatasetTorch( file_ids=file_list, - batch_size=16, + batch_size=20, label_lower_bound=-16.0, features_key="lan_data", label_key="lan_labels", @@ -790,14 +813,14 @@ def test_model_trainer_torch_mlp_with_bce_loss(create_mock_data_files): # Create datasets train_dataset = DatasetTorch( file_ids=file_list, - batch_size=16, + batch_size=20, label_lower_bound=-16.0, features_key="lan_data", label_key="lan_labels", ) valid_dataset = DatasetTorch( file_ids=file_list, - batch_size=16, + batch_size=20, label_lower_bound=-16.0, features_key="lan_data", label_key="lan_labels", @@ -857,14 +880,14 @@ def test_model_trainer_torch_mlp_with_sgd_optimizer(create_mock_data_files): # Create datasets train_dataset = DatasetTorch( file_ids=file_list, - batch_size=16, + batch_size=20, label_lower_bound=-16.0, features_key="lan_data", label_key="lan_labels", ) valid_dataset = DatasetTorch( file_ids=file_list, - batch_size=16, + batch_size=20, label_lower_bound=-16.0, features_key="lan_data", label_key="lan_labels", @@ -993,7 +1016,7 @@ def test_model_trainer_torch_mlp_with_none_train_config(create_mock_data_files): train_dataset = DatasetTorch( file_ids=file_list, - batch_size=16, + batch_size=20, label_lower_bound=-16.0, features_key="lan_data", label_key="lan_labels",