This commit is contained in:
Jesse Farebrother
2019-06-21 19:50:54 -04:00
committed by pzhokhov
parent f4f2b6a133
commit fc6891c9c8
2 changed files with 23 additions and 1 deletions

View File

@@ -26,6 +26,8 @@ class AtariEnv(gym.Env, utils.EzPickle):
def __init__( def __init__(
self, self,
game='pong', game='pong',
mode=0,
difficulty=0,
obs_type='ram', obs_type='ram',
frameskip=(2, 5), frameskip=(2, 5),
repeat_action_probability=0., repeat_action_probability=0.,
@@ -36,12 +38,18 @@ class AtariEnv(gym.Env, utils.EzPickle):
utils.EzPickle.__init__( utils.EzPickle.__init__(
self, self,
game, game,
mode,
difficulty,
obs_type, obs_type,
frameskip, frameskip,
repeat_action_probability) repeat_action_probability)
assert obs_type in ('ram', 'image') assert obs_type in ('ram', 'image')
self.game = game
self.game_path = atari_py.get_game_path(game) self.game_path = atari_py.get_game_path(game)
self.game_mode = mode
self.game_difficulty = difficulty
if not os.path.exists(self.game_path): if not os.path.exists(self.game_path):
msg = 'You asked for game %s but path %s does not exist' msg = 'You asked for game %s but path %s does not exist'
raise IOError(msg % (game, self.game_path)) raise IOError(msg % (game, self.game_path))
@@ -81,6 +89,20 @@ class AtariEnv(gym.Env, utils.EzPickle):
# Empirically, we need to seed before loading the ROM. # Empirically, we need to seed before loading the ROM.
self.ale.setInt(b'random_seed', seed2) self.ale.setInt(b'random_seed', seed2)
self.ale.loadROM(self.game_path) self.ale.loadROM(self.game_path)
modes = self.ale.getAvailableModes()
difficulties = self.ale.getAvailableDifficulties()
assert self.game_mode in modes, (
"Invalid game mode \"{}\" for game {}.\nAvailable modes are: {}"
).format(self.game_mode, self.game, modes)
assert self.game_difficulty in difficulties, (
"Invalid game difficulty \"{}\" for game {}.\nAvailable difficulties are: {}"
).format(self.game_difficulty, self.game, difficulties)
self.ale.setMode(self.game_mode)
self.ale.setDifficulty(self.game_difficulty)
return [seed1, seed2] return [seed1, seed2]
def step(self, a): def step(self, a):

View File

@@ -7,7 +7,7 @@ from version import VERSION
# Environment-specific dependencies. # Environment-specific dependencies.
extras = { extras = {
'atari': ['atari_py~=0.1.4', 'Pillow', 'opencv-python'], 'atari': ['atari_py~=0.2.0', 'Pillow', 'opencv-python'],
'box2d': ['box2d-py~=2.3.5'], 'box2d': ['box2d-py~=2.3.5'],
'classic_control': [], 'classic_control': [],
'mujoco': ['mujoco_py>=1.50, <2.1', 'imageio'], 'mujoco': ['mujoco_py>=1.50, <2.1', 'imageio'],