Compare commits

...

11 Commits

Author SHA1 Message Date
Peter Zhokhov
8d9e20fec3 narrow down gym version to 0.15.4 <= gym < 0.16.0 2019-11-10 11:08:59 -08:00
Peter Zhokhov
fc23c78c77 fix imports 2019-11-08 15:39:57 -08:00
Peter Zhokhov
25f750d84f update to use latest version of gym 2019-11-08 15:31:40 -08:00
Tomasz Wrona
391811d98c SubprocVecEnv uses CloudpickleWrapper to send specs (#1028) 2019-11-08 15:23:49 -08:00
Yen-Chen Lin
665b888eeb Fix behavior cloning due to API changes (#1014) 2019-10-25 15:44:43 -07:00
Christopher Hesse
f40a477a17 fix tf2 branch name 2019-10-25 15:27:30 -07:00
johannespitz
c6144bdb6a Fix RuntimeError (#910) (#1015)
* Update the commands to install Tensorflow

The current 'tensorflow' package is for Tensorflow 2, which is not supported by the master branch of baselines.

* Update command to install Tensorflow 1.14

* Fix RuntimeError (#910)

 - Removed interfering calls to env.reset() in play mode.
   (Note that the worker in the subprocess is calling env.reset() already)

 - Fixed the printed reward when running multiple envs in play mode.
2019-10-25 15:24:41 -07:00
Peter Zhokhov
adba88b218 add quote marks to tensorflow < 2 to avoid bash logic 2019-10-11 17:13:43 -07:00
Peter Zhokhov
bfbc3bae14 update status, fix the tensorflow version in the build 2019-10-11 15:23:14 -07:00
Haiyang Chen
f703776c91 fix a bug in acer saving and loading model (#990) 2019-09-27 15:39:41 -07:00
pzhokhov
53797293e5 use allreduce instead of Allreduce (send pickled data instead of floats) - probably affects performance somewhat, but avoid element number mismatch. Fixes 998 (#1000) 2019-09-27 14:45:31 -07:00
10 changed files with 32 additions and 28 deletions

View File

@@ -11,7 +11,7 @@ WORKDIR $CODE_DIR/baselines
# Clean up pycache and pyc files # Clean up pycache and pyc files
RUN rm -rf __pycache__ && \ RUN rm -rf __pycache__ && \
find . -name "*.pyc" -delete && \ find . -name "*.pyc" -delete && \
pip install tensorflow && \ pip install 'tensorflow < 2' && \
pip install -e .[test] pip install -e .[test]

View File

@@ -1,4 +1,4 @@
**Status:** Active (under active development, breaking changes may occur) **Status:** Maintenance (expect bug fixes and minor updates)
<img src="data/logo.jpg" width=25% align="right" /> [![Build status](https://travis-ci.org/openai/baselines.svg?branch=master)](https://travis-ci.org/openai/baselines) <img src="data/logo.jpg" width=25% align="right" /> [![Build status](https://travis-ci.org/openai/baselines.svg?branch=master)](https://travis-ci.org/openai/baselines)
@@ -40,7 +40,7 @@ More thorough tutorial on virtualenvs and options can be found [here](https://vi
## Tensorflow versions ## Tensorflow versions
The master branch supports Tensorflow from version 1.4 to 1.14. For Tensorflow 2.0 support, please use tf-2 branch. The master branch supports Tensorflow from version 1.4 to 1.14. For Tensorflow 2.0 support, please use tf2 branch.
## Installation ## Installation
- Clone the repo and cd into it: - Clone the repo and cd into it:
@@ -48,15 +48,15 @@ The master branch supports Tensorflow from version 1.4 to 1.14. For Tensorflow 2
git clone https://github.com/openai/baselines.git git clone https://github.com/openai/baselines.git
cd baselines cd baselines
``` ```
- If you don't have TensorFlow installed already, install your favourite flavor of TensorFlow. In most cases, - If you don't have TensorFlow installed already, install your favourite flavor of TensorFlow. In most cases, you may use
```bash ```bash
pip install tensorflow-gpu # if you have a CUDA-compatible gpu and proper drivers pip install tensorflow-gpu==1.14 # if you have a CUDA-compatible gpu and proper drivers
``` ```
or or
```bash ```bash
pip install tensorflow pip install tensorflow==1.14
``` ```
should be sufficient. Refer to [TensorFlow installation guide](https://www.tensorflow.org/install/) to install Tensorflow 1.14, which is the latest version of Tensorflow supported by the master branch. Refer to [TensorFlow installation guide](https://www.tensorflow.org/install/)
for more details. for more details.
- Install baselines package - Install baselines package

View File

@@ -6,7 +6,7 @@ from baselines import logger
from baselines.common import set_global_seeds from baselines.common import set_global_seeds
from baselines.common.policies import build_policy from baselines.common.policies import build_policy
from baselines.common.tf_util import get_session, save_variables from baselines.common.tf_util import get_session, save_variables, load_variables
from baselines.common.vec_env.vec_frame_stack import VecFrameStack from baselines.common.vec_env.vec_frame_stack import VecFrameStack
from baselines.a2c.utils import batch_to_seq, seq_to_batch from baselines.a2c.utils import batch_to_seq, seq_to_batch
@@ -216,7 +216,8 @@ class Model(object):
self.train = train self.train = train
self.save = functools.partial(save_variables, sess=sess, variables=params) self.save = functools.partial(save_variables, sess=sess)
self.load = functools.partial(load_variables, sess=sess)
self.train_model = train_model self.train_model = train_model
self.step_model = step_model self.step_model = step_model
self._step = _step self._step = _step
@@ -358,6 +359,9 @@ def learn(network, env, seed=None, nsteps=20, total_timesteps=int(80e6), q_coef=
total_timesteps=total_timesteps, lrschedule=lrschedule, c=c, total_timesteps=total_timesteps, lrschedule=lrschedule, c=c,
trust_region=trust_region, alpha=alpha, delta=delta) trust_region=trust_region, alpha=alpha, delta=delta)
if load_path is not None:
model.load(load_path)
runner = Runner(env=env, model=model, nsteps=nsteps) runner = Runner(env=env, model=model, nsteps=nsteps)
if replay_ratio > 0: if replay_ratio > 0:
buffer = Buffer(env=env, nsteps=nsteps, size=buffer_size) buffer = Buffer(env=env, nsteps=nsteps, size=buffer_size)

View File

@@ -9,7 +9,7 @@ except ImportError:
MPI = None MPI = None
import gym import gym
from gym.wrappers import FlattenDictWrapper from gym.wrappers import FlattenObservation, FilterObservation
from baselines import logger from baselines import logger
from baselines.bench import Monitor from baselines.bench import Monitor
from baselines.common import set_global_seeds from baselines.common import set_global_seeds
@@ -81,8 +81,7 @@ def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.
env = gym.make(env_id, **env_kwargs) env = gym.make(env_id, **env_kwargs)
if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict): if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
keys = env.observation_space.spaces.keys() env = FlattenObservation(env)
env = gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys))
env.seed(seed + subrank if seed is not None else None) env.seed(seed + subrank if seed is not None else None)
env = Monitor(env, env = Monitor(env,
@@ -128,7 +127,7 @@ def make_robotics_env(env_id, seed, rank=0):
""" """
set_global_seeds(seed) set_global_seeds(seed)
env = gym.make(env_id) env = gym.make(env_id)
env = FlattenDictWrapper(env, ['observation', 'desired_goal']) env = FlattenObservation(FilterObservation(env, ['observation', 'desired_goal']))
env = Monitor( env = Monitor(
env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
info_keywords=('is_success',)) info_keywords=('is_success',))

View File

@@ -12,8 +12,9 @@ def mpi_mean(x, axis=0, comm=None, keepdims=False):
localsum = np.zeros(n+1, x.dtype) localsum = np.zeros(n+1, x.dtype)
localsum[:n] = xsum.ravel() localsum[:n] = xsum.ravel()
localsum[n] = x.shape[axis] localsum[n] = x.shape[axis]
globalsum = np.zeros_like(localsum) # globalsum = np.zeros_like(localsum)
comm.Allreduce(localsum, globalsum, op=MPI.SUM) # comm.Allreduce(localsum, globalsum, op=MPI.SUM)
globalsum = comm.allreduce(localsum, op=MPI.SUM)
return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n] return globalsum[:n].reshape(xsum.shape) / globalsum[n], globalsum[n]
def mpi_moments(x, axis=0, comm=None, keepdims=False): def mpi_moments(x, axis=0, comm=None, keepdims=False):

View File

@@ -26,7 +26,7 @@ def worker(remote, parent_remote, env_fn_wrappers):
remote.close() remote.close()
break break
elif cmd == 'get_spaces_spec': elif cmd == 'get_spaces_spec':
remote.send((envs[0].observation_space, envs[0].action_space, envs[0].spec)) remote.send(CloudpickleWrapper((envs[0].observation_space, envs[0].action_space, envs[0].spec)))
else: else:
raise NotImplementedError raise NotImplementedError
except KeyboardInterrupt: except KeyboardInterrupt:
@@ -68,7 +68,7 @@ class SubprocVecEnv(VecEnv):
remote.close() remote.close()
self.remotes[0].send(('get_spaces_spec', None)) self.remotes[0].send(('get_spaces_spec', None))
observation_space, action_space, self.spec = self.remotes[0].recv() observation_space, action_space, self.spec = self.remotes[0].recv().x
self.viewer = None self.viewer = None
VecEnv.__init__(self, nenvs, observation_space, action_space) VecEnv.__init__(self, nenvs, observation_space, action_space)

View File

@@ -23,7 +23,7 @@ from baselines.gail.dataset.mujoco_dset import Mujoco_Dset
def argsparser(): def argsparser():
parser = argparse.ArgumentParser("Tensorflow Implementation of Behavior Cloning") parser = argparse.ArgumentParser("Tensorflow Implementation of Behavior Cloning")
parser.add_argument('--env_id', help='environment ID', default='Hopper-v1') parser.add_argument('--env_id', help='environment ID', default='Hopper-v2')
parser.add_argument('--seed', help='RNG seed', type=int, default=0) parser.add_argument('--seed', help='RNG seed', type=int, default=0)
parser.add_argument('--expert_path', type=str, default='data/deterministic.trpo.Hopper.0.00.npz') parser.add_argument('--expert_path', type=str, default='data/deterministic.trpo.Hopper.0.00.npz')
parser.add_argument('--checkpoint_dir', help='the directory to save model', default='checkpoint') parser.add_argument('--checkpoint_dir', help='the directory to save model', default='checkpoint')
@@ -73,7 +73,7 @@ def learn(env, policy_func, dataset, optim_batch_size=128, max_iters=1e4,
savedir_fname = tempfile.TemporaryDirectory().name savedir_fname = tempfile.TemporaryDirectory().name
else: else:
savedir_fname = osp.join(ckpt_dir, task_name) savedir_fname = osp.join(ckpt_dir, task_name)
U.save_state(savedir_fname, var_list=pi.get_variables()) U.save_variables(savedir_fname, variables=pi.get_variables())
return savedir_fname return savedir_fname

View File

@@ -165,7 +165,7 @@ def runner(env, policy_func, load_model_path, timesteps_per_batch, number_trajs,
U.initialize() U.initialize()
# Prepare for rollouts # Prepare for rollouts
# ---------------------------------------- # ----------------------------------------
U.load_state(load_model_path) U.load_variables(load_model_path)
obs_list = [] obs_list = []
acs_list = [] acs_list = []

View File

@@ -226,7 +226,7 @@ def main(args):
state = model.initial_state if hasattr(model, 'initial_state') else None state = model.initial_state if hasattr(model, 'initial_state') else None
dones = np.zeros((1,)) dones = np.zeros((1,))
episode_rew = 0 episode_rew = np.zeros(env.num_envs) if isinstance(env, VecEnv) else np.zeros(1)
while True: while True:
if state is not None: if state is not None:
actions, _, state, _ = model.step(obs,S=state, M=dones) actions, _, state, _ = model.step(obs,S=state, M=dones)
@@ -234,13 +234,13 @@ def main(args):
actions, _, _, _ = model.step(obs) actions, _, _, _ = model.step(obs)
obs, rew, done, _ = env.step(actions) obs, rew, done, _ = env.step(actions)
episode_rew += rew[0] if isinstance(env, VecEnv) else rew episode_rew += rew
env.render() env.render()
done = done.any() if isinstance(done, np.ndarray) else done done_any = done.any() if isinstance(done, np.ndarray) else done
if done: if done_any:
print('episode_rew={}'.format(episode_rew)) for i in np.nonzero(done)[0]:
episode_rew = 0 print('episode_rew={}'.format(episode_rew[i]))
obs = env.reset() episode_rew[i] = 0
env.close() env.close()

View File

@@ -31,7 +31,7 @@ setup(name='baselines',
packages=[package for package in find_packages() packages=[package for package in find_packages()
if package.startswith('baselines')], if package.startswith('baselines')],
install_requires=[ install_requires=[
'gym>=0.10.0, <1.0.0', 'gym>=0.15.4, <0.16.0',
'scipy', 'scipy',
'tqdm', 'tqdm',
'joblib', 'joblib',