* import rl-algs from 2e3a166 commit * extra import of the baselines badge * exported commit with identity test * proper rng seeding in the test_identity
84 lines
2.4 KiB
Python
84 lines
2.4 KiB
Python
from baselines.common.input import observation_input
|
|
|
|
import tensorflow as tf
|
|
|
|
# ================================================================
|
|
# Placeholders
|
|
# ================================================================
|
|
|
|
|
|
class TfInput(object):
|
|
def __init__(self, name="(unnamed)"):
|
|
"""Generalized Tensorflow placeholder. The main differences are:
|
|
- possibly uses multiple placeholders internally and returns multiple values
|
|
- can apply light postprocessing to the value feed to placeholder.
|
|
"""
|
|
self.name = name
|
|
|
|
def get(self):
|
|
"""Return the tf variable(s) representing the possibly postprocessed value
|
|
of placeholder(s).
|
|
"""
|
|
raise NotImplemented()
|
|
|
|
def make_feed_dict(data):
|
|
"""Given data input it to the placeholder(s)."""
|
|
raise NotImplemented()
|
|
|
|
|
|
class PlaceholderTfInput(TfInput):
|
|
def __init__(self, placeholder):
|
|
"""Wrapper for regular tensorflow placeholder."""
|
|
super().__init__(placeholder.name)
|
|
self._placeholder = placeholder
|
|
|
|
def get(self):
|
|
return self._placeholder
|
|
|
|
def make_feed_dict(self, data):
|
|
return {self._placeholder: data}
|
|
|
|
|
|
class Uint8Input(PlaceholderTfInput):
|
|
def __init__(self, shape, name=None):
|
|
"""Takes input in uint8 format which is cast to float32 and divided by 255
|
|
before passing it to the model.
|
|
|
|
On GPU this ensures lower data transfer times.
|
|
|
|
Parameters
|
|
----------
|
|
shape: [int]
|
|
shape of the tensor.
|
|
name: str
|
|
name of the underlying placeholder
|
|
"""
|
|
|
|
super().__init__(tf.placeholder(tf.uint8, [None] + list(shape), name=name))
|
|
self._shape = shape
|
|
self._output = tf.cast(super().get(), tf.float32) / 255.0
|
|
|
|
def get(self):
|
|
return self._output
|
|
|
|
|
|
class ObservationInput(PlaceholderTfInput):
|
|
def __init__(self, observation_space, name=None):
|
|
"""Creates an input placeholder tailored to a specific observation space
|
|
|
|
Parameters
|
|
----------
|
|
|
|
observation_space:
|
|
observation space of the environment. Should be one of the gym.spaces types
|
|
name: str
|
|
tensorflow name of the underlying placeholder
|
|
"""
|
|
inpt, self.processed_inpt = observation_input(observation_space, name=name)
|
|
super().__init__(inpt)
|
|
|
|
def get(self):
|
|
return self.processed_inpt
|
|
|
|
|