HER : new functionality, enables demo based training (#474)
* Add, initialize, normalize and sample from a demo buffer * Modify losses and add cloning loss * Add demo file parameter to train.py * Introduce new params in config.py for demo based training * Change logger.warning to logger.warn in rollout.py;bug * Add data generation file for Fetch environments * Update README file
This commit is contained in:
@@ -30,3 +30,51 @@ python -m baselines.her.experiment.train --num_cpu 19
|
||||
This will require a machine with sufficient amount of physical CPU cores. In our experiments,
|
||||
we used [Azure's D15v2 instances](https://docs.microsoft.com/en-us/azure/virtual-machines/linux/sizes),
|
||||
which have 20 physical cores. We only scheduled the experiment on 19 of those to leave some head-room on the system.
|
||||
|
||||
|
||||
## Hindsight Experience Replay with Demonstrations
|
||||
Using pre-recorded demonstrations to Overcome the exploration problem in HER based Reinforcement learning.
|
||||
For details, please read the [paper](https://arxiv.org/pdf/1709.10089.pdf).
|
||||
|
||||
### Getting started
|
||||
The first step is to generate the demonstration dataset. This can be done in two ways, either by using a VR system to manipulate the arm using physical VR trackers or the simpler way is to write a script to carry out the respective task. Now some tasks can be complex and thus it would be difficult to write a hardcoded script for that task (eg. Fetch Push), but here our focus is on providing an algorithm that helps the agent to learn from demonstrations, and not on the demonstration generation paradigm itself. Thus the data collection part is left to the reader's choice.
|
||||
|
||||
We provide a script for the Fetch Pick and Place task, to generate demonstrations for the Pick and Place task execute:
|
||||
```bash
|
||||
python experiment/data_generation/fetch_data_generation.py
|
||||
```
|
||||
This outputs ```data_fetch_random_100.npz``` file which is our data file.
|
||||
|
||||
#### Configuration
|
||||
The provided configuration is for training an agent with HER without demonstrations, we need to change a few paramters for the HER algorithm to learn through demonstrations, to do that, set:
|
||||
|
||||
* bc_loss: 1 - whether or not to use the behavior cloning loss as an auxilliary loss
|
||||
* q_filter: 1 - whether or not a Q value filter should be used on the Actor outputs
|
||||
* num_demo: 100 - number of expert demo episodes
|
||||
* demo_batch_size: 128 - number of samples to be used from the demonstrations buffer, per mpi thread
|
||||
* prm_loss_weight: 0.001 - Weight corresponding to the primary loss
|
||||
* aux_loss_weight: 0.0078 - Weight corresponding to the auxilliary loss also called the cloning loss
|
||||
|
||||
Apart from these changes the reported results also have the following configurational changes:
|
||||
|
||||
* n_cycles: 20 - per epoch
|
||||
* batch_size: 1024 - per mpi thread, total batch size
|
||||
* random_eps: 0.1 - percentage of time a random action is taken
|
||||
* noise_eps: 0.1 - std of gaussian noise added to not-completely-random actions
|
||||
|
||||
Now training an agent with pre-recorded demonstrations:
|
||||
```bash
|
||||
python -m baselines.her.experiment.train --env=FetchPickAndPlace-v0 --n_epochs=1000 --demo_file=/Path/to/demo_file.npz --num_cpu=1
|
||||
```
|
||||
|
||||
This will train a DDPG+HER agent on the `FetchPickAndPlace` environment by using previously generated demonstration data.
|
||||
To inspect what the agent has learned, use the play script as described above.
|
||||
|
||||
### Results
|
||||
Training with demonstrations helps overcome the exploration problem and achieves a faster and better convergence. The following graphs contrast the difference between training with and without demonstration data, We report the mean Q values vs Epoch and the Success Rate vs Epoch:
|
||||
|
||||
|
||||
<div class="imgcap" align="middle">
|
||||
<center><img src="../../data/fetchPickAndPlaceContrast.png"></center>
|
||||
<div class="thecap" align="middle"><b>Training results for Fetch Pick and Place task constrasting between training with and without demonstration data.</b></div>
|
||||
</div>
|
||||
|
@@ -6,7 +6,7 @@ from tensorflow.contrib.staging import StagingArea
|
||||
|
||||
from baselines import logger
|
||||
from baselines.her.util import (
|
||||
import_function, store_args, flatten_grads, transitions_in_episode_batch)
|
||||
import_function, store_args, flatten_grads, transitions_in_episode_batch, convert_episode_to_batch_major)
|
||||
from baselines.her.normalizer import Normalizer
|
||||
from baselines.her.replay_buffer import ReplayBuffer
|
||||
from baselines.common.mpi_adam import MpiAdam
|
||||
@@ -16,13 +16,17 @@ def dims_to_shapes(input_dims):
|
||||
return {key: tuple([val]) if val > 0 else tuple() for key, val in input_dims.items()}
|
||||
|
||||
|
||||
global demoBuffer #buffer for demonstrations
|
||||
|
||||
class DDPG(object):
|
||||
@store_args
|
||||
def __init__(self, input_dims, buffer_size, hidden, layers, network_class, polyak, batch_size,
|
||||
Q_lr, pi_lr, norm_eps, norm_clip, max_u, action_l2, clip_obs, scope, T,
|
||||
rollout_batch_size, subtract_goals, relative_goals, clip_pos_returns, clip_return,
|
||||
bc_loss, q_filter, num_demo, demo_batch_size, prm_loss_weight, aux_loss_weight,
|
||||
sample_transitions, gamma, reuse=False, **kwargs):
|
||||
"""Implementation of DDPG that is used in combination with Hindsight Experience Replay (HER).
|
||||
Added functionality to use demonstrations for training to Overcome exploration problem.
|
||||
|
||||
Args:
|
||||
input_dims (dict of ints): dimensions for the observation (o), the goal (g), and the
|
||||
@@ -50,6 +54,12 @@ class DDPG(object):
|
||||
sample_transitions (function) function that samples from the replay buffer
|
||||
gamma (float): gamma used for Q learning updates
|
||||
reuse (boolean): whether or not the networks should be reused
|
||||
bc_loss: whether or not the behavior cloning loss should be used as an auxilliary loss
|
||||
q_filter: whether or not a filter on the q value update should be used when training with demonstartions
|
||||
num_demo: Number of episodes in to be used in the demonstration buffer
|
||||
demo_batch_size: number of samples to be used from the demonstrations buffer, per mpi thread
|
||||
prm_loss_weight: Weight corresponding to the primary loss
|
||||
aux_loss_weight: Weight corresponding to the auxilliary loss also called the cloning loss
|
||||
"""
|
||||
if self.clip_return is None:
|
||||
self.clip_return = np.inf
|
||||
@@ -92,6 +102,9 @@ class DDPG(object):
|
||||
buffer_size = (self.buffer_size // self.rollout_batch_size) * self.rollout_batch_size
|
||||
self.buffer = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions)
|
||||
|
||||
global demoBuffer
|
||||
demoBuffer = ReplayBuffer(buffer_shapes, buffer_size, self.T, self.sample_transitions) #initialize the demo buffer; in the same way as the primary data buffer
|
||||
|
||||
def _random_action(self, n):
|
||||
return np.random.uniform(low=-self.max_u, high=self.max_u, size=(n, self.dimu))
|
||||
|
||||
@@ -138,6 +151,57 @@ class DDPG(object):
|
||||
else:
|
||||
return ret
|
||||
|
||||
def initDemoBuffer(self, demoDataFile, update_stats=True): #function that initializes the demo buffer
|
||||
|
||||
demoData = np.load(demoDataFile) #load the demonstration data from data file
|
||||
info_keys = [key.replace('info_', '') for key in self.input_dims.keys() if key.startswith('info_')]
|
||||
info_values = [np.empty((self.T, 1, self.input_dims['info_' + key]), np.float32) for key in info_keys]
|
||||
|
||||
for epsd in range(self.num_demo): # we initialize the whole demo buffer at the start of the training
|
||||
obs, acts, goals, achieved_goals = [], [] ,[] ,[]
|
||||
i = 0
|
||||
for transition in range(self.T):
|
||||
obs.append([demoData['obs'][epsd ][transition].get('observation')])
|
||||
acts.append([demoData['acs'][epsd][transition]])
|
||||
goals.append([demoData['obs'][epsd][transition].get('desired_goal')])
|
||||
achieved_goals.append([demoData['obs'][epsd][transition].get('achieved_goal')])
|
||||
for idx, key in enumerate(info_keys):
|
||||
info_values[idx][transition, i] = demoData['info'][epsd][transition][key]
|
||||
|
||||
obs.append([demoData['obs'][epsd][self.T].get('observation')])
|
||||
achieved_goals.append([demoData['obs'][epsd][self.T].get('achieved_goal')])
|
||||
|
||||
episode = dict(o=obs,
|
||||
u=acts,
|
||||
g=goals,
|
||||
ag=achieved_goals)
|
||||
for key, value in zip(info_keys, info_values):
|
||||
episode['info_{}'.format(key)] = value
|
||||
|
||||
episode = convert_episode_to_batch_major(episode)
|
||||
global demoBuffer
|
||||
demoBuffer.store_episode(episode) # create the observation dict and append them into the demonstration buffer
|
||||
|
||||
print("Demo buffer size currently ", demoBuffer.get_current_size()) #print out the demonstration buffer size
|
||||
|
||||
if update_stats:
|
||||
# add transitions to normalizer to normalize the demo data as well
|
||||
episode['o_2'] = episode['o'][:, 1:, :]
|
||||
episode['ag_2'] = episode['ag'][:, 1:, :]
|
||||
num_normalizing_transitions = transitions_in_episode_batch(episode)
|
||||
transitions = self.sample_transitions(episode, num_normalizing_transitions)
|
||||
|
||||
o, o_2, g, ag = transitions['o'], transitions['o_2'], transitions['g'], transitions['ag']
|
||||
transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
|
||||
# No need to preprocess the o_2 and g_2 since this is only used for stats
|
||||
|
||||
self.o_stats.update(transitions['o'])
|
||||
self.g_stats.update(transitions['g'])
|
||||
|
||||
self.o_stats.recompute_stats()
|
||||
self.g_stats.recompute_stats()
|
||||
episode.clear()
|
||||
|
||||
def store_episode(self, episode_batch, update_stats=True):
|
||||
"""
|
||||
episode_batch: array of batch_size x (T or T+1) x dim_key
|
||||
@@ -185,7 +249,18 @@ class DDPG(object):
|
||||
self.pi_adam.update(pi_grad, self.pi_lr)
|
||||
|
||||
def sample_batch(self):
|
||||
transitions = self.buffer.sample(self.batch_size)
|
||||
if self.bc_loss: #use demonstration buffer to sample as well if bc_loss flag is set TRUE
|
||||
transitions = self.buffer.sample(self.batch_size - self.demo_batch_size)
|
||||
global demoBuffer
|
||||
transitionsDemo = demoBuffer.sample(self.demo_batch_size) #sample from the demo buffer
|
||||
for k, values in transitionsDemo.items():
|
||||
rolloutV = transitions[k].tolist()
|
||||
for v in values:
|
||||
rolloutV.append(v.tolist())
|
||||
transitions[k] = np.array(rolloutV)
|
||||
else:
|
||||
transitions = self.buffer.sample(self.batch_size) #otherwise only sample from primary buffer
|
||||
|
||||
o, o_2, g = transitions['o'], transitions['o_2'], transitions['g']
|
||||
ag, ag_2 = transitions['ag'], transitions['ag_2']
|
||||
transitions['o'], transitions['g'] = self._preprocess_og(o, ag, g)
|
||||
@@ -248,6 +323,9 @@ class DDPG(object):
|
||||
for i, key in enumerate(self.stage_shapes.keys())])
|
||||
batch_tf['r'] = tf.reshape(batch_tf['r'], [-1, 1])
|
||||
|
||||
#choose only the demo buffer samples
|
||||
mask = np.concatenate((np.zeros(self.batch_size - self.demo_batch_size), np.ones(self.demo_batch_size)), axis = 0)
|
||||
|
||||
# networks
|
||||
with tf.variable_scope('main') as vs:
|
||||
if reuse:
|
||||
@@ -270,6 +348,25 @@ class DDPG(object):
|
||||
clip_range = (-self.clip_return, 0. if self.clip_pos_returns else np.inf)
|
||||
target_tf = tf.clip_by_value(batch_tf['r'] + self.gamma * target_Q_pi_tf, *clip_range)
|
||||
self.Q_loss_tf = tf.reduce_mean(tf.square(tf.stop_gradient(target_tf) - self.main.Q_tf))
|
||||
|
||||
if self.bc_loss ==1 and self.q_filter == 1 : # train with demonstrations and use bc_loss and q_filter both
|
||||
maskMain = tf.reshape(tf.boolean_mask(self.main.Q_tf > self.main.Q_pi_tf, mask), [-1]) #where is the demonstrator action better than actor action according to the critic? choose those samples only
|
||||
#define the cloning loss on the actor's actions only on the samples which adhere to the above masks
|
||||
self.cloning_loss_tf = tf.reduce_sum(tf.square(tf.boolean_mask(tf.boolean_mask((self.main.pi_tf), mask), maskMain, axis=0) - tf.boolean_mask(tf.boolean_mask((batch_tf['u']), mask), maskMain, axis=0)))
|
||||
self.pi_loss_tf = -self.prm_loss_weight * tf.reduce_mean(self.main.Q_pi_tf) #primary loss scaled by it's respective weight prm_loss_weight
|
||||
self.pi_loss_tf += self.prm_loss_weight * self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u)) #L2 loss on action values scaled by the same weight prm_loss_weight
|
||||
self.pi_loss_tf += self.aux_loss_weight * self.cloning_loss_tf #adding the cloning loss to the actor loss as an auxilliary loss scaled by its weight aux_loss_weight
|
||||
|
||||
elif self.bc_loss == 1 and self.q_filter == 0: # train with demonstrations without q_filter
|
||||
self.cloning_loss_tf = tf.reduce_sum(tf.square(tf.boolean_mask((self.main.pi_tf), mask) - tf.boolean_mask((batch_tf['u']), mask)))
|
||||
self.pi_loss_tf = -self.prm_loss_weight * tf.reduce_mean(self.main.Q_pi_tf)
|
||||
self.pi_loss_tf += self.prm_loss_weight * self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
|
||||
self.pi_loss_tf += self.aux_loss_weight * self.cloning_loss_tf
|
||||
|
||||
else: #If not training with demonstrations
|
||||
self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
|
||||
self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
|
||||
|
||||
self.pi_loss_tf = -tf.reduce_mean(self.main.Q_pi_tf)
|
||||
self.pi_loss_tf += self.action_l2 * tf.reduce_mean(tf.square(self.main.pi_tf / self.max_u))
|
||||
Q_grads_tf = tf.gradients(self.Q_loss_tf, self._vars('main/Q'))
|
||||
|
@@ -44,6 +44,13 @@ DEFAULT_PARAMS = {
|
||||
# normalization
|
||||
'norm_eps': 0.01, # epsilon used for observation normalization
|
||||
'norm_clip': 5, # normalized observations are cropped to this values
|
||||
|
||||
'bc_loss': 0, # whether or not to use the behavior cloning loss as an auxilliary loss
|
||||
'q_filter': 0, # whether or not a Q value filter should be used on the Actor outputs
|
||||
'num_demo': 100, # number of expert demo episodes
|
||||
'demo_batch_size': 128, #number of samples to be used from the demonstrations buffer, per mpi thread 128/1024 or 32/256
|
||||
'prm_loss_weight': 0.001, #Weight corresponding to the primary loss
|
||||
'aux_loss_weight': 0.0078, #Weight corresponding to the auxilliary loss also called the cloning loss
|
||||
}
|
||||
|
||||
|
||||
@@ -145,6 +152,12 @@ def configure_ddpg(dims, params, reuse=False, use_mpi=True, clip_return=True):
|
||||
'subtract_goals': simple_goal_subtract,
|
||||
'sample_transitions': sample_her_transitions,
|
||||
'gamma': gamma,
|
||||
'bc_loss': params['bc_loss'],
|
||||
'q_filter': params['q_filter'],
|
||||
'num_demo': params['num_demo'],
|
||||
'demo_batch_size': params['demo_batch_size'],
|
||||
'prm_loss_weight': params['prm_loss_weight'],
|
||||
'aux_loss_weight': params['aux_loss_weight'],
|
||||
})
|
||||
ddpg_params['info'] = {
|
||||
'env_name': params['env_name'],
|
||||
|
@@ -0,0 +1,149 @@
|
||||
import gym
|
||||
import time
|
||||
import random
|
||||
import numpy as np
|
||||
import rospy
|
||||
import roslaunch
|
||||
|
||||
from random import randint
|
||||
from std_srvs.srv import Empty
|
||||
from sensor_msgs.msg import JointState
|
||||
from geometry_msgs.msg import PoseStamped
|
||||
from geometry_msgs.msg import Pose
|
||||
from std_msgs.msg import Float64
|
||||
from controller_manager_msgs.srv import SwitchController
|
||||
from gym.utils import seeding
|
||||
|
||||
|
||||
"""Data generation for the case of a single block pick and place in Fetch Env"""
|
||||
|
||||
actions = []
|
||||
observations = []
|
||||
infos = []
|
||||
|
||||
def main():
|
||||
env = gym.make('FetchPickAndPlace-v0')
|
||||
numItr = 100
|
||||
initStateSpace = "random"
|
||||
env.reset()
|
||||
print("Reset!")
|
||||
while len(actions) < numItr:
|
||||
obs = env.reset()
|
||||
print("ITERATION NUMBER ", len(actions))
|
||||
goToGoal(env, obs)
|
||||
|
||||
|
||||
fileName = "data_fetch"
|
||||
fileName += "_" + initStateSpace
|
||||
fileName += "_" + str(numItr)
|
||||
fileName += ".npz"
|
||||
|
||||
np.savez_compressed(fileName, acs=actions, obs=observations, info=infos) # save the file
|
||||
|
||||
def goToGoal(env, lastObs):
|
||||
|
||||
goal = lastObs['desired_goal']
|
||||
objectPos = lastObs['observation'][3:6]
|
||||
gripperPos = lastObs['observation'][:3]
|
||||
gripperState = lastObs['observation'][9:11]
|
||||
object_rel_pos = lastObs['observation'][6:9]
|
||||
episodeAcs = []
|
||||
episodeObs = []
|
||||
episodeInfo = []
|
||||
|
||||
object_oriented_goal = object_rel_pos.copy()
|
||||
object_oriented_goal[2] += 0.03 # first make the gripper go slightly above the object
|
||||
|
||||
timeStep = 0 #count the total number of timesteps
|
||||
episodeObs.append(lastObs)
|
||||
|
||||
while np.linalg.norm(object_oriented_goal) >= 0.005 and timeStep <= env._max_episode_steps:
|
||||
env.render()
|
||||
action = [0, 0, 0, 0]
|
||||
object_oriented_goal = object_rel_pos.copy()
|
||||
object_oriented_goal[2] += 0.03
|
||||
|
||||
for i in range(len(object_oriented_goal)):
|
||||
action[i] = object_oriented_goal[i]*6
|
||||
|
||||
action[len(action)-1] = 0.05 #open
|
||||
|
||||
obsDataNew, reward, done, info = env.step(action)
|
||||
timeStep += 1
|
||||
|
||||
episodeAcs.append(action)
|
||||
episodeInfo.append(info)
|
||||
episodeObs.append(obsDataNew)
|
||||
|
||||
objectPos = obsDataNew['observation'][3:6]
|
||||
gripperPos = obsDataNew['observation'][:3]
|
||||
gripperState = obsDataNew['observation'][9:11]
|
||||
object_rel_pos = obsDataNew['observation'][6:9]
|
||||
|
||||
while np.linalg.norm(object_rel_pos) >= 0.005 and timeStep <= env._max_episode_steps :
|
||||
env.render()
|
||||
action = [0, 0, 0, 0]
|
||||
for i in range(len(object_rel_pos)):
|
||||
action[i] = object_rel_pos[i]*6
|
||||
|
||||
action[len(action)-1] = -0.005
|
||||
|
||||
obsDataNew, reward, done, info = env.step(action)
|
||||
timeStep += 1
|
||||
|
||||
episodeAcs.append(action)
|
||||
episodeInfo.append(info)
|
||||
episodeObs.append(obsDataNew)
|
||||
|
||||
objectPos = obsDataNew['observation'][3:6]
|
||||
gripperPos = obsDataNew['observation'][:3]
|
||||
gripperState = obsDataNew['observation'][9:11]
|
||||
object_rel_pos = obsDataNew['observation'][6:9]
|
||||
|
||||
|
||||
while np.linalg.norm(goal - objectPos) >= 0.01 and timeStep <= env._max_episode_steps :
|
||||
env.render()
|
||||
action = [0, 0, 0, 0]
|
||||
for i in range(len(goal - objectPos)):
|
||||
action[i] = (goal - objectPos)[i]*6
|
||||
|
||||
action[len(action)-1] = -0.005
|
||||
|
||||
obsDataNew, reward, done, info = env.step(action)
|
||||
timeStep += 1
|
||||
|
||||
episodeAcs.append(action)
|
||||
episodeInfo.append(info)
|
||||
episodeObs.append(obsDataNew)
|
||||
|
||||
objectPos = obsDataNew['observation'][3:6]
|
||||
gripperPos = obsDataNew['observation'][:3]
|
||||
gripperState = obsDataNew['observation'][9:11]
|
||||
object_rel_pos = obsDataNew['observation'][6:9]
|
||||
|
||||
while True: #limit the number of timesteps in the episode to a fixed duration
|
||||
env.render()
|
||||
action = [0, 0, 0, 0]
|
||||
action[len(action)-1] = -0.005 # keep the gripper closed
|
||||
|
||||
obsDataNew, reward, done, info = env.step(action)
|
||||
timeStep += 1
|
||||
|
||||
episodeAcs.append(action)
|
||||
episodeInfo.append(info)
|
||||
episodeObs.append(obsDataNew)
|
||||
|
||||
objectPos = obsDataNew['observation'][3:6]
|
||||
gripperPos = obsDataNew['observation'][:3]
|
||||
gripperState = obsDataNew['observation'][9:11]
|
||||
object_rel_pos = obsDataNew['observation'][6:9]
|
||||
|
||||
if timeStep >= env._max_episode_steps: break
|
||||
|
||||
actions.append(episodeAcs)
|
||||
observations.append(episodeObs)
|
||||
infos.append(episodeInfo)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
@@ -26,7 +26,7 @@ def mpi_average(value):
|
||||
|
||||
def train(policy, rollout_worker, evaluator,
|
||||
n_epochs, n_test_rollouts, n_cycles, n_batches, policy_save_interval,
|
||||
save_policies, **kwargs):
|
||||
save_policies, demo_file, **kwargs):
|
||||
rank = MPI.COMM_WORLD.Get_rank()
|
||||
|
||||
latest_policy_path = os.path.join(logger.get_dir(), 'policy_latest.pkl')
|
||||
@@ -35,6 +35,8 @@ def train(policy, rollout_worker, evaluator,
|
||||
|
||||
logger.info("Training...")
|
||||
best_success_rate = -1
|
||||
|
||||
if policy.bc_loss == 1: policy.initDemoBuffer(demo_file) #initialize demo buffer if training with demonstrations
|
||||
for epoch in range(n_epochs):
|
||||
# train
|
||||
rollout_worker.clear_history()
|
||||
@@ -84,7 +86,7 @@ def train(policy, rollout_worker, evaluator,
|
||||
|
||||
def launch(
|
||||
env, logdir, n_epochs, num_cpu, seed, replay_strategy, policy_save_interval, clip_return,
|
||||
override_params={}, save_policies=True
|
||||
demo_file, override_params={}, save_policies=True
|
||||
):
|
||||
# Fork for multi-CPU MPI implementation.
|
||||
if num_cpu > 1:
|
||||
@@ -171,7 +173,7 @@ def launch(
|
||||
logdir=logdir, policy=policy, rollout_worker=rollout_worker,
|
||||
evaluator=evaluator, n_epochs=n_epochs, n_test_rollouts=params['n_test_rollouts'],
|
||||
n_cycles=params['n_cycles'], n_batches=params['n_batches'],
|
||||
policy_save_interval=policy_save_interval, save_policies=save_policies)
|
||||
policy_save_interval=policy_save_interval, save_policies=save_policies, demo_file=demo_file)
|
||||
|
||||
|
||||
@click.command()
|
||||
@@ -183,6 +185,7 @@ def launch(
|
||||
@click.option('--policy_save_interval', type=int, default=5, help='the interval with which policy pickles are saved. If set to 0, only the best and latest policy will be pickled.')
|
||||
@click.option('--replay_strategy', type=click.Choice(['future', 'none']), default='future', help='the HER replay strategy to be used. "future" uses HER, "none" disables HER.')
|
||||
@click.option('--clip_return', type=int, default=1, help='whether or not returns should be clipped')
|
||||
@click.option('--demo_file', type=str, default = 'PATH/TO/DEMO/DATA/FILE.npz', help='demo data file path')
|
||||
def main(**kwargs):
|
||||
launch(**kwargs)
|
||||
|
||||
|
BIN
data/fetchPickAndPlaceContrast.png
Normal file
BIN
data/fetchPickAndPlaceContrast.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 68 KiB |
Reference in New Issue
Block a user