PyTorch implementation of hierarchical classification for CIFAR5.
The CIFAR5 dataset is a subset of the CIFAR10 dataset and contains the target classes: plane, car, bird, horse & truck.
The hierarchical classification scheme is shown in the following diagram:
To run:
python main.py
The code uses the torch.transforms library to convert the CIFAR5 input images from dimensions of (3,32,32) to (1,50,50):
transform = transforms.Compose([
transforms.Grayscale(num_output_channels=1),
transforms.Resize(50),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
User defined parameters are set at the start of the main.py script:
batch_size = 4 # number of samples per mini-batch
imsize = 50 # image size
params = [2,4,5] # [coarse1, coarse2, fine]
weights = [0.8,0.1,0.1] # weights for loss function
lr0 = torch.tensor(1e-3) # speed of convergence
momentum = torch.tensor(8e-1) # momentum for optimizer
decay = torch.tensor(1e-6) # weight decay for regularisation
random_seed = 42
This runs the BCNN defined in models.py, which has the structure:
The loss function is defined as:
where l1 corresponds to the cross entropy loss from the COARSE1 level, l2 corresponds to the cross entropy loss from the COARSE2 level and l3 corresponds to the cross entropy loss from the FINE level. A vector of weights controls the contribution of each level to the combined loss function.
For example, an input image of a truck would have the following one hot vectors as its targets for each of the three loss components: [[1 0][0 1 0 0][0 0 1 0 0]]. Coarse level classifications are calculated from the original CIFAR5 target classification in the utils.py script.