diff --git a/baselines/ppo2/ppo2.py b/baselines/ppo2/ppo2.py index 122e09c..f1bbc79 100644 --- a/baselines/ppo2/ppo2.py +++ b/baselines/ppo2/ppo2.py @@ -72,6 +72,7 @@ class Model(object): for p, loaded_p in zip(params, loaded_params): restores.append(p.assign(loaded_p)) sess.run(restores) + # If you want to load weights, also save/load observation scaling inside VecNormalize self.train = train self.train_model = train_model