git subrepo pull (merge) baselines
subrepo: subdir: "baselines" merged: "8785db28" upstream: origin: "git@github.com:openai/baselines.git" branch: "master" commit: "35e95ee8" git-subrepo: version: "0.4.0" origin: "git@github.com:ingydotnet/git-subrepo.git" commit: "74339e8"
This commit is contained in:
@@ -26,6 +26,7 @@ class VecNormalize(VecEnvWrapper):
|
|||||||
if self.ret_rms:
|
if self.ret_rms:
|
||||||
self.ret_rms.update(self.ret)
|
self.ret_rms.update(self.ret)
|
||||||
rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
|
rews = np.clip(rews / np.sqrt(self.ret_rms.var + self.epsilon), -self.cliprew, self.cliprew)
|
||||||
|
self.ret[news] = 0.
|
||||||
return obs, rews, news, infos
|
return obs, rews, news, infos
|
||||||
|
|
||||||
def _obfilt(self, obs):
|
def _obfilt(self, obs):
|
||||||
@@ -37,5 +38,6 @@ class VecNormalize(VecEnvWrapper):
|
|||||||
return obs
|
return obs
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
|
self.ret = np.zeros(self.num_envs)
|
||||||
obs = self.venv.reset()
|
obs = self.venv.reset()
|
||||||
return self._obfilt(obs)
|
return self._obfilt(obs)
|
||||||
|
@@ -53,7 +53,7 @@ _game_envs['retro'] = {
|
|||||||
|
|
||||||
def train(args, extra_args):
|
def train(args, extra_args):
|
||||||
env_type, env_id = get_env_type(args.env)
|
env_type, env_id = get_env_type(args.env)
|
||||||
print(f'env_type: {env_type}')
|
print('env_type: {}'.format(env_type))
|
||||||
|
|
||||||
total_timesteps = int(args.num_timesteps)
|
total_timesteps = int(args.num_timesteps)
|
||||||
seed = args.seed
|
seed = args.seed
|
||||||
|
Reference in New Issue
Block a user