diff --git a/main.py b/main.py index 15d1456..8c6870a 100644 --- a/main.py +++ b/main.py @@ -9,25 +9,56 @@ import torchvision.transforms as transforms # Define the neural network architecture +import torch +import torch.nn as nn + class Net(nn.Module): def __init__(self): super(Net, self).__init__() - self.conv1 = nn.Conv2d(3, 6, 5) + + # Changed convolution layers to lazy convolution layers. + # Perks: Lazy layers defer the initialization of parameters until the input is passed through. + # This avoids the need to explicitly define input sizes beforehand, making the model more flexible. + # It also helps in memory efficiency as the actual memory is allocated only when the layers are used. + # Added BatchNormalisation for faster convergence + # Added Dropout for regularisation + # Added adaptive_pool to avoid any tensor size mismatches + + self.conv1 = nn.LazyConv2d(6, 5) + self.bn1 = nn.BatchNorm2d(6) # Added Batch Normalization self.pool = nn.MaxPool2d(2, 2) - self.conv2 = nn.Conv2d(6, 16, 5) + + self.conv2 = nn.LazyConv2d(16, 5) + self.bn2 = nn.BatchNorm2d(16) # Added Batch Normalization + + self.adaptive_pool = nn.AdaptiveAvgPool2d((5, 5)) # Adaptive pooling instead of fixed size + + # Fully connected layers self.fc1 = nn.Linear(16 * 5 * 5, 120) + self.dropout1 = nn.Dropout(0.5) # Added Dropout for regularization + self.fc2 = nn.Linear(120, 84) + self.dropout2 = nn.Dropout(0.5) # Added Dropout for regularization + self.fc3 = nn.Linear(84, 10) def forward(self, x): - x = self.pool(torch.relu(self.conv1(x))) - x = self.pool(torch.relu(self.conv2(x))) - x = x.view(-1, 16 * 5 * 5) + x = self.pool(torch.relu(self.bn1(self.conv1(x)))) + x = self.pool(torch.relu(self.bn2(self.conv2(x)))) + x = self.adaptive_pool(x) + + x = torch.flatten(x, 1) # Flatten feature maps + x = torch.relu(self.fc1(x)) + x = self.dropout1(x) # Dropout after fully connected layer + x = torch.relu(self.fc2(x)) + x = self.dropout2(x) # Dropout after fully connected layer + x = self.fc3(x) return x + # Load and preprocess the data transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)