diff --git a/ML/Pytorch/object_detection/YOLO/dataset.py b/ML/Pytorch/object_detection/YOLO/dataset.py index 2958da79..b8743eb2 100755 --- a/ML/Pytorch/object_detection/YOLO/dataset.py +++ b/ML/Pytorch/object_detection/YOLO/dataset.py @@ -73,16 +73,16 @@ def __getitem__(self, index): # If no object already found for specific cell i,j # Note: This means we restrict to ONE object # per cell! - if label_matrix[i, j, 20] == 0: + if label_matrix[i, j, self.C] == 0: # Set that there exists an object - label_matrix[i, j, 20] = 1 + label_matrix[i, j, self.C] = 1 # Box coordinates box_coordinates = torch.tensor( [x_cell, y_cell, width_cell, height_cell] ) - label_matrix[i, j, 21:25] = box_coordinates + label_matrix[i, j, (self.C + 1):(self.C + 5)] = box_coordinates # Set one hot encoding for class_label label_matrix[i, j, class_label] = 1