improves for 'play' option.

This commit is contained in:
gyunt
2019-03-21 20:45:01 +09:00
parent 1ecc242aec
commit 299d7d2f86
2 changed files with 9 additions and 4 deletions

View File

@@ -129,8 +129,13 @@ class Model(object):
if MPI is not None:
sync_from_root(sess, global_variables) # pylint: disable=E1101
def step(self, *args, **kwargs):
return self.act_model.step(*args, **kwargs)
def step_as_dict(self, **kwargs):
return self.act_model.step(**kwargs)
def step(self, observations, **kwargs):
kwargs.update({'observations': observations})
transition = self.act_model.step(**kwargs)
return transition['actions'], transition['values'], transition['states'], transition['neglogpacs']
def train(self,
lr,

View File

@@ -54,7 +54,7 @@ class Runner(AbstractEnvRunner):
transition = {}
transition['obs'] = self.obs.copy()
transition['dones'] = dones
transition.update(self.model.step(observations=self.obs, **prev_state))
transition.update(self.model.step_as_dict(observations=self.obs, **prev_state))
transition['values'] = transition['values']
# Take actions in env and look the results
@@ -84,7 +84,7 @@ class Runner(AbstractEnvRunner):
dtype = data_type[key] if key in data_type else np.float
minibatch[key] = np.asarray(minibatch[key], dtype=dtype)
last_values = self.model.step(observations=self.obs, **self.states)['values']
last_values = self.model.step_as_dict(observations=self.obs, **self.states)['values']
# Calculate returns and advantages.
minibatch['advs'], minibatch['returns'] = \