improves for 'play' option.

This commit is contained in:
gyunt
2019-03-21 20:45:01 +09:00
parent 1ecc242aec
commit 299d7d2f86
2 changed files with 9 additions and 4 deletions

View File

@@ -129,8 +129,13 @@ class Model(object):
if MPI is not None: if MPI is not None:
sync_from_root(sess, global_variables) # pylint: disable=E1101 sync_from_root(sess, global_variables) # pylint: disable=E1101
def step(self, *args, **kwargs): def step_as_dict(self, **kwargs):
return self.act_model.step(*args, **kwargs) return self.act_model.step(**kwargs)
def step(self, observations, **kwargs):
kwargs.update({'observations': observations})
transition = self.act_model.step(**kwargs)
return transition['actions'], transition['values'], transition['states'], transition['neglogpacs']
def train(self, def train(self,
lr, lr,

View File

@@ -54,7 +54,7 @@ class Runner(AbstractEnvRunner):
transition = {} transition = {}
transition['obs'] = self.obs.copy() transition['obs'] = self.obs.copy()
transition['dones'] = dones transition['dones'] = dones
transition.update(self.model.step(observations=self.obs, **prev_state)) transition.update(self.model.step_as_dict(observations=self.obs, **prev_state))
transition['values'] = transition['values'] transition['values'] = transition['values']
# Take actions in env and look the results # Take actions in env and look the results
@@ -84,7 +84,7 @@ class Runner(AbstractEnvRunner):
dtype = data_type[key] if key in data_type else np.float dtype = data_type[key] if key in data_type else np.float
minibatch[key] = np.asarray(minibatch[key], dtype=dtype) minibatch[key] = np.asarray(minibatch[key], dtype=dtype)
last_values = self.model.step(observations=self.obs, **self.states)['values'] last_values = self.model.step_as_dict(observations=self.obs, **self.states)['values']
# Calculate returns and advantages. # Calculate returns and advantages.
minibatch['advs'], minibatch['returns'] = \ minibatch['advs'], minibatch['returns'] = \