Get rid of class variables in algorithmic env. (#1910)

Make reward_shortfalls and min_length instance variables.
Fixes the TODO in algorithmic_env.py.

Co-authored-by: pzhokhov <peterz@openai.com>
This commit is contained in:
InstanceLabs
2020-05-29 23:23:44 +02:00
committed by GitHub
parent 821841c1a1
commit d8908cbf10
2 changed files with 13 additions and 24 deletions

View File

@@ -62,16 +62,13 @@ class AlgorithmicEnv(Env):
self.episode_total_reward = None self.episode_total_reward = None
# Running tally of reward shortfalls. e.g. if there were 10 points to # Running tally of reward shortfalls. e.g. if there were 10 points to
# earn and we got 8, we'd append -2 # earn and we got 8, we'd append -2
AlgorithmicEnv.reward_shortfalls = [] self.reward_shortfalls = []
if chars: if chars:
self.charmap = [chr(ord('A')+i) for i in range(base)] self.charmap = [chr(ord('A')+i) for i in range(base)]
else: else:
self.charmap = [str(i) for i in range(base)] self.charmap = [str(i) for i in range(base)]
self.charmap.append(' ') self.charmap.append(' ')
# TODO: Not clear why this is a class variable rather than instance. self.min_length = starting_min_length
# Could lead to some spooky action at a distance if someone is working
# with multiple algorithmic envs at once. Also makes testing tricky.
AlgorithmicEnv.min_length = starting_min_length
# Three sub-actions: # Three sub-actions:
# 1. Move read head left or right (or up/down) # 1. Move read head left or right (or up/down)
# 2. Write or not # 2. Write or not
@@ -211,15 +208,13 @@ class AlgorithmicEnv(Env):
if self.episode_total_reward is None: if self.episode_total_reward is None:
# This is before the first episode/call to reset(). Nothing to do. # This is before the first episode/call to reset(). Nothing to do.
return return
AlgorithmicEnv.reward_shortfalls.append( self.reward_shortfalls.append(self.episode_total_reward - len(self.target))
self.episode_total_reward - len(self.target) self.reward_shortfalls = self.reward_shortfalls[-self.last:]
) if len(self.reward_shortfalls) == self.last and \
AlgorithmicEnv.reward_shortfalls = AlgorithmicEnv.reward_shortfalls[-self.last:] min(self.reward_shortfalls) >= self.MIN_REWARD_SHORTFALL_FOR_PROMOTION and \
if len(AlgorithmicEnv.reward_shortfalls) == self.last and \ self.min_length < 30:
min(AlgorithmicEnv.reward_shortfalls) >= self.MIN_REWARD_SHORTFALL_FOR_PROMOTION and \ self.min_length += 1
AlgorithmicEnv.min_length < 30: self.reward_shortfalls = []
AlgorithmicEnv.min_length += 1
AlgorithmicEnv.reward_shortfalls = []
def reset(self): def reset(self):
self._check_levelup() self._check_levelup()
@@ -229,7 +224,7 @@ class AlgorithmicEnv(Env):
self.write_head_position = 0 self.write_head_position = 0
self.episode_total_reward = 0.0 self.episode_total_reward = 0.0
self.time = 0 self.time = 0
length = self.np_random.randint(3) + AlgorithmicEnv.min_length length = self.np_random.randint(3) + self.min_length
self.input_data = self.generate_input_data(length) self.input_data = self.generate_input_data(length)
self.target = self.target_from_input_data(self.input_data) self.target = self.target_from_input_data(self.input_data)
return self._get_obs() return self._get_obs()

View File

@@ -58,7 +58,7 @@ class TestAlgorithmicEnvInteractions(unittest.TestCase):
def test_levelup(self): def test_levelup(self):
obs = self.env.reset() obs = self.env.reset()
# Kind of a hack # Kind of a hack
alg.algorithmic_env.AlgorithmicEnv.reward_shortfalls = [] self.env.reward_shortfalls = []
min_length = self.env.min_length min_length = self.env.min_length
for i in range(self.env.last): for i in range(self.env.last):
obs, reward, done, _ = self.env.step([self.RIGHT, 1, 0]) obs, reward, done, _ = self.env.step([self.RIGHT, 1, 0])
@@ -67,17 +67,11 @@ class TestAlgorithmicEnvInteractions(unittest.TestCase):
self.assertTrue(done) self.assertTrue(done)
self.env.reset() self.env.reset()
if i < self.env.last-1: if i < self.env.last-1:
self.assertEqual( self.assertEqual(len(self.env.reward_shortfalls), i+1)
len(alg.algorithmic_env.AlgorithmicEnv.reward_shortfalls),
i+1
)
else: else:
# Should have leveled up on the last iteration # Should have leveled up on the last iteration
self.assertEqual(self.env.min_length, min_length+1) self.assertEqual(self.env.min_length, min_length+1)
self.assertEqual( self.assertEqual(len(self.env.reward_shortfalls), 0)
len(alg.algorithmic_env.AlgorithmicEnv.reward_shortfalls),
0
)
def test_walk_off_the_end(self): def test_walk_off_the_end(self):
obs = self.env.reset() obs = self.env.reset()