61 lines
2.6 KiB
Python
61 lines
2.6 KiB
Python
import numpy as np
|
|
from baselines.common.runners import AbstractEnvRunner
|
|
|
|
class Runner(AbstractEnvRunner):
|
|
|
|
def __init__(self, env, model, nsteps, nstack):
|
|
super().__init__(env=env, model=model, nsteps=nsteps)
|
|
self.nstack = nstack
|
|
nh, nw, nc = env.observation_space.shape
|
|
self.nc = nc # nc = 1 for atari, but just in case
|
|
self.nact = env.action_space.n
|
|
nenv = self.nenv
|
|
self.nbatch = nenv * nsteps
|
|
self.batch_ob_shape = (nenv*(nsteps+1), nh, nw, nc*nstack)
|
|
self.obs = np.zeros((nenv, nh, nw, nc * nstack), dtype=np.uint8)
|
|
obs = env.reset()
|
|
self.update_obs(obs)
|
|
|
|
def update_obs(self, obs, dones=None):
|
|
#self.obs = obs
|
|
if dones is not None:
|
|
self.obs *= (1 - dones.astype(np.uint8))[:, None, None, None]
|
|
self.obs = np.roll(self.obs, shift=-self.nc, axis=3)
|
|
self.obs[:, :, :, -self.nc:] = obs[:, :, :, :]
|
|
|
|
def run(self):
|
|
enc_obs = np.split(self.obs, self.nstack, axis=3) # so now list of obs steps
|
|
mb_obs, mb_actions, mb_mus, mb_dones, mb_rewards = [], [], [], [], []
|
|
for _ in range(self.nsteps):
|
|
actions, mus, states = self.model._step(self.obs, S=self.states, M=self.dones)
|
|
mb_obs.append(np.copy(self.obs))
|
|
mb_actions.append(actions)
|
|
mb_mus.append(mus)
|
|
mb_dones.append(self.dones)
|
|
obs, rewards, dones, _ = self.env.step(actions)
|
|
# states information for statefull models like LSTM
|
|
self.states = states
|
|
self.dones = dones
|
|
self.update_obs(obs, dones)
|
|
mb_rewards.append(rewards)
|
|
enc_obs.append(obs)
|
|
mb_obs.append(np.copy(self.obs))
|
|
mb_dones.append(self.dones)
|
|
|
|
enc_obs = np.asarray(enc_obs, dtype=np.uint8).swapaxes(1, 0)
|
|
mb_obs = np.asarray(mb_obs, dtype=np.uint8).swapaxes(1, 0)
|
|
mb_actions = np.asarray(mb_actions, dtype=np.int32).swapaxes(1, 0)
|
|
mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
|
|
mb_mus = np.asarray(mb_mus, dtype=np.float32).swapaxes(1, 0)
|
|
|
|
mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(1, 0)
|
|
|
|
mb_masks = mb_dones # Used for statefull models like LSTM's to mask state when done
|
|
mb_dones = mb_dones[:, 1:] # Used for calculating returns. The dones array is now aligned with rewards
|
|
|
|
# shapes are now [nenv, nsteps, []]
|
|
# When pulling from buffer, arrays will now be reshaped in place, preventing a deep copy.
|
|
|
|
return enc_obs, mb_obs, mb_actions, mb_rewards, mb_mus, mb_dones, mb_masks
|
|
|