reverted VecNormalize to use RunningMeanStd (no tf)

This commit is contained in:
Peter Zhokhov
2018-08-02 10:55:09 -07:00
parent f6d1115ead
commit 1c5c6563b7

View File

@@ -1,5 +1,5 @@
from baselines.common.vec_env import VecEnvWrapper
from baselines.common.running_mean_std import TfRunningMeanStd
from baselines.common.running_mean_std import RunningMeanStd
import numpy as np
class VecNormalize(VecEnvWrapper):
@@ -8,10 +8,10 @@ class VecNormalize(VecEnvWrapper):
"""
def __init__(self, venv, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8):
VecEnvWrapper.__init__(self, venv)
#self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
#self.ret_rms = RunningMeanStd(shape=()) if ret else None
self.ob_rms = TfRunningMeanStd(shape=self.observation_space.shape, scope='observation_running_mean_std') if ob else None
self.ret_rms = TfRunningMeanStd(shape=(), scope='return_running_mean_std') if ret else None
self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None
self.ret_rms = RunningMeanStd(shape=()) if ret else None
#self.ob_rms = TfRunningMeanStd(shape=self.observation_space.shape, scope='observation_running_mean_std') if ob else None
#self.ret_rms = TfRunningMeanStd(shape=(), scope='return_running_mean_std') if ret else None
self.clipob = clipob
self.cliprew = cliprew
self.ret = np.zeros(self.num_envs)