Compare commits
10 Commits
peterz_383
...
gdb
Author | SHA1 | Date | |
---|---|---|---|
|
6bbc4635e6 | ||
|
5b41c926c7 | ||
|
ab02fae71d | ||
|
b55eda1dde | ||
|
57e05eb420 | ||
|
01ab1d8ef7 | ||
|
73683435ff | ||
|
4d0746b957 | ||
|
5115707ce9 | ||
|
6c44fb28fe |
@@ -37,9 +37,6 @@ class Runner(AbstractEnvRunner):
|
||||
obs, rewards, dones, _ = self.env.step(actions)
|
||||
self.states = states
|
||||
self.dones = dones
|
||||
for n, done in enumerate(dones):
|
||||
if done:
|
||||
self.obs[n] = self.obs[n]*0
|
||||
self.obs = obs
|
||||
mb_rewards.append(rewards)
|
||||
mb_dones.append(self.dones)
|
||||
|
@@ -75,8 +75,8 @@ class Model(object):
|
||||
train_ob_placeholder = tf.placeholder(dtype=ob_space.dtype, shape=(nenvs*(nsteps+1),) + ob_space.shape)
|
||||
with tf.variable_scope('acer_model', reuse=tf.AUTO_REUSE):
|
||||
|
||||
step_model = policy(observ_placeholder=step_ob_placeholder, sess=sess)
|
||||
train_model = policy(observ_placeholder=train_ob_placeholder, sess=sess)
|
||||
step_model = policy(nbatch=nenvs, nsteps=1, observ_placeholder=step_ob_placeholder, sess=sess)
|
||||
train_model = policy(nbatch=nbatch, nsteps=nsteps, observ_placeholder=train_ob_placeholder, sess=sess)
|
||||
|
||||
|
||||
params = find_trainable_variables("acer_model")
|
||||
@@ -94,7 +94,7 @@ class Model(object):
|
||||
return v
|
||||
|
||||
with tf.variable_scope("acer_model", custom_getter=custom_getter, reuse=True):
|
||||
polyak_model = policy(observ_placeholder=train_ob_placeholder, sess=sess)
|
||||
polyak_model = policy(nbatch=nbatch, nsteps=nsteps, observ_placeholder=train_ob_placeholder, sess=sess)
|
||||
|
||||
# Notation: (var) = batch variable, (var)s = seqeuence variable, (var)_i = variable index by action at step i
|
||||
|
||||
|
@@ -156,9 +156,10 @@ register_benchmark({
|
||||
|
||||
# HER DDPG
|
||||
|
||||
_fetch_tasks = ['FetchReach-v1', 'FetchPush-v1', 'FetchSlide-v1']
|
||||
register_benchmark({
|
||||
'name': 'HerDdpg',
|
||||
'description': 'Smoke-test only benchmark of HER',
|
||||
'tasks': [{'trials': 1, 'env_id': 'FetchReach-v1'}]
|
||||
'name': 'Fetch1M',
|
||||
'description': 'Fetch* benchmarks for 1M timesteps',
|
||||
'tasks': [{'trials': 6, 'env_id': env_id, 'num_timesteps': int(1e6)} for env_id in _fetch_tasks]
|
||||
})
|
||||
|
||||
|
@@ -18,33 +18,50 @@ from baselines.common.vec_env.subproc_vec_env import SubprocVecEnv
|
||||
from baselines.common.vec_env.dummy_vec_env import DummyVecEnv
|
||||
from baselines.common import retro_wrappers
|
||||
|
||||
def make_vec_env(env_id, env_type, num_env, seed, wrapper_kwargs=None, start_index=0, reward_scale=1.0, gamestate=None):
|
||||
def make_vec_env(env_id, env_type, num_env, seed,
|
||||
wrapper_kwargs=None,
|
||||
start_index=0,
|
||||
reward_scale=1.0,
|
||||
flatten_dict_observations=True,
|
||||
gamestate=None,
|
||||
initializer=None,
|
||||
env_kwargs=None,
|
||||
force_dummy=False):
|
||||
"""
|
||||
Create a wrapped, monitored SubprocVecEnv for Atari and MuJoCo.
|
||||
"""
|
||||
if wrapper_kwargs is None: wrapper_kwargs = {}
|
||||
wrapper_kwargs = wrapper_kwargs or {}
|
||||
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||
seed = seed + 10000 * mpi_rank if seed is not None else None
|
||||
def make_thunk(rank):
|
||||
logger_dir = logger.get_dir()
|
||||
def make_thunk(rank, initializer=None):
|
||||
return lambda: make_env(
|
||||
env_id=env_id,
|
||||
env_type=env_type,
|
||||
subrank = rank,
|
||||
mpi_rank=mpi_rank,
|
||||
subrank=rank,
|
||||
seed=seed,
|
||||
reward_scale=reward_scale,
|
||||
gamestate=gamestate,
|
||||
wrapper_kwargs=wrapper_kwargs
|
||||
flatten_dict_observations=flatten_dict_observations,
|
||||
wrapper_kwargs=wrapper_kwargs,
|
||||
logger_dir=logger_dir,
|
||||
initializer=initializer,
|
||||
env_kwargs=env_kwargs,
|
||||
)
|
||||
|
||||
set_global_seeds(seed)
|
||||
if num_env > 1:
|
||||
return SubprocVecEnv([make_thunk(i + start_index) for i in range(num_env)])
|
||||
if not force_dummy and num_env > 1:
|
||||
return SubprocVecEnv([make_thunk(i + start_index, initializer=initializer) for i in range(num_env)])
|
||||
else:
|
||||
return DummyVecEnv([make_thunk(start_index)])
|
||||
return DummyVecEnv([make_thunk(i + start_index, initializer=None) for i in range(num_env)])
|
||||
|
||||
|
||||
def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate=None, wrapper_kwargs={}):
|
||||
mpi_rank = MPI.COMM_WORLD.Get_rank() if MPI else 0
|
||||
def make_env(env_id, env_type, mpi_rank=0, subrank=0, seed=None, reward_scale=1.0, gamestate=None, flatten_dict_observations=True, wrapper_kwargs=None, logger_dir=None, initializer=None, env_kwargs=None):
|
||||
if initializer is not None:
|
||||
initializer(mpi_rank=mpi_rank, subrank=subrank)
|
||||
|
||||
wrapper_kwargs = wrapper_kwargs or {}
|
||||
if env_type == 'atari':
|
||||
env = make_atari(env_id)
|
||||
elif env_type == 'retro':
|
||||
@@ -52,11 +69,15 @@ def make_env(env_id, env_type, subrank=0, seed=None, reward_scale=1.0, gamestate
|
||||
gamestate = gamestate or retro.State.DEFAULT
|
||||
env = retro_wrappers.make_retro(game=env_id, max_episode_steps=10000, use_restricted_actions=retro.Actions.DISCRETE, state=gamestate)
|
||||
else:
|
||||
env = gym.make(env_id)
|
||||
env = gym.make(env_id, **(env_kwargs or {}))
|
||||
|
||||
if flatten_dict_observations and isinstance(env.observation_space, gym.spaces.Dict):
|
||||
keys = env.observation_space.spaces.keys()
|
||||
env = gym.wrappers.FlattenDictWrapper(env, dict_keys=list(keys))
|
||||
|
||||
env.seed(seed + subrank if seed is not None else None)
|
||||
env = Monitor(env,
|
||||
logger.get_dir() and os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(subrank)),
|
||||
logger_dir and os.path.join(logger_dir, str(mpi_rank) + '.' + str(subrank)),
|
||||
allow_early_resets=True)
|
||||
|
||||
if env_type == 'atari':
|
||||
@@ -123,6 +144,7 @@ def common_arg_parser():
|
||||
"""
|
||||
parser = arg_parser()
|
||||
parser.add_argument('--env', help='environment ID', type=str, default='Reacher-v2')
|
||||
parser.add_argument('--env_type', help='type of environment, used when the environment type cannot be automatically determined', type=str)
|
||||
parser.add_argument('--seed', help='RNG seed', type=int, default=None)
|
||||
parser.add_argument('--alg', help='Algorithm', type=str, default='ppo2')
|
||||
parser.add_argument('--num_timesteps', type=float, default=1e6),
|
||||
@@ -134,6 +156,7 @@ def common_arg_parser():
|
||||
parser.add_argument('--save_video_interval', help='Save video every x steps (0 = disabled)', default=0, type=int)
|
||||
parser.add_argument('--save_video_length', help='Length of recorded video. Default: 200', default=200, type=int)
|
||||
parser.add_argument('--play', default=False, action='store_true')
|
||||
parser.add_argument('--extra_import', help='Extra module to import to access external environments', type=str, default=None)
|
||||
return parser
|
||||
|
||||
def robotics_arg_parser():
|
||||
|
@@ -75,7 +75,8 @@ class CategoricalPdType(PdType):
|
||||
|
||||
class MultiCategoricalPdType(PdType):
|
||||
def __init__(self, nvec):
|
||||
self.ncats = nvec
|
||||
self.ncats = nvec.astype('int32')
|
||||
assert (self.ncats > 0).all()
|
||||
def pdclass(self):
|
||||
return MultiCategoricalPd
|
||||
def pdfromflat(self, flat):
|
||||
|
@@ -168,6 +168,7 @@ def load_results(root_dir_or_dirs, enable_progress=True, enable_monitor=True, ve
|
||||
- monitor - if enable_monitor is True, this field contains pandas dataframe with loaded monitor.csv file (or aggregate of all *.monitor.csv files in the directory)
|
||||
- progress - if enable_progress is True, this field contains pandas dataframe with loaded progress.csv file
|
||||
'''
|
||||
import re
|
||||
if isinstance(root_dir_or_dirs, str):
|
||||
rootdirs = [osp.expanduser(root_dir_or_dirs)]
|
||||
else:
|
||||
@@ -179,7 +180,9 @@ def load_results(root_dir_or_dirs, enable_progress=True, enable_monitor=True, ve
|
||||
if '-proc' in dirname:
|
||||
files[:] = []
|
||||
continue
|
||||
if set(['metadata.json', 'monitor.json', 'monitor.csv', 'progress.json', 'progress.csv']).intersection(files):
|
||||
monitor_re = re.compile(r'(\d+\.)?(\d+\.)?monitor\.csv')
|
||||
if set(['metadata.json', 'monitor.json', 'progress.json', 'progress.csv']).intersection(files) or \
|
||||
any([f for f in files if monitor_re.match(f)]): # also match monitor files like 0.1.monitor.csv
|
||||
# used to be uncommented, which means do not go deeper than current directory if any of the data files
|
||||
# are found
|
||||
# dirs[:] = []
|
||||
|
39
baselines/common/tests/test_fetchreach.py
Normal file
39
baselines/common/tests/test_fetchreach.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import pytest
|
||||
import gym
|
||||
|
||||
from baselines.run import get_learn_function
|
||||
from baselines.common.tests.util import reward_per_episode_test
|
||||
|
||||
pytest.importorskip('mujoco_py')
|
||||
|
||||
common_kwargs = dict(
|
||||
network='mlp',
|
||||
seed=0,
|
||||
)
|
||||
|
||||
learn_kwargs = {
|
||||
'her': dict(total_timesteps=2000)
|
||||
}
|
||||
|
||||
@pytest.mark.slow
|
||||
@pytest.mark.parametrize("alg", learn_kwargs.keys())
|
||||
def test_fetchreach(alg):
|
||||
'''
|
||||
Test if the algorithm (with an mlp policy)
|
||||
can learn the FetchReach task
|
||||
'''
|
||||
|
||||
kwargs = common_kwargs.copy()
|
||||
kwargs.update(learn_kwargs[alg])
|
||||
|
||||
learn_fn = lambda e: get_learn_function(alg)(env=e, **kwargs)
|
||||
def env_fn():
|
||||
|
||||
env = gym.make('FetchReach-v1')
|
||||
env.seed(0)
|
||||
return env
|
||||
|
||||
reward_per_episode_test(env_fn, learn_fn, -15)
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_fetchreach('her')
|
@@ -18,7 +18,9 @@ def test_function():
|
||||
initialize()
|
||||
|
||||
assert lin(2) == 6
|
||||
assert lin(x=3) == 9
|
||||
assert lin(2, 2) == 10
|
||||
assert lin(x=2, y=3) == 12
|
||||
|
||||
|
||||
def test_multikwargs():
|
||||
|
@@ -63,7 +63,7 @@ def rollout(env, model, n_trials):
|
||||
|
||||
for i in range(n_trials):
|
||||
obs = env.reset()
|
||||
state = model.initial_state
|
||||
state = model.initial_state if hasattr(model, 'initial_state') else None
|
||||
episode_rew = []
|
||||
episode_actions = []
|
||||
episode_obs = []
|
||||
|
@@ -186,6 +186,7 @@ class _Function(object):
|
||||
if not hasattr(inpt, 'make_feed_dict') and not (type(inpt) is tf.Tensor and len(inpt.op.inputs) == 0):
|
||||
assert False, "inputs should all be placeholders, constants, or have a make_feed_dict method"
|
||||
self.inputs = inputs
|
||||
self.input_names = {inp.name.split("/")[-1].split(":")[0]: inp for inp in inputs}
|
||||
updates = updates or []
|
||||
self.update_group = tf.group(*updates)
|
||||
self.outputs_update = list(outputs) + [self.update_group]
|
||||
@@ -197,15 +198,17 @@ class _Function(object):
|
||||
else:
|
||||
feed_dict[inpt] = adjust_shape(inpt, value)
|
||||
|
||||
def __call__(self, *args):
|
||||
assert len(args) <= len(self.inputs), "Too many arguments provided"
|
||||
def __call__(self, *args, **kwargs):
|
||||
assert len(args) + len(kwargs) <= len(self.inputs), "Too many arguments provided"
|
||||
feed_dict = {}
|
||||
# Update the args
|
||||
for inpt, value in zip(self.inputs, args):
|
||||
self._feed_input(feed_dict, inpt, value)
|
||||
# Update feed dict with givens.
|
||||
for inpt in self.givens:
|
||||
feed_dict[inpt] = adjust_shape(inpt, feed_dict.get(inpt, self.givens[inpt]))
|
||||
# Update the args
|
||||
for inpt, value in zip(self.inputs, args):
|
||||
self._feed_input(feed_dict, inpt, value)
|
||||
for inpt_name, value in kwargs.items():
|
||||
self._feed_input(feed_dict, self.input_names[inpt_name], value)
|
||||
results = get_session().run(self.outputs_update, feed_dict=feed_dict)[:-1]
|
||||
return results
|
||||
|
||||
@@ -337,7 +340,7 @@ def save_state(fname, sess=None):
|
||||
|
||||
def save_variables(save_path, variables=None, sess=None):
|
||||
sess = sess or get_session()
|
||||
variables = variables or tf.trainable_variables()
|
||||
variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
||||
|
||||
ps = sess.run(variables)
|
||||
save_dict = {v.name: value for v, value in zip(variables, ps)}
|
||||
@@ -348,7 +351,7 @@ def save_variables(save_path, variables=None, sess=None):
|
||||
|
||||
def load_variables(load_path, variables=None, sess=None):
|
||||
sess = sess or get_session()
|
||||
variables = variables or tf.trainable_variables()
|
||||
variables = variables or tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
|
||||
|
||||
loaded_params = joblib.load(os.path.expanduser(load_path))
|
||||
restores = []
|
||||
|
@@ -27,6 +27,7 @@ class DummyVecEnv(VecEnv):
|
||||
self.buf_rews = np.zeros((self.num_envs,), dtype=np.float32)
|
||||
self.buf_infos = [{} for _ in range(self.num_envs)]
|
||||
self.actions = None
|
||||
self.specs = [e.spec for e in self.envs]
|
||||
|
||||
def step_async(self, actions):
|
||||
listify = True
|
||||
|
@@ -54,6 +54,7 @@ class ShmemVecEnv(VecEnv):
|
||||
proc.start()
|
||||
child_pipe.close()
|
||||
self.waiting_step = False
|
||||
self.specs = [f().spec for f in env_fns]
|
||||
self.viewer = None
|
||||
|
||||
def reset(self):
|
||||
|
@@ -57,6 +57,7 @@ class SubprocVecEnv(VecEnv):
|
||||
self.remotes[0].send(('get_spaces', None))
|
||||
observation_space, action_space = self.remotes[0].recv()
|
||||
self.viewer = None
|
||||
self.specs = [f().spec for f in env_fns]
|
||||
VecEnv.__init__(self, len(env_fns), observation_space, action_space)
|
||||
|
||||
def step_async(self, actions):
|
||||
@@ -70,13 +71,13 @@ class SubprocVecEnv(VecEnv):
|
||||
results = [remote.recv() for remote in self.remotes]
|
||||
self.waiting = False
|
||||
obs, rews, dones, infos = zip(*results)
|
||||
return np.stack(obs), np.stack(rews), np.stack(dones), infos
|
||||
return _flatten_obs(obs), np.stack(rews), np.stack(dones), infos
|
||||
|
||||
def reset(self):
|
||||
self._assert_not_closed()
|
||||
for remote in self.remotes:
|
||||
remote.send(('reset', None))
|
||||
return np.stack([remote.recv() for remote in self.remotes])
|
||||
return _flatten_obs([remote.recv() for remote in self.remotes])
|
||||
|
||||
def close_extras(self):
|
||||
self.closed = True
|
||||
@@ -97,3 +98,17 @@ class SubprocVecEnv(VecEnv):
|
||||
|
||||
def _assert_not_closed(self):
|
||||
assert not self.closed, "Trying to operate on a SubprocVecEnv after calling close()"
|
||||
|
||||
|
||||
def _flatten_obs(obs):
|
||||
assert isinstance(obs, list) or isinstance(obs, tuple)
|
||||
assert len(obs) > 0
|
||||
|
||||
if isinstance(obs[0], dict):
|
||||
import collections
|
||||
assert isinstance(obs, collections.OrderedDict)
|
||||
keys = obs[0].keys()
|
||||
return {k: np.stack([o[k] for o in obs]) for k in keys}
|
||||
else:
|
||||
return np.stack(obs)
|
||||
|
||||
|
@@ -24,7 +24,7 @@ Hopper-v1, Walker2d-v1, HalfCheetah-v1, Humanoid-v1, HumanoidStandup-v1. Every i
|
||||
|
||||
For details (e.g., adversarial loss, discriminator accuracy, etc.) about GAIL training, please see [here](https://drive.google.com/drive/folders/1nnU8dqAV9i37-_5_vWIspyFUJFQLCsDD?usp=sharing)
|
||||
|
||||
### Determinstic Polciy (Set std=0)
|
||||
### Determinstic Policy (Set std=0)
|
||||
| | Un-normalized | Normalized |
|
||||
|---|---|---|
|
||||
| Hopper-v1 | <img src='Hopper-unnormalized-deterministic-scores.png'> | <img src='Hopper-normalized-deterministic-scores.png'> |
|
||||
|
@@ -6,26 +6,29 @@ For details on Hindsight Experience Replay (HER), please read the [paper](https:
|
||||
### Getting started
|
||||
Training an agent is very simple:
|
||||
```bash
|
||||
python -m baselines.her.experiment.train
|
||||
python -m baselines.run --alg=her --env=FetchReach-v1 --num_timesteps=5000
|
||||
```
|
||||
This will train a DDPG+HER agent on the `FetchReach` environment.
|
||||
You should see the success rate go up quickly to `1.0`, which means that the agent achieves the
|
||||
desired goal in 100% of the cases.
|
||||
The training script logs other diagnostics as well and pickles the best policy so far (w.r.t. to its test success rate),
|
||||
the latest policy, and, if enabled, a history of policies every K epochs.
|
||||
|
||||
To inspect what the agent has learned, use the play script:
|
||||
desired goal in 100% of the cases (note how HER can solve it in <5k steps - try doing that with PPO by replacing her with ppo2 :))
|
||||
The training script logs other diagnostics as well. Policy at the end of the training can be saved using `--save_path` flag, for instance:
|
||||
```bash
|
||||
python -m baselines.her.experiment.play /path/to/an/experiment/policy_best.pkl
|
||||
python -m baselines.run --alg=her --env=FetchReach-v1 --num_timesteps=5000 --save_path=~/policies/her/fetchreach5k
|
||||
```
|
||||
You can try it right now with the results of the training step (the script prints out the path for you).
|
||||
This should visualize the current policy for 10 episodes and will also print statistics.
|
||||
|
||||
To inspect what the agent has learned, use the `--play` flag:
|
||||
```bash
|
||||
python -m baselines.run --alg=her --env=FetchReach-v1 --num_timesteps=5000 --play
|
||||
```
|
||||
(note `--play` can be combined with `--load_path`, which lets one load trained policies, for more results see [README.md](../../README.md))
|
||||
|
||||
|
||||
### Reproducing results
|
||||
In order to reproduce the results from [Plappert et al. (2018)](https://arxiv.org/abs/1802.09464), run the following command:
|
||||
In [Plappert et al. (2018)](https://arxiv.org/abs/1802.09464), 38 trajectories were generated in parallel
|
||||
(19 MPI processes, each generating computing gradients from 2 trajectories and aggregating).
|
||||
To reproduce that behaviour, use
|
||||
```bash
|
||||
python -m baselines.her.experiment.train --num_cpu 19
|
||||
mpirun -np 19 python -m baselines.run --num_env=2 --alg=her ...
|
||||
```
|
||||
This will require a machine with sufficient amount of physical CPU cores. In our experiments,
|
||||
we used [Azure's D15v2 instances](https://docs.microsoft.com/en-us/azure/virtual-machines/linux/sizes),
|
||||
@@ -45,6 +48,13 @@ python experiment/data_generation/fetch_data_generation.py
|
||||
```
|
||||
This outputs ```data_fetch_random_100.npz``` file which is our data file.
|
||||
|
||||
To launch training with demonstrations (more technically, with behaviour cloning loss as an auxilliary loss), run the following
|
||||
```bash
|
||||
python -m baselines.run --alg=her --env=FetchPickAndPlace-v1 --num_timesteps=2.5e6 --demo_file=/Path/to/demo_file.npz
|
||||
```
|
||||
This will train a DDPG+HER agent on the `FetchPickAndPlace` environment by using previously generated demonstration data.
|
||||
To inspect what the agent has learned, use the `--play` flag as described above.
|
||||
|
||||
#### Configuration
|
||||
The provided configuration is for training an agent with HER without demonstrations, we need to change a few paramters for the HER algorithm to learn through demonstrations, to do that, set:
|
||||
|
||||
@@ -62,13 +72,7 @@ Apart from these changes the reported results also have the following configurat
|
||||
* random_eps: 0.1 - percentage of time a random action is taken
|
||||
* noise_eps: 0.1 - std of gaussian noise added to not-completely-random actions
|
||||
|
||||
Now training an agent with pre-recorded demonstrations:
|
||||
```bash
|
||||
python -m baselines.her.experiment.train --env=FetchPickAndPlace-v0 --n_epochs=1000 --demo_file=/Path/to/demo_file.npz --num_cpu=1
|
||||
```
|
||||
|
||||
This will train a DDPG+HER agent on the `FetchPickAndPlace` environment by using previously generated demonstration data.
|
||||
To inspect what the agent has learned, use the play script as described above.
|
||||
These parameters can be changed either in [experiment/config.py](experiment/config.py) or passed to the command line as `--param=value`)
|
||||
|
||||
### Results
|
||||
Training with demonstrations helps overcome the exploration problem and achieves a faster and better convergence. The following graphs contrast the difference between training with and without demonstration data, We report the mean Q values vs Epoch and the Success Rate vs Epoch:
|
||||
@@ -78,3 +82,4 @@ Training with demonstrations helps overcome the exploration problem and achieves
|
||||
<center><img src="../../data/fetchPickAndPlaceContrast.png"></center>
|
||||
<div class="thecap" align="middle"><b>Training results for Fetch Pick and Place task constrasting between training with and without demonstration data.</b></div>
|
||||
</div>
|
||||
|
||||
|
@@ -10,13 +10,14 @@ from baselines.her.util import (
|
||||
from baselines.her.normalizer import Normalizer
|
||||
from baselines.her.replay_buffer import ReplayBuffer
|
||||
from baselines.common.mpi_adam import MpiAdam
|
||||
from baselines.common import tf_util
|
||||
|
||||
|
||||
def dims_to_shapes(input_dims):
|
||||
return {key: tuple([val]) if val > 0 else tuple() for key, val in input_dims.items()}
|
||||
|
||||
|
||||
global demoBuffer #buffer for demonstrations
|
||||
global DEMO_BUFFER #buffer for demonstrations
|
||||
|
||||
class DDPG(object):
|
||||
@store_args
|
||||
@@ -94,16 +95,16 @@ class DDPG(object):
|
||||
self._create_network(reuse=reuse)
|
||||
|
||||
# Configure the replay buffer.
|
||||
buffer_shapes = {key: (self.T if key != 'o' else self.T+1, *input_shapes[key])
|
||||
buffer_shapes = {key: (self.T-1 if key != 'o' else self.T, *input_shapes[key])
|
||||
for key, val in input_shapes.items()}
|
||||
buffer_shapes['g'] = (buffer_shapes['g'][0], self.dimg)
|
||||
buffer_shapes['ag'] = (self.T+1, self.dimg)
|
||||
buffer_shapes['ag'] = (self.T, self.dimg)
|
||||
|
||||
buffer_size = (self.buffer_size // self.rollout_batch_size) * self.rollout_batch_size
|
||||
self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions)
|
||||
|
||||
global demoBuffer
|
||||
demoBuffer = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions) #initialize the demo buffer; in the same way as the primary data buffer
|
||||
global DEMO_BUFFER
|
||||
DEMO_BUFFER = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions) #initialize the demo buffer; in the same way as the primary data buffer
|
||||
|
||||
def _random_action(self, n):
|
||||
return np.random.uniform(low=-self.max_u, high=self.max_u, size=(n, self.dimu))
|
||||
@@ -119,6 +120,11 @@ class DDPG(object):
|
||||
g = np.clip(g, -self.clip_obs, self.clip_obs)
|
||||
return o, g
|
||||
|
||||
def step(self, obs):
|
||||
actions = self.get_actions(obs['observation'], obs['achieved_goal'], obs['desired_goal'])
|
||||
return actions, None, None, None
|
||||
|
||||
|
||||
def get_actions(self, o, ag, g, noise_eps=0., random_eps=0., use_target_net=False,
|
||||
compute_Q=False):
|
||||
o, g = self._preprocess_og(o, ag, g)
|
||||
@@ -151,25 +157,30 @@ class DDPG(object):
|
||||
else:
|
||||
return ret
|
||||
|
||||
def initDemoBuffer(self, demoDataFile, update_stats=True): #function that initializes the demo buffer
|
||||
def init_demo_buffer(self, demoDataFile, update_stats=True): #function that initializes the demo buffer
|
||||
|
||||
demoData = np.load(demoDataFile) #load the demonstration data from data file
|
||||
info_keys = [key.replace('info_', '') for key in self.input_dims.keys() if key.startswith('info_')]
|
||||
info_values = [np.empty((self.T, 1, self.input_dims['info_' + key]), np.float32) for key in info_keys]
|
||||
info_values = [np.empty((self.T - 1, 1, self.input_dims['info_' + key]), np.float32) for key in info_keys]
|
||||
|
||||
demo_data_obs = demoData['obs']
|
||||
demo_data_acs = demoData['acs']
|
||||
demo_data_info = demoData['info']
|
||||
|
||||
for epsd in range(self.num_demo): # we initialize the whole demo buffer at the start of the training
|
||||
obs, acts, goals, achieved_goals = [], [] ,[] ,[]
|
||||
i = 0
|
||||
for transition in range(self.T):
|
||||
obs.append([demoData['obs'][epsd ][transition].get('observation')])
|
||||
acts.append([demoData['acs'][epsd][transition]])
|
||||
goals.append([demoData['obs'][epsd][transition].get('desired_goal')])
|
||||
achieved_goals.append([demoData['obs'][epsd][transition].get('achieved_goal')])
|
||||
for transition in range(self.T - 1):
|
||||
obs.append([demo_data_obs[epsd][transition].get('observation')])
|
||||
acts.append([demo_data_acs[epsd][transition]])
|
||||
goals.append([demo_data_obs[epsd][transition].get('desired_goal')])
|
||||
achieved_goals.append([demo_data_obs[epsd][transition].get('achieved_goal')])
|
||||
for idx, key in enumerate(info_keys):
|
||||
info_values[idx][transition, i] = demoData['info'][epsd][transition][key]
|
||||
info_values[idx][transition, i] = demo_data_info[epsd][transition][key]
|
||||
|
||||
obs.append([demoData['obs'][epsd][self.T].get('observation')])
|
||||
achieved_goals.append([demoData['obs'][epsd][self.T].get('achieved_goal')])
|
||||
|
||||
obs.append([demo_data_obs[epsd][self.T - 1].get('observation')])
|
||||
achieved_goals.append([demo_data_obs[epsd][self.T - 1].get('achieved_goal')])
|
||||
|
||||
episode = dict(o=obs,
|
||||
u=acts,
|
||||
@@ -179,10 +190,9 @@ class DDPG(object):
|
||||
episode['info_{}'.format(key)] = value
|
||||
|
||||
episode = convert_episode_to_batch_major(episode)
|
||||
global demoBuffer
|
||||
demoBuffer.store_episode(episode) # create the observation dict and append them into the demonstration buffer
|
||||
|
||||
print("Demo buffer size currently ", demoBuffer.get_current_size()) #print out the demonstration buffer size
|
||||
global DEMO_BUFFER
|
||||
DEMO_BUFFER.store_episode(episode) # create the observation dict and append them into the demonstration buffer
|
||||
logger.debug("Demo buffer size currently ", DEMO_BUFFER.get_current_size()) #print out the demonstration buffer size
|
||||
|
||||
if update_stats:
|
||||
# add transitions to normalizer to normalize the demo data as well
|
||||
@@ -191,7 +201,7 @@ class DDPG(object):
|
||||
num_normalizing_transitions = transitions_in_episode_batch(episode)
|
||||
transitions = self.sample_transitions(episode, num_normalizing_transitions)
|
||||
|
||||
o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions['g'], transitions['ag']
|
||||
o, g, ag = transitions['o'], transitions['g'], transitions['ag']
|
||||
transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
|
||||
# No need to preprocess the o_2 and g_2 since this is only used for stats
|
||||
|
||||
@@ -202,6 +212,8 @@ class DDPG(object):
|
||||
self.g_stats.recompute_stats()
|
||||
episode.clear()
|
||||
|
||||
logger.info("Demo buffer size: ", DEMO_BUFFER.get_current_size()) #print out the demonstration buffer size
|
||||
|
||||
def store_episode(self, episode_batch, update_stats=True):
|
||||
"""
|
||||
episode_batch: array of batch_size x (T or T+1) x dim_key
|
||||
@@ -217,7 +229,7 @@ class DDPG(object):
|
||||
num_normalizing_transitions = transitions_in_episode_batch(episode_batch)
|
||||
transitions = self.sample_transitions(episode_batch, num_normalizing_transitions)
|
||||
|
||||
o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions['g'], transitions['ag']
|
||||
o, g, ag = transitions['o'], transitions['g'], transitions['ag']
|
||||
transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
|
||||
# No need to preprocess the o_2 and g_2 since this is only used for stats
|
||||
|
||||
@@ -251,9 +263,9 @@ class DDPG(object):
|
||||
def sample_batch(self):
|
||||
if self.bc_loss: #use demonstration buffer to sample as well if bc_loss flag is set TRUE
|
||||
transitions = self.buffer.sample(self.batch_size - self.demo_batch_size)
|
||||
global demoBuffer
|
||||
transitionsDemo = demoBuffer.sample(self.demo_batch_size) #sample from the demo buffer
|
||||
for k, values in transitionsDemo.items():
|
||||
global DEMO_BUFFER
|
||||
transitions_demo = DEMO_BUFFER.sample(self.demo_batch_size) #sample from the demo buffer
|
||||
for k, values in transitions_demo.items():
|
||||
rolloutV = transitions[k].tolist()
|
||||
for v in values:
|
||||
rolloutV.append(v.tolist())
|
||||
@@ -302,10 +314,7 @@ class DDPG(object):
|
||||
|
||||
def _create_network(self, reuse=False):
|
||||
logger.info("Creating a DDPG agent with action space %d x %s..." % (self.dimu, self.max_u))
|
||||
|
||||
self.sess = tf.get_default_session()
|
||||
if self.sess is None:
|
||||
self.sess = tf.InteractiveSession()
|
||||
self.sess = tf_util.get_session()
|
||||
|
||||
# running averages
|
||||
with tf.variable_scope('o_stats') as vs:
|
||||
@@ -401,7 +410,7 @@ class DDPG(object):
|
||||
logs += [('stats_g/mean', np.mean(self.sess.run([self.g_stats.mean])))]
|
||||
logs += [('stats_g/std', np.mean(self.sess.run([self.g_stats.std])))]
|
||||
|
||||
if prefix is not '' and not prefix.endswith('/'):
|
||||
if prefix != '' and not prefix.endswith('/'):
|
||||
return [(prefix + '/' + key, val) for key, val in logs]
|
||||
else:
|
||||
return logs
|
||||
@@ -433,3 +442,7 @@ class DDPG(object):
|
||||
assert(len(vars) == len(state["tf"]))
|
||||
node = [tf.assign(var, val) for var, val in zip(vars, state["tf"])]
|
||||
self.sess.run(node)
|
||||
|
||||
def save(self, save_path):
|
||||
tf_util.save_variables(save_path)
|
||||
|
||||
|
@@ -1,10 +1,11 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import gym
|
||||
|
||||
from baselines import logger
|
||||
from baselines.her.ddpg import DDPG
|
||||
from baselines.her.her import make_sample_her_transitions
|
||||
|
||||
from baselines.her.her_sampler import make_sample_her_transitions
|
||||
from baselines.bench.monitor import Monitor
|
||||
|
||||
DEFAULT_ENV_PARAMS = {
|
||||
'FetchReach-v1': {
|
||||
@@ -72,16 +73,32 @@ def cached_make_env(make_env):
|
||||
def prepare_params(kwargs):
|
||||
# DDPG params
|
||||
ddpg_params = dict()
|
||||
|
||||
env_name = kwargs['env_name']
|
||||
|
||||
def make_env():
|
||||
return gym.make(env_name)
|
||||
def make_env(subrank=None):
|
||||
env = gym.make(env_name)
|
||||
if subrank is not None and logger.get_dir() is not None:
|
||||
try:
|
||||
from mpi4py import MPI
|
||||
mpi_rank = MPI.COMM_WORLD.Get_rank()
|
||||
except ImportError:
|
||||
MPI = None
|
||||
mpi_rank = 0
|
||||
logger.warn('Running with a single MPI process. This should work, but the results may differ from the ones publshed in Plappert et al.')
|
||||
|
||||
max_episode_steps = env._max_episode_steps
|
||||
env = Monitor(env,
|
||||
os.path.join(logger.get_dir(), str(mpi_rank) + '.' + str(subrank)),
|
||||
allow_early_resets=True)
|
||||
# hack to re-expose _max_episode_steps (ideally should replace reliance on it downstream)
|
||||
env = gym.wrappers.TimeLimit(env, max_episode_steps=max_episode_steps)
|
||||
return env
|
||||
|
||||
kwargs['make_env'] = make_env
|
||||
tmp_env = cached_make_env(kwargs['make_env'])
|
||||
assert hasattr(tmp_env, '_max_episode_steps')
|
||||
kwargs['T'] = tmp_env._max_episode_steps
|
||||
tmp_env.reset()
|
||||
|
||||
kwargs['max_u'] = np.array(kwargs['max_u']) if isinstance(kwargs['max_u'], list) else kwargs['max_u']
|
||||
kwargs['gamma'] = 1. - 1. / kwargs['T']
|
||||
if 'lr' in kwargs:
|
||||
|
@@ -1,18 +1,5 @@
|
||||
import gym
|
||||
import time
|
||||
import random
|
||||
import numpy as np
|
||||
import rospy
|
||||
import roslaunch
|
||||
|
||||
from random import randint
|
||||
from std_srvs.srv import Empty
|
||||
from sensor_msgs.msg import JointState
|
||||
from geometry_msgs.msg import PoseStamped
|
||||
from geometry_msgs.msg import Pose
|
||||
from std_msgs.msg import Float64
|
||||
from controller_manager_msgs.srv import SwitchController
|
||||
from gym.utils import seeding
|
||||
|
||||
|
||||
"""Data generation for the case of a single block pick and place in Fetch Env"""
|
||||
@@ -22,7 +9,7 @@ observations = []
|
||||
infos = []
|
||||
|
||||
def main():
|
||||
env = gym.make('FetchPickAndPlace-v0')
|
||||
env = gym.make('FetchPickAndPlace-v1')
|
||||
numItr = 100
|
||||
initStateSpace = "random"
|
||||
env.reset()
|
||||
@@ -31,21 +18,19 @@ def main():
|
||||
obs = env.reset()
|
||||
print("ITERATION NUMBER ", len(actions))
|
||||
goToGoal(env, obs)
|
||||
|
||||
|
||||
|
||||
fileName = "data_fetch"
|
||||
fileName += "_" + initStateSpace
|
||||
fileName += "_" + str(numItr)
|
||||
fileName += ".npz"
|
||||
|
||||
|
||||
np.savez_compressed(fileName, acs=actions, obs=observations, info=infos) # save the file
|
||||
|
||||
def goToGoal(env, lastObs):
|
||||
|
||||
goal = lastObs['desired_goal']
|
||||
objectPos = lastObs['observation'][3:6]
|
||||
gripperPos = lastObs['observation'][:3]
|
||||
gripperState = lastObs['observation'][9:11]
|
||||
object_rel_pos = lastObs['observation'][6:9]
|
||||
episodeAcs = []
|
||||
episodeObs = []
|
||||
@@ -53,7 +38,7 @@ def goToGoal(env, lastObs):
|
||||
|
||||
object_oriented_goal = object_rel_pos.copy()
|
||||
object_oriented_goal[2] += 0.03 # first make the gripper go slightly above the object
|
||||
|
||||
|
||||
timeStep = 0 #count the total number of timesteps
|
||||
episodeObs.append(lastObs)
|
||||
|
||||
@@ -76,8 +61,6 @@ def goToGoal(env, lastObs):
|
||||
episodeObs.append(obsDataNew)
|
||||
|
||||
objectPos = obsDataNew['observation'][3:6]
|
||||
gripperPos = obsDataNew['observation'][:3]
|
||||
gripperState = obsDataNew['observation'][9:11]
|
||||
object_rel_pos = obsDataNew['observation'][6:9]
|
||||
|
||||
while np.linalg.norm(object_rel_pos) >= 0.005 and timeStep <= env._max_episode_steps :
|
||||
@@ -96,8 +79,6 @@ def goToGoal(env, lastObs):
|
||||
episodeObs.append(obsDataNew)
|
||||
|
||||
objectPos = obsDataNew['observation'][3:6]
|
||||
gripperPos = obsDataNew['observation'][:3]
|
||||
gripperState = obsDataNew['observation'][9:11]
|
||||
object_rel_pos = obsDataNew['observation'][6:9]
|
||||
|
||||
|
||||
@@ -117,8 +98,6 @@ def goToGoal(env, lastObs):
|
||||
episodeObs.append(obsDataNew)
|
||||
|
||||
objectPos = obsDataNew['observation'][3:6]
|
||||
gripperPos = obsDataNew['observation'][:3]
|
||||
gripperState = obsDataNew['observation'][9:11]
|
||||
object_rel_pos = obsDataNew['observation'][6:9]
|
||||
|
||||
while True: #limit the number of timesteps in the episode to a fixed duration
|
||||
@@ -134,8 +113,6 @@ def goToGoal(env, lastObs):
|
||||
episodeObs.append(obsDataNew)
|
||||
|
||||
objectPos = obsDataNew['observation'][3:6]
|
||||
gripperPos = obsDataNew['observation'][:3]
|
||||
gripperState = obsDataNew['observation'][9:11]
|
||||
object_rel_pos = obsDataNew['observation'][6:9]
|
||||
|
||||
if timeStep >= env._max_episode_steps: break
|
||||
|
@@ -1,3 +1,4 @@
|
||||
# DEPRECATED, use --play flag to baselines.run instead
|
||||
import click
|
||||
import numpy as np
|
||||
import pickle
|
||||
|
@@ -1,3 +1,5 @@
|
||||
# DEPRECATED, use baselines.common.plot_util instead
|
||||
|
||||
import os
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
|
@@ -1,194 +0,0 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
import click
|
||||
import numpy as np
|
||||
import json
|
||||
from mpi4py import MPI
|
||||
|
||||
from baselines import logger
|
||||
from baselines.common import set_global_seeds
|
||||
from baselines.common.mpi_moments import mpi_moments
|
||||
import baselines.her.experiment.config as config
|
||||
from baselines.her.rollout import RolloutWorker
|
||||
from baselines.her.util import mpi_fork
|
||||
|
||||
from subprocess import CalledProcessError
|
||||
|
||||
|
||||
def mpi_average(value):
|
||||
if value == []:
|
||||
value = [0.]
|
||||
if not isinstance(value, list):
|
||||
value = [value]
|
||||
return mpi_moments(np.array(value))[0]
|
||||
|
||||
|
||||
def train(policy, rollout_worker, evaluator,
|
||||
n_epochs, n_test_rollouts, n_cycles, n_batches, policy_save_interval,
|
||||
save_policies, demo_file, **kwargs):
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
|
||||
latest_policy_path = os.path.join(logger.get_dir(), 'policy_latest.pkl')
|
||||
best_policy_path = os.path.join(logger.get_dir(), 'policy_best.pkl')
|
||||
periodic_policy_path = os.path.join(logger.get_dir(), 'policy_{}.pkl')
|
||||
|
||||
logger.info("Training...")
|
||||
best_success_rate = -1
|
||||
|
||||
if policy.bc_loss == 1: policy.initDemoBuffer(demo_file) #initialize demo buffer if training with demonstrations
|
||||
for epoch in range(n_epochs):
|
||||
# train
|
||||
rollout_worker.clear_history()
|
||||
for _ in range(n_cycles):
|
||||
episode = rollout_worker.generate_rollouts()
|
||||
policy.store_episode(episode)
|
||||
for _ in range(n_batches):
|
||||
policy.train()
|
||||
policy.update_target_net()
|
||||
|
||||
# test
|
||||
evaluator.clear_history()
|
||||
for _ in range(n_test_rollouts):
|
||||
evaluator.generate_rollouts()
|
||||
|
||||
# record logs
|
||||
logger.record_tabular('epoch', epoch)
|
||||
for key, val in evaluator.logs('test'):
|
||||
logger.record_tabular(key, mpi_average(val))
|
||||
for key, val in rollout_worker.logs('train'):
|
||||
logger.record_tabular(key, mpi_average(val))
|
||||
for key, val in policy.logs():
|
||||
logger.record_tabular(key, mpi_average(val))
|
||||
|
||||
if rank == 0:
|
||||
logger.dump_tabular()
|
||||
|
||||
# save the policy if it's better than the previous ones
|
||||
success_rate = mpi_average(evaluator.current_success_rate())
|
||||
if rank == 0 and success_rate >= best_success_rate and save_policies:
|
||||
best_success_rate = success_rate
|
||||
logger.info('New best success rate: {}. Saving policy to {} ...'.format(best_success_rate, best_policy_path))
|
||||
evaluator.save_policy(best_policy_path)
|
||||
evaluator.save_policy(latest_policy_path)
|
||||
if rank == 0 and policy_save_interval > 0 and epoch % policy_save_interval == 0 and save_policies:
|
||||
policy_path = periodic_policy_path.format(epoch)
|
||||
logger.info('Saving periodic policy to {} ...'.format(policy_path))
|
||||
evaluator.save_policy(policy_path)
|
||||
|
||||
# make sure that different threads have different seeds
|
||||
local_uniform = np.random.uniform(size=(1,))
|
||||
root_uniform = local_uniform.copy()
|
||||
MPI.COMM_WORLD.Bcast(root_uniform, root=0)
|
||||
if rank != 0:
|
||||
assert local_uniform[0] != root_uniform[0]
|
||||
|
||||
|
||||
def launch(
|
||||
env, logdir, n_epochs, num_cpu, seed, replay_strategy, policy_save_interval, clip_return,
|
||||
demo_file, override_params={}, save_policies=True
|
||||
):
|
||||
# Fork for multi-CPU MPI implementation.
|
||||
if num_cpu > 1:
|
||||
try:
|
||||
whoami = mpi_fork(num_cpu, ['--bind-to', 'core'])
|
||||
except CalledProcessError:
|
||||
# fancy version of mpi call failed, try simple version
|
||||
whoami = mpi_fork(num_cpu)
|
||||
|
||||
if whoami == 'parent':
|
||||
sys.exit(0)
|
||||
import baselines.common.tf_util as U
|
||||
U.single_threaded_session().__enter__()
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
|
||||
# Configure logging
|
||||
if rank == 0:
|
||||
if logdir or logger.get_dir() is None:
|
||||
logger.configure(dir=logdir)
|
||||
else:
|
||||
logger.configure()
|
||||
logdir = logger.get_dir()
|
||||
assert logdir is not None
|
||||
os.makedirs(logdir, exist_ok=True)
|
||||
|
||||
# Seed everything.
|
||||
rank_seed = seed + 1000000 * rank
|
||||
set_global_seeds(rank_seed)
|
||||
|
||||
# Prepare params.
|
||||
params = config.DEFAULT_PARAMS
|
||||
params['env_name'] = env
|
||||
params['replay_strategy'] = replay_strategy
|
||||
if env in config.DEFAULT_ENV_PARAMS:
|
||||
params.update(config.DEFAULT_ENV_PARAMS[env]) # merge env-specific parameters in
|
||||
params.update(**override_params) # makes it possible to override any parameter
|
||||
with open(os.path.join(logger.get_dir(), 'params.json'), 'w') as f:
|
||||
json.dump(params, f)
|
||||
params = config.prepare_params(params)
|
||||
config.log_params(params, logger=logger)
|
||||
|
||||
if num_cpu == 1:
|
||||
logger.warn()
|
||||
logger.warn('*** Warning ***')
|
||||
logger.warn(
|
||||
'You are running HER with just a single MPI worker. This will work, but the ' +
|
||||
'experiments that we report in Plappert et al. (2018, https://arxiv.org/abs/1802.09464) ' +
|
||||
'were obtained with --num_cpu 19. This makes a significant difference and if you ' +
|
||||
'are looking to reproduce those results, be aware of this. Please also refer to ' +
|
||||
'https://github.com/openai/baselines/issues/314 for further details.')
|
||||
logger.warn('****************')
|
||||
logger.warn()
|
||||
|
||||
dims = config.configure_dims(params)
|
||||
policy = config.configure_ddpg(dims=dims, params=params, clip_return=clip_return)
|
||||
|
||||
rollout_params = {
|
||||
'exploit': False,
|
||||
'use_target_net': False,
|
||||
'use_demo_states': True,
|
||||
'compute_Q': False,
|
||||
'T': params['T'],
|
||||
}
|
||||
|
||||
eval_params = {
|
||||
'exploit': True,
|
||||
'use_target_net': params['test_with_polyak'],
|
||||
'use_demo_states': False,
|
||||
'compute_Q': True,
|
||||
'T': params['T'],
|
||||
}
|
||||
|
||||
for name in ['T', 'rollout_batch_size', 'gamma', 'noise_eps', 'random_eps']:
|
||||
rollout_params[name] = params[name]
|
||||
eval_params[name] = params[name]
|
||||
|
||||
rollout_worker = RolloutWorker(params['make_env'], policy, dims, logger, **rollout_params)
|
||||
rollout_worker.seed(rank_seed)
|
||||
|
||||
evaluator = RolloutWorker(params['make_env'], policy, dims, logger, **eval_params)
|
||||
evaluator.seed(rank_seed)
|
||||
|
||||
train(
|
||||
logdir=logdir, policy=policy, rollout_worker=rollout_worker,
|
||||
evaluator=evaluator, n_epochs=n_epochs, n_test_rollouts=params['n_test_rollouts'],
|
||||
n_cycles=params['n_cycles'], n_batches=params['n_batches'],
|
||||
policy_save_interval=policy_save_interval, save_policies=save_policies, demo_file=demo_file)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option('--env', type=str, default='FetchReach-v1', help='the name of the OpenAI Gym environment that you want to train on')
|
||||
@click.option('--logdir', type=str, default=None, help='the path to where logs and policy pickles should go. If not specified, creates a folder in /tmp/')
|
||||
@click.option('--n_epochs', type=int, default=50, help='the number of training epochs to run')
|
||||
@click.option('--num_cpu', type=int, default=1, help='the number of CPU cores to use (using MPI)')
|
||||
@click.option('--seed', type=int, default=0, help='the random seed used to seed both the environment and the training code')
|
||||
@click.option('--policy_save_interval', type=int, default=5, help='the interval with which policy pickles are saved. If set to 0, only the best and latest policy will be pickled.')
|
||||
@click.option('--replay_strategy', type=click.Choice(['future', 'none']), default='future', help='the HER replay strategy to be used. "future" uses HER, "none" disables HER.')
|
||||
@click.option('--clip_return', type=int, default=1, help='whether or not returns should be clipped')
|
||||
@click.option('--demo_file', type=str, default = 'PATH/TO/DEMO/DATA/FILE.npz', help='demo data file path')
|
||||
def main(**kwargs):
|
||||
launch(**kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
@@ -1,63 +1,193 @@
|
||||
import os
|
||||
|
||||
import click
|
||||
import numpy as np
|
||||
import json
|
||||
from mpi4py import MPI
|
||||
|
||||
from baselines import logger
|
||||
from baselines.common import set_global_seeds, tf_util
|
||||
from baselines.common.mpi_moments import mpi_moments
|
||||
import baselines.her.experiment.config as config
|
||||
from baselines.her.rollout import RolloutWorker
|
||||
|
||||
def mpi_average(value):
|
||||
if not isinstance(value, list):
|
||||
value = [value]
|
||||
if not any(value):
|
||||
value = [0.]
|
||||
return mpi_moments(np.array(value))[0]
|
||||
|
||||
|
||||
def make_sample_her_transitions(replay_strategy, replay_k, reward_fun):
|
||||
"""Creates a sample function that can be used for HER experience replay.
|
||||
def train(*, policy, rollout_worker, evaluator,
|
||||
n_epochs, n_test_rollouts, n_cycles, n_batches, policy_save_interval,
|
||||
save_path, demo_file, **kwargs):
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
|
||||
Args:
|
||||
replay_strategy (in ['future', 'none']): the HER replay strategy; if set to 'none',
|
||||
regular DDPG experience replay is used
|
||||
replay_k (int): the ratio between HER replays and regular replays (e.g. k = 4 -> 4 times
|
||||
as many HER replays as regular replays are used)
|
||||
reward_fun (function): function to re-compute the reward with substituted goals
|
||||
"""
|
||||
if replay_strategy == 'future':
|
||||
future_p = 1 - (1. / (1 + replay_k))
|
||||
else: # 'replay_strategy' == 'none'
|
||||
future_p = 0
|
||||
if save_path:
|
||||
latest_policy_path = os.path.join(save_path, 'policy_latest.pkl')
|
||||
best_policy_path = os.path.join(save_path, 'policy_best.pkl')
|
||||
periodic_policy_path = os.path.join(save_path, 'policy_{}.pkl')
|
||||
|
||||
def _sample_her_transitions(episode_batch, batch_size_in_transitions):
|
||||
"""episode_batch is {key: array(buffer_size x T x dim_key)}
|
||||
"""
|
||||
T = episode_batch['u'].shape[1]
|
||||
rollout_batch_size = episode_batch['u'].shape[0]
|
||||
batch_size = batch_size_in_transitions
|
||||
logger.info("Training...")
|
||||
best_success_rate = -1
|
||||
|
||||
# Select which episodes and time steps to use.
|
||||
episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
|
||||
t_samples = np.random.randint(T, size=batch_size)
|
||||
transitions = {key: episode_batch[key][episode_idxs, t_samples].copy()
|
||||
for key in episode_batch.keys()}
|
||||
if policy.bc_loss == 1: policy.init_demo_buffer(demo_file) #initialize demo buffer if training with demonstrations
|
||||
|
||||
# Select future time indexes proportional with probability future_p. These
|
||||
# will be used for HER replay by substituting in future goals.
|
||||
her_indexes = np.where(np.random.uniform(size=batch_size) < future_p)
|
||||
future_offset = np.random.uniform(size=batch_size) * (T - t_samples)
|
||||
future_offset = future_offset.astype(int)
|
||||
future_t = (t_samples + 1 + future_offset)[her_indexes]
|
||||
# num_timesteps = n_epochs * n_cycles * rollout_length * number of rollout workers
|
||||
for epoch in range(n_epochs):
|
||||
# train
|
||||
rollout_worker.clear_history()
|
||||
for _ in range(n_cycles):
|
||||
episode = rollout_worker.generate_rollouts()
|
||||
policy.store_episode(episode)
|
||||
for _ in range(n_batches):
|
||||
policy.train()
|
||||
policy.update_target_net()
|
||||
|
||||
# Replace goal with achieved goal but only for the previously-selected
|
||||
# HER transitions (as defined by her_indexes). For the other transitions,
|
||||
# keep the original goal.
|
||||
future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t]
|
||||
transitions['g'][her_indexes] = future_ag
|
||||
# test
|
||||
evaluator.clear_history()
|
||||
for _ in range(n_test_rollouts):
|
||||
evaluator.generate_rollouts()
|
||||
|
||||
# Reconstruct info dictionary for reward computation.
|
||||
info = {}
|
||||
for key, value in transitions.items():
|
||||
if key.startswith('info_'):
|
||||
info[key.replace('info_', '')] = value
|
||||
# record logs
|
||||
logger.record_tabular('epoch', epoch)
|
||||
for key, val in evaluator.logs('test'):
|
||||
logger.record_tabular(key, mpi_average(val))
|
||||
for key, val in rollout_worker.logs('train'):
|
||||
logger.record_tabular(key, mpi_average(val))
|
||||
for key, val in policy.logs():
|
||||
logger.record_tabular(key, mpi_average(val))
|
||||
|
||||
# Re-compute reward since we may have substituted the goal.
|
||||
reward_params = {k: transitions[k] for k in ['ag_2', 'g']}
|
||||
reward_params['info'] = info
|
||||
transitions['r'] = reward_fun(**reward_params)
|
||||
if rank == 0:
|
||||
logger.dump_tabular()
|
||||
|
||||
transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:])
|
||||
for k in transitions.keys()}
|
||||
# save the policy if it's better than the previous ones
|
||||
success_rate = mpi_average(evaluator.current_success_rate())
|
||||
if rank == 0 and success_rate >= best_success_rate and save_path:
|
||||
best_success_rate = success_rate
|
||||
logger.info('New best success rate: {}. Saving policy to {} ...'.format(best_success_rate, best_policy_path))
|
||||
evaluator.save_policy(best_policy_path)
|
||||
evaluator.save_policy(latest_policy_path)
|
||||
if rank == 0 and policy_save_interval > 0 and epoch % policy_save_interval == 0 and save_path:
|
||||
policy_path = periodic_policy_path.format(epoch)
|
||||
logger.info('Saving periodic policy to {} ...'.format(policy_path))
|
||||
evaluator.save_policy(policy_path)
|
||||
|
||||
assert(transitions['u'].shape[0] == batch_size_in_transitions)
|
||||
# make sure that different threads have different seeds
|
||||
local_uniform = np.random.uniform(size=(1,))
|
||||
root_uniform = local_uniform.copy()
|
||||
MPI.COMM_WORLD.Bcast(root_uniform, root=0)
|
||||
if rank != 0:
|
||||
assert local_uniform[0] != root_uniform[0]
|
||||
|
||||
return transitions
|
||||
return policy
|
||||
|
||||
return _sample_her_transitions
|
||||
|
||||
def learn(*, network, env, total_timesteps,
|
||||
seed=None,
|
||||
eval_env=None,
|
||||
replay_strategy='future',
|
||||
policy_save_interval=5,
|
||||
clip_return=True,
|
||||
demo_file=None,
|
||||
override_params=None,
|
||||
load_path=None,
|
||||
save_path=None,
|
||||
**kwargs
|
||||
):
|
||||
|
||||
override_params = override_params or {}
|
||||
if MPI is not None:
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
num_cpu = MPI.COMM_WORLD.Get_size()
|
||||
|
||||
# Seed everything.
|
||||
rank_seed = seed + 1000000 * rank if seed is not None else None
|
||||
set_global_seeds(rank_seed)
|
||||
|
||||
# Prepare params.
|
||||
params = config.DEFAULT_PARAMS
|
||||
env_name = env.specs[0].id
|
||||
params['env_name'] = env_name
|
||||
params['replay_strategy'] = replay_strategy
|
||||
if env_name in config.DEFAULT_ENV_PARAMS:
|
||||
params.update(config.DEFAULT_ENV_PARAMS[env_name]) # merge env-specific parameters in
|
||||
params.update(**override_params) # makes it possible to override any parameter
|
||||
with open(os.path.join(logger.get_dir(), 'params.json'), 'w') as f:
|
||||
json.dump(params, f)
|
||||
params = config.prepare_params(params)
|
||||
params['rollout_batch_size'] = env.num_envs
|
||||
|
||||
if demo_file is not None:
|
||||
params['bc_loss'] = 1
|
||||
params.update(kwargs)
|
||||
|
||||
config.log_params(params, logger=logger)
|
||||
|
||||
if num_cpu == 1:
|
||||
logger.warn()
|
||||
logger.warn('*** Warning ***')
|
||||
logger.warn(
|
||||
'You are running HER with just a single MPI worker. This will work, but the ' +
|
||||
'experiments that we report in Plappert et al. (2018, https://arxiv.org/abs/1802.09464) ' +
|
||||
'were obtained with --num_cpu 19. This makes a significant difference and if you ' +
|
||||
'are looking to reproduce those results, be aware of this. Please also refer to ' +
|
||||
'https://github.com/openai/baselines/issues/314 for further details.')
|
||||
logger.warn('****************')
|
||||
logger.warn()
|
||||
|
||||
dims = config.configure_dims(params)
|
||||
policy = config.configure_ddpg(dims=dims, params=params, clip_return=clip_return)
|
||||
if load_path is not None:
|
||||
tf_util.load_variables(load_path)
|
||||
|
||||
rollout_params = {
|
||||
'exploit': False,
|
||||
'use_target_net': False,
|
||||
'use_demo_states': True,
|
||||
'compute_Q': False,
|
||||
'T': params['T'],
|
||||
}
|
||||
|
||||
eval_params = {
|
||||
'exploit': True,
|
||||
'use_target_net': params['test_with_polyak'],
|
||||
'use_demo_states': False,
|
||||
'compute_Q': True,
|
||||
'T': params['T'],
|
||||
}
|
||||
|
||||
for name in ['T', 'rollout_batch_size', 'gamma', 'noise_eps', 'random_eps']:
|
||||
rollout_params[name] = params[name]
|
||||
eval_params[name] = params[name]
|
||||
|
||||
eval_env = eval_env or env
|
||||
|
||||
rollout_worker = RolloutWorker(env, policy, dims, logger, monitor=True, **rollout_params)
|
||||
evaluator = RolloutWorker(eval_env, policy, dims, logger, **eval_params)
|
||||
|
||||
n_cycles = params['n_cycles']
|
||||
n_epochs = total_timesteps // n_cycles // rollout_worker.T // rollout_worker.rollout_batch_size
|
||||
|
||||
return train(
|
||||
save_path=save_path, policy=policy, rollout_worker=rollout_worker,
|
||||
evaluator=evaluator, n_epochs=n_epochs, n_test_rollouts=params['n_test_rollouts'],
|
||||
n_cycles=params['n_cycles'], n_batches=params['n_batches'],
|
||||
policy_save_interval=policy_save_interval, demo_file=demo_file)
|
||||
|
||||
|
||||
@click.command()
|
||||
@click.option('--env', type=str, default='FetchReach-v1', help='the name of the OpenAI Gym environment that you want to train on')
|
||||
@click.option('--total_timesteps', type=int, default=int(5e5), help='the number of timesteps to run')
|
||||
@click.option('--seed', type=int, default=0, help='the random seed used to seed both the environment and the training code')
|
||||
@click.option('--policy_save_interval', type=int, default=5, help='the interval with which policy pickles are saved. If set to 0, only the best and latest policy will be pickled.')
|
||||
@click.option('--replay_strategy', type=click.Choice(['future', 'none']), default='future', help='the HER replay strategy to be used. "future" uses HER, "none" disables HER.')
|
||||
@click.option('--clip_return', type=int, default=1, help='whether or not returns should be clipped')
|
||||
@click.option('--demo_file', type=str, default = 'PATH/TO/DEMO/DATA/FILE.npz', help='demo data file path')
|
||||
def main(**kwargs):
|
||||
learn(**kwargs)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
|
63
baselines/her/her_sampler.py
Normal file
63
baselines/her/her_sampler.py
Normal file
@@ -0,0 +1,63 @@
|
||||
import numpy as np
|
||||
|
||||
|
||||
def make_sample_her_transitions(replay_strategy, replay_k, reward_fun):
|
||||
"""Creates a sample function that can be used for HER experience replay.
|
||||
|
||||
Args:
|
||||
replay_strategy (in ['future', 'none']): the HER replay strategy; if set to 'none',
|
||||
regular DDPG experience replay is used
|
||||
replay_k (int): the ratio between HER replays and regular replays (e.g. k = 4 -> 4 times
|
||||
as many HER replays as regular replays are used)
|
||||
reward_fun (function): function to re-compute the reward with substituted goals
|
||||
"""
|
||||
if replay_strategy == 'future':
|
||||
future_p = 1 - (1. / (1 + replay_k))
|
||||
else: # 'replay_strategy' == 'none'
|
||||
future_p = 0
|
||||
|
||||
def _sample_her_transitions(episode_batch, batch_size_in_transitions):
|
||||
"""episode_batch is {key: array(buffer_size x T x dim_key)}
|
||||
"""
|
||||
T = episode_batch['u'].shape[1]
|
||||
rollout_batch_size = episode_batch['u'].shape[0]
|
||||
batch_size = batch_size_in_transitions
|
||||
|
||||
# Select which episodes and time steps to use.
|
||||
episode_idxs = np.random.randint(0, rollout_batch_size, batch_size)
|
||||
t_samples = np.random.randint(T, size=batch_size)
|
||||
transitions = {key: episode_batch[key][episode_idxs, t_samples].copy()
|
||||
for key in episode_batch.keys()}
|
||||
|
||||
# Select future time indexes proportional with probability future_p. These
|
||||
# will be used for HER replay by substituting in future goals.
|
||||
her_indexes = np.where(np.random.uniform(size=batch_size) < future_p)
|
||||
future_offset = np.random.uniform(size=batch_size) * (T - t_samples)
|
||||
future_offset = future_offset.astype(int)
|
||||
future_t = (t_samples + 1 + future_offset)[her_indexes]
|
||||
|
||||
# Replace goal with achieved goal but only for the previously-selected
|
||||
# HER transitions (as defined by her_indexes). For the other transitions,
|
||||
# keep the original goal.
|
||||
future_ag = episode_batch['ag'][episode_idxs[her_indexes], future_t]
|
||||
transitions['g'][her_indexes] = future_ag
|
||||
|
||||
# Reconstruct info dictionary for reward computation.
|
||||
info = {}
|
||||
for key, value in transitions.items():
|
||||
if key.startswith('info_'):
|
||||
info[key.replace('info_', '')] = value
|
||||
|
||||
# Re-compute reward since we may have substituted the goal.
|
||||
reward_params = {k: transitions[k] for k in ['ag_2', 'g']}
|
||||
reward_params['info'] = info
|
||||
transitions['r'] = reward_fun(**reward_params)
|
||||
|
||||
transitions = {k: transitions[k].reshape(batch_size, *transitions[k].shape[1:])
|
||||
for k in transitions.keys()}
|
||||
|
||||
assert(transitions['u'].shape[0] == batch_size_in_transitions)
|
||||
|
||||
return transitions
|
||||
|
||||
return _sample_her_transitions
|
@@ -2,7 +2,6 @@ from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import pickle
|
||||
from mujoco_py import MujocoException
|
||||
|
||||
from baselines.her.util import convert_episode_to_batch_major, store_args
|
||||
|
||||
@@ -10,9 +9,9 @@ from baselines.her.util import convert_episode_to_batch_major, store_args
|
||||
class RolloutWorker:
|
||||
|
||||
@store_args
|
||||
def __init__(self, make_env, policy, dims, logger, T, rollout_batch_size=1,
|
||||
def __init__(self, venv, policy, dims, logger, T, rollout_batch_size=1,
|
||||
exploit=False, use_target_net=False, compute_Q=False, noise_eps=0,
|
||||
random_eps=0, history_len=100, render=False, **kwargs):
|
||||
random_eps=0, history_len=100, render=False, monitor=False, **kwargs):
|
||||
"""Rollout worker generates experience by interacting with one or many environments.
|
||||
|
||||
Args:
|
||||
@@ -31,7 +30,7 @@ class RolloutWorker:
|
||||
history_len (int): length of history for statistics smoothing
|
||||
render (boolean): whether or not to render the rollouts
|
||||
"""
|
||||
self.envs = [make_env() for _ in range(rollout_batch_size)]
|
||||
|
||||
assert self.T > 0
|
||||
|
||||
self.info_keys = [key.replace('info_', '') for key in dims.keys() if key.startswith('info_')]
|
||||
@@ -40,26 +39,14 @@ class RolloutWorker:
|
||||
self.Q_history = deque(maxlen=history_len)
|
||||
|
||||
self.n_episodes = 0
|
||||
self.g = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # goals
|
||||
self.initial_o = np.empty((self.rollout_batch_size, self.dims['o']), np.float32) # observations
|
||||
self.initial_ag = np.empty((self.rollout_batch_size, self.dims['g']), np.float32) # achieved goals
|
||||
self.reset_all_rollouts()
|
||||
self.clear_history()
|
||||
|
||||
def reset_rollout(self, i):
|
||||
"""Resets the `i`-th rollout environment, re-samples a new goal, and updates the `initial_o`
|
||||
and `g` arrays accordingly.
|
||||
"""
|
||||
obs = self.envs[i].reset()
|
||||
self.initial_o[i] = obs['observation']
|
||||
self.initial_ag[i] = obs['achieved_goal']
|
||||
self.g[i] = obs['desired_goal']
|
||||
|
||||
def reset_all_rollouts(self):
|
||||
"""Resets all `rollout_batch_size` rollout workers.
|
||||
"""
|
||||
for i in range(self.rollout_batch_size):
|
||||
self.reset_rollout(i)
|
||||
self.obs_dict = self.venv.reset()
|
||||
self.initial_o = self.obs_dict['observation']
|
||||
self.initial_ag = self.obs_dict['achieved_goal']
|
||||
self.g = self.obs_dict['desired_goal']
|
||||
|
||||
def generate_rollouts(self):
|
||||
"""Performs `rollout_batch_size` rollouts in parallel for time horizon `T` with the current
|
||||
@@ -75,7 +62,8 @@ class RolloutWorker:
|
||||
|
||||
# generate episodes
|
||||
obs, achieved_goals, acts, goals, successes = [], [], [], [], []
|
||||
info_values = [np.empty((self.T, self.rollout_batch_size, self.dims['info_' + key]), np.float32) for key in self.info_keys]
|
||||
dones = []
|
||||
info_values = [np.empty((self.T - 1, self.rollout_batch_size, self.dims['info_' + key]), np.float32) for key in self.info_keys]
|
||||
Qs = []
|
||||
for t in range(self.T):
|
||||
policy_output = self.policy.get_actions(
|
||||
@@ -99,27 +87,27 @@ class RolloutWorker:
|
||||
ag_new = np.empty((self.rollout_batch_size, self.dims['g']))
|
||||
success = np.zeros(self.rollout_batch_size)
|
||||
# compute new states and observations
|
||||
for i in range(self.rollout_batch_size):
|
||||
try:
|
||||
# We fully ignore the reward here because it will have to be re-computed
|
||||
# for HER.
|
||||
curr_o_new, _, _, info = self.envs[i].step(u[i])
|
||||
if 'is_success' in info:
|
||||
success[i] = info['is_success']
|
||||
o_new[i] = curr_o_new['observation']
|
||||
ag_new[i] = curr_o_new['achieved_goal']
|
||||
for idx, key in enumerate(self.info_keys):
|
||||
info_values[idx][t, i] = info[key]
|
||||
if self.render:
|
||||
self.envs[i].render()
|
||||
except MujocoException as e:
|
||||
return self.generate_rollouts()
|
||||
obs_dict_new, _, done, info = self.venv.step(u)
|
||||
o_new = obs_dict_new['observation']
|
||||
ag_new = obs_dict_new['achieved_goal']
|
||||
success = np.array([i.get('is_success', 0.0) for i in info])
|
||||
|
||||
if any(done):
|
||||
# here we assume all environments are done is ~same number of steps, so we terminate rollouts whenever any of the envs returns done
|
||||
# trick with using vecenvs is not to add the obs from the environments that are "done", because those are already observations
|
||||
# after a reset
|
||||
break
|
||||
|
||||
for i, info_dict in enumerate(info):
|
||||
for idx, key in enumerate(self.info_keys):
|
||||
info_values[idx][t, i] = info[i][key]
|
||||
|
||||
if np.isnan(o_new).any():
|
||||
self.logger.warn('NaN caught during rollout generation. Trying again...')
|
||||
self.reset_all_rollouts()
|
||||
return self.generate_rollouts()
|
||||
|
||||
dones.append(done)
|
||||
obs.append(o.copy())
|
||||
achieved_goals.append(ag.copy())
|
||||
successes.append(success.copy())
|
||||
@@ -129,7 +117,6 @@ class RolloutWorker:
|
||||
ag[...] = ag_new
|
||||
obs.append(o.copy())
|
||||
achieved_goals.append(ag.copy())
|
||||
self.initial_o[:] = o
|
||||
|
||||
episode = dict(o=obs,
|
||||
u=acts,
|
||||
@@ -176,13 +163,8 @@ class RolloutWorker:
|
||||
logs += [('mean_Q', np.mean(self.Q_history))]
|
||||
logs += [('episode', self.n_episodes)]
|
||||
|
||||
if prefix is not '' and not prefix.endswith('/'):
|
||||
if prefix != '' and not prefix.endswith('/'):
|
||||
return [(prefix + '/' + key, val) for key, val in logs]
|
||||
else:
|
||||
return logs
|
||||
|
||||
def seed(self, seed):
|
||||
"""Seeds each environment with a distinct seed derived from the passed in global seed.
|
||||
"""
|
||||
for idx, env in enumerate(self.envs):
|
||||
env.seed(seed + 1000 * idx)
|
||||
|
@@ -110,7 +110,8 @@ def build_env(args):
|
||||
config.gpu_options.allow_growth = True
|
||||
get_session(config=config)
|
||||
|
||||
env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale)
|
||||
flatten_dict_observations = alg not in {'her'}
|
||||
env = make_vec_env(env_id, env_type, args.num_env or 1, seed, reward_scale=args.reward_scale, flatten_dict_observations=flatten_dict_observations)
|
||||
|
||||
if env_type == 'mujoco':
|
||||
env = VecNormalize(env)
|
||||
@@ -119,6 +120,11 @@ def build_env(args):
|
||||
|
||||
|
||||
def get_env_type(env_id):
|
||||
# Re-parse the gym registry, since we could have new envs since last time.
|
||||
for env in gym.envs.registry.all():
|
||||
env_type = env._entry_point.split(':')[0].split('.')[-1]
|
||||
_game_envs[env_type].add(env.id) # This is a set so add is idempotent
|
||||
|
||||
if env_id in _game_envs.keys():
|
||||
env_type = env_id
|
||||
env_id = [g for g in _game_envs[env_type]][0]
|
||||
@@ -188,6 +194,9 @@ def main(args):
|
||||
args, unknown_args = arg_parser.parse_known_args(args)
|
||||
extra_args = parse_cmdline_kwargs(unknown_args)
|
||||
|
||||
if args.extra_import is not None:
|
||||
import_module(args.extra_import)
|
||||
|
||||
if MPI is None or MPI.COMM_WORLD.Get_rank() == 0:
|
||||
rank = 0
|
||||
logger.configure()
|
||||
@@ -206,11 +215,16 @@ def main(args):
|
||||
logger.log("Running trained model")
|
||||
env = build_env(args)
|
||||
obs = env.reset()
|
||||
def initialize_placeholders(nlstm=128,**kwargs):
|
||||
return np.zeros((args.num_env or 1, 2*nlstm)), np.zeros((1))
|
||||
state, dones = initialize_placeholders(**extra_args)
|
||||
|
||||
state = model.initial_state if hasattr(model, 'initial_state') else None
|
||||
dones = np.zeros((1,))
|
||||
|
||||
while True:
|
||||
actions, _, state, _ = model.step(obs,S=state, M=dones)
|
||||
if state is not None:
|
||||
actions, _, state, _ = model.step(obs,S=state, M=dones)
|
||||
else:
|
||||
actions, _, _, _ = model.step(obs)
|
||||
|
||||
obs, _, done, _ = env.step(actions)
|
||||
env.render()
|
||||
done = done.any() if isinstance(done, np.ndarray) else done
|
||||
|
@@ -3,6 +3,5 @@ select = F,E999,W291,W293
|
||||
exclude =
|
||||
.git,
|
||||
__pycache__,
|
||||
baselines/her,
|
||||
baselines/ppo1,
|
||||
baselines/bench,
|
||||
|
6
setup.py
6
setup.py
@@ -53,11 +53,11 @@ setup(name='baselines',
|
||||
# ensure there is some tensorflow build with version above 1.4
|
||||
import pkg_resources
|
||||
tf_pkg = None
|
||||
for tf_pkg_name in ['tensorflow', 'tensorflow-gpu']:
|
||||
for tf_pkg_name in ['tensorflow', 'tensorflow-gpu', 'tf-nightly', 'tf-nightly-gpu']:
|
||||
try:
|
||||
tf_pkg = pkg_resources.get_distribution(tf_pkg_name)
|
||||
except pkg_resources.DistributionNotFound:
|
||||
pass
|
||||
assert tf_pkg is not None, 'TensorFlow needed, of version above 1.4'
|
||||
from distutils.version import StrictVersion
|
||||
assert StrictVersion(re.sub(r'-?rc\d+$', '', tf_pkg.version)) >= StrictVersion('1.4.0')
|
||||
from distutils.version import LooseVersion
|
||||
assert LooseVersion(re.sub(r'-?rc\d+$', '', tf_pkg.version)) >= LooseVersion('1.4.0')
|
||||
|
Reference in New Issue
Block a user