Files
baselines/baselines/common/vec_env/subproc_vec_env.py

101 lines
3.5 KiB
Python
Raw Normal View History

2017-08-18 09:25:39 -07:00
import numpy as np
from multiprocessing import Process, Pipe
from baselines.common.vec_env import VecEnv, CloudpickleWrapper
from baselines.common.tile_images import tile_images
2017-08-18 09:25:39 -07:00
def worker(remote, parent_remote, env_fn_wrapper):
parent_remote.close()
2017-08-18 09:25:39 -07:00
env = env_fn_wrapper.x()
refactor a2c, acer, acktr, ppo2, deepq, and trpo_mpi (#490) * exported rl-algs * more stuff from rl-algs * run slow tests * re-exported rl_algs * re-exported rl_algs - fixed problems with serialization test and test_cartpole * replaced atari_arg_parser with common_arg_parser * run.py can run algos from both baselines and rl_algs * added approximate humanoid reward with ppo2 into the README for reference * dummy commit to RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * very dummy commit to RUN BENCHMARKS * serialize variables as a dict, not as a list * running_mean_std uses tensorflow variables * fixed import in vec_normalize * dummy commit to RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * flake8 complaints * save all variables to make sure we save the vec_normalize normalization * benchmarks on ppo2 only RUN BENCHMARKS * make_atari_env compatible with mpi * run ppo_mpi benchmarks only RUN BENCHMARKS * hardcode names of retro environments * add defaults * changed default ppo2 lr schedule to linear RUN BENCHMARKS * non-tf normalization benchmark RUN BENCHMARKS * use ncpu=1 for mujoco sessions - gives a bit of a performance speedup * reverted running_mean_std to user property decorators for mean, var, count * reverted VecNormalize to use RunningMeanStd (no tf) * reverted VecNormalize to use RunningMeanStd (no tf) * profiling wip * use VecNormalize with regular RunningMeanStd * added acer runner (missing import) * flake8 complaints * added a note in README about TfRunningMeanStd and serialization of VecNormalize * dummy commit to RUN BENCHMARKS * merged benchmarks branch
2018-08-13 09:56:44 -07:00
try:
while True:
cmd, data = remote.recv()
if cmd == 'step':
ob, reward, done, info = env.step(data)
if done:
ob = env.reset()
remote.send((ob, reward, done, info))
elif cmd == 'reset':
2017-08-18 09:25:39 -07:00
ob = env.reset()
refactor a2c, acer, acktr, ppo2, deepq, and trpo_mpi (#490) * exported rl-algs * more stuff from rl-algs * run slow tests * re-exported rl_algs * re-exported rl_algs - fixed problems with serialization test and test_cartpole * replaced atari_arg_parser with common_arg_parser * run.py can run algos from both baselines and rl_algs * added approximate humanoid reward with ppo2 into the README for reference * dummy commit to RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * very dummy commit to RUN BENCHMARKS * serialize variables as a dict, not as a list * running_mean_std uses tensorflow variables * fixed import in vec_normalize * dummy commit to RUN BENCHMARKS * dummy commit to RUN BENCHMARKS * flake8 complaints * save all variables to make sure we save the vec_normalize normalization * benchmarks on ppo2 only RUN BENCHMARKS * make_atari_env compatible with mpi * run ppo_mpi benchmarks only RUN BENCHMARKS * hardcode names of retro environments * add defaults * changed default ppo2 lr schedule to linear RUN BENCHMARKS * non-tf normalization benchmark RUN BENCHMARKS * use ncpu=1 for mujoco sessions - gives a bit of a performance speedup * reverted running_mean_std to user property decorators for mean, var, count * reverted VecNormalize to use RunningMeanStd (no tf) * reverted VecNormalize to use RunningMeanStd (no tf) * profiling wip * use VecNormalize with regular RunningMeanStd * added acer runner (missing import) * flake8 complaints * added a note in README about TfRunningMeanStd and serialization of VecNormalize * dummy commit to RUN BENCHMARKS * merged benchmarks branch
2018-08-13 09:56:44 -07:00
remote.send(ob)
elif cmd == 'render':
remote.send(env.render(mode='rgb_array'))
elif cmd == 'close':
remote.close()
break
elif cmd == 'get_spaces':
remote.send((env.observation_space, env.action_space))
else:
raise NotImplementedError
except KeyboardInterrupt:
print('SubprocVecEnv worker: got KeyboardInterrupt')
finally:
env.close()
2017-08-18 09:25:39 -07:00
class SubprocVecEnv(VecEnv):
def __init__(self, env_fns, spaces=None):
2017-08-18 09:25:39 -07:00
"""
envs: list of gym environments to run in subprocesses
"""
self.waiting = False
self.closed = False
2017-08-18 09:25:39 -07:00
nenvs = len(env_fns)
self.remotes, self.work_remotes = zip(*[Pipe() for _ in range(nenvs)])
self.ps = [Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
2017-08-18 09:25:39 -07:00
for p in self.ps:
p.daemon = True # if the main process crashes, we should not cause things to hang
2017-08-18 09:25:39 -07:00
p.start()
for remote in self.work_remotes:
remote.close()
2017-08-18 09:25:39 -07:00
self.remotes[0].send(('get_spaces', None))
observation_space, action_space = self.remotes[0].recv()
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
2017-08-18 09:25:39 -07:00
def step_async(self, actions):
2017-08-18 09:25:39 -07:00
for remote, action in zip(self.remotes, actions):
remote.send(('step', action))
self.waiting = True
def step_wait(self):
2017-08-18 09:25:39 -07:00
results = [remote.recv() for remote in self.remotes]
self.waiting = False
2017-08-18 09:25:39 -07:00
obs, rews, dones, infos = zip(*results)
return np.stack(obs), np.stack(rews), np.stack(dones), infos
def reset(self):
for remote in self.remotes:
remote.send(('reset', None))
return np.stack([remote.recv() for remote in self.remotes])
def reset_task(self):
for remote in self.remotes:
remote.send(('reset_task', None))
return np.stack([remote.recv() for remote in self.remotes])
2017-08-18 09:25:39 -07:00
def close(self):
if self.closed:
return
if self.waiting:
for remote in self.remotes:
remote.recv()
2017-08-18 09:25:39 -07:00
for remote in self.remotes:
remote.send(('close', None))
for p in self.ps:
p.join()
self.closed = True
def render(self, mode='human'):
for pipe in self.remotes:
pipe.send(('render', None))
imgs = [pipe.recv() for pipe in self.remotes]
bigimg = tile_images(imgs)
if mode == 'human':
import cv2
cv2.imshow('vecenv', bigimg[:,:,::-1])
cv2.waitKey(1)
elif mode == 'rgb_array':
return bigimg
else:
raise NotImplementedError