fix DQN learning bug (#632)
* Update run.py * Update utils.py * Update utils.py
This commit is contained in:
@@ -1,8 +1,6 @@
|
||||
from baselines.common.input import observation_input
|
||||
from baselines.common.tf_util import adjust_shape
|
||||
|
||||
import tensorflow as tf
|
||||
|
||||
# ================================================================
|
||||
# Placeholders
|
||||
# ================================================================
|
||||
@@ -40,29 +38,6 @@ class PlaceholderTfInput(TfInput):
|
||||
return {self._placeholder: adjust_shape(self._placeholder, data)}
|
||||
|
||||
|
||||
class Uint8Input(PlaceholderTfInput):
|
||||
def __init__(self, shape, name=None):
|
||||
"""Takes input in uint8 format which is cast to float32 and divided by 255
|
||||
before passing it to the model.
|
||||
|
||||
On GPU this ensures lower data transfer times.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
shape: [int]
|
||||
shape of the tensor.
|
||||
name: str
|
||||
name of the underlying placeholder
|
||||
"""
|
||||
|
||||
super().__init__(tf.placeholder(tf.uint8, [None] + list(shape), name=name))
|
||||
self._shape = shape
|
||||
self._output = tf.cast(super().get(), tf.float32) / 255.0
|
||||
|
||||
def get(self):
|
||||
return self._output
|
||||
|
||||
|
||||
class ObservationInput(PlaceholderTfInput):
|
||||
def __init__(self, observation_space, name=None):
|
||||
"""Creates an input placeholder tailored to a specific observation space
|
||||
|
@@ -99,7 +99,7 @@ def build_env(args):
|
||||
env = atari_wrappers.make_atari(env_id)
|
||||
env.seed(seed)
|
||||
env = bench.Monitor(env, logger.get_dir())
|
||||
env = atari_wrappers.wrap_deepmind(env, frame_stack=True, scale=True)
|
||||
env = atari_wrappers.wrap_deepmind(env, frame_stack=True)
|
||||
elif alg == 'trpo_mpi':
|
||||
env = atari_wrappers.make_atari(env_id)
|
||||
env.seed(seed)
|
||||
|
Reference in New Issue
Block a user