fix DQN learning bug (#632)

* Update run.py

* Update utils.py

* Update utils.py
This commit is contained in:
pzhokhov
2018-10-03 14:37:40 -07:00
committed by GitHub
parent 34ae3194b4
commit 4121d9c1a8
2 changed files with 1 additions and 26 deletions

View File

@@ -1,8 +1,6 @@
from baselines.common.input import observation_input
from baselines.common.tf_util import adjust_shape
import tensorflow as tf
# ================================================================
# Placeholders
# ================================================================
@@ -40,29 +38,6 @@ class PlaceholderTfInput(TfInput):
return {self._placeholder: adjust_shape(self._placeholder, data)}
class Uint8Input(PlaceholderTfInput):
def __init__(self, shape, name=None):
"""Takes input in uint8 format which is cast to float32 and divided by 255
before passing it to the model.
On GPU this ensures lower data transfer times.
Parameters
----------
shape: [int]
shape of the tensor.
name: str
name of the underlying placeholder
"""
super().__init__(tf.placeholder(tf.uint8, [None] + list(shape), name=name))
self._shape = shape
self._output = tf.cast(super().get(), tf.float32) / 255.0
def get(self):
return self._output
class ObservationInput(PlaceholderTfInput):
def __init__(self, observation_space, name=None):
"""Creates an input placeholder tailored to a specific observation space

View File

@@ -99,7 +99,7 @@ def build_env(args):
env = atari_wrappers.make_atari(env_id)
env.seed(seed)
env = bench.Monitor(env, logger.get_dir())
env = atari_wrappers.wrap_deepmind(env, frame_stack=True, scale=True)
env = atari_wrappers.wrap_deepmind(env, frame_stack=True)
elif alg == 'trpo_mpi':
env = atari_wrappers.make_atari(env_id)
env.seed(seed)