diff --git a/baselines/ppo2/model.py b/baselines/ppo2/model.py index 3ba3e82..bbaa4d8 100644 --- a/baselines/ppo2/model.py +++ b/baselines/ppo2/model.py @@ -136,7 +136,8 @@ class Model(object): def step(self, observations, **kwargs): kwargs.update({'observations': observations}) transition = self.act_model.step(**kwargs) - return transition['actions'], transition['values'], transition['states'], transition['neglogpacs'] + states = transition['states'] if 'states' in transition else None + return transition['actions'], transition['values'], states, transition['neglogpacs'] def train(self, lr,