Skip to content

Commit

Permalink
Merge branch 'mnist' of github.com:realityengines/post_hoc_debiasing …
Browse files Browse the repository at this point in the history
…into mnist
  • Loading branch information
yashsavani committed Jul 14, 2020
2 parents a659e35 + ef2d855 commit cbccb3a
Showing 1 changed file with 42 additions and 253 deletions.
295 changes: 42 additions & 253 deletions celebA.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"# CelebA"
"# Visualize debiasing experiments on CelebA"
]
},
{
Expand All @@ -25,25 +25,13 @@
"outputs": [],
"source": [
"import numpy as np\n",
"import matplotlib.pyplot as plt\n",
"import sys\n",
"import os\n",
"from os import listdir\n",
"from os.path import isfile, join\n",
"from pathlib import Path\n",
"from PIL import Image\n",
"import cv2\n",
"\n",
"from os.path import join\n",
"import matplotlib.pyplot as plt\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch.optim as optim\n",
"import yaml\n",
"from sklearn.metrics import roc_auc_score\n",
"from torchvision import models, transforms\n",
"\n",
"from celeb_race import CelebRace, unambiguous\n",
"from post_hoc_celeba import load_celeba, get_resnet_model"
"from post_hoc_celeba import load_celeba, get_resnet_model\n",
"from PIL import Image"
]
},
{
Expand All @@ -53,7 +41,6 @@
"outputs": [],
"source": [
"device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
"torch.manual_seed(0)\n",
"\n",
"descriptions = ['5_o_Clock_Shadow', 'Arched_Eyebrows', 'Attractive',\n",
" 'Bags_Under_Eyes', 'Bald', 'Bangs', 'Big_Lips', 'Big_Nose',\n",
Expand All @@ -65,7 +52,10 @@
" 'Receding_Hairline', 'Rosy_Cheeks', 'Sideburns', 'Smiling',\n",
" 'Straight_Hair', 'Wavy_Hair', 'Wearing_Earrings', 'Wearing_Hat',\n",
" 'Wearing_Lipstick', 'Wearing_Necklace', 'Wearing_Necktie',\n",
" 'Young', 'White', 'Black', 'Asian', 'Index']"
" 'Young', 'White', 'Black', 'Asian', 'Index']\n",
"\n",
"def sigmoid(x):\n",
" return 1/(1 + np.exp(-x)) "
]
},
{
Expand All @@ -75,26 +65,22 @@
"outputs": [],
"source": [
"def image_from_index(index, folder='~/post_hoc_debiasing/data/celeba/img_align_celeba/', show=False):\n",
" # given the index of the image, output the image\n",
" file = str(index).zfill(6)+'.jpg'\n",
" img = Image.open(join(os.path.expanduser(folder), file))\n",
" if show:\n",
" plt.imshow(img)\n",
" plt.show()\n",
" return img\n",
"\n",
"def imshow_from_tensor(img):\n",
" # plot from a tensor. Only works for non-transformed data\n",
" npimg = img.numpy()\n",
" plt.imshow(np.transpose(npimg, (1, 2, 0)))\n",
" plt.show()\n",
" \n",
"def imshow_group(data, n):\n",
"def imshow_group(imgs, n):\n",
" # plot multiple images at once\n",
" plt.figure(figsize=(20,10))\n",
" columns = n\n",
" \n",
" for i in range(n):\n",
" plt.subplot(1, columns, i + 1)\n",
" img = data[i]\n",
" img = imgs[i]\n",
" #img = img.astype(int)\n",
" plt.axis('off')\n",
" plt.imshow(img)"
Expand All @@ -108,39 +94,36 @@
"source": [
"def output_debiased_imgs(biased_net,\n",
" debiased_net,\n",
" testloader,\n",
" loader,\n",
" protected_attr,\n",
" prediction_attr, \n",
" threshold):\n",
" prediction_attr):\n",
" \"\"\"\n",
" Display images of people from the protected class who were predicted as the unfavorable\n",
" outcome in the original model, but predicted the favorable outcome in the debiased model \n",
" Display images along with their biased and debiased predictions\n",
" \"\"\" \n",
" prediction_index = descriptions.index(prediction_attr)\n",
" protected_index = descriptions.index(protected_attr)\n",
" ind = descriptions.index('Index')\n",
"\n",
" imgs = []\n",
" total_batches = len(testloader)\n",
" for batch_num, (inputs, labels) in enumerate(testloader):\n",
" outputs = []\n",
" total_batches = len(loader)\n",
" for batch_num, (inputs, labels) in enumerate(loader):\n",
" inputs, labels = inputs.to(device), labels.to(device)\n",
" biased_outputs = biased_net(inputs)[:, 0]\n",
" debiased_outputs = debiased_net(inputs)[:, 0]\n",
"\n",
" for i in range(len(inputs)):\n",
" img = image_from_index(labels[i][ind].item())\n",
" label = labels[i][prediction_index].item()\n",
" protected = labels[i][protected_index].item()\n",
" biased_output = biased_outputs[i].item()\n",
" debiased_output = debiased_outputs[i].item()\n",
" \n",
" if protected and debiased_output - biased_output > threshold:\n",
" index = labels[i][ind].item()\n",
" imgs.append(image_from_index(index))\n",
" biased_output = sigmoid(biased_outputs[i].item())\n",
" debiased_output = sigmoid(debiased_outputs[i].item()) \n",
"\n",
" outputs.append([img, label, protected, biased_output, debiased_output])\n",
"\n",
" if batch_num % 10 == 0:\n",
" print('found', len(imgs), 'at', batch_num, '/', total_batches)\n",
" print('At', batch_num, '/', total_batches)\n",
"\n",
" return imgs"
" return outputs"
]
},
{
Expand All @@ -151,150 +134,28 @@
"source": [
"# load the test set\n",
"_, _, _, _, _, testloader = load_celeba(trainsize=0, \n",
" testsize=1000, \n",
" testsize=100, \n",
" num_workers=0, \n",
" batch_size=32,\n",
" transform_type='tensor')\n",
"\n",
"biased_model_path = 'models/by_random_checkpoint.pt'\n",
"debiased_model_path = 'models/by_checkpoint.pt'\n",
"\n",
"# load the biased and unbiased models\n",
"biased_net = get_resnet_model()\n",
"biased_net.load_state_dict(torch.load('models/bs_random_checkpoint.pt'))\n",
"biased_net.load_state_dict(torch.load(biased_model_path, map_location=device))\n",
"\n",
"debiased_net = get_resnet_model()\n",
"debiased_net.load_state_dict(torch.load('models/bs_checkpoint.pt')['model_state_dict'])\n",
"debiased_net.load_state_dict(torch.load(debiased_model_path, map_location=device)['model_state_dict'])\n",
"\n",
"# output images which were debiased\n",
"imgs = output_debiased_imgs(biased_net=biased_net,\n",
" debiased_net=debiased_net,\n",
" testloader=testloader,\n",
" protected_attr = 'Black',\n",
" prediction_attr = 'Smiling',\n",
" threshold = .5)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"imshow_group(imgs[:7], 7)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for i in range(int(len(imgs)/8)):\n",
" imshow_group(imgs[8*i:8*(i+1)], 8)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Older data exploration functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def load(n=100, folder='~/post_hoc_debiasing/data/celeba/img_align_celeba/'):\n",
" # convert the folder of images into a numpy array\n",
" \n",
" data = []\n",
" num = 0\n",
" for i in range(1,n+1):\n",
" file = str(i).zfill(6)+'.jpg'\n",
" img = Image.open(join(os.path.expanduser(folder), file))\n",
" img = np.array(img)\n",
" data.append(img)\n",
"\n",
" data = np.array(data)\n",
" return data\n",
"\n",
"\n",
"\n",
"def load_race(filepath='~/post_hoc_debiasing/celebrace/'):\n",
" races = []\n",
" #for i,file in enumerate(['black_100k.npy', 'asian_100k.npy', 'white_100k.npy']):\n",
" for i,file in enumerate(['black_full.npy', 'asian_full.npy', 'white_full.npy']):\n",
" races.append(np.load(os.path.expanduser(os.path.join(filepath, file))))\n",
" return races\n",
"\n",
"def load_attrs(file='~/post_hoc_debiasing/data/celeba/list_attr_celeba.txt', max_n=-1):\n",
" # parse the features\n",
" f = open(os.path.expanduser(file), \"r\")\n",
" attrs = []\n",
" descriptions = []\n",
" num_attrs = 0\n",
" n = 0\n",
" for index,line in enumerate(f):\n",
" \n",
" #the first row is the header\n",
" if index == 0:\n",
" n = line\n",
" elif index == 1:\n",
" descriptions = [*line.split()]\n",
" num_attrs = len(line.split())\n",
" elif index == max_n:\n",
" break\n",
" else:\n",
" attr = [int(num) for i, num in enumerate(line.split()) if i>0]\n",
" attrs.append(attr)\n",
" \n",
" attrs = np.array(attrs)\n",
" print(attrs.shape)\n",
" return attrs, descriptions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# load all the data\n",
"data = load(n=20000) # 202599\n",
"print(data.shape)\n",
"attrs, descriptions = load_attrs()\n",
"races = load_race()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"# check the attributes are correct\n",
"print(descriptions)\n",
"for i in range(3):\n",
" plt.imshow(data[i])\n",
" plt.show()\n",
" for attr in ['Male', 'Attractive', 'Smiling', 'Pale_Skin']:\n",
" print(attr, attrs[i][descriptions.index(attr)])\n",
" print('black', races[0][i])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# check features\n",
"print(descriptions)\n",
"attr = 'Goatee'\n",
"inds = [i for i in range(1000) if attrs[i][descriptions.index(attr)]==1]\n",
"imshow_group([data[i] for i in inds[8:16]], 8)"
"# output images\n",
"outputs = output_debiased_imgs(biased_net=biased_net,\n",
" debiased_net=debiased_net,\n",
" loader=testloader,\n",
" protected_attr = 'Black',\n",
" prediction_attr = 'Smiling')\n",
"imgs = [output[0] for output in outputs]"
]
},
{
Expand All @@ -303,81 +164,9 @@
"metadata": {},
"outputs": [],
"source": [
"# check races\n",
"for race in range(1):\n",
" inds = [i for i in range(20000) if races[race][i]>.6]\n",
" print(len(inds))\n",
" k = 0\n",
" print(inds[8*k:8*(k+1)])\n",
" imshow_group([data[i] for i in inds[8*k:8*(k+1)]], 8)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": false
},
"outputs": [],
"source": [
"# check races\n",
"for race in [0,2]:\n",
" print('Attractive')\n",
" inds = [i for i in range(20000) if races[race][i]>.8 and attrs[i][descriptions.index('Attractive')]==1]\n",
" imshow_group([data[i] for i in inds[0:8]], 8)\n",
" plt.show()\n",
" print('Unattractive')\n",
" inds = [i for i in range(10000) if races[race][i]>.8 and attrs[i][descriptions.index('Attractive')]==-1]\n",
" imshow_group([data[i] for i in inds[0:8]], 8)\n",
" plt.show()\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"inds = [i for i in range(1000) if (attrs[i][descriptions.index('Male')]==1 and attrs[i][descriptions.index('Attractive')]==1)]\n",
"imshow_group([data[i] for i in inds[0:8]], 8)\n",
"inds = [i for i in range(1000) if (attrs[i][descriptions.index('Male')]==-1) and (attrs[i][descriptions.index('Attractive')]==1)]\n",
"imshow_group([data[i] for i in inds[0:8]], 8)\n",
"inds = [i for i in range(1000) if (attrs[i][descriptions.index('Male')]==1 and attrs[i][descriptions.index('Attractive')]==-1)]\n",
"imshow_group([data[i] for i in inds[0:8]], 8)\n",
"inds = [i for i in range(1000) if (attrs[i][descriptions.index('Male')]==-1) and (attrs[i][descriptions.index('Attractive')]==-1)]\n",
"imshow_group([data[i] for i in inds[0:8]], 8)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"inds = [i for i in range(1000) if (attrs[i][descriptions.index('Attractive')]==1)]\n",
"imshow_group([data[i] for i in inds[0:8]], 8)\n",
"imshow_group([data[i] for i in inds[8:16]], 8)\n",
"imshow_group([data[i] for i in inds[16:24]], 8)\n",
"inds = [i for i in range(1000) if (attrs[i][descriptions.index('Attractive')]==-1)]\n",
"imshow_group([data[i] for i in inds[0:8]], 8)\n",
"imshow_group([data[i] for i in inds[8:16]], 8)\n",
"imshow_group([data[i] for i in inds[16:24]], 8)\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# get statistics for races\n",
"counts = [0,0,0]\n",
"for i in range(len(races[0])):\n",
" for r in range(3):\n",
" counts[r] += (races[r][i] > .501)\n",
"\n",
"counts = [c / len(races[0]) for c in counts]\n",
"print(counts)"
"rowsize = 8\n",
"for i in range(min(len(imgs)//rowsize, 5)):\n",
" imshow_group(imgs[rowsize*i:rowsize*(i+1)], rowsize)"
]
}
],
Expand Down

0 comments on commit cbccb3a

Please sign in to comment.