From cc8818f49eb2c3a9f5097d4de2e73febbfa6b994 Mon Sep 17 00:00:00 2001 From: Malcolm Karutz Date: Mon, 9 Oct 2017 13:08:41 +1100 Subject: [PATCH] Fixes frame stacking in A2C and ACKTR for multi-channel observation spaces. --- baselines/a2c/a2c.py | 5 +++-- baselines/acktr/acktr_disc.py | 5 +++-- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/baselines/a2c/a2c.py b/baselines/a2c/a2c.py index 56d5430..cfb1d7c 100644 --- a/baselines/a2c/a2c.py +++ b/baselines/a2c/a2c.py @@ -98,6 +98,7 @@ class Runner(object): nenv = env.num_envs self.batch_ob_shape = (nenv*nsteps, nh, nw, nc*nstack) self.obs = np.zeros((nenv, nh, nw, nc*nstack), dtype=np.uint8) + self.nc = nc obs = env.reset() self.update_obs(obs) self.gamma = gamma @@ -108,8 +109,8 @@ class Runner(object): def update_obs(self, obs): # Do frame-stacking here instead of the FrameStack wrapper to reduce # IPC overhead - self.obs = np.roll(self.obs, shift=-1, axis=3) - self.obs[:, :, :, -1] = obs[:, :, :, 0] + self.obs = np.roll(self.obs, shift=-self.nc, axis=3) + self.obs[:, :, :, -self.nc:] = obs def run(self): mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[] diff --git a/baselines/acktr/acktr_disc.py b/baselines/acktr/acktr_disc.py index 28c1cbf..feb702c 100644 --- a/baselines/acktr/acktr_disc.py +++ b/baselines/acktr/acktr_disc.py @@ -113,6 +113,7 @@ class Runner(object): nenv = env.num_envs self.batch_ob_shape = (nenv*nsteps, nh, nw, nc*nstack) self.obs = np.zeros((nenv, nh, nw, nc*nstack), dtype=np.uint8) + self.nc = nc obs = env.reset() self.update_obs(obs) self.gamma = gamma @@ -121,8 +122,8 @@ class Runner(object): self.dones = [False for _ in range(nenv)] def update_obs(self, obs): - self.obs = np.roll(self.obs, shift=-1, axis=3) - self.obs[:, :, :, -1] = obs[:, :, :, 0] + self.obs = np.roll(self.obs, shift=-self.nc, axis=3) + self.obs[:, :, :, -self.nc:] = obs def run(self): mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]