diff --git a/baselines/ppo2/model.py b/baselines/ppo2/model.py index 7437043..45629c8 100644 --- a/baselines/ppo2/model.py +++ b/baselines/ppo2/model.py @@ -129,8 +129,13 @@ class Model(object): if MPI is not None: sync_from_root(sess, global_variables) # pylint: disable=E1101 - def step(self, *args, **kwargs): - return self.act_model.step(*args, **kwargs) + def step_as_dict(self, **kwargs): + return self.act_model.step(**kwargs) + + def step(self, observations, **kwargs): + kwargs.update({'observations': observations}) + transition = self.act_model.step(**kwargs) + return transition['actions'], transition['values'], transition['states'], transition['neglogpacs'] def train(self, lr, diff --git a/baselines/ppo2/runner.py b/baselines/ppo2/runner.py index 9a3c770..c87a85c 100644 --- a/baselines/ppo2/runner.py +++ b/baselines/ppo2/runner.py @@ -54,7 +54,7 @@ class Runner(AbstractEnvRunner): transition = {} transition['obs'] = self.obs.copy() transition['dones'] = dones - transition.update(self.model.step(observations=self.obs, **prev_state)) + transition.update(self.model.step_as_dict(observations=self.obs, **prev_state)) transition['values'] = transition['values'] # Take actions in env and look the results @@ -84,7 +84,7 @@ class Runner(AbstractEnvRunner): dtype = data_type[key] if key in data_type else np.float minibatch[key] = np.asarray(minibatch[key], dtype=dtype) - last_values = self.model.step(observations=self.obs, **self.states)['values'] + last_values = self.model.step_as_dict(observations=self.obs, **self.states)['values'] # Calculate returns and advantages. minibatch['advs'], minibatch['returns'] = \