Fixes frame stacking in A2C and ACKTR for multi-channel observation spaces.
This commit is contained in:
@@ -98,6 +98,7 @@ class Runner(object):
|
|||||||
nenv = env.num_envs
|
nenv = env.num_envs
|
||||||
self.batch_ob_shape = (nenv*nsteps, nh, nw, nc*nstack)
|
self.batch_ob_shape = (nenv*nsteps, nh, nw, nc*nstack)
|
||||||
self.obs = np.zeros((nenv, nh, nw, nc*nstack), dtype=np.uint8)
|
self.obs = np.zeros((nenv, nh, nw, nc*nstack), dtype=np.uint8)
|
||||||
|
self.nc = nc
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
self.update_obs(obs)
|
self.update_obs(obs)
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
@@ -108,8 +109,8 @@ class Runner(object):
|
|||||||
def update_obs(self, obs):
|
def update_obs(self, obs):
|
||||||
# Do frame-stacking here instead of the FrameStack wrapper to reduce
|
# Do frame-stacking here instead of the FrameStack wrapper to reduce
|
||||||
# IPC overhead
|
# IPC overhead
|
||||||
self.obs = np.roll(self.obs, shift=-1, axis=3)
|
self.obs = np.roll(self.obs, shift=-self.nc, axis=3)
|
||||||
self.obs[:, :, :, -1] = obs[:, :, :, 0]
|
self.obs[:, :, :, -self.nc:] = obs
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
|
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
|
||||||
|
@@ -113,6 +113,7 @@ class Runner(object):
|
|||||||
nenv = env.num_envs
|
nenv = env.num_envs
|
||||||
self.batch_ob_shape = (nenv*nsteps, nh, nw, nc*nstack)
|
self.batch_ob_shape = (nenv*nsteps, nh, nw, nc*nstack)
|
||||||
self.obs = np.zeros((nenv, nh, nw, nc*nstack), dtype=np.uint8)
|
self.obs = np.zeros((nenv, nh, nw, nc*nstack), dtype=np.uint8)
|
||||||
|
self.nc = nc
|
||||||
obs = env.reset()
|
obs = env.reset()
|
||||||
self.update_obs(obs)
|
self.update_obs(obs)
|
||||||
self.gamma = gamma
|
self.gamma = gamma
|
||||||
@@ -121,8 +122,8 @@ class Runner(object):
|
|||||||
self.dones = [False for _ in range(nenv)]
|
self.dones = [False for _ in range(nenv)]
|
||||||
|
|
||||||
def update_obs(self, obs):
|
def update_obs(self, obs):
|
||||||
self.obs = np.roll(self.obs, shift=-1, axis=3)
|
self.obs = np.roll(self.obs, shift=-self.nc, axis=3)
|
||||||
self.obs[:, :, :, -1] = obs[:, :, :, 0]
|
self.obs[:, :, :, -self.nc:] = obs
|
||||||
|
|
||||||
def run(self):
|
def run(self):
|
||||||
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
|
mb_obs, mb_rewards, mb_actions, mb_values, mb_dones = [],[],[],[],[]
|
||||||
|
Reference in New Issue
Block a user