From c2df27bee4c8153cebebcc19c54be0648a4a2ae3 Mon Sep 17 00:00:00 2001 From: Peter Zhokhov Date: Thu, 2 Aug 2018 09:41:41 -0700 Subject: [PATCH] non-tf normalization benchmark RUN BENCHMARKS --- baselines/common/vec_env/vec_normalize.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/baselines/common/vec_env/vec_normalize.py b/baselines/common/vec_env/vec_normalize.py index f47e0e5..37ee02f 100644 --- a/baselines/common/vec_env/vec_normalize.py +++ b/baselines/common/vec_env/vec_normalize.py @@ -1,5 +1,5 @@ from baselines.common.vec_env import VecEnvWrapper -from baselines.common.running_mean_std import TfRunningMeanStd +from baselines.common.running_mean_std import TfRunningMeanStd, RunningMeanStd import numpy as np class VecNormalize(VecEnvWrapper): @@ -8,10 +8,10 @@ class VecNormalize(VecEnvWrapper): """ def __init__(self, venv, ob=True, ret=True, clipob=10., cliprew=10., gamma=0.99, epsilon=1e-8): VecEnvWrapper.__init__(self, venv) - #self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None - #self.ret_rms = RunningMeanStd(shape=()) if ret else None - self.ob_rms = TfRunningMeanStd(shape=self.observation_space.shape, scope='observation_running_mean_std') if ob else None - self.ret_rms = TfRunningMeanStd(shape=(), scope='return_running_mean_std') if ret else None + self.ob_rms = RunningMeanStd(shape=self.observation_space.shape) if ob else None + self.ret_rms = RunningMeanStd(shape=()) if ret else None + #self.ob_rms = TfRunningMeanStd(shape=self.observation_space.shape, scope='observation_running_mean_std') if ob else None + #self.ret_rms = TfRunningMeanStd(shape=(), scope='return_running_mean_std') if ret else None self.clipob = clipob self.cliprew = cliprew self.ret = np.zeros(self.num_envs)