-
Notifications
You must be signed in to change notification settings - Fork 0
/
dqn_pole.py
151 lines (134 loc) · 5.36 KB
/
dqn_pole.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
# CartPole solved using DQN by following RL tutorial on pytorch official website
import gymnasium as gym
import math
import random
import matplotlib
import matplotlib.pyplot as plt
from collections import namedtuple, deque
from itertools import count
import time
from pprint import pprint
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
env = gym.make("CartPole-v1", render_mode="human")
BATCH_SIZE = 128 # number of transition extracted from replay buffer
GAMMA = 0.99 # Discount factor
EPS_START = 0.9 # Starting value from epsilon
EPS_END = 0.05 # End value of epsilon
EPS_DECAY = 1000 # decay rate of epsilon, higher mean slower
TAU = 0.005 # update rate
LR = 1e-4 # learning rate
SHOW_EVERY = 100
class ReplayMemory(object):
def __init__(self, capacity):
self.memory = deque([], maxlen=capacity)
def push(self, *args):
trans = Transition(*args)
self.memory.append(trans)
def sample(self, batch_size):
return random.sample(self.memory, batch_size)
def __len__(self):
return len(self.memory)
class DQN(nn.Module):
def __init__(self, n_observations, n_actions):
super(DQN, self).__init__()
self.layer1 = nn.Linear(n_observations, 128)
self.layer2 = nn.Linear(128, 128)
self.layer3 = nn.Linear(128, n_actions)
def forward(self, x):
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
return self.layer3(x)
def select_action(state, policy_net):
global steps_done
sample = random.random()
eps_threshold = EPS_END + (EPS_START - EPS_END) * math.exp(-1. * steps_done / EPS_DECAY)
steps_done += 1
if sample > eps_threshold:
with torch.no_grad():
# t.max(1) will return the largest column value of each row.
# second column on max result is index of where max element was
# found, so we pick action with the larger expected reward.
return policy_net(state).max(1)[1].view(1, 1)
else:
return torch.tensor([[env.action_space.sample()]], dtype=torch.long)
def optimize_model():
if len(memory) < BATCH_SIZE:
return
transitions = memory.sample(BATCH_SIZE)
# converts batch-array of Transitions to inverse
batch = Transition(*zip(*transitions))
# Compute a mask of non-final states and concatenate the batch elements
non_final_mask = torch.tensor(tuple(map(lambda s: s is not None, batch.next_state)), dtype=torch.bool)
non_final_next_states = torch.cat([s for s in batch.next_state if s is not None])
state_batch = torch.cat(batch.state)
action_batch = torch.cat(batch.action)
reward_batch = torch.cat(batch.reward)
# Compute Q(s_t, a) the model computes Q(s_t), then we select the columns of actions taken
state_action_values = policy_net(state_batch).gather(1, action_batch)
next_state_values = torch.zeros(BATCH_SIZE)
with torch.no_grad():
next_state_values[non_final_mask] = target_net(non_final_next_states).max(1)[0]
expected_state_action_values = (next_state_values * GAMMA) + reward_batch
# Compute Huber loss
criterion = nn.SmoothL1Loss()
loss = criterion(state_action_values, expected_state_action_values.unsqueeze(1))
# optimize
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_value_(policy_net.parameters(), 100)
optimizer.step()
n_actions = env.action_space.n # get nb of actions
state, info = env.reset() # get nb of state obs
n_observations = len(state)
Transition = namedtuple('Transition', ('state', 'action', 'next_state', 'reward'))
policy_net = DQN(n_observations, n_actions)
target_net = DQN(n_observations, n_actions)
target_net.load_state_dict(policy_net.state_dict())
optimizer = optim.AdamW(policy_net.parameters(), lr=LR, amsgrad=True)
memory = ReplayMemory(10000)
steps_done = 0
episode_durations = []
num_episodes = 600
score = 0
for episode in range(num_episodes):
display = False
# Initialize the environment and get it's state
state, info = env.reset()
state = torch.tensor(state, dtype=torch.float32).unsqueeze(0)
if episode % SHOW_EVERY == 0:
display = True
print("Got score : {} at episode {}".format(score, episode))
score = 0
t = 0
while True:
if display == True:
env.render()
display = False
action = select_action(state, policy_net)
observation, reward, terminated, truncated, _ = env.step(action.item())
reward = torch.tensor([reward])
score += reward
done = terminated
if terminated:
next_state = None
else:
next_state = torch.tensor(observation, dtype=torch.float32).unsqueeze(0)
memory.push(state, action, next_state, reward)
state = next_state
optimize_model()
# Soft update of the target network's weights θ′ ← τ θ + (1 −τ )θ′
target_net_state_dict = target_net.state_dict()
policy_net_state_dict = policy_net.state_dict()
for key in policy_net_state_dict:
target_net_state_dict[key] = policy_net_state_dict[key] * TAU + target_net_state_dict[key] * (1-TAU)
target_net.load_state_dict(target_net_state_dict)
t += 1
if done:
episode_durations.append(t + 1)
print(f"episode : {episode} over")
time.sleep(0.1)
env.reset()
break