Compare commits

..

1 Commits

10 changed files with 26 additions and 30 deletions

View File

@@ -11,7 +11,7 @@ WORKDIR $CODE_DIR/baselines
# Clean up pycache and pyc files # Clean up pycache and pyc files
RUN rm -rf __pycache__ && \ RUN rm -rf __pycache__ && \
find . -name "*.pyc" -delete && \ find . -name "*.pyc" -delete && \
pip install 'tensorflow < 2' && \ pip install tensorflow && \
pip install -e .[test] pip install -e .[test]

View File

@@ -1,4 +1,4 @@
**Status:** Maintenance (expect bug fixes and minor updates) **Status:** Active (under active development, breaking changes may occur)
<img src="data/logo.jpg" width=25% align="right" /> [![Build status](https://travis-ci.org/openai/baselines.svg?branch=master)](https://travis-ci.org/openai/baselines) <img src="data/logo.jpg" width=25% align="right" /> [![Build status](https://travis-ci.org/openai/baselines.svg?branch=master)](https://travis-ci.org/openai/baselines)
@@ -40,7 +40,7 @@ More thorough tutorial on virtualenvs and options can be found [here](https://vi
## Tensorflow versions ## Tensorflow versions
The master branch supports Tensorflow from version 1.4 to 1.14. For Tensorflow 2.0 support, please use tf2 branch. The master branch supports Tensorflow from version 1.4 to 1.14. For Tensorflow 2.0 support, please use tf-2 branch.
## Installation ## Installation
- Clone the repo and cd into it: - Clone the repo and cd into it:
@@ -48,15 +48,15 @@ The master branch supports Tensorflow from version 1.4 to 1.14. For Tensorflow 2
git clone https://github.com/openai/baselines.git git clone https://github.com/openai/baselines.git
cd baselines cd baselines
``` ```
- If you don't have TensorFlow installed already, install your favourite flavor of TensorFlow. In most cases, you may use - If you don't have TensorFlow installed already, install your favourite flavor of TensorFlow. In most cases,
```bash ```bash
pip install tensorflow-gpu==1.14 # if you have a CUDA-compatible gpu and proper drivers pip install tensorflow-gpu # if you have a CUDA-compatible gpu and proper drivers
``` ```
or or
```bash ```bash
pip install tensorflow==1.14 pip install tensorflow
``` ```
to install Tensorflow 1.14, which is the latest version of Tensorflow supported by the master branch. Refer to [TensorFlow installation guide](https://www.tensorflow.org/install/) should be sufficient. Refer to [TensorFlow installation guide](https://www.tensorflow.org/install/)
for more details. for more details.
- Install baselines package - Install baselines package

View File

@@ -6,7 +6,7 @@ from baselines import logger
from baselines.common import set_global_seeds from baselines.common import set_global_seeds
from baselines.common.policies import build_policy from baselines.common.policies import build_policy
from baselines.common.tf_util import get_session, save_variables, load_variables from baselines.common.tf_util import get_session, save_variables
from baselines.common.vec_env.vec_frame_stack import VecFrameStack from baselines.common.vec_env.vec_frame_stack import VecFrameStack
from baselines.a2c.utils import batch_to_seq, seq_to_batch from baselines.a2c.utils import batch_to_seq, seq_to_batch
@@ -216,8 +216,7 @@ class Model(object):
self.train = train self.train = train
self.save = functools.partial(save_variables, sess=sess) self.save = functools.partial(save_variables, sess=sess, variables=params)
self.load = functools.partial(load_variables, sess=sess)
self.train_model = train_model self.train_model = train_model
self.step_model = step_model self.step_model = step_model
self._step = _step self._step = _step
@@ -359,9 +358,6 @@ def learn(network, env, seed=None, nsteps=20, total_timesteps=int(80e6), q_coef=
total_timesteps=total_timesteps, lrschedule=lrschedule, c=c, total_timesteps=total_timesteps, lrschedule=lrschedule, c=c,
trust_region=trust_region, alpha=alpha, delta=delta) trust_region=trust_region, alpha=alpha, delta=delta)
if load_path is not None:
model.load(load_path)
runner = Runner(env=env, model=model, nsteps=nsteps) runner = Runner(env=env, model=model, nsteps=nsteps)
if replay_ratio > 0: if replay_ratio > 0:
buffer = Buffer(env=env, nsteps=nsteps, size=buffer_size) buffer = Buffer(env=env, nsteps=nsteps, size=buffer_size)

View File

@@ -77,7 +77,6 @@ class Monitor(Wrapper):
self.total_steps += 1 self.total_steps += 1
def close(self): def close(self):
super(Monitor, self).close()
if self.f is not None: if self.f is not None:
self.f.close() self.f.close()

View File

@@ -9,7 +9,7 @@ except ImportError:
MPI = None MPI = None
import gym import gym
from gym.wrappers import FlattenObservation, FilterObservation from gym.wrappers import FlattenDictWrapper
from baselines import logger from baselines import logger
from baselines.bench import Monitor from baselines.bench import Monitor
from baselines.common import set_global_seeds from baselines.common import set_global_seeds
@@ -81,7 +81,8 @@ def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.
env = gym.make(env_id, **env_kwargs) env = gym.make(env_id, **env_kwargs)
if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict): if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
env = FlattenObservation(env) keys = env.observation_space.spaces.keys()
env = gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys))
env.seed(seed + subrank if seed is not None else None) env.seed(seed + subrank if seed is not None else None)
env = Monitor(env, env = Monitor(env,
@@ -127,7 +128,7 @@ def make_robotics_env(env_id, seed, rank=0):
""" """
set_global_seeds(seed) set_global_seeds(seed)
env = gym.make(env_id) env = gym.make(env_id)
env = FlattenObservation(FilterObservation(env, ['observation', 'desired_goal'])) env = FlattenDictWrapper(env, ['observation', 'desired_goal'])
env = Monitor( env = Monitor(
env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)), env, logger.get_dir() and os.path.join(logger.get_dir(), str(rank)),
info_keywords=('is_success',)) info_keywords=('is_success',))

View File

@@ -26,7 +26,7 @@ def worker(remote, parent_remote, env_fn_wrappers):
remote.close() remote.close()
break break
elif cmd == 'get_spaces_spec': elif cmd == 'get_spaces_spec':
remote.send(CloudpickleWrapper((envs[0].observation_space, envs[0].action_space, envs[0].spec))) remote.send((envs[0].observation_space, envs[0].action_space, envs[0].spec))
else: else:
raise NotImplementedError raise NotImplementedError
except KeyboardInterrupt: except KeyboardInterrupt:
@@ -68,7 +68,7 @@ class SubprocVecEnv(VecEnv):
remote.close() remote.close()
self.remotes[0].send(('get_spaces_spec', None)) self.remotes[0].send(('get_spaces_spec', None))
observation_space, action_space, self.spec = self.remotes[0].recv().x observation_space, action_space, self.spec = self.remotes[0].recv()
self.viewer = None self.viewer = None
VecEnv.__init__(self, nenvs, observation_space, action_space) VecEnv.__init__(self, nenvs, observation_space, action_space)

View File

@@ -23,7 +23,7 @@ from baselines.gail.dataset.mujoco_dset import Mujoco_Dset
def argsparser(): def argsparser():
parser = argparse.ArgumentParser("Tensorflow Implementation of Behavior Cloning") parser = argparse.ArgumentParser("Tensorflow Implementation of Behavior Cloning")
parser.add_argument('--env_id', help='environment ID', default='Hopper-v2') parser.add_argument('--env_id', help='environment ID', default='Hopper-v1')
parser.add_argument('--seed', help='RNG seed', type=int, default=0) parser.add_argument('--seed', help='RNG seed', type=int, default=0)
parser.add_argument('--expert_path', type=str, default='data/deterministic.trpo.Hopper.0.00.npz') parser.add_argument('--expert_path', type=str, default='data/deterministic.trpo.Hopper.0.00.npz')
parser.add_argument('--checkpoint_dir', help='the directory to save model', default='checkpoint') parser.add_argument('--checkpoint_dir', help='the directory to save model', default='checkpoint')
@@ -73,7 +73,7 @@ def learn(env, policy_func, dataset, optim_batch_size=128, max_iters=1e4,
savedir_fname = tempfile.TemporaryDirectory().name savedir_fname = tempfile.TemporaryDirectory().name
else: else:
savedir_fname = osp.join(ckpt_dir, task_name) savedir_fname = osp.join(ckpt_dir, task_name)
U.save_variables(savedir_fname, variables=pi.get_variables()) U.save_state(savedir_fname, var_list=pi.get_variables())
return savedir_fname return savedir_fname

View File

@@ -165,7 +165,7 @@ def runner(env, policy_func, load_model_path, timesteps_per_batch, number_trajs,
U.initialize() U.initialize()
# Prepare for rollouts # Prepare for rollouts
# ---------------------------------------- # ----------------------------------------
U.load_variables(load_model_path) U.load_state(load_model_path)
obs_list = [] obs_list = []
acs_list = [] acs_list = []

View File

@@ -226,7 +226,7 @@ def main(args):
state = model.initial_state if hasattr(model, 'initial_state') else None state = model.initial_state if hasattr(model, 'initial_state') else None
dones = np.zeros((1,)) dones = np.zeros((1,))
episode_rew = np.zeros(env.num_envs) if isinstance(env, VecEnv) else np.zeros(1) episode_rew = 0
while True: while True:
if state is not None: if state is not None:
actions, _, state, _ = model.step(obs,S=state, M=dones) actions, _, state, _ = model.step(obs,S=state, M=dones)
@@ -234,13 +234,13 @@ def main(args):
actions, _, _, _ = model.step(obs) actions, _, _, _ = model.step(obs)
obs, rew, done, _ = env.step(actions) obs, rew, done, _ = env.step(actions)
episode_rew += rew episode_rew += rew[0] if isinstance(env, VecEnv) else rew
env.render() env.render()
done_any = done.any() if isinstance(done, np.ndarray) else done done = done.any() if isinstance(done, np.ndarray) else done
if done_any: if done:
for i in np.nonzero(done)[0]: print('episode_rew={}'.format(episode_rew))
print('episode_rew={}'.format(episode_rew[i])) episode_rew = 0
episode_rew[i] = 0 obs = env.reset()
env.close() env.close()

View File

@@ -31,7 +31,7 @@ setup(name='baselines',
packages=[package for package in find_packages() packages=[package for package in find_packages()
if package.startswith('baselines')], if package.startswith('baselines')],
install_requires=[ install_requires=[
'gym>=0.15.4, <0.16.0', 'gym>=0.10.0, <1.0.0',
'scipy', 'scipy',
'tqdm', 'tqdm',
'joblib', 'joblib',