* exported rl-algs * more stuff from rl-algs * run slow tests * re-exported rl_algs * re-exported rl_algs - fixed problems with serialization test and test_cartpole * replaced atari_arg_parser with common_arg_parser * run.py can run algos from both baselines and rl_algs * added approximate humanoid reward with ppo2 into the README for reference * dummy commit to RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * very dummy commit to RUN BENCHMARKS * serialize variables as a dict, not as a list * running_mean_std uses tensorflow variables * fixed import in vec_normalize * dummy commit to RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * flake8 complaints * save all variables to make sure we save the vec_normalize normalization * benchmarks on ppo2 only RUN BENCHMARKS * make_atari_env compatible with mpi * run ppo_mpi benchmarks only RUN BENCHMARKS * hardcode names of retro environments * add defaults * changed default ppo2 lr schedule to linear RUN BENCHMARKS * non-tf normalization benchmark RUN BENCHMARKS * use ncpu=1 for mujoco sessions - gives a bit of a performance speedup * reverted running_mean_std to user property decorators for mean, var, count * reverted VecNormalize to use RunningMeanStd (no tf) * reverted VecNormalize to use RunningMeanStd (no tf) * profiling wip * use VecNormalize with regular RunningMeanStd * added acer runner (missing import) * flake8 complaints * added a note in README about TfRunningMeanStd and serialization of VecNormalize * dummy commit to RUN BENCHMARKS * merged benchmarks branch
61 lines
2.6 KiB
Python
61 lines
2.6 KiB
Python
import numpy as np
|
|
from baselines.common.runners import AbstractEnvRunner
|
|
|
|
class Runner(AbstractEnvRunner):
|
|
|
|
def __init__(self, env, model, nsteps, nstack):
|
|
super().__init__(env=env, model=model, nsteps=nsteps)
|
|
self.nstack = nstack
|
|
nh, nw, nc = env.observation_space.shape
|
|
self.nc = nc # nc = 1 for atari, but just in case
|
|
self.nact = env.action_space.n
|
|
nenv = self.nenv
|
|
self.nbatch = nenv * nsteps
|
|
self.batch_ob_shape = (nenv*(nsteps+1), nh, nw, nc*nstack)
|
|
self.obs = np.zeros((nenv, nh, nw, nc * nstack), dtype=np.uint8)
|
|
obs = env.reset()
|
|
self.update_obs(obs)
|
|
|
|
def update_obs(self, obs, dones=None):
|
|
#self.obs = obs
|
|
if dones is not None:
|
|
self.obs *= (1 - dones.astype(np.uint8))[:, None, None, None]
|
|
self.obs = np.roll(self.obs, shift=-self.nc, axis=3)
|
|
self.obs[:, :, :, -self.nc:] = obs[:, :, :, :]
|
|
|
|
def run(self):
|
|
enc_obs = np.split(self.obs, self.nstack, axis=3) # so now list of obs steps
|
|
mb_obs, mb_actions, mb_mus, mb_dones, mb_rewards = [], [], [], [], []
|
|
for _ in range(self.nsteps):
|
|
actions, mus, states = self.model._step(self.obs, S=self.states, M=self.dones)
|
|
mb_obs.append(np.copy(self.obs))
|
|
mb_actions.append(actions)
|
|
mb_mus.append(mus)
|
|
mb_dones.append(self.dones)
|
|
obs, rewards, dones, _ = self.env.step(actions)
|
|
# states information for statefull models like LSTM
|
|
self.states = states
|
|
self.dones = dones
|
|
self.update_obs(obs, dones)
|
|
mb_rewards.append(rewards)
|
|
enc_obs.append(obs)
|
|
mb_obs.append(np.copy(self.obs))
|
|
mb_dones.append(self.dones)
|
|
|
|
enc_obs = np.asarray(enc_obs, dtype=np.uint8).swapaxes(1, 0)
|
|
mb_obs = np.asarray(mb_obs, dtype=np.uint8).swapaxes(1, 0)
|
|
mb_actions = np.asarray(mb_actions, dtype=np.int32).swapaxes(1, 0)
|
|
mb_rewards = np.asarray(mb_rewards, dtype=np.float32).swapaxes(1, 0)
|
|
mb_mus = np.asarray(mb_mus, dtype=np.float32).swapaxes(1, 0)
|
|
|
|
mb_dones = np.asarray(mb_dones, dtype=np.bool).swapaxes(1, 0)
|
|
|
|
mb_masks = mb_dones # Used for statefull models like LSTM's to mask state when done
|
|
mb_dones = mb_dones[:, 1:] # Used for calculating returns. The dones array is now aligned with rewards
|
|
|
|
# shapes are now [nenv, nsteps, []]
|
|
# When pulling from buffer, arrays will now be reshaped in place, preventing a deep copy.
|
|
|
|
return enc_obs, mb_obs, mb_actions, mb_rewards, mb_mus, mb_dones, mb_masks
|
|
|