JRL PPO test with delayed identity env (#355)
* add a custom delay to identity_env * min reward 0.8 in delayed identity test * seed the tests, perfect score on delayed_identity_test * delay=1 in delayed_identity_test * flake8 complaints * increased number of steps in fixed_seq_test * seed identity tests to ensure reproducibility * docstrings
This commit is contained in:
@@ -2,43 +2,45 @@ import numpy as np
|
|||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from gym import Env
|
from gym import Env
|
||||||
from gym.spaces import MultiDiscrete, Discrete, Box
|
from gym.spaces import MultiDiscrete, Discrete, Box
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
class IdentityEnv(Env):
|
class IdentityEnv(Env):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
episode_len=None
|
episode_len=None,
|
||||||
|
delay=0,
|
||||||
|
zero_first_rewards=True
|
||||||
):
|
):
|
||||||
|
|
||||||
self.observation_space = self.action_space
|
self.observation_space = self.action_space
|
||||||
self.episode_len = episode_len
|
self.episode_len = episode_len
|
||||||
self.time = 0
|
self.time = 0
|
||||||
self.reset()
|
self.delay = delay
|
||||||
|
self.zero_first_rewards = zero_first_rewards
|
||||||
|
self.q = deque(maxlen=delay+1)
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
self._choose_next_state()
|
self.q.clear()
|
||||||
|
for _ in range(self.delay + 1):
|
||||||
|
self.q.append(self.action_space.sample())
|
||||||
self.time = 0
|
self.time = 0
|
||||||
|
|
||||||
return self.state
|
return self.q[-1]
|
||||||
|
|
||||||
def step(self, actions):
|
def step(self, actions):
|
||||||
rew = self._get_reward(actions)
|
rew = self._get_reward(self.q.popleft(), actions)
|
||||||
self._choose_next_state()
|
if self.zero_first_rewards and self.time < self.delay:
|
||||||
done = False
|
rew = 0
|
||||||
if self.episode_len and self.time >= self.episode_len:
|
self.q.append(self.action_space.sample())
|
||||||
done = True
|
self.time += 1
|
||||||
|
done = self.episode_len is not None and self.time >= self.episode_len
|
||||||
return self.state, rew, done, {}
|
return self.q[-1], rew, done, {}
|
||||||
|
|
||||||
def seed(self, seed=None):
|
def seed(self, seed=None):
|
||||||
self.action_space.seed(seed)
|
self.action_space.seed(seed)
|
||||||
|
|
||||||
def _choose_next_state(self):
|
|
||||||
self.state = self.action_space.sample()
|
|
||||||
self.time += 1
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def _get_reward(self, actions):
|
def _get_reward(self, state, actions):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@@ -47,26 +49,29 @@ class DiscreteIdentityEnv(IdentityEnv):
|
|||||||
self,
|
self,
|
||||||
dim,
|
dim,
|
||||||
episode_len=None,
|
episode_len=None,
|
||||||
|
delay=0,
|
||||||
|
zero_first_rewards=True
|
||||||
):
|
):
|
||||||
|
|
||||||
self.action_space = Discrete(dim)
|
self.action_space = Discrete(dim)
|
||||||
super().__init__(episode_len=episode_len)
|
super().__init__(episode_len=episode_len, delay=delay, zero_first_rewards=zero_first_rewards)
|
||||||
|
|
||||||
def _get_reward(self, actions):
|
def _get_reward(self, state, actions):
|
||||||
return 1 if self.state == actions else 0
|
return 1 if state == actions else 0
|
||||||
|
|
||||||
class MultiDiscreteIdentityEnv(IdentityEnv):
|
class MultiDiscreteIdentityEnv(IdentityEnv):
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
dims,
|
dims,
|
||||||
episode_len=None,
|
episode_len=None,
|
||||||
|
delay=0,
|
||||||
):
|
):
|
||||||
|
|
||||||
self.action_space = MultiDiscrete(dims)
|
self.action_space = MultiDiscrete(dims)
|
||||||
super().__init__(episode_len=episode_len)
|
super().__init__(episode_len=episode_len, delay=delay)
|
||||||
|
|
||||||
def _get_reward(self, actions):
|
def _get_reward(self, state, actions):
|
||||||
return 1 if all(self.state == actions) else 0
|
return 1 if all(state == actions) else 0
|
||||||
|
|
||||||
|
|
||||||
class BoxIdentityEnv(IdentityEnv):
|
class BoxIdentityEnv(IdentityEnv):
|
||||||
@@ -79,7 +84,7 @@ class BoxIdentityEnv(IdentityEnv):
|
|||||||
self.action_space = Box(low=-1.0, high=1.0, shape=shape, dtype=np.float32)
|
self.action_space = Box(low=-1.0, high=1.0, shape=shape, dtype=np.float32)
|
||||||
super().__init__(episode_len=episode_len)
|
super().__init__(episode_len=episode_len)
|
||||||
|
|
||||||
def _get_reward(self, actions):
|
def _get_reward(self, state, actions):
|
||||||
diff = actions - self.state
|
diff = actions - state
|
||||||
diff = diff[:]
|
diff = diff[:]
|
||||||
return -0.5 * np.dot(diff, diff)
|
return -0.5 * np.dot(diff, diff)
|
||||||
|
36
baselines/common/tests/envs/identity_env_test.py
Normal file
36
baselines/common/tests/envs/identity_env_test.py
Normal file
@@ -0,0 +1,36 @@
|
|||||||
|
from baselines.common.tests.envs.identity_env import DiscreteIdentityEnv
|
||||||
|
|
||||||
|
|
||||||
|
def test_discrete_nodelay():
|
||||||
|
nsteps = 100
|
||||||
|
eplen = 50
|
||||||
|
env = DiscreteIdentityEnv(10, episode_len=eplen)
|
||||||
|
ob = env.reset()
|
||||||
|
for t in range(nsteps):
|
||||||
|
action = env.action_space.sample()
|
||||||
|
next_ob, rew, done, info = env.step(action)
|
||||||
|
assert rew == (1 if action == ob else 0)
|
||||||
|
if (t + 1) % eplen == 0:
|
||||||
|
assert done
|
||||||
|
next_ob = env.reset()
|
||||||
|
else:
|
||||||
|
assert not done
|
||||||
|
ob = next_ob
|
||||||
|
|
||||||
|
def test_discrete_delay1():
|
||||||
|
eplen = 50
|
||||||
|
env = DiscreteIdentityEnv(10, episode_len=eplen, delay=1)
|
||||||
|
ob = env.reset()
|
||||||
|
prev_ob = None
|
||||||
|
for t in range(eplen):
|
||||||
|
action = env.action_space.sample()
|
||||||
|
next_ob, rew, done, info = env.step(action)
|
||||||
|
if t > 0:
|
||||||
|
assert rew == (1 if action == prev_ob else 0)
|
||||||
|
else:
|
||||||
|
assert rew == 0
|
||||||
|
prev_ob = ob
|
||||||
|
ob = next_ob
|
||||||
|
if t < eplen - 1:
|
||||||
|
assert not done
|
||||||
|
assert done
|
Reference in New Issue
Block a user