Skip to content

Commit

Permalink
Add CartPole environment (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
mthrok authored Dec 31, 2016
1 parent 3f205f2 commit 34899c3
Show file tree
Hide file tree
Showing 11 changed files with 585 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ exclude-protected=_asdict,_fields,_replace,_source,_make
[DESIGN]

# Maximum number of arguments for function / method
max-args=5
max-args=10

# Argument names that match this expression will be ignored. Default to name
# with leading underscore
Expand Down
2 changes: 2 additions & 0 deletions example/CartPole_agent.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
name: CartPoleAgent
args: {}
7 changes: 7 additions & 0 deletions example/CartPole_env.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
name: CartPole
args:
angle_limit: 12 # Degree
distance_limit: 2.4 # meter
pole_length: 0.5
cart_mass: 1.0
display_screen: True
1 change: 1 addition & 0 deletions luchador/agent/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@

from .base import * # noqa: F401, F403
from .dqn import * # noqa: F401, F403
from .cart_pole import * # noqa: F401, F403
113 changes: 113 additions & 0 deletions luchador/agent/cart_pole.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
from __future__ import division
from __future__ import absolute_import

import numpy as np

from . base import BaseAgent

_1_DEG = 0.0174532
_6_DEG = 6 * _1_DEG
_12_DEG = 12 * _1_DEG
_15_DEG = 15 * _1_DEG

N_BOX = 162


def _get_box(x, x_dot, theta, theta_dot):
if abs(x) > 2.4 or abs(theta) > _12_DEG:
return -1

box = 0
if x < -0.8:
pass
elif x < 0.8:
box = 1
else:
box = 2

if x_dot < -0.5:
pass
elif x_dot < 0.5:
box += 3
else:
box += 6

if theta < -_6_DEG:
pass
elif theta < -_1_DEG:
box += 9
elif theta < 0:
box += 18
elif theta < _1_DEG:
box += 27
elif theta < _6_DEG:
box += 36
else:
box += 45

if theta_dot < -_15_DEG:
pass
elif theta_dot < _15_DEG:
box += 54
else:
box += 108

return box


def _truncated_sigmoid(s):
return 1. / (1. + np.exp(-max(-50., min(s, 50.))))


class CartPoleAgent(BaseAgent):
def __init__(self,
action_lr=1000,
critic_lr=0.5,
critic_discount=0.95,
action_decay=0.9,
critic_decay=0.8):

self.action_lr = action_lr
self.critic_lr = critic_lr
self.critic_discount = critic_discount
self.action_decay = action_decay
self.critic_decay = critic_decay

self.action_weight = np.zeros((N_BOX,))
self.critic_weight = np.zeros((N_BOX,))
self.action_eligibility = np.zeros((N_BOX,))
self.critic_eligibility = np.zeros((N_BOX,))

self.box = 0

def init(self, env):
pass

def reset(self, observation):
self.box = _get_box(**observation)
self.action_eligibility = np.zeros((N_BOX,))
self.critic_eligibility = np.zeros((N_BOX,))

def observe(self, action, outcome):
p_prev = self.critic_weight[self.box]
self.box = _get_box(**outcome.observation)
p_current = 0.0 if outcome.terminal else self.critic_weight[self.box]

r_hat = outcome.reward + self.critic_discount * p_current - p_prev

self.action_weight += self.action_lr * r_hat * self.action_eligibility
self.critic_weight += self.critic_lr * r_hat * self.critic_eligibility

self.action_eligibility *= self.action_decay
self.critic_eligibility *= self.critic_decay

def act(self, observation):
prob = _truncated_sigmoid(self.action_weight[self.box])
action = int(np.random.rand() < prob)
update = action - 0.5
self.action_eligibility[self.box] += (1.0 - self.action_decay) * update
self.critic_eligibility[self.box] += (1.0 - self.critic_decay)
return action

def perform_post_episode_task(self, stats):
pass
10 changes: 6 additions & 4 deletions luchador/command/exercise.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,23 +24,25 @@ def _main(env, agent, episodes, steps, report_every=1000):
runner = EpisodeRunner(env, agent, max_steps=steps)

_LG.info('Running %s episodes', episodes)
n_ep, time_, steps_, rewards = 0, 0, 0, 0.0
n_ep, time_, steps_, rewards_ = 0, 0, 0, 0.0
for i in range(1, episodes+1):
stats = runner.run_episode()

n_ep += 1
time_ += stats['time']
steps_ += stats['steps']
rewards += stats['rewards']
rewards_ += stats['rewards']
if i % report_every == 0 or i == episodes:
_LG.info('Finished episode: %d', i)
_LG.info(' Rewards: %12.3f [/epi]', rewards / n_ep)
_LG.info(' Rewards: %12.3f', rewards_)
_LG.info(' %12.3f [/epi]', rewards_ / n_ep)
_LG.info(' %12.3f [/steps]', rewards_ / steps_)
_LG.info(' Steps: %8d', steps_)
_LG.info(' %12.3f [/epi]', steps_ / n_ep)
_LG.info(' %12.3f [/sec]', steps_ / time_)
_LG.info(' Total Steps: %8d', runner.steps)
_LG.info(' Total Time: %s', _format_time(runner.time))
n_ep, time_, steps_, rewards = 0, 0, 0, 0.0
n_ep, time_, steps_, rewards_ = 0, 0, 0, 0.
_LG.info('Done')


Expand Down
1 change: 1 addition & 0 deletions luchador/env/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def step(self, action):
_ENVIRONMENT_MODULE_MAPPING = {
'ALEEnvironment': 'ale',
'FlappyBird': 'flappy_bird',
'CartPole': 'cart_pole',
}


Expand Down
3 changes: 3 additions & 0 deletions luchador/env/cart_pole/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from __future__ import absolute_import

from .cart_pole import CartPole # noqa: F401
Loading

0 comments on commit 34899c3

Please sign in to comment.