* exported rl-algs * more stuff from rl-algs * run slow tests * re-exported rl_algs * re-exported rl_algs - fixed problems with serialization test and test_cartpole * replaced atari_arg_parser with common_arg_parser * run.py can run algos from both baselines and rl_algs * added approximate humanoid reward with ppo2 into the README for reference * dummy commit to RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * very dummy commit to RUN BENCHMARKS * serialize variables as a dict, not as a list * running_mean_std uses tensorflow variables * fixed import in vec_normalize * dummy commit to RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * flake8 complaints * save all variables to make sure we save the vec_normalize normalization * benchmarks on ppo2 only RUN BENCHMARKS * make_atari_env compatible with mpi * run ppo_mpi benchmarks only RUN BENCHMARKS * hardcode names of retro environments * add defaults * changed default ppo2 lr schedule to linear RUN BENCHMARKS * non-tf normalization benchmark RUN BENCHMARKS * use ncpu=1 for mujoco sessions - gives a bit of a performance speedup * reverted running_mean_std to user property decorators for mean, var, count * reverted VecNormalize to use RunningMeanStd (no tf) * reverted VecNormalize to use RunningMeanStd (no tf) * profiling wip * use VecNormalize with regular RunningMeanStd * added acer runner (missing import) * flake8 complaints * added a note in README about TfRunningMeanStd and serialization of VecNormalize * dummy commit to RUN BENCHMARKS * merged benchmarks branch
83 lines
2.7 KiB
Python
83 lines
2.7 KiB
Python
import numpy as np
|
|
from gym import spaces
|
|
from collections import OrderedDict
|
|
from . import VecEnv
|
|
|
|
class DummyVecEnv(VecEnv):
|
|
def __init__(self, env_fns):
|
|
self.envs = [fn() for fn in env_fns]
|
|
env = self.envs[0]
|
|
VecEnv.__init__(self, len(env_fns), env.observation_space, env.action_space)
|
|
shapes, dtypes = {}, {}
|
|
self.keys = []
|
|
obs_space = env.observation_space
|
|
|
|
if isinstance(obs_space, spaces.Dict):
|
|
assert isinstance(obs_space.spaces, OrderedDict)
|
|
subspaces = obs_space.spaces
|
|
else:
|
|
subspaces = {None: obs_space}
|
|
|
|
for key, box in subspaces.items():
|
|
shapes[key] = box.shape
|
|
dtypes[key] = box.dtype
|
|
self.keys.append(key)
|
|
|
|
self.buf_obs = { k: np.zeros((self.num_envs,) + tuple(shapes[k]), dtype=dtypes[k]) for k in self.keys }
|
|
self.buf_dones = np.zeros((self.num_envs,), dtype=np.bool)
|
|
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
|
self.buf_infos = [{} for _ in range(self.num_envs)]
|
|
self.actions = None
|
|
|
|
def step_async(self, actions):
|
|
listify = True
|
|
try:
|
|
if len(actions) == self.num_envs:
|
|
listify = False
|
|
except TypeError:
|
|
pass
|
|
|
|
if not listify:
|
|
self.actions = actions
|
|
else:
|
|
assert self.num_envs == 1, "actions {} is either not a list or has a wrong size - cannot match to {} environments".format(actions, self.num_envs)
|
|
self.actions = [actions]
|
|
|
|
def step_wait(self):
|
|
for e in range(self.num_envs):
|
|
action = self.actions[e]
|
|
if isinstance(self.envs[e].action_space, spaces.Discrete):
|
|
action = int(action)
|
|
|
|
obs, self.buf_rews[e], self.buf_dones[e], self.buf_infos[e] = self.envs[e].step(action)
|
|
if self.buf_dones[e]:
|
|
obs = self.envs[e].reset()
|
|
self._save_obs(e, obs)
|
|
return (np.copy(self._obs_from_buf()), np.copy(self.buf_rews), np.copy(self.buf_dones),
|
|
self.buf_infos.copy())
|
|
|
|
def reset(self):
|
|
for e in range(self.num_envs):
|
|
obs = self.envs[e].reset()
|
|
self._save_obs(e, obs)
|
|
return self._obs_from_buf()
|
|
|
|
def close(self):
|
|
return
|
|
|
|
def render(self, mode='human'):
|
|
return [e.render(mode=mode) for e in self.envs]
|
|
|
|
def _save_obs(self, e, obs):
|
|
for k in self.keys:
|
|
if k is None:
|
|
self.buf_obs[k][e] = obs
|
|
else:
|
|
self.buf_obs[k][e] = obs[k]
|
|
|
|
def _obs_from_buf(self):
|
|
if self.keys==[None]:
|
|
return self.buf_obs[None]
|
|
else:
|
|
return self.buf_obs
|