minor fixes from internal (#665)

* sync internal changes. Make ddpg work with vecenvs

* B -> nenvs for consistency with other algos, small cleanups

* eval_done[d]==True -> eval_done[d]

* flake8 and numpy.random.random_integers deprecation warning

* Merge branch 'master' of github.com:openai/games into peterz_track_baselines_branch
This commit is contained in:
pzhokhov
2018-10-22 09:15:04 -07:00
committed by GitHub
parent bd390c2ade
commit c0fa11a3a7
4 changed files with 16 additions and 11 deletions

View File

@@ -213,8 +213,11 @@ class LazyFrames(object):
def __getitem__(self, i):
return self._force()[i]
def make_atari(env_id):
def make_atari(env_id, timelimit=True):
# XXX(john): remove timelimit argument after gym is upgraded to allow double wrapping
env = gym.make(env_id)
if not timelimit:
env = env.env
assert 'NoFrameskip' in env.spec.id
env = NoopResetEnv(env, noop_max=30)
env = MaxAndSkipEnv(env, skip=4)

View File

@@ -28,6 +28,3 @@ class VecFrameStack(VecEnvWrapper):
self.stackedobs[...] = 0
self.stackedobs[..., -obs.shape[-1]:] = obs
return self.stackedobs
def close(self):
self.venv.close()

View File

@@ -106,7 +106,8 @@ class CSVOutputFormat(KVWriter):
def writekvs(self, kvs):
# Add our current row to the history
extra_keys = kvs.keys() - self.keys
extra_keys = list(kvs.keys() - self.keys)
extra_keys.sort()
if extra_keys:
self.keys.extend(extra_keys)
self.file.seek(0)

View File

@@ -48,9 +48,13 @@ setup(name='baselines',
# ensure there is some tensorflow build with version above 1.4
try:
from distutils.version import StrictVersion
import tensorflow
assert StrictVersion(re.sub(r'-rc\d+$', '', tensorflow.__version__)) >= StrictVersion('1.4.0')
except ImportError:
assert False, "TensorFlow needed, of version above 1.4"
import pkg_resources
tf_pkg = None
for tf_pkg_name in ['tensorflow', 'tensorflow-gpu']:
try:
tf_pkg = pkg_resources.get_distribution(tf_pkg_name)
except pkg_resources.DistributionNotFound:
pass
assert tf_pkg is not None, 'TensorFlow needed, of version above 1.4'
from distutils.version import StrictVersion
assert StrictVersion(re.sub(r'-rc\d+$', '', tf_pkg.version)) >= StrictVersion('1.4.0')