Skip to content
This repository has been archived by the owner on Jul 23, 2022. It is now read-only.

Commit

Permalink
Support observer for DeepRacerEnv.
Browse files Browse the repository at this point in the history
  • Loading branch information
Chris La committed Mar 3, 2022
1 parent bc3b097 commit 6d54d30
Show file tree
Hide file tree
Showing 4 changed files with 166 additions and 9 deletions.
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.1.2
0.1.3
2 changes: 1 addition & 1 deletion deepracer_env/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# limitations under the License. #
#################################################################################
"""DeepRacerEnv modules"""
from .deepracer_env import DeepRacerEnv
from .deepracer_env import DeepRacerEnv, DeepRacerEnvObserverInterface

"""DeepRacer Environment Config modules"""
from deepracer_env_config import TrackDirection
Expand Down
82 changes: 79 additions & 3 deletions deepracer_env/deepracer_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
"""A class for DeepRacerEnv environment."""
from typing import Dict, Optional, List, Tuple, Any, FrozenSet, Union
import math
from threading import RLock

from gym import Space

Expand All @@ -32,6 +33,43 @@
)


class DeepRacerEnvObserverInterface(object):
"""
DeepRacerEnv Observer Interface
"""
def on_step(self, env: 'DeepRacerEnv', step_result: UDEStepResult) -> None:
"""
On Step callback.
- Called after step completed.
Args:
env (DeepRacerEnv): DeepRacer environment.
step_result (UDEStepResult): step result (obs, reward, done, last action, info)
"""
pass

def on_reset(self, env: 'DeepRacerEnv', reset_result: UDEResetResult) -> None:
"""
On Reset callback.
- Called after reset completed.
Args:
env (DeepRacerEnv): DeepRacer environment.
reset_result (UDEResetResult): reset result (obs, info)
"""
pass

def on_close(self, env: 'DeepRacerEnv') -> None:
"""
On Close callback.
- Called after close completed.
Args:
env (DeepRacerEnv): DeepRacer environment.
"""
pass


class DeepRacerEnv(UDEEnvironmentInterface):
"""
DeepRacerEnv Class.
Expand Down Expand Up @@ -76,6 +114,28 @@ def __init__(self,
area_config = self._deepracer_config.get_area()
self._track_names = area_config.track_names
self._shell_names = area_config.shell_names
self._observer_lock = RLock()
self._observers = set()

def register(self, observer: DeepRacerEnvObserverInterface) -> None:
"""
Register given observer.
Args:
observer (DeepRacerEnvObserverInterface): observer
"""
with self._observer_lock:
self._observers.add(observer)

def unregister(self, observer: DeepRacerEnvObserverInterface) -> None:
"""
Unregister given observer.
Args:
observer (DeepRacerEnvObserverInterface): observer to discard
"""
with self._observer_lock:
self._observers.discard(observer)

def step(self, action_dict: MultiAgentDict) -> UDEStepResult:
"""
Expand All @@ -101,7 +161,13 @@ def step(self, action_dict: MultiAgentDict) -> UDEStepResult:
math.isnan(speed) or math.isinf(speed):
raise ValueError("Agent's action value cannot contain nan or inf: {{}: {}}".format(agent_id,
action))
return self._env.step(action_dict=action_dict)
step_result = self._env.step(action_dict=action_dict)

with self._observer_lock:
observers = self._observers.copy()
for observer in observers:
observer.on_step(env=self, step_result=step_result)
return step_result

def reset(self) -> UDEResetResult:
"""
Expand All @@ -111,13 +177,23 @@ def reset(self) -> UDEResetResult:
Returns:
UDEResetResult: first observation and info in new episode.
"""
return self._env.reset()
reset_result = self._env.reset()

with self._observer_lock:
observers = self._observers.copy()
for observer in observers:
observer.on_reset(env=self, reset_result=reset_result)
return reset_result

def close(self) -> None:
"""
Close the environment, and environment will be no longer available to be used.
"""
return self._env.close()
self._env.close()
with self._observer_lock:
observers = self._observers.copy()
for observer in observers:
observer.on_close(env=self)

@property
def observation_space(self) -> Dict[AgentID, Space]:
Expand Down
89 changes: 85 additions & 4 deletions test/deepracer_env/test_deepracer_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,26 @@

import math

from deepracer_env import DeepRacerEnv
from ude import Compression

from deepracer_env import DeepRacerEnv, DeepRacerEnvObserverInterface
from ude import Compression, UDEResetResult, UDEStepResult

myself: Callable[[], Any] = lambda: inspect.stack()[1][3]


class DummyObserver(DeepRacerEnvObserverInterface):
def __init__(self):
self.mock = MagicMock()

def on_step(self, env: 'DeepRacerEnv', step_result: UDEStepResult) -> None:
self.mock.on_step(env=env, step_result=step_result)

def on_reset(self, env: 'DeepRacerEnv', reset_result: UDEResetResult) -> None:
self.mock.on_reset(env=env, reset_result=reset_result)

def on_close(self, env: 'DeepRacerEnv') -> None:
self.mock.on_close(env=env)


@patch("deepracer_env.deepracer_env.Client")
@patch("deepracer_env.deepracer_env.UDEEnvironment")
@patch("deepracer_env.deepracer_env.RemoteEnvironmentAdapter")
Expand Down Expand Up @@ -94,6 +107,31 @@ def test_initialize_with_param(self,
assert env.track_names == deepracer_config_mock.return_value.get_area.return_value.track_names
assert env.shell_names == deepracer_config_mock.return_value.get_area.return_value.shell_names

def test_register(self,
remote_env_adapter_mock,
ude_env_mock,
deepracer_config_mock):
address = "test_ip"
env = DeepRacerEnv(address=address)
observer_mock = DummyObserver()
env.register(observer=observer_mock)

assert observer_mock in env._observers

def test_unregister(self,
remote_env_adapter_mock,
ude_env_mock,
deepracer_config_mock):
address = "test_ip"
env = DeepRacerEnv(address=address)
observer_mock = DummyObserver()
env.register(observer=observer_mock)

assert observer_mock in env._observers

env.unregister(observer=observer_mock)
assert observer_mock not in env._observers

def test_step(self,
remote_env_adapter_mock,
ude_env_mock,
Expand Down Expand Up @@ -136,7 +174,6 @@ def test_step_ignore_more_than_two_value(self,
ude_env_mock.return_value.step.assert_called_once_with(action_dict=expected_action_dict)
assert step_result == ude_env_mock.return_value.step.return_value


def test_step_nan_or_inf(self,
remote_env_adapter_mock,
ude_env_mock,
Expand Down Expand Up @@ -168,6 +205,23 @@ def test_step_nan_or_inf(self,
with self.assertRaises(ValueError):
_ = env.step(action_dict=action_dict)

def test_step_with_observer(self,
remote_env_adapter_mock,
ude_env_mock,
deepracer_config_mock):
address = "test_ip"

env = DeepRacerEnv(address=address)
observer_mock = DummyObserver()
env.register(observer=observer_mock)

action_dict = {"agent1": (1.0, 2.0)}

step_result = env.step(action_dict=action_dict)
ude_env_mock.return_value.step.assert_called_once_with(action_dict=action_dict)
assert step_result == ude_env_mock.return_value.step.return_value
observer_mock.mock.on_step.assert_called_once_with(env=env,
step_result=step_result)

def test_reset(self,
remote_env_adapter_mock,
Expand All @@ -179,6 +233,21 @@ def test_reset(self,
ude_env_mock.return_value.reset.assert_called_once()
assert reset_result == ude_env_mock.return_value.reset.return_value

def test_reset_with_observer(self,
remote_env_adapter_mock,
ude_env_mock,
deepracer_config_mock):
address = "test_ip"
env = DeepRacerEnv(address=address)
observer_mock = DummyObserver()
env.register(observer=observer_mock)

reset_result = env.reset()
ude_env_mock.return_value.reset.assert_called_once()
assert reset_result == ude_env_mock.return_value.reset.return_value
observer_mock.mock.on_reset.assert_called_once_with(env=env,
reset_result=reset_result)

def test_close(self,
remote_env_adapter_mock,
ude_env_mock,
Expand All @@ -188,6 +257,18 @@ def test_close(self,
env.close()
ude_env_mock.return_value.close.assert_called_once()

def test_close_with_observer(self,
remote_env_adapter_mock,
ude_env_mock,
deepracer_config_mock):
address = "test_ip"
env = DeepRacerEnv(address=address)
observer_mock = DummyObserver()
env.register(observer=observer_mock)
env.close()
ude_env_mock.return_value.close.assert_called_once()
observer_mock.mock.on_close.assert_called_once_with(env=env)

def test_observation_space(self,
remote_env_adapter_mock,
ude_env_mock,
Expand Down

0 comments on commit 6d54d30

Please sign in to comment.