diff --git a/CIFAR10/output/confusion_matrix_0.png b/CIFAR10/output/confusion_matrix_0.png new file mode 100644 index 00000000..3ed88246 Binary files /dev/null and b/CIFAR10/output/confusion_matrix_0.png differ diff --git a/CIFAR10/output/confusion_matrix_1.png b/CIFAR10/output/confusion_matrix_1.png new file mode 100644 index 00000000..bfe10513 Binary files /dev/null and b/CIFAR10/output/confusion_matrix_1.png differ diff --git a/CIFAR10/output/confusion_matrix_2.png b/CIFAR10/output/confusion_matrix_2.png new file mode 100644 index 00000000..0ac8e328 Binary files /dev/null and b/CIFAR10/output/confusion_matrix_2.png differ diff --git a/CIFAR10/output/confusion_matrix_3.png b/CIFAR10/output/confusion_matrix_3.png new file mode 100644 index 00000000..e0f225fa Binary files /dev/null and b/CIFAR10/output/confusion_matrix_3.png differ diff --git a/CIFAR10/output/confusion_matrix_4.png b/CIFAR10/output/confusion_matrix_4.png new file mode 100644 index 00000000..d12e6ad4 Binary files /dev/null and b/CIFAR10/output/confusion_matrix_4.png differ diff --git a/CIFAR10/output/confusion_matrix_5.png b/CIFAR10/output/confusion_matrix_5.png new file mode 100644 index 00000000..b21498bb Binary files /dev/null and b/CIFAR10/output/confusion_matrix_5.png differ diff --git a/CIFAR10/output/confusion_matrix_6.png b/CIFAR10/output/confusion_matrix_6.png new file mode 100644 index 00000000..34718b6c Binary files /dev/null and b/CIFAR10/output/confusion_matrix_6.png differ diff --git a/CIFAR10/output/confusion_matrix_7.png b/CIFAR10/output/confusion_matrix_7.png new file mode 100644 index 00000000..081d065f Binary files /dev/null and b/CIFAR10/output/confusion_matrix_7.png differ diff --git a/CIFAR10/output/metrics_0.json b/CIFAR10/output/metrics_0.json new file mode 100644 index 00000000..e9a3bf57 --- /dev/null +++ b/CIFAR10/output/metrics_0.json @@ -0,0 +1,130 @@ +{ + "epoch": 0, + "loss": 0.036152522850036624, + "accuracy": 0.0987, + "precision": 0.00987, + "recall": 0.1, + "f1": 0.017966687903886412, + "confusion_matrix": [ + [ + 0, + 991, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 987, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 990, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 968, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 984, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 1019, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 1041, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 988, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 1007, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 1025, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ] + ] +} \ No newline at end of file diff --git a/CIFAR10/output/metrics_1.json b/CIFAR10/output/metrics_1.json new file mode 100644 index 00000000..c27b0830 --- /dev/null +++ b/CIFAR10/output/metrics_1.json @@ -0,0 +1,130 @@ +{ + "epoch": 1, + "loss": 0.07208095054626465, + "accuracy": 0.0987, + "precision": 0.00987, + "recall": 0.1, + "f1": 0.017966687903886412, + "confusion_matrix": [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 1000, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 989, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 1004, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 1003, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 937, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 1045, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 987, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 1030, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 988, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 1017, + 0, + 0, + 0 + ] + ] +} \ No newline at end of file diff --git a/CIFAR10/output/metrics_2.json b/CIFAR10/output/metrics_2.json new file mode 100644 index 00000000..aadc57f7 --- /dev/null +++ b/CIFAR10/output/metrics_2.json @@ -0,0 +1,130 @@ +{ + "epoch": 2, + "loss": 0.07207579364776612, + "accuracy": 0.0987, + "precision": 0.00987, + "recall": 0.1, + "f1": 0.017966687903886412, + "confusion_matrix": [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 1031, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 1010, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 967, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 1056, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 1000, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 985, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 987, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 1006, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 1002, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 956, + 0, + 0, + 0 + ] + ] +} \ No newline at end of file diff --git a/CIFAR10/output/metrics_3.json b/CIFAR10/output/metrics_3.json new file mode 100644 index 00000000..4499d3c6 --- /dev/null +++ b/CIFAR10/output/metrics_3.json @@ -0,0 +1,130 @@ +{ + "epoch": 3, + "loss": 0.07207806870937347, + "accuracy": 0.0967, + "precision": 0.00967, + "recall": 0.1, + "f1": 0.017634722348864776, + "confusion_matrix": [ + [ + 0, + 0, + 1031, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 0, + 1010, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 0, + 967, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 0, + 1056, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 0, + 1000, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 0, + 985, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 0, + 987, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 0, + 1006, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 0, + 1002, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ], + [ + 0, + 0, + 956, + 0, + 0, + 0, + 0, + 0, + 0, + 0 + ] + ] +} \ No newline at end of file diff --git a/CIFAR10/output/metrics_4.json b/CIFAR10/output/metrics_4.json new file mode 100644 index 00000000..6c2d8c29 --- /dev/null +++ b/CIFAR10/output/metrics_4.json @@ -0,0 +1,130 @@ +{ + "epoch": 4, + "loss": 0.07220881459712983, + "accuracy": 0.0984, + "precision": 0.00984, + "recall": 0.1, + "f1": 0.0179169701383831, + "confusion_matrix": [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 982, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1048, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 944, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1015, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1051, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1020, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1005, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 963, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 984, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 988, + 0 + ] + ] +} \ No newline at end of file diff --git a/CIFAR10/output/metrics_5.json b/CIFAR10/output/metrics_5.json new file mode 100644 index 00000000..c635318b --- /dev/null +++ b/CIFAR10/output/metrics_5.json @@ -0,0 +1,130 @@ +{ + "epoch": 5, + "loss": 0.0724243151664734, + "accuracy": 0.1005, + "precision": 0.01005, + "recall": 0.1, + "f1": 0.01826442526124489, + "confusion_matrix": [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 982, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 1048, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 944, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 1015, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 1051, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 1020, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 1005, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 963, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 984, + 0, + 0, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 988, + 0, + 0, + 0 + ] + ] +} \ No newline at end of file diff --git a/CIFAR10/output/metrics_6.json b/CIFAR10/output/metrics_6.json new file mode 100644 index 00000000..9afaedbb --- /dev/null +++ b/CIFAR10/output/metrics_6.json @@ -0,0 +1,130 @@ +{ + "epoch": 6, + "loss": 0.07233350160121918, + "accuracy": 0.0984, + "precision": 0.00984, + "recall": 0.1, + "f1": 0.0179169701383831, + "confusion_matrix": [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 982, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1048, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 944, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1015, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1051, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1020, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1005, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 963, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 984, + 0 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 988, + 0 + ] + ] +} \ No newline at end of file diff --git a/CIFAR10/output/metrics_7.json b/CIFAR10/output/metrics_7.json new file mode 100644 index 00000000..c9fa9fc1 --- /dev/null +++ b/CIFAR10/output/metrics_7.json @@ -0,0 +1,130 @@ +{ + "epoch": 7, + "loss": 0.07240547826290131, + "accuracy": 0.0988, + "precision": 0.00988, + "recall": 0.1, + "f1": 0.017983254459410267, + "confusion_matrix": [ + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 982 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1048 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 944 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1015 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1051 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1020 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1005 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 963 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 984 + ], + [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 988 + ] + ] +} \ No newline at end of file diff --git a/main_transformer.py b/main_transformer.py new file mode 100644 index 00000000..eaac8765 --- /dev/null +++ b/main_transformer.py @@ -0,0 +1,99 @@ +import os +import torch +import torch.nn as nn +import torch.optim as optim +from torchvision import datasets, transforms +import time + +from networks.vision_transformer import VisionTransformer +from dataloader.cifar10_dataset import CIFAR10Dataset +from dataloader.dataloader import get_data_loaders +from train.train import train +from train.val import val +from test.test import test + + +# Hyperparameters (can use CLI) +batch_size = 64 +learning_rate = 0.00001 +epochs = 10 + +# transformer params +img_size=32 +patch_size=4 +in_channels=3 # CIFAR-10 images are RGB, so 3 input channels +embed_size=128 # Smaller embedding size for a smaller dataset +num_layers=6 # Fewer layers might be sufficient for CIFAR-10 +num_heads=8 # Adjusted number of heads +num_classes=10 # CIFAR-10 has 10 classes +dropout=0.01 + +# device +device = torch.device("mps" if torch.backends.mps.is_available() else "cpu") +# device = 'cpu' +print(f"Using device: {device}") + + +def count_parameters(model): + return sum(p.numel() for p in model.parameters() if p.requires_grad) + + +def main(): + start_time = time.time() + + # Define transforms + transform = transforms.Compose([ + transforms.Resize((32, 32)), + transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) + ]) + + # Create the datasets + data_path = 'CIFAR10/' + if not os.path.exists(os.path.join(data_path, 'output')): + os.mkdir(os.path.join(data_path, 'output')) + train_dataset = CIFAR10Dataset(os.path.join(data_path, 'train'), transform=transform) + test_dataset = CIFAR10Dataset(os.path.join(data_path, 'test'), transform=transform) + + # Create data loaders + train_loader, val_loader, test_loader = get_data_loaders(train_dataset, test_dataset, batch_size) + + # Model & loss & optimizer + model = VisionTransformer( + img_size=img_size, + patch_size=patch_size, + in_channels=in_channels, + embed_size=embed_size, + num_layers=num_layers, + num_heads=num_heads, + num_classes=num_classes, + dropout=dropout + ).to(device) + criterion = nn.CrossEntropyLoss() + optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9) + # optimizer = optim.Adam(model.parameters(), lr=learning_rate) + + print(model) + print(count_parameters(model)) + + # Train the model + for epoch in range(epochs): + print('training') + train(model, device, train_loader, optimizer, criterion, epoch) + print('validating') + val(model, device, val_loader, criterion, epoch, data_path) + + # Test the model + test(model, device, test_loader, criterion, data_path) + + # Save the model checkpoint + torch.save(model.state_dict(), f'{data_path}output/model.pth') + print('Finished Training. Model saved as model.pth.') + + end_time = time.time() + print("Total Time: ", end_time-start_time) + print("Start Time: ", start_time) + print("End Time: ", end_time) + + +if __name__ == '__main__': + main() diff --git a/networks/vision_transformer.py b/networks/vision_transformer.py new file mode 100644 index 00000000..90650fd2 --- /dev/null +++ b/networks/vision_transformer.py @@ -0,0 +1,170 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +# https://uvadlc-notebooks.readthedocs.io/en/latest/tutorial_notebooks/JAX/tutorial15/Vision_Transformer.html + +# # patch embedding for transformer model +# # split image into patches and put them in the embedding space +# class PatchEmbedding(nn.Module): +# def __init__(self, img_size, patch_size, in_channels, embed_size): +# super().__init__() +# self.img_size = img_size +# self.patch_size = patch_size +# self.n_patches = (img_size // patch_size) ** 2 +# self.in_channels = in_channels +# self.embed_size = embed_size + +# self.flatten = nn.Flatten(2) +# self.projection = nn.Linear(patch_size * patch_size * in_channels, embed_size) + +# def forward(self, x): +# x = x.unfold(2, self.patch_size, self.patch_size).unfold(3, self.patch_size, self.patch_size) +# x = x.contiguous().view(x.size(0), -1, self.patch_size * self.patch_size * self.in_channels) +# x = self.projection(x) +# return x + + +# patch embedding for transformer model +# split image into patches and put them in the embedding space +# using CNN to do feature extraction +# class PatchEmbedding(nn.Module): +# def __init__(self, img_size, patch_size, in_channels, embed_size): +# super().__init__() +# self.img_size = img_size +# self.patch_size = patch_size +# self.n_patches = (img_size // patch_size) ** 2 +# self.embed_size = embed_size + +# self.proj = nn.Conv2d(in_channels, embed_size, kernel_size=patch_size, stride=patch_size) + +# def forward(self, x): +# x = self.proj(x) # shape: [batch_size, embed_size, H', W'] +# x = x.flatten(2) # shape: [batch_size, embed_size, n_patches] +# x = x.transpose(1, 2) # shape: [batch_size, n_patches, embed_size] +# return x + + +# patch embedding for transformer model +# split image into patches and put them in the embedding space +# using CNN to do feature extraction +class PatchEmbedding(nn.Module): + def __init__(self, img_size=32, patch_size=4, in_channels=3, embed_size=64): + super().__init__() + self.img_size = img_size + self.patch_size = patch_size + self.n_patches = (img_size // patch_size) ** 2 + self.embed_size = embed_size + + self.proj = nn.Conv2d(in_channels, embed_size, kernel_size=patch_size, stride=patch_size) + + def forward(self, x): + x = self.proj(x) # shape: [batch_size, embed_size, H', W'] + x = x.flatten(2) # shape: [batch_size, embed_size, n_patches] + x = x.transpose(1, 2) # shape: [batch_size, n_patches, embed_size] + return x + + +# actual multi-head attention and feed forward network +class TransformerBlock(nn.Module): + def __init__(self, embed_size, heads, dropout): + super().__init__() + self.attention = nn.MultiheadAttention(embed_size, heads, dropout=dropout) + self.norm1 = nn.LayerNorm(embed_size) + self.norm2 = nn.LayerNorm(embed_size) + + self.feed_forward = nn.Sequential( + nn.Linear(embed_size, embed_size * 4), + # nn.GELU(), + nn.Linear(embed_size * 4, embed_size), + nn.Dropout(dropout) + ) + + def forward(self, x): + attn_output, _ = self.attention(x, x, x) + x = self.norm1(attn_output + x) + x = self.norm2(self.feed_forward(x) + x) + return x + + +# combining the patch layer and multi-head attention layer +# class VisionTransformer(nn.Module): +# def __init__(self, img_size, patch_size, in_channels, embed_size, num_layers, num_heads, num_classes, dropout): +# super().__init__() +# self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embed_size) +# self.cls_token = nn.Parameter(torch.randn(1, 1, embed_size)) +# self.positional_embedding = nn.Parameter(torch.randn(1, 1 + self.patch_embedding.n_patches, embed_size)) +# self.layers = nn.ModuleList([TransformerBlock(embed_size, num_heads, dropout) for _ in range(num_layers)]) +# self.to_cls_token = nn.Identity() +# self.mlp_head = nn.Sequential( +# nn.LayerNorm(embed_size), +# nn.Linear(embed_size, embed_size), +# nn.Linear(embed_size, num_classes), +# nn.LogSoftmax(dim=1) +# ) + +# def forward(self, x): +# x = self.patch_embedding(x) +# cls_token = self.cls_token.expand(x.shape[0], -1, -1) +# x = torch.cat((cls_token, x), dim=1) +# x += self.positional_embedding +# for layer in self.layers: +# x = layer(x) +# x = self.to_cls_token(x[:, 0]) +# return self.mlp_head(x) + +# combining the patch layer and multi-head attention layer +# update with changes to patch embeddings +class VisionTransformer(nn.Module): + def __init__(self, img_size, patch_size, in_channels, embed_size, num_layers, num_heads, num_classes, dropout): + super().__init__() + self.patch_embedding = PatchEmbedding(img_size, patch_size, in_channels, embed_size) + self.cls_token = nn.Parameter(torch.randn(1, 1, embed_size)) + self.positional_embedding = nn.Parameter(torch.randn(1, 1 + self.patch_embedding.n_patches, embed_size)) + self.layers = nn.ModuleList([TransformerBlock(embed_size, num_heads, dropout) for _ in range(num_layers)]) + self.to_cls_token = nn.Identity() + self.mlp_head = nn.Sequential( + nn.LayerNorm(embed_size), + nn.Linear(embed_size, embed_size), + nn.Linear(embed_size, num_classes), + nn.Softmax(dim=1) + # nn.LogSoftmax(dim=1) + ) + + def forward(self, x): + x = self.patch_embedding(x) + cls_token = self.cls_token.expand(x.shape[0], -1, -1) + x = torch.cat((cls_token, x), dim=1) + x += self.positional_embedding + for layer in self.layers: + x = layer(x) + x = self.to_cls_token(x[:, 0]) + return self.mlp_head(x) + + +if __name__ == '__main__': + # CIFAR10 model + model = VisionTransformer( + img_size=32, + patch_size=4, + in_channels=3, # CIFAR-10 images are RGB, so 3 input channels + embed_size=64, # Smaller embedding size for a smaller dataset + num_layers=6, # Fewer layers might be sufficient for CIFAR-10 + num_heads=8, # Adjusted number of heads + num_classes=10, # CIFAR-10 has 10 classes + dropout=0.1 + ) + + example_input = torch.randn(16, 3, 32, 32) # (batch_size, channels, height, width) + example_target = torch.randint(0, 10, (16,)) + + # Forward pass through the model + output = model(example_input) + print(output.shape) + print(output[0]) + + criterion = nn.CrossEntropyLoss() + # optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) + + loss = criterion(output, example_target) + print(loss) diff --git a/run_caffeinated.sh b/run_caffeinated.sh new file mode 100755 index 00000000..95a05c92 --- /dev/null +++ b/run_caffeinated.sh @@ -0,0 +1,3 @@ +echo "start script caffeinated" +caffeinate python main_transformer.py +echo "end script caffeinated" diff --git a/train/val.py b/train/val.py index 1db2d1a1..e47ebe73 100644 --- a/train/val.py +++ b/train/val.py @@ -32,6 +32,9 @@ def val(model, device, val_loader, criterion, epoch, output_path, num_classes=10 for t, p in zip(target.view(-1), pred.view(-1)): conf_mat[t.long(), p.long()] += 1 + print(output) + print(target) + # Metrics calculation val_loss /= len(val_loader.dataset) accuracy = correct / len(val_loader.dataset)