fix hand serialization tests (#1647)

* fix serialization by moving EzPickle inheritance to the leaf classes and passing correct arguments to EzPickle

* fix touch sensor serialization too
This commit is contained in:
pzhokhov
2019-08-14 10:28:27 -07:00
committed by GitHub
parent f99ce5f324
commit cedecb35e3
2 changed files with 23 additions and 17 deletions

View File

@@ -25,10 +25,10 @@ MANIPULATE_EGG_XML = os.path.join('hand', 'manipulate_egg.xml')
MANIPULATE_PEN_XML = os.path.join('hand', 'manipulate_pen.xml') MANIPULATE_PEN_XML = os.path.join('hand', 'manipulate_pen.xml')
class ManipulateEnv(hand_env.HandEnv, utils.EzPickle): class ManipulateEnv(hand_env.HandEnv):
def __init__( def __init__(
self, model_path, target_position, target_rotation, self, model_path, target_position, target_rotation,
target_position_range, reward_type, initial_qpos={}, target_position_range, reward_type, initial_qpos=None,
randomize_initial_position=True, randomize_initial_rotation=True, randomize_initial_position=True, randomize_initial_rotation=True,
distance_threshold=0.01, rotation_threshold=0.1, n_substeps=20, relative_control=False, distance_threshold=0.01, rotation_threshold=0.1, n_substeps=20, relative_control=False,
ignore_z_target_rotation=False, ignore_z_target_rotation=False,
@@ -71,11 +71,12 @@ class ManipulateEnv(hand_env.HandEnv, utils.EzPickle):
assert self.target_position in ['ignore', 'fixed', 'random'] assert self.target_position in ['ignore', 'fixed', 'random']
assert self.target_rotation in ['ignore', 'fixed', 'xyz', 'z', 'parallel'] assert self.target_rotation in ['ignore', 'fixed', 'xyz', 'z', 'parallel']
initial_qpos = initial_qpos or {}
hand_env.HandEnv.__init__( hand_env.HandEnv.__init__(
self, model_path, n_substeps=n_substeps, initial_qpos=initial_qpos, self, model_path, n_substeps=n_substeps, initial_qpos=initial_qpos,
relative_control=relative_control) relative_control=relative_control)
utils.EzPickle.__init__(self)
def _get_achieved_goal(self): def _get_achieved_goal(self):
# Object position and rotation. # Object position and rotation.
@@ -271,27 +272,30 @@ class ManipulateEnv(hand_env.HandEnv, utils.EzPickle):
} }
class HandBlockEnv(ManipulateEnv): class HandBlockEnv(ManipulateEnv, utils.EzPickle):
def __init__(self, target_position='random', target_rotation='xyz', reward_type='sparse'): def __init__(self, target_position='random', target_rotation='xyz', reward_type='sparse'):
super(HandBlockEnv, self).__init__( utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
ManipulateEnv.__init__(self,
model_path=MANIPULATE_BLOCK_XML, target_position=target_position, model_path=MANIPULATE_BLOCK_XML, target_position=target_position,
target_rotation=target_rotation, target_rotation=target_rotation,
target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]), target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]),
reward_type=reward_type) reward_type=reward_type)
class HandEggEnv(ManipulateEnv): class HandEggEnv(ManipulateEnv, utils.EzPickle):
def __init__(self, target_position='random', target_rotation='xyz', reward_type='sparse'): def __init__(self, target_position='random', target_rotation='xyz', reward_type='sparse'):
super(HandEggEnv, self).__init__( utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
ManipulateEnv.__init__(self,
model_path=MANIPULATE_EGG_XML, target_position=target_position, model_path=MANIPULATE_EGG_XML, target_position=target_position,
target_rotation=target_rotation, target_rotation=target_rotation,
target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]), target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]),
reward_type=reward_type) reward_type=reward_type)
class HandPenEnv(ManipulateEnv): class HandPenEnv(ManipulateEnv, utils.EzPickle):
def __init__(self, target_position='random', target_rotation='xyz', reward_type='sparse'): def __init__(self, target_position='random', target_rotation='xyz', reward_type='sparse'):
super(HandPenEnv, self).__init__( utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
ManipulateEnv.__init__(self,
model_path=MANIPULATE_PEN_XML, target_position=target_position, model_path=MANIPULATE_PEN_XML, target_position=target_position,
target_rotation=target_rotation, target_rotation=target_rotation,
target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]), target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]),

View File

@@ -10,7 +10,7 @@ MANIPULATE_EGG_XML = os.path.join('hand', 'manipulate_egg_touch_sensors.xml')
MANIPULATE_PEN_XML = os.path.join('hand', 'manipulate_pen_touch_sensors.xml') MANIPULATE_PEN_XML = os.path.join('hand', 'manipulate_pen_touch_sensors.xml')
class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv, utils.EzPickle): class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
def __init__( def __init__(
self, model_path, target_position, target_rotation, self, model_path, target_position, target_rotation,
target_position_range, reward_type, initial_qpos={}, target_position_range, reward_type, initial_qpos={},
@@ -46,7 +46,6 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv, utils.EzPickle):
distance_threshold=distance_threshold, rotation_threshold=rotation_threshold, n_substeps=n_substeps, relative_control=relative_control, distance_threshold=distance_threshold, rotation_threshold=rotation_threshold, n_substeps=n_substeps, relative_control=relative_control,
ignore_z_target_rotation=ignore_z_target_rotation, ignore_z_target_rotation=ignore_z_target_rotation,
) )
utils.EzPickle.__init__(self)
for k, v in self.sim.model._sensor_name2id.items(): # get touch sensor site names and their ids for k, v in self.sim.model._sensor_name2id.items(): # get touch sensor site names and their ids
if 'robot0:TS_' in k: if 'robot0:TS_' in k:
@@ -95,9 +94,10 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv, utils.EzPickle):
} }
class HandBlockTouchSensorsEnv(ManipulateTouchSensorsEnv): class HandBlockTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
def __init__(self, target_position='random', target_rotation='xyz', touch_get_obs='sensordata', reward_type='sparse'): def __init__(self, target_position='random', target_rotation='xyz', touch_get_obs='sensordata', reward_type='sparse'):
super(HandBlockTouchSensorsEnv, self).__init__( utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
ManipulateTouchSensorsEnv.__init__(self,
model_path=MANIPULATE_BLOCK_XML, model_path=MANIPULATE_BLOCK_XML,
touch_get_obs=touch_get_obs, touch_get_obs=touch_get_obs,
target_rotation=target_rotation, target_rotation=target_rotation,
@@ -106,9 +106,10 @@ class HandBlockTouchSensorsEnv(ManipulateTouchSensorsEnv):
reward_type=reward_type) reward_type=reward_type)
class HandEggTouchSensorsEnv(ManipulateTouchSensorsEnv): class HandEggTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
def __init__(self, target_position='random', target_rotation='xyz', touch_get_obs='sensordata', reward_type='sparse'): def __init__(self, target_position='random', target_rotation='xyz', touch_get_obs='sensordata', reward_type='sparse'):
super(HandEggTouchSensorsEnv, self).__init__( utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
ManipulateTouchSensorsEnv.__init__(self,
model_path=MANIPULATE_EGG_XML, model_path=MANIPULATE_EGG_XML,
touch_get_obs=touch_get_obs, touch_get_obs=touch_get_obs,
target_rotation=target_rotation, target_rotation=target_rotation,
@@ -117,9 +118,10 @@ class HandEggTouchSensorsEnv(ManipulateTouchSensorsEnv):
reward_type=reward_type) reward_type=reward_type)
class HandPenTouchSensorsEnv(ManipulateTouchSensorsEnv): class HandPenTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
def __init__(self, target_position='random', target_rotation='xyz', touch_get_obs='sensordata', reward_type='sparse'): def __init__(self, target_position='random', target_rotation='xyz', touch_get_obs='sensordata', reward_type='sparse'):
super(HandPenTouchSensorsEnv, self).__init__( utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
ManipulateTouchSensorsEnv.__init__(self,
model_path=MANIPULATE_PEN_XML, model_path=MANIPULATE_PEN_XML,
touch_get_obs=touch_get_obs, touch_get_obs=touch_get_obs,
target_rotation=target_rotation, target_rotation=target_rotation,