Compare commits

...

16 Commits
tf2 ... master

Author SHA1 Message Date
Harry Uglow
ea25b9e8b2 Monitor should close what it inherits (#1076) 2020-01-31 05:06:18 -08:00
pzhokhov
9ee399f5b2 Fix build with latest gym (#1034)
* update to use latest version of gym

* fix imports

* narrow down gym version to 0.15.4 <= gym < 0.16.0
2019-11-10 11:10:01 -08:00
Tomasz Wrona
391811d98c SubprocVecEnv uses CloudpickleWrapper to send specs (#1028) 2019-11-08 15:23:49 -08:00
Yen-Chen Lin
665b888eeb Fix behavior cloning due to API changes (#1014) 2019-10-25 15:44:43 -07:00
Christopher Hesse
f40a477a17 fix tf2 branch name 2019-10-25 15:27:30 -07:00
johannespitz
c6144bdb6a Fix RuntimeError (#910) (#1015)
* Update the commands to install Tensorflow

The current 'tensorflow' package is for Tensorflow 2, which is not supported by the master branch of baselines.

* Update command to install Tensorflow 1.14

* Fix RuntimeError (#910)

 - Removed interfering calls to env.reset() in play mode.
   (Note that the worker in the subprocess is calling env.reset() already)

 - Fixed the printed reward when running multiple envs in play mode.
2019-10-25 15:24:41 -07:00
Peter Zhokhov
adba88b218 add quote marks to tensorflow < 2 to avoid bash logic 2019-10-11 17:13:43 -07:00
Peter Zhokhov
bfbc3bae14 update status, fix the tensorflow version in the build 2019-10-11 15:23:14 -07:00
Haiyang Chen
f703776c91 fix a bug in acer saving and loading model (#990) 2019-09-27 15:39:41 -07:00
pzhokhov
53797293e5 use allreduce instead of Allreduce (send pickled data instead of floats) - probably affects performance somewhat, but avoid element number mismatch. Fixes 998 (#1000) 2019-09-27 14:45:31 -07:00
tanzhenyu
229a772b81 Release notes for Tensorflow 2.0 support. (#997) 2019-08-29 14:25:44 -07:00
Tomasz Wrona
d80b075904 Make SubprocVecEnv works with DummyVecEnv (#908)
* Make SubprocVecEnv works with DummyVecEnv (nested environments for synchronous sampling)

* SubprocVecEnv now supports running environments in series in each process

* Added docstring to the test definition

* Added additional test to check, whether SubprocVecEnv results with the same output when in_series parameter is enabled and not

* Added more test cases for in_series parameter

* Refactored worker function, added docstring for in_series parameter

* Remove check for TF presence in setup.py
2019-08-29 12:16:25 -07:00
NicoBach
0182fe1877 entrypoint variable made public (#970) 2019-08-06 02:03:19 +03:00
Seungjae Ryan Lee
1fb4dfb780 Fix typo in GAIL dataset log (#950) 2019-08-06 02:02:43 +03:00
Timo Kaufmann
7cadef715f Fix typo (#930)
* Fix typo

* Fix train_freq documentation

Seems to be a copy-paste error, train_freq has nothing to do with
printing.

* Fix documentation typo
2019-08-06 02:02:21 +03:00
tanzhenyu
fce4370ba2 Remove duplicate code in adaptive param noise. (#976) 2019-08-06 02:01:54 +03:00
17 changed files with 120 additions and 54 deletions

View File

@@ -11,7 +11,7 @@ WORKDIR $CODE_DIR/baselines
# Clean up pycache and pyc files # Clean up pycache and pyc files
RUN rm -rf __pycache__ && \ RUN rm -rf __pycache__ && \
find . -name "*.pyc" -delete && \ find . -name "*.pyc" -delete && \
pip install tensorflow && \ pip install 'tensorflow < 2' && \
pip install -e .[test] pip install -e .[test]

View File

@@ -1,4 +1,4 @@
**Status:** Active (under active development, breaking changes may occur) **Status:** Maintenance (expect bug fixes and minor updates)
<img src="data/logo.jpg" width=25% align="right" /> [![Build status](https://travis-ci.org/openai/baselines.svg?branch=master)](https://travis-ci.org/openai/baselines) <img src="data/logo.jpg" width=25% align="right" /> [![Build status](https://travis-ci.org/openai/baselines.svg?branch=master)](https://travis-ci.org/openai/baselines)
@@ -39,21 +39,24 @@ To activate a virtualenv:
More thorough tutorial on virtualenvs and options can be found [here](https://virtualenv.pypa.io/en/stable/) More thorough tutorial on virtualenvs and options can be found [here](https://virtualenv.pypa.io/en/stable/)
## Tensorflow versions
The master branch supports Tensorflow from version 1.4 to 1.14. For Tensorflow 2.0 support, please use tf2 branch.
## Installation ## Installation
- Clone the repo and cd into it: - Clone the repo and cd into it:
```bash ```bash
git clone https://github.com/openai/baselines.git git clone https://github.com/openai/baselines.git
cd baselines cd baselines
``` ```
- If you don't have TensorFlow installed already, install your favourite flavor of TensorFlow. In most cases, - If you don't have TensorFlow installed already, install your favourite flavor of TensorFlow. In most cases, you may use
```bash ```bash
pip install tensorflow-gpu # if you have a CUDA-compatible gpu and proper drivers pip install tensorflow-gpu==1.14 # if you have a CUDA-compatible gpu and proper drivers
``` ```
or or
```bash ```bash
pip install tensorflow pip install tensorflow==1.14
``` ```
should be sufficient. Refer to [TensorFlow installation guide](https://www.tensorflow.org/install/) to install Tensorflow 1.14, which is the latest version of Tensorflow supported by the master branch. Refer to [TensorFlow installation guide](https://www.tensorflow.org/install/)
for more details. for more details.
- Install baselines package - Install baselines package

View File

@@ -6,7 +6,7 @@ from baselines import logger
from baselines.common import set_global_seeds from baselines.common import set_global_seeds
from baselines.common.policies import build_policy from baselines.common.policies import build_policy
from baselines.common.tf_util import get_session, save_variables from baselines.common.tf_util import get_session, save_variables, load_variables
from baselines.common.vec_env.vec_frame_stack import VecFrameStack from baselines.common.vec_env.vec_frame_stack import VecFrameStack
from baselines.a2c.utils import batch_to_seq, seq_to_batch from baselines.a2c.utils import batch_to_seq, seq_to_batch
@@ -216,7 +216,8 @@ class Model(object):
self.train = train self.train = train
self.save = functools.partial(save_variables, sess=sess, variables=params) self.save = functools.partial(save_variables, sess=sess)
self.load = functools.partial(load_variables, sess=sess)
self.train_model = train_model self.train_model = train_model
self.step_model = step_model self.step_model = step_model
self._step = _step self._step = _step
@@ -358,6 +359,9 @@ def learn(network, env, seed=None, nsteps=20, total_timesteps=int(80e6), q_coef=
total_timesteps=total_timesteps, lrschedule=lrschedule, c=c, total_timesteps=total_timesteps, lrschedule=lrschedule, c=c,
trust_region=trust_region, alpha=alpha, delta=delta) trust_region=trust_region, alpha=alpha, delta=delta)
if load_path is not None:
model.load(load_path)
runner = Runner(env=env, model=model, nsteps=nsteps) runner = Runner(env=env, model=model, nsteps=nsteps)
if replay_ratio > 0: if replay_ratio > 0:
buffer = Buffer(env=env, nsteps=nsteps, size=buffer_size) buffer = Buffer(env=env, nsteps=nsteps, size=buffer_size)

View File

@@ -77,6 +77,7 @@ class Monitor(Wrapper):
self.total_steps += 1 self.total_steps += 1
def close(self): def close(self):
super(Monitor, self).close()
if self.f is not None: if self.f is not None:
self.f.close() self.f.close()

View File

@@ -9,7 +9,7 @@ except ImportError:
MPI = None MPI = None
import gym import gym
from gym.wrappers import FlattenDictWrapper from gym.wrappers import FlattenObservation, FilterObservation
from baselines import logger from baselines import logger
from baselines.bench import Monitor from baselines.bench import Monitor
from baselines.common import set_global_seeds from baselines.common import set_global_seeds
@@ -81,8 +81,7 @@ def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.
env = gym.make(env_id, **env_kwargs) env = gym.make(env_id, **env_kwargs)
if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict): if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
keys = env.observation_space.spaces.keys() env = FlattenObservation(env)
env = gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys))
env.seed(seed + subrank if seed is not None else None) env.seed(seed + subrank if seed is not None else None)
env = Monitor(env, env = Monitor(env,
@@ -128,7 +127,7 @@ def make_robotics_env(env_id, seed, rank=0):
""" """
set_global_seeds(seed) set_global_seeds(seed)
env = gym.make(env_id) env = gym.make(env_id)
env = FlattenDictWrapper(env, ['observation', 'desired_goal']) env = FlattenObservation(FilterObservation(env, ['observation', 'desired_goal']))
env = Monitor( env = Monitor(
env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
info_keywords=('is_success',)) info_keywords=('is_success',))

View File

@@ -12,8 +12,9 @@ def mpi_mean(x, axis=0, comm=None, keepdims=False):
localsum = np.zeros(n+1, x.dtype) localsum = np.zeros(n+1, x.dtype)
localsum[:n] = xsum.ravel() localsum[:n] = xsum.ravel()
localsum[n] = x.shape[axis] localsum[n] = x.shape[axis]
globalsum = np.zeros_like(localsum) # globalsum = np.zeros_like(localsum)
comm.Allreduce(localsum, globalsum, op=MPI.SUM) # comm.Allreduce(localsum, globalsum, op=MPI.SUM)
globalsum = comm.allreduce(localsum, op=MPI.SUM)
return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n] return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n]
def mpi_moments(x, axis=0, comm=None, keepdims=False): def mpi_moments(x, axis=0, comm=None, keepdims=False):

View File

@@ -4,33 +4,36 @@ import numpy as np
from .vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars from .vec_env import VecEnv, CloudpickleWrapper, clear_mpi_env_vars
def worker(remote, parent_remote, env_fn_wrapper): def worker(remote, parent_remote, env_fn_wrappers):
def step_env(env, action):
ob, reward, done, info = env.step(action)
if done:
ob = env.reset()
return ob, reward, done, info
parent_remote.close() parent_remote.close()
env = env_fn_wrapper.x() envs = [env_fn_wrapper() for env_fn_wrapper in env_fn_wrappers.x]
try: try:
while True: while True:
cmd, data = remote.recv() cmd, data = remote.recv()
if cmd == 'step': if cmd == 'step':
ob, reward, done, info = env.step(data) remote.send([step_env(env, action) for env, action in zip(envs, data)])
if done:
ob = env.reset()
remote.send((ob, reward, done, info))
elif cmd == 'reset': elif cmd == 'reset':
ob = env.reset() remote.send([env.reset() for env in envs])
remote.send(ob)
elif cmd == 'render': elif cmd == 'render':
remote.send(env.render(mode='rgb_array')) remote.send([env.render(mode='rgb_array') for env in envs])
elif cmd == 'close': elif cmd == 'close':
remote.close() remote.close()
break break
elif cmd == 'get_spaces_spec': elif cmd == 'get_spaces_spec':
remote.send((env.observation_space, env.action_space, env.spec)) remote.send(CloudpickleWrapper((envs[0].observation_space, envs[0].action_space, envs[0].spec)))
else: else:
raise NotImplementedError raise NotImplementedError
except KeyboardInterrupt: except KeyboardInterrupt:
print('SubprocVecEnv worker: got KeyboardInterrupt') print('SubprocVecEnv worker: got KeyboardInterrupt')
finally: finally:
env.close() for env in envs:
env.close()
class SubprocVecEnv(VecEnv): class SubprocVecEnv(VecEnv):
@@ -38,17 +41,23 @@ class SubprocVecEnv(VecEnv):
VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes. VecEnv that runs multiple environments in parallel in subproceses and communicates with them via pipes.
Recommended to use when num_envs > 1 and step() can be a bottleneck. Recommended to use when num_envs > 1 and step() can be a bottleneck.
""" """
def __init__(self, env_fns, spaces=None, context='spawn'): def __init__(self, env_fns, spaces=None, context='spawn', in_series=1):
""" """
Arguments: Arguments:
env_fns: iterable of callables - functions that create environments to run in subprocesses. Need to be cloud-pickleable env_fns: iterable of callables - functions that create environments to run in subprocesses. Need to be cloud-pickleable
in_series: number of environments to run in series in a single process
(e.g. when len(env_fns) == 12 and in_series == 3, it will run 4 processes, each running 3 envs in series)
""" """
self.waiting = False self.waiting = False
self.closed = False self.closed = False
self.in_series = in_series
nenvs = len(env_fns) nenvs = len(env_fns)
assert nenvs % in_series == 0, "Number of envs must be divisible by number of envs to run in series"
self.nremotes = nenvs // in_series
env_fns = np.array_split(env_fns, self.nremotes)
ctx = mp.get_context(context) ctx = mp.get_context(context)
self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(nenvs)]) self.remotes, self.work_remotes = zip(*[ctx.Pipe() for _ in range(self.nremotes)])
self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn))) self.ps = [ctx.Process(target=worker, args=(work_remote, remote, CloudpickleWrapper(env_fn)))
for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)] for (work_remote, remote, env_fn) in zip(self.work_remotes, self.remotes, env_fns)]
for p in self.ps: for p in self.ps:
@@ -59,12 +68,13 @@ class SubprocVecEnv(VecEnv):
remote.close() remote.close()
self.remotes[0].send(('get_spaces_spec', None)) self.remotes[0].send(('get_spaces_spec', None))
observation_space, action_space, self.spec = self.remotes[0].recv() observation_space, action_space, self.spec = self.remotes[0].recv().x
self.viewer = None self.viewer = None
VecEnv.__init__(self, len(env_fns), observation_space, action_space) VecEnv.__init__(self, nenvs, observation_space, action_space)
def step_async(self, actions): def step_async(self, actions):
self._assert_not_closed() self._assert_not_closed()
actions = np.array_split(actions, self.nremotes)
for remote, action in zip(self.remotes, actions): for remote, action in zip(self.remotes, actions):
remote.send(('step', action)) remote.send(('step', action))
self.waiting = True self.waiting = True
@@ -72,6 +82,7 @@ class SubprocVecEnv(VecEnv):
def step_wait(self): def step_wait(self):
self._assert_not_closed() self._assert_not_closed()
results = [remote.recv() for remote in self.remotes] results = [remote.recv() for remote in self.remotes]
results = _flatten_list(results)
self.waiting = False self.waiting = False
obs, rews, dones, infos = zip(*results) obs, rews, dones, infos = zip(*results)
return _flatten_obs(obs), np.stack(rews), np.stack(dones), infos return _flatten_obs(obs), np.stack(rews), np.stack(dones), infos
@@ -80,7 +91,9 @@ class SubprocVecEnv(VecEnv):
self._assert_not_closed() self._assert_not_closed()
for remote in self.remotes: for remote in self.remotes:
remote.send(('reset', None)) remote.send(('reset', None))
return _flatten_obs([remote.recv() for remote in self.remotes]) obs = [remote.recv() for remote in self.remotes]
obs = _flatten_list(obs)
return _flatten_obs(obs)
def close_extras(self): def close_extras(self):
self.closed = True self.closed = True
@@ -97,6 +110,7 @@ class SubprocVecEnv(VecEnv):
for pipe in self.remotes: for pipe in self.remotes:
pipe.send(('render', None)) pipe.send(('render', None))
imgs = [pipe.recv() for pipe in self.remotes] imgs = [pipe.recv() for pipe in self.remotes]
imgs = _flatten_list(imgs)
return imgs return imgs
def _assert_not_closed(self): def _assert_not_closed(self):
@@ -115,3 +129,10 @@ def _flatten_obs(obs):
return {k: np.stack([o[k] for o in obs]) for k in keys} return {k: np.stack([o[k] for o in obs]) for k in keys}
else: else:
return np.stack(obs) return np.stack(obs)
def _flatten_list(l):
assert isinstance(l, (list, tuple))
assert len(l) > 0
assert all([len(l_) > 0 for l_ in l])
return [l__ for l_ in l for l__ in l_]

View File

@@ -67,6 +67,50 @@ def test_vec_env(klass, dtype): # pylint: disable=R0914
assert_venvs_equal(env1, env2, num_steps=num_steps) assert_venvs_equal(env1, env2, num_steps=num_steps)
@pytest.mark.parametrize('dtype', ('uint8', 'float32'))
@pytest.mark.parametrize('num_envs_in_series', (3, 4, 6))
def test_sync_sampling(dtype, num_envs_in_series):
"""
Test that a SubprocVecEnv running with envs in series
outputs the same as DummyVecEnv.
"""
num_envs = 12
num_steps = 100
shape = (3, 8)
def make_fn(seed):
"""
Get an environment constructor with a seed.
"""
return lambda: SimpleEnv(seed, shape, dtype)
fns = [make_fn(i) for i in range(num_envs)]
env1 = DummyVecEnv(fns)
env2 = SubprocVecEnv(fns, in_series=num_envs_in_series)
assert_venvs_equal(env1, env2, num_steps=num_steps)
@pytest.mark.parametrize('dtype', ('uint8', 'float32'))
@pytest.mark.parametrize('num_envs_in_series', (3, 4, 6))
def test_sync_sampling_sanity(dtype, num_envs_in_series):
"""
Test that a SubprocVecEnv running with envs in series
outputs the same as SubprocVecEnv without running in series.
"""
num_envs = 12
num_steps = 100
shape = (3, 8)
def make_fn(seed):
"""
Get an environment constructor with a seed.
"""
return lambda: SimpleEnv(seed, shape, dtype)
fns = [make_fn(i) for i in range(num_envs)]
env1 = SubprocVecEnv(fns)
env2 = SubprocVecEnv(fns, in_series=num_envs_in_series)
assert_venvs_equal(env1, env2, num_steps=num_steps)
class SimpleEnv(gym.Env): class SimpleEnv(gym.Env):
""" """
An environment with a pre-determined observation space An environment with a pre-determined observation space

View File

@@ -378,11 +378,6 @@ class DDPG(object):
self.param_noise_stddev: self.param_noise.current_stddev, self.param_noise_stddev: self.param_noise.current_stddev,
}) })
if MPI is not None:
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
else:
mean_distance = distance
if MPI is not None: if MPI is not None:
mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size() mean_distance = MPI.COMM_WORLD.allreduce(distance, op=MPI.SUM) / MPI.COMM_WORLD.Get_size()
else: else:

View File

@@ -13,7 +13,7 @@ The functions in this file can are used to create the following functions:
stochastic: bool stochastic: bool
if set to False all the actions are always deterministic (default False) if set to False all the actions are always deterministic (default False)
update_eps_ph: float update_eps_ph: float
update epsilon a new value, if negative not update happens update epsilon a new value, if negative no update happens
(default: no update) (default: no update)
Returns Returns

View File

@@ -142,9 +142,8 @@ def learn(env,
final value of random action probability final value of random action probability
train_freq: int train_freq: int
update the model every `train_freq` steps. update the model every `train_freq` steps.
set to None to disable printing
batch_size: int batch_size: int
size of a batched sampled from replay buffer for training size of a batch sampled from replay buffer for training
print_freq: int print_freq: int
how often to print out training progress how often to print out training progress
set to None to disable printing set to None to disable printing

View File

@@ -23,7 +23,7 @@ from baselines.gail.dataset.mujoco_dset import Mujoco_Dset
def argsparser(): def argsparser():
parser = argparse.ArgumentParser("Tensorflow Implementation of Behavior Cloning") parser = argparse.ArgumentParser("Tensorflow Implementation of Behavior Cloning")
parser.add_argument('--env_id', help='environment ID', default='Hopper-v1') parser.add_argument('--env_id', help='environment ID', default='Hopper-v2')
parser.add_argument('--seed', help='RNG seed', type=int, default=0) parser.add_argument('--seed', help='RNG seed', type=int, default=0)
parser.add_argument('--expert_path', type=str, default='data/deterministic.trpo.Hopper.0.00.npz') parser.add_argument('--expert_path', type=str, default='data/deterministic.trpo.Hopper.0.00.npz')
parser.add_argument('--checkpoint_dir', help='the directory to save model', default='checkpoint') parser.add_argument('--checkpoint_dir', help='the directory to save model', default='checkpoint')
@@ -73,7 +73,7 @@ def learn(env, policy_func, dataset, optim_batch_size=128, max_iters=1e4,
savedir_fname = tempfile.TemporaryDirectory().name savedir_fname = tempfile.TemporaryDirectory().name
else: else:
savedir_fname = osp.join(ckpt_dir, task_name) savedir_fname = osp.join(ckpt_dir, task_name)
U.save_state(savedir_fname, var_list=pi.get_variables()) U.save_variables(savedir_fname, variables=pi.get_variables())
return savedir_fname return savedir_fname

View File

@@ -77,7 +77,7 @@ class Mujoco_Dset(object):
self.log_info() self.log_info()
def log_info(self): def log_info(self):
logger.log("Total trajectorues: %d" % self.num_traj) logger.log("Total trajectories: %d" % self.num_traj)
logger.log("Total transitions: %d" % self.num_transition) logger.log("Total transitions: %d" % self.num_transition)
logger.log("Average returns: %f" % self.avg_ret) logger.log("Average returns: %f" % self.avg_ret)
logger.log("Std for returns: %f" % self.std_ret) logger.log("Std for returns: %f" % self.std_ret)

View File

@@ -165,7 +165,7 @@ def runner(env, policy_func, load_model_path, timesteps_per_batch, number_trajs,
U.initialize() U.initialize()
# Prepare for rollouts # Prepare for rollouts
# ---------------------------------------- # ----------------------------------------
U.load_state(load_model_path) U.load_variables(load_model_path)
obs_list = [] obs_list = []
acs_list = [] acs_list = []

View File

@@ -15,8 +15,7 @@ class RolloutWorker:
"""Rollout worker generates experience by interacting with one or many environments. """Rollout worker generates experience by interacting with one or many environments.
Args: Args:
make_env (function): a factory function that creates a new instance of the environment venv: vectorized gym environments.
when called
policy (object): the policy that is used to act policy (object): the policy that is used to act
dims (dict of ints): the dimensions for observations (o), goals (g), and actions (u) dims (dict of ints): the dimensions for observations (o), goals (g), and actions (u)
logger (object): the logger that is used by the rollout worker logger (object): the logger that is used by the rollout worker

View File

@@ -32,7 +32,7 @@ except ImportError:
_game_envs = defaultdict(set) _game_envs = defaultdict(set)
for env in gym.envs.registry.all(): for env in gym.envs.registry.all():
# TODO: solve this with regexes # TODO: solve this with regexes
env_type = env._entry_point.split(':')[0].split('.')[-1] env_type = env.entry_point.split(':')[0].split('.')[-1]
_game_envs[env_type].add(env.id) _game_envs[env_type].add(env.id)
# reading benchmark names directly from retro requires # reading benchmark names directly from retro requires
@@ -126,7 +126,7 @@ def get_env_type(args):
# Re-parse the gym registry, since we could have new envs since last time. # Re-parse the gym registry, since we could have new envs since last time.
for env in gym.envs.registry.all(): for env in gym.envs.registry.all():
env_type = env._entry_point.split(':')[0].split('.')[-1] env_type = env.entry_point.split(':')[0].split('.')[-1]
_game_envs[env_type].add(env.id) # This is a set so add is idempotent _game_envs[env_type].add(env.id) # This is a set so add is idempotent
if env_id in _game_envs.keys(): if env_id in _game_envs.keys():
@@ -226,7 +226,7 @@ def main(args):
state = model.initial_state if hasattr(model, 'initial_state') else None state = model.initial_state if hasattr(model, 'initial_state') else None
dones = np.zeros((1,)) dones = np.zeros((1,))
episode_rew = 0 episode_rew = np.zeros(env.num_envs) if isinstance(env, VecEnv) else np.zeros(1)
while True: while True:
if state is not None: if state is not None:
actions, _, state, _ = model.step(obs,S=state, M=dones) actions, _, state, _ = model.step(obs,S=state, M=dones)
@@ -234,13 +234,13 @@ def main(args):
actions, _, _, _ = model.step(obs) actions, _, _, _ = model.step(obs)
obs, rew, done, _ = env.step(actions) obs, rew, done, _ = env.step(actions)
episode_rew += rew[0] if isinstance(env, VecEnv) else rew episode_rew += rew
env.render() env.render()
done = done.any() if isinstance(done, np.ndarray) else done done_any = done.any() if isinstance(done, np.ndarray) else done
if done: if done_any:
print('episode_rew={}'.format(episode_rew)) for i in np.nonzero(done)[0]:
episode_rew = 0 print('episode_rew={}'.format(episode_rew[i]))
obs = env.reset() episode_rew[i] = 0
env.close() env.close()

View File

@@ -31,7 +31,7 @@ setup(name='baselines',
packages=[package for package in find_packages() packages=[package for package in find_packages()
if package.startswith('baselines')], if package.startswith('baselines')],
install_requires=[ install_requires=[
'gym>=0.10.0, <1.0.0', 'gym>=0.15.4, <0.16.0',
'scipy', 'scipy',
'tqdm', 'tqdm',
'joblib', 'joblib',