2016-04-27 08:00:58 -07:00
|
|
|
"""
|
2016-10-21 16:06:48 -07:00
|
|
|
Task is to return every nth character from the input tape.
|
2016-04-27 08:00:58 -07:00
|
|
|
http://arxiv.org/abs/1511.07275
|
|
|
|
"""
|
2016-10-21 16:06:48 -07:00
|
|
|
from __future__ import division
|
2016-04-27 08:00:58 -07:00
|
|
|
import numpy as np
|
|
|
|
from gym.envs.algorithmic import algorithmic_env
|
|
|
|
|
2016-10-21 16:06:48 -07:00
|
|
|
class DuplicatedInputEnv(algorithmic_env.TapeAlgorithmicEnv):
|
2016-04-27 08:00:58 -07:00
|
|
|
def __init__(self, duplication=2, base=5):
|
|
|
|
self.duplication = duplication
|
2016-10-21 16:06:48 -07:00
|
|
|
super(DuplicatedInputEnv, self).__init__(base=base, chars=True)
|
|
|
|
|
|
|
|
def generate_input_data(self, size):
|
|
|
|
res = []
|
|
|
|
if size < self.duplication:
|
|
|
|
size = self.duplication
|
|
|
|
for i in range(size//self.duplication):
|
|
|
|
char = self.np_random.randint(self.base)
|
|
|
|
for _ in range(self.duplication):
|
|
|
|
res.append(char)
|
|
|
|
return res
|
|
|
|
|
|
|
|
def target_from_input_data(self, input_data):
|
|
|
|
return [input_data[i] for i in range(0, len(input_data), self.duplication)]
|