JRL PPO test with delayed identity env (#355)

* add a custom delay to identity_env

* min reward 0.8 in delayed identity test

* seed the tests, perfect score on delayed_identity_test

* delay=1 in delayed_identity_test

* flake8 complaints

* increased number of steps in fixed_seq_test

* seed identity tests to ensure reproducibility

* docstrings
This commit is contained in:
pzhokhov
2019-04-24 17:04:36 -07:00
committed by Peter Zhokhov
parent 07536451ee
commit 1fa6ac38f1
2 changed files with 66 additions and 25 deletions

View File

@@ -2,43 +2,45 @@ import numpy as np
from abc import abstractmethod from abc import abstractmethod
from gym import Env from gym import Env
from gym.spaces import MultiDiscrete, Discrete, Box from gym.spaces import MultiDiscrete, Discrete, Box
from collections import deque
class IdentityEnv(Env): class IdentityEnv(Env):
def __init__( def __init__(
self, self,
episode_len=None episode_len=None,
delay=0,
zero_first_rewards=True
): ):
self.observation_space = self.action_space self.observation_space = self.action_space
self.episode_len = episode_len self.episode_len = episode_len
self.time = 0 self.time = 0
self.reset() self.delay = delay
self.zero_first_rewards = zero_first_rewards
self.q = deque(maxlen=delay+1)
def reset(self): def reset(self):
self._choose_next_state() self.q.clear()
for _ in range(self.delay + 1):
self.q.append(self.action_space.sample())
self.time = 0 self.time = 0
return self.state return self.q[-1]
def step(self, actions): def step(self, actions):
rew = self._get_reward(actions) rew = self._get_reward(self.q.popleft(), actions)
self._choose_next_state() if self.zero_first_rewards and self.time < self.delay:
done = False rew = 0
if self.episode_len and self.time >= self.episode_len: self.q.append(self.action_space.sample())
done = True self.time += 1
done = self.episode_len is not None and self.time >= self.episode_len
return self.state, rew, done, {} return self.q[-1], rew, done, {}
def seed(self, seed=None): def seed(self, seed=None):
self.action_space.seed(seed) self.action_space.seed(seed)
def _choose_next_state(self):
self.state = self.action_space.sample()
self.time += 1
@abstractmethod @abstractmethod
def _get_reward(self, actions): def _get_reward(self, state, actions):
raise NotImplementedError raise NotImplementedError
@@ -47,26 +49,29 @@ class DiscreteIdentityEnv(IdentityEnv):
self, self,
dim, dim,
episode_len=None, episode_len=None,
delay=0,
zero_first_rewards=True
): ):
self.action_space = Discrete(dim) self.action_space = Discrete(dim)
super().__init__(episode_len=episode_len) super().__init__(episode_len=episode_len, delay=delay, zero_first_rewards=zero_first_rewards)
def _get_reward(self, actions): def _get_reward(self, state, actions):
return 1 if self.state == actions else 0 return 1 if state == actions else 0
class MultiDiscreteIdentityEnv(IdentityEnv): class MultiDiscreteIdentityEnv(IdentityEnv):
def __init__( def __init__(
self, self,
dims, dims,
episode_len=None, episode_len=None,
delay=0,
): ):
self.action_space = MultiDiscrete(dims) self.action_space = MultiDiscrete(dims)
super().__init__(episode_len=episode_len) super().__init__(episode_len=episode_len, delay=delay)
def _get_reward(self, actions): def _get_reward(self, state, actions):
return 1 if all(self.state == actions) else 0 return 1 if all(state == actions) else 0
class BoxIdentityEnv(IdentityEnv): class BoxIdentityEnv(IdentityEnv):
@@ -79,7 +84,7 @@ class BoxIdentityEnv(IdentityEnv):
self.action_space = Box(low=-1.0, high=1.0, shape=shape, dtype=np.float32) self.action_space = Box(low=-1.0, high=1.0, shape=shape, dtype=np.float32)
super().__init__(episode_len=episode_len) super().__init__(episode_len=episode_len)
def _get_reward(self, actions): def _get_reward(self, state, actions):
diff = actions - self.state diff = actions - state
diff = diff[:] diff = diff[:]
return -0.5 * np.dot(diff, diff) return -0.5 * np.dot(diff, diff)

View File

@@ -0,0 +1,36 @@
from baselines.common.tests.envs.identity_env import DiscreteIdentityEnv
def test_discrete_nodelay():
nsteps = 100
eplen = 50
env = DiscreteIdentityEnv(10, episode_len=eplen)
ob = env.reset()
for t in range(nsteps):
action = env.action_space.sample()
next_ob, rew, done, info = env.step(action)
assert rew == (1 if action == ob else 0)
if (t + 1) % eplen == 0:
assert done
next_ob = env.reset()
else:
assert not done
ob = next_ob
def test_discrete_delay1():
eplen = 50
env = DiscreteIdentityEnv(10, episode_len=eplen, delay=1)
ob = env.reset()
prev_ob = None
for t in range(eplen):
action = env.action_space.sample()
next_ob, rew, done, info = env.step(action)
if t > 0:
assert rew == (1 if action == prev_ob else 0)
else:
assert rew == 0
prev_ob = ob
ob = next_ob
if t < eplen - 1:
assert not done
assert done