-
Notifications
You must be signed in to change notification settings - Fork 9
Expand file tree
/
Copy pathtwo_layer_perceptron_pytorch.py
More file actions
60 lines (53 loc) · 1.89 KB
/
two_layer_perceptron_pytorch.py
File metadata and controls
60 lines (53 loc) · 1.89 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import torch
from torch.autograd import Variable
import torch.nn as nn
import numpy as np
def two_layer_perceptron(
number_of_features, number_of_hidden, number_of_classes):
network = nn.Sequential(
nn.Linear(number_of_features, number_of_hidden),
nn.ReLU(),
nn.Linear(number_of_hidden, number_of_classes),
nn.Softmax(dim=1))
return network.cuda()
def cost(points, labels, network, criterion):
points = torch.FloatTensor(points).cuda()
labels = torch.FloatTensor(labels).cuda()
predictions = network.forward(points)
loss = criterion(predictions, labels)
return loss.cpu().detach().numpy()
def initialize(number_of_features, number_of_hidden, number_of_classes):
return two_layer_perceptron(
number_of_features, number_of_hidden, number_of_classes)
def step(points, labels, network, criterion, batch_size, learning_rate):
number_of_samples = len(points)
start = 0
while start<number_of_samples:
end = min(start+batch_size, number_of_samples)
points1 = Variable(torch.FloatTensor(points[start:end]).cuda())
labels1 = Variable(torch.FloatTensor(labels[start:end]).cuda())
predictions = network.forward(points1)
loss = criterion(predictions, labels1)
loss.backward()
for W in network.parameters():
W.data = W.data-learning_rate*W.grad.data
start += batch_size
def train(points, labels, network, criterion, batch_size, learning_rate):
for i in range(1000):
step(points, labels, network, criterion, batch_size, learning_rate)
def classify(point, network):
points = torch.FloatTensor(point).unsqueeze(0).cuda()
predictions = network.forward(points)
if predictions[0][0]>predictions[0][1]:
return 0
else:
return 1
def all_labels(labels):
red = False
blue = False
for label in labels:
if label[0]>label[1]:
red = True
else:
blue = True
return red and blue