From 934d1d3789555677a91a12f9f91f1268dfc4f74a Mon Sep 17 00:00:00 2001 From: Ashwin Vaswani Date: Wed, 9 Jun 2021 12:54:42 +0530 Subject: [PATCH] Error handling and minor bug fixes in CSKD (#100) * pairwise sampler added and csdk updated * links added in init * Finalised and logs added * CSKD added with tests * Docs added * Testing internal kdloss * Adding docstrings and paper summary in tutorial * Minor correction in docs * Error handling for non-none teacher added * cskd reformatted --- KD_Lib/KD/vision/CSKD/cskd.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/KD_Lib/KD/vision/CSKD/cskd.py b/KD_Lib/KD/vision/CSKD/cskd.py index fac0ae18..8ab1db3e 100644 --- a/KD_Lib/KD/vision/CSKD/cskd.py +++ b/KD_Lib/KD/vision/CSKD/cskd.py @@ -11,14 +11,14 @@ class CSKD(BaseClass): """ - Implementation of assisted Knowledge distillation from the paper "Improved Knowledge - Distillation via Teacher Assistant" https://arxiv.org/pdf/1902.03393.pdf + Implementation of "Regularizing Class-wise Predictions via Self-knowledge Distillation" + https://arxiv.org/pdf/2003.13964.pdf - :param teacher_model (torch.nn.Module): Teacher model + :param teacher_model (torch.nn.Module): Teacher model -> Should be None :param student_model (torch.nn.Module): Student model :param train_loader (torch.utils.data.DataLoader): Dataloader for training :param val_loader (torch.utils.data.DataLoader): Dataloader for validation/testing - :param optimizer_teacher (torch.optim.*): Optimizer used for training teacher + :param optimizer_teacher (torch.optim.*): Optimizer used for training teacher -> Should be None :param optimizer_student (torch.optim.*): Optimizer used for training student :param loss_fn (torch.nn.Module): Calculates loss during distillation :param temp (float): Temperature parameter for distillation @@ -60,6 +60,11 @@ def __init__( logdir, ) self.lamda = lamda + if teacher_model is not None or optimizer_teacher is not None: + print( + "Error!!! Teacher model and Teacher optimizer should be None for self-distillation, please refer to the documentation." + ) + assert teacher_model == None def calculate_kd_loss(self, y_pred_pair_1, y_pred_pair_2): """