JRL PPO test with delayed identity env (#355)
* add a custom delay to identity_env * min reward 0.8 in delayed identity test * seed the tests, perfect score on delayed_identity_test * delay=1 in delayed_identity_test * flake8 complaints * increased number of steps in fixed_seq_test * seed identity tests to ensure reproducibility * docstrings
This commit is contained in:
@@ -2,43 +2,45 @@ import numpy as np
|
||||
from abc import abstractmethod
|
||||
from gym import Env
|
||||
from gym.spaces import MultiDiscrete, Discrete, Box
|
||||
|
||||
from collections import deque
|
||||
|
||||
class IdentityEnv(Env):
|
||||
def __init__(
|
||||
self,
|
||||
episode_len=None
|
||||
episode_len=None,
|
||||
delay=0,
|
||||
zero_first_rewards=True
|
||||
):
|
||||
|
||||
self.observation_space = self.action_space
|
||||
self.episode_len = episode_len
|
||||
self.time = 0
|
||||
self.reset()
|
||||
self.delay = delay
|
||||
self.zero_first_rewards = zero_first_rewards
|
||||
self.q = deque(maxlen=delay+1)
|
||||
|
||||
def reset(self):
|
||||
self._choose_next_state()
|
||||
self.q.clear()
|
||||
for _ in range(self.delay + 1):
|
||||
self.q.append(self.action_space.sample())
|
||||
self.time = 0
|
||||
|
||||
return self.state
|
||||
return self.q[-1]
|
||||
|
||||
def step(self, actions):
|
||||
rew = self._get_reward(actions)
|
||||
self._choose_next_state()
|
||||
done = False
|
||||
if self.episode_len and self.time >= self.episode_len:
|
||||
done = True
|
||||
|
||||
return self.state, rew, done, {}
|
||||
rew = self._get_reward(self.q.popleft(), actions)
|
||||
if self.zero_first_rewards and self.time < self.delay:
|
||||
rew = 0
|
||||
self.q.append(self.action_space.sample())
|
||||
self.time += 1
|
||||
done = self.episode_len is not None and self.time >= self.episode_len
|
||||
return self.q[-1], rew, done, {}
|
||||
|
||||
def seed(self, seed=None):
|
||||
self.action_space.seed(seed)
|
||||
|
||||
def _choose_next_state(self):
|
||||
self.state = self.action_space.sample()
|
||||
self.time += 1
|
||||
|
||||
@abstractmethod
|
||||
def _get_reward(self, actions):
|
||||
def _get_reward(self, state, actions):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -47,26 +49,29 @@ class DiscreteIdentityEnv(IdentityEnv):
|
||||
self,
|
||||
dim,
|
||||
episode_len=None,
|
||||
delay=0,
|
||||
zero_first_rewards=True
|
||||
):
|
||||
|
||||
self.action_space = Discrete(dim)
|
||||
super().__init__(episode_len=episode_len)
|
||||
super().__init__(episode_len=episode_len, delay=delay, zero_first_rewards=zero_first_rewards)
|
||||
|
||||
def _get_reward(self, actions):
|
||||
return 1 if self.state == actions else 0
|
||||
def _get_reward(self, state, actions):
|
||||
return 1 if state == actions else 0
|
||||
|
||||
class MultiDiscreteIdentityEnv(IdentityEnv):
|
||||
def __init__(
|
||||
self,
|
||||
dims,
|
||||
episode_len=None,
|
||||
delay=0,
|
||||
):
|
||||
|
||||
self.action_space = MultiDiscrete(dims)
|
||||
super().__init__(episode_len=episode_len)
|
||||
super().__init__(episode_len=episode_len, delay=delay)
|
||||
|
||||
def _get_reward(self, actions):
|
||||
return 1 if all(self.state == actions) else 0
|
||||
def _get_reward(self, state, actions):
|
||||
return 1 if all(state == actions) else 0
|
||||
|
||||
|
||||
class BoxIdentityEnv(IdentityEnv):
|
||||
@@ -79,7 +84,7 @@ class BoxIdentityEnv(IdentityEnv):
|
||||
self.action_space = Box(low=-1.0, high=1.0, shape=shape, dtype=np.float32)
|
||||
super().__init__(episode_len=episode_len)
|
||||
|
||||
def _get_reward(self, actions):
|
||||
diff = actions - self.state
|
||||
def _get_reward(self, state, actions):
|
||||
diff = actions - state
|
||||
diff = diff[:]
|
||||
return -0.5 * np.dot(diff, diff)
|
||||
|
36
baselines/common/tests/envs/identity_env_test.py
Normal file
36
baselines/common/tests/envs/identity_env_test.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from baselines.common.tests.envs.identity_env import DiscreteIdentityEnv
|
||||
|
||||
|
||||
def test_discrete_nodelay():
|
||||
nsteps = 100
|
||||
eplen = 50
|
||||
env = DiscreteIdentityEnv(10, episode_len=eplen)
|
||||
ob = env.reset()
|
||||
for t in range(nsteps):
|
||||
action = env.action_space.sample()
|
||||
next_ob, rew, done, info = env.step(action)
|
||||
assert rew == (1 if action == ob else 0)
|
||||
if (t + 1) % eplen == 0:
|
||||
assert done
|
||||
next_ob = env.reset()
|
||||
else:
|
||||
assert not done
|
||||
ob = next_ob
|
||||
|
||||
def test_discrete_delay1():
|
||||
eplen = 50
|
||||
env = DiscreteIdentityEnv(10, episode_len=eplen, delay=1)
|
||||
ob = env.reset()
|
||||
prev_ob = None
|
||||
for t in range(eplen):
|
||||
action = env.action_space.sample()
|
||||
next_ob, rew, done, info = env.step(action)
|
||||
if t > 0:
|
||||
assert rew == (1 if action == prev_ob else 0)
|
||||
else:
|
||||
assert rew == 0
|
||||
prev_ob = ob
|
||||
ob = next_ob
|
||||
if t < eplen - 1:
|
||||
assert not done
|
||||
assert done
|
Reference in New Issue
Block a user