improves for 'play' option.
This commit is contained in:
@@ -129,8 +129,13 @@ class Model(object):
|
||||
if MPI is not None:
|
||||
sync_from_root(sess, global_variables) # pylint: disable=E1101
|
||||
|
||||
def step(self, *args, **kwargs):
|
||||
return self.act_model.step(*args, **kwargs)
|
||||
def step_as_dict(self, **kwargs):
|
||||
return self.act_model.step(**kwargs)
|
||||
|
||||
def step(self, observations, **kwargs):
|
||||
kwargs.update({'observations': observations})
|
||||
transition = self.act_model.step(**kwargs)
|
||||
return transition['actions'], transition['values'], transition['states'], transition['neglogpacs']
|
||||
|
||||
def train(self,
|
||||
lr,
|
||||
|
@@ -54,7 +54,7 @@ class Runner(AbstractEnvRunner):
|
||||
transition = {}
|
||||
transition['obs'] = self.obs.copy()
|
||||
transition['dones'] = dones
|
||||
transition.update(self.model.step(observations=self.obs, **prev_state))
|
||||
transition.update(self.model.step_as_dict(observations=self.obs, **prev_state))
|
||||
transition['values'] = transition['values']
|
||||
|
||||
# Take actions in env and look the results
|
||||
@@ -84,7 +84,7 @@ class Runner(AbstractEnvRunner):
|
||||
dtype = data_type[key] if key in data_type else np.float
|
||||
minibatch[key] = np.asarray(minibatch[key], dtype=dtype)
|
||||
|
||||
last_values = self.model.step(observations=self.obs, **self.states)['values']
|
||||
last_values = self.model.step_as_dict(observations=self.obs, **self.states)['values']
|
||||
|
||||
# Calculate returns and advantages.
|
||||
minibatch['advs'], minibatch['returns'] = \
|
||||
|
Reference in New Issue
Block a user