Skip to content

Commit

Permalink
added transfer learning demos
Browse files Browse the repository at this point in the history
  • Loading branch information
jmacey committed Dec 11, 2024
1 parent 3786153 commit 2b9d1f6
Show file tree
Hide file tree
Showing 3 changed files with 229 additions and 94 deletions.
92 changes: 42 additions & 50 deletions PreTrainedModels/PreTrainedModelsPart1.ipynb

Large diffs are not rendered by default.

218 changes: 174 additions & 44 deletions PreTrainedModels/TransferLearning.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
},
{
"cell_type": "code",
"execution_count": 6,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand All @@ -28,19 +28,7 @@
}
],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.optim import Adam\n",
"from torch.utils.data import Dataset, DataLoader\n",
"import torchvision.transforms.v2 as transforms\n",
"import torchvision.io as tv_io\n",
"from torchvision.models import vgg16\n",
"from torchvision.models import VGG16_Weights\n",
"import sys\n",
"import pathlib\n",
"import glob\n",
"import json\n",
"from PIL import Image\n",
"from imports import *\n",
"\n",
"\n",
"sys.path.append(\"../\")\n",
Expand All @@ -62,13 +50,52 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"We will now download the pre-trained model and as before and do the setup "
"## Dataset download\n",
"\n",
"In this demo we are going to add a number of images of different types of pokemon and train the model to recognize them, we also need something that is not a pokemon so we need to download other images too. We will then use the model to determine if we have a pokemon or animal image.\n",
"\n",
"There are a number of datasets on Kaggle we can use. In this case we are going to use the following datasets:\n",
"\n",
"https://www.kaggle.com/api/v1/datasets/download/vishalsubbiah/pokemon-images-and-types\n",
"\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"url = \"https://www.kaggle.com/api/v1/datasets/download/vishalsubbiah/pokemon-images-and-types\"\n",
"\n",
"desitnation = DATASET_LOCATION + \"pokemon.zip\"\n",
"if not pathlib.Path(desitnation).exists():\n",
" Utils.download(url, desitnation)\n",
" Utils.unzip_file(desitnation, DATASET_LOCATION)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# VGG16 Model\n",
"\n",
"we are going to usethe vgg16 model which has a 1000 categories, we will remove the last layer and add a new layer with 1001 categories so we can add pokemon as a new category.\n",
"\n",
"We also need to add the new images to the dataset and retrain the model with a new label 1001 for the pokemon images.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"We will now download the pre-trained model and as before and do the setup "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
Expand Down Expand Up @@ -120,18 +147,128 @@
")"
]
},
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from torchvision.models import vgg16\n",
"from torchvision.models import VGG16_Weights\n",
"\n",
"# load the VGG16 network *pre-trained* on the ImageNet dataset\n",
"weights = VGG16_Weights.DEFAULT\n",
"vgg_model = vgg16(weights=weights)\n",
"vgg_model.to(device)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Loaded 809 images\n"
]
}
],
"source": [
"import pathlib\n",
"\n",
"pre_trans = weights.transforms()\n",
"\n",
"# IMAGE_WIDTH, IMAGE_HEIGHT = (224, 224)\n",
"\n",
"# pre_trans = transforms.Compose([\n",
"# transforms.ToDtype(torch.float32, scale=True), # Converts [0, 255] to [0, 1]\n",
"# transforms.Resize((IMAGE_WIDTH, IMAGE_HEIGHT)),\n",
"# transforms.Normalize(\n",
"# mean=[0.485, 0.456, 0.406],\n",
"# std=[0.229, 0.224, 0.225],\n",
"# ),\n",
"# transforms.CenterCrop(224)\n",
"# ])\n",
"\n",
"\n",
"class MyDataset(Dataset):\n",
" def __init__(self, data_dir):\n",
" self.imgs = []\n",
" self.labels = []\n",
" images=list(pathlib.Path(data_dir).rglob(\"*.png\"))\n",
" for image in images:\n",
" img = Image.open(image).convert(\"RGB\")\n",
" img_transformed=pre_trans(img)\n",
" self.imgs.append(img_transformed.to(device))\n",
" self.labels.append(torch.tensor(1001).to(device).float())\n",
" print(f\"Loaded {len(self.imgs)} images\")\n",
"\n",
" def __getitem__(self, idx):\n",
" img = self.imgs[idx]\n",
" label = self.labels[idx]\n",
" return img, label\n",
"\n",
"\n",
" def __len__(self):\n",
" return len(self.imgs)\n",
" \n",
"data_loader = DataLoader(MyDataset(DATASET_LOCATION)) "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The vgg16 model has a classifier attribute, which is a sequential module defining the fully connected layers. The last layer is the classification layer."
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Sequential(\n",
" (0): Linear(in_features=25088, out_features=4096, bias=True)\n",
" (1): ReLU(inplace=True)\n",
" (2): Dropout(p=0.5, inplace=False)\n",
" (3): Linear(in_features=4096, out_features=4096, bias=True)\n",
" (4): ReLU(inplace=True)\n",
" (5): Dropout(p=0.5, inplace=False)\n",
" (6): Linear(in_features=4096, out_features=1000, bias=True)\n",
")\n"
]
}
],
"source": [
"print(vgg_model.classifier)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"The final layer outputs logits for 1000 classes (ImageNet categories). To add a new category, you must replace this layer."
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Number of input features to the final layer\n",
"num_features = vgg_model.classifier[6].in_features\n",
"\n",
"# Replace the final layer\n",
"vgg_model.classifier[6] = torch.nn.Linear(num_features, 1001).to(device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
Expand All @@ -151,43 +288,36 @@
]
},
{
"cell_type": "markdown",
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"## Dataset download\n",
"\n",
"In this demo we are going to add a number of images of different types of pokemon and train the model to recognize them, we also need something that is not a pokemon so we need to download other images too. We will then use the model to determine if we have a pokemon or animal image.\n",
"\n",
"There are a number of datasets on Kaggle we can use. In this case we are going to use the following datasets:\n",
"\n",
"https://www.kaggle.com/api/v1/datasets/download/vishalsubbiah/pokemon-images-and-types\n",
"\n"
"for param in vgg_model.features.parameters():\n",
" param.requires_grad = False\n",
" "
]
},
{
"cell_type": "code",
"execution_count": 13,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/Users/jmacey/.pyenv/versions/anaconda3-2024.02-1/lib/python3.11/site-packages/urllib3/connectionpool.py:1100: InsecureRequestWarning: Unverified HTTPS request is being made to host 'www.kaggle.com'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#tls-warnings\n",
" warnings.warn(\n",
"/Users/jmacey/.pyenv/versions/anaconda3-2024.02-1/lib/python3.11/site-packages/urllib3/connectionpool.py:1100: InsecureRequestWarning: Unverified HTTPS request is being made to host 'storage.googleapis.com'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#tls-warnings\n",
" warnings.warn(\n",
"./pokemon/pokemon.zip: 100%|██████████| 3.68M/3.68M [00:01<00:00, 2.48MiB/s]\n"
]
}
],
"outputs": [],
"source": [
"url = \"https://www.kaggle.com/api/v1/datasets/download/vishalsubbiah/pokemon-images-and-types\"\n",
"## Fine Tuning\n",
"\n",
"desitnation = DATASET_LOCATION + \"pokemon.zip\"\n",
"if not pathlib.Path(desitnation).exists():\n",
" Utils.download(url, desitnation)\n",
" Utils.unzip_file(desitnation, DATASET_LOCATION)"
"loss_function = torch.nn.CrossEntropyLoss()\n",
"optimizer = torch.optim.Adam(vgg_model.classifier.parameters(), lr=0.001)\n",
"\n",
"# Example training loop\n",
"num_epochs = 10\n",
"for epoch in range(num_epochs):\n",
" for inputs, labels in data_loader: \n",
" optimizer.zero_grad()\n",
" outputs = vgg_model(inputs)\n",
" loss = loss_function(outputs, labels)\n",
" loss.backward()\n",
" optimizer.step()\n"
]
},
{
Expand Down
13 changes: 13 additions & 0 deletions PreTrainedModels/imports.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import json
import pathlib
import sys

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import torch
import torchvision.io as tv_io
import torchvision.transforms.functional as F
import torchvision.transforms.v2 as transforms
from PIL import Image
from torch.utils.data import Dataset, DataLoader

0 comments on commit 2b9d1f6

Please sign in to comment.