sync internal changes. Make ddpg work with vecenvs
This commit is contained in:
@@ -14,7 +14,7 @@ learn_kwargs = {
|
||||
'a2c' : {},
|
||||
'acktr': {},
|
||||
'deepq': {},
|
||||
'ddpg': dict(nb_epochs=None, layer_norm=True),
|
||||
'ddpg': dict(layer_norm=True),
|
||||
'ppo2': dict(lr=1e-3, nsteps=64, ent_coef=0.0),
|
||||
'trpo_mpi': dict(timesteps_per_batch=100, cg_iters=10, gamma=0.9, lam=1.0, max_kl=0.01)
|
||||
}
|
||||
|
@@ -293,7 +293,7 @@ def display_var_info(vars):
|
||||
if "/Adam" in name or "beta1_power" in name or "beta2_power" in name: continue
|
||||
v_params = np.prod(v.shape.as_list())
|
||||
count_params += v_params
|
||||
if "/b:" in name or "/biases" in name: continue # Wx+b, bias is not interesting to look at => count params, but not print
|
||||
if "/b:" in name or "/bias" in name: continue # Wx+b, bias is not interesting to look at => count params, but not print
|
||||
logger.info(" %s%s %i params %s" % (name, " "*(55-len(name)), v_params, str(v.shape)))
|
||||
|
||||
logger.info("Total model parameters: %0.2f million" % (count_params*1e-6))
|
||||
|
@@ -104,6 +104,7 @@ class VecEnv(ABC):
|
||||
bigimg = tile_images(imgs)
|
||||
if mode == 'human':
|
||||
self.get_viewer().imshow(bigimg)
|
||||
return self.get_viewer().isopen
|
||||
elif mode == 'rgb_array':
|
||||
return bigimg
|
||||
else:
|
||||
|
@@ -78,6 +78,7 @@ def learn(network, env,
|
||||
|
||||
max_action = env.action_space.high
|
||||
logger.info('scaling actions by {} before executing in env'.format(max_action))
|
||||
|
||||
agent = DDPG(actor, critic, memory, env.observation_space.shape, env.action_space.shape,
|
||||
gamma=gamma, tau=tau, normalize_returns=normalize_returns, normalize_observations=normalize_observations,
|
||||
batch_size=batch_size, action_noise=action_noise, param_noise=param_noise, critic_l2_reg=critic_l2_reg,
|
||||
@@ -94,16 +95,21 @@ def learn(network, env,
|
||||
sess.graph.finalize()
|
||||
|
||||
agent.reset()
|
||||
|
||||
obs = env.reset()
|
||||
if eval_env is not None:
|
||||
eval_obs = eval_env.reset()
|
||||
done = False
|
||||
episode_reward = 0.
|
||||
episode_step = 0
|
||||
episodes = 0
|
||||
t = 0
|
||||
B = obs.shape[0]
|
||||
|
||||
episode_reward = np.zeros(B, dtype = np.float32) #vector
|
||||
episode_step = np.zeros(B, dtype = int) # vector
|
||||
episodes = 0 #scalar
|
||||
t = 0 # scalar
|
||||
|
||||
epoch = 0
|
||||
|
||||
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
epoch_episode_rewards = []
|
||||
@@ -114,16 +120,22 @@ def learn(network, env,
|
||||
for epoch in range(nb_epochs):
|
||||
for cycle in range(nb_epoch_cycles):
|
||||
# Perform rollouts.
|
||||
if B > 1:
|
||||
# if simulating multiple envs in parallel, impossible to reset agent at the end of the episode in each
|
||||
# of the environments, so resetting here instead
|
||||
agent.reset()
|
||||
for t_rollout in range(nb_rollout_steps):
|
||||
# Predict next action.
|
||||
action, q, _, _ = agent.step(obs, apply_noise=True, compute_Q=True)
|
||||
assert action.shape == env.action_space.shape
|
||||
|
||||
# Execute next action.
|
||||
if rank == 0 and render:
|
||||
env.render()
|
||||
assert max_action.shape == action.shape
|
||||
|
||||
# max_action is of dimension A, whereas action is dimension (B,A) - the multiplication gets broadcasted to the batch
|
||||
new_obs, r, done, info = env.step(max_action * action) # scale for execution in env (as far as DDPG is concerned, every action is in [-1, 1])
|
||||
# note these outputs are batched from vecenv
|
||||
|
||||
t += 1
|
||||
if rank == 0 and render:
|
||||
env.render()
|
||||
@@ -133,21 +145,24 @@ def learn(network, env,
|
||||
# Book-keeping.
|
||||
epoch_actions.append(action)
|
||||
epoch_qs.append(q)
|
||||
agent.store_transition(obs, action, r, new_obs, done)
|
||||
agent.store_transition(obs, action, r, new_obs, done) #the batched data will be unrolled in memory.py's append.
|
||||
|
||||
obs = new_obs
|
||||
|
||||
if done:
|
||||
# Episode done.
|
||||
epoch_episode_rewards.append(episode_reward)
|
||||
episode_rewards_history.append(episode_reward)
|
||||
epoch_episode_steps.append(episode_step)
|
||||
episode_reward = 0.
|
||||
episode_step = 0
|
||||
epoch_episodes += 1
|
||||
episodes += 1
|
||||
for d in range(len(done)):
|
||||
if done[d] == True:
|
||||
# Episode done.
|
||||
epoch_episode_rewards.append(episode_reward[d])
|
||||
episode_rewards_history.append(episode_reward[d])
|
||||
epoch_episode_steps.append(episode_step[d])
|
||||
episode_reward[d] = 0.
|
||||
episode_step[d] = 0
|
||||
epoch_episodes += 1
|
||||
episodes += 1
|
||||
if B == 1:
|
||||
agent.reset()
|
||||
|
||||
|
||||
agent.reset()
|
||||
obs = env.reset()
|
||||
|
||||
# Train.
|
||||
epoch_actor_losses = []
|
||||
@@ -168,7 +183,8 @@ def learn(network, env,
|
||||
eval_episode_rewards = []
|
||||
eval_qs = []
|
||||
if eval_env is not None:
|
||||
eval_episode_reward = 0.
|
||||
B = eval_obs.shape[0]
|
||||
eval_episode_reward = np.zeros(B, dtype = np.float32)
|
||||
for t_rollout in range(nb_eval_steps):
|
||||
eval_action, eval_q, _, _ = agent.step(eval_obs, apply_noise=False, compute_Q=True)
|
||||
eval_obs, eval_r, eval_done, eval_info = eval_env.step(max_action * eval_action) # scale for execution in env (as far as DDPG is concerned, every action is in [-1, 1])
|
||||
@@ -177,11 +193,11 @@ def learn(network, env,
|
||||
eval_episode_reward += eval_r
|
||||
|
||||
eval_qs.append(eval_q)
|
||||
if eval_done:
|
||||
eval_obs = eval_env.reset()
|
||||
eval_episode_rewards.append(eval_episode_reward)
|
||||
eval_episode_rewards_history.append(eval_episode_reward)
|
||||
eval_episode_reward = 0.
|
||||
for d in range(len(eval_done)):
|
||||
if eval_done[d] == True:
|
||||
eval_episode_rewards.append(eval_episode_reward[d])
|
||||
eval_episode_rewards_history.append(eval_episode_reward[d])
|
||||
eval_episode_reward[d] = 0.0
|
||||
|
||||
mpi_size = MPI.COMM_WORLD.Get_size()
|
||||
# Log stats.
|
||||
@@ -216,7 +232,8 @@ def learn(network, env,
|
||||
return x
|
||||
else:
|
||||
raise ValueError('expected scalar, got %s'%x)
|
||||
combined_stats_sums = MPI.COMM_WORLD.allreduce(np.array([as_scalar(x) for x in combined_stats.values()]))
|
||||
|
||||
combined_stats_sums = MPI.COMM_WORLD.allreduce(np.array([ np.array(x).flatten()[0] for x in combined_stats.values()]))
|
||||
combined_stats = {k : v / mpi_size for (k,v) in zip(combined_stats.keys(), combined_stats_sums)}
|
||||
|
||||
# Total statistics.
|
||||
@@ -225,7 +242,9 @@ def learn(network, env,
|
||||
|
||||
for key in sorted(combined_stats.keys()):
|
||||
logger.record_tabular(key, combined_stats[key])
|
||||
logger.dump_tabular()
|
||||
|
||||
if rank == 0:
|
||||
logger.dump_tabular()
|
||||
logger.info('')
|
||||
logdir = logger.get_dir()
|
||||
if rank == 0 and logdir:
|
||||
|
@@ -265,19 +265,24 @@ class DDPG(object):
|
||||
else:
|
||||
action = self.sess.run(actor_tf, feed_dict=feed_dict)
|
||||
q = None
|
||||
action = action.flatten()
|
||||
|
||||
if self.action_noise is not None and apply_noise:
|
||||
noise = self.action_noise()
|
||||
assert noise.shape == action.shape
|
||||
action += noise
|
||||
action = np.clip(action, self.action_range[0], self.action_range[1])
|
||||
|
||||
|
||||
return action, q, None, None
|
||||
|
||||
def store_transition(self, obs0, action, reward, obs1, terminal1):
|
||||
reward *= self.reward_scale
|
||||
self.memory.append(obs0, action, reward, obs1, terminal1)
|
||||
if self.normalize_observations:
|
||||
self.obs_rms.update(np.array([obs0]))
|
||||
|
||||
B = obs0.shape[0]
|
||||
for b in range(B):
|
||||
self.memory.append(obs0[b], action[b], reward[b], obs1[b], terminal1[b])
|
||||
if self.normalize_observations:
|
||||
self.obs_rms.update(np.array([obs0[b]]))
|
||||
|
||||
def train(self):
|
||||
# Get a batch.
|
||||
|
Reference in New Issue
Block a user