Fixes frame stacking in A2C and ACKTR for multi-channel observation spaces.

This commit is contained in:
Malcolm Karutz
2017-10-09 13:08:41 +11:00
parent 3eb71a0ece
commit cc8818f49e
2 changed files with 6 additions and 4 deletions

View File

@@ -98,6 +98,7 @@ class Runner(object):
nenv = env.num_envs
self.batch_ob_shape = (nenv*nsteps, nh, nw, nc*nstack)
self.obs = np.zeros((nenv, nh, nw, nc*nstack), dtype=np.uint8)
self.nc = nc
obs = env.reset()
self.update_obs(obs)
self.gamma = gamma
@@ -108,8 +109,8 @@ class Runner(object):
def update_obs(self, obs):
# Do frame-stacking here instead of the FrameStack wrapper to reduce
# IPC overhead
self.obs = np.roll(self.obs, shift=-1, axis=3)
self.obs[:, :, :, -1] = obs[:, :, :, 0]
self.obs = np.roll(self.obs, shift=-self.nc, axis=3)
self.obs[:, :, :, -self.nc:] = obs
def run(self):
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]

View File

@@ -113,6 +113,7 @@ class Runner(object):
nenv = env.num_envs
self.batch_ob_shape = (nenv*nsteps, nh, nw, nc*nstack)
self.obs = np.zeros((nenv, nh, nw, nc*nstack), dtype=np.uint8)
self.nc = nc
obs = env.reset()
self.update_obs(obs)
self.gamma = gamma
@@ -121,8 +122,8 @@ class Runner(object):
self.dones = [False for _ in range(nenv)]
def update_obs(self, obs):
self.obs = np.roll(self.obs, shift=-1, axis=3)
self.obs[:, :, :, -1] = obs[:, :, :, 0]
self.obs = np.roll(self.obs, shift=-self.nc, axis=3)
self.obs[:, :, :, -self.nc:] = obs
def run(self):
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]