From cedecb35e3428985fd4efad738befeb75b9077f1 Mon Sep 17 00:00:00 2001 From: pzhokhov Date: Wed, 14 Aug 2019 10:28:27 -0700 Subject: [PATCH] fix hand serialization tests (#1647) * fix serialization by moving EzPickle inheritance to the leaf classes and passing correct arguments to EzPickle * fix touch sensor serialization too --- gym/envs/robotics/hand/manipulate.py | 22 +++++++++++-------- .../robotics/hand/manipulate_touch_sensors.py | 18 ++++++++------- 2 files changed, 23 insertions(+), 17 deletions(-) diff --git a/gym/envs/robotics/hand/manipulate.py b/gym/envs/robotics/hand/manipulate.py index de55f3482..41b7bcff9 100644 --- a/gym/envs/robotics/hand/manipulate.py +++ b/gym/envs/robotics/hand/manipulate.py @@ -25,10 +25,10 @@ MANIPULATE_EGG_XML = os.path.join('hand', 'manipulate_egg.xml') MANIPULATE_PEN_XML = os.path.join('hand', 'manipulate_pen.xml') -class ManipulateEnv(hand_env.HandEnv, utils.EzPickle): +class ManipulateEnv(hand_env.HandEnv): def __init__( self, model_path, target_position, target_rotation, - target_position_range, reward_type, initial_qpos={}, + target_position_range, reward_type, initial_qpos=None, randomize_initial_position=True, randomize_initial_rotation=True, distance_threshold=0.01, rotation_threshold=0.1, n_substeps=20, relative_control=False, ignore_z_target_rotation=False, @@ -71,11 +71,12 @@ class ManipulateEnv(hand_env.HandEnv, utils.EzPickle): assert self.target_position in ['ignore', 'fixed', 'random'] assert self.target_rotation in ['ignore', 'fixed', 'xyz', 'z', 'parallel'] + initial_qpos = initial_qpos or {} hand_env.HandEnv.__init__( self, model_path, n_substeps=n_substeps, initial_qpos=initial_qpos, relative_control=relative_control) - utils.EzPickle.__init__(self) + def _get_achieved_goal(self): # Object position and rotation. @@ -271,27 +272,30 @@ class ManipulateEnv(hand_env.HandEnv, utils.EzPickle): } -class HandBlockEnv(ManipulateEnv): +class HandBlockEnv(ManipulateEnv, utils.EzPickle): def __init__(self, target_position='random', target_rotation='xyz', reward_type='sparse'): - super(HandBlockEnv, self).__init__( + utils.EzPickle.__init__(self, target_position, target_rotation, reward_type) + ManipulateEnv.__init__(self, model_path=MANIPULATE_BLOCK_XML, target_position=target_position, target_rotation=target_rotation, target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]), reward_type=reward_type) -class HandEggEnv(ManipulateEnv): +class HandEggEnv(ManipulateEnv, utils.EzPickle): def __init__(self, target_position='random', target_rotation='xyz', reward_type='sparse'): - super(HandEggEnv, self).__init__( + utils.EzPickle.__init__(self, target_position, target_rotation, reward_type) + ManipulateEnv.__init__(self, model_path=MANIPULATE_EGG_XML, target_position=target_position, target_rotation=target_rotation, target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]), reward_type=reward_type) -class HandPenEnv(ManipulateEnv): +class HandPenEnv(ManipulateEnv, utils.EzPickle): def __init__(self, target_position='random', target_rotation='xyz', reward_type='sparse'): - super(HandPenEnv, self).__init__( + utils.EzPickle.__init__(self, target_position, target_rotation, reward_type) + ManipulateEnv.__init__(self, model_path=MANIPULATE_PEN_XML, target_position=target_position, target_rotation=target_rotation, target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]), diff --git a/gym/envs/robotics/hand/manipulate_touch_sensors.py b/gym/envs/robotics/hand/manipulate_touch_sensors.py index 115d07740..c364868b8 100644 --- a/gym/envs/robotics/hand/manipulate_touch_sensors.py +++ b/gym/envs/robotics/hand/manipulate_touch_sensors.py @@ -10,7 +10,7 @@ MANIPULATE_EGG_XML = os.path.join('hand', 'manipulate_egg_touch_sensors.xml') MANIPULATE_PEN_XML = os.path.join('hand', 'manipulate_pen_touch_sensors.xml') -class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv, utils.EzPickle): +class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv): def __init__( self, model_path, target_position, target_rotation, target_position_range, reward_type, initial_qpos={}, @@ -46,7 +46,6 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv, utils.EzPickle): distance_threshold=distance_threshold, rotation_threshold=rotation_threshold, n_substeps=n_substeps, relative_control=relative_control, ignore_z_target_rotation=ignore_z_target_rotation, ) - utils.EzPickle.__init__(self) for k, v in self.sim.model._sensor_name2id.items(): # get touch sensor site names and their ids if 'robot0:TS_' in k: @@ -95,9 +94,10 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv, utils.EzPickle): } -class HandBlockTouchSensorsEnv(ManipulateTouchSensorsEnv): +class HandBlockTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle): def __init__(self, target_position='random', target_rotation='xyz', touch_get_obs='sensordata', reward_type='sparse'): - super(HandBlockTouchSensorsEnv, self).__init__( + utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type) + ManipulateTouchSensorsEnv.__init__(self, model_path=MANIPULATE_BLOCK_XML, touch_get_obs=touch_get_obs, target_rotation=target_rotation, @@ -106,9 +106,10 @@ class HandBlockTouchSensorsEnv(ManipulateTouchSensorsEnv): reward_type=reward_type) -class HandEggTouchSensorsEnv(ManipulateTouchSensorsEnv): +class HandEggTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle): def __init__(self, target_position='random', target_rotation='xyz', touch_get_obs='sensordata', reward_type='sparse'): - super(HandEggTouchSensorsEnv, self).__init__( + utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type) + ManipulateTouchSensorsEnv.__init__(self, model_path=MANIPULATE_EGG_XML, touch_get_obs=touch_get_obs, target_rotation=target_rotation, @@ -117,9 +118,10 @@ class HandEggTouchSensorsEnv(ManipulateTouchSensorsEnv): reward_type=reward_type) -class HandPenTouchSensorsEnv(ManipulateTouchSensorsEnv): +class HandPenTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle): def __init__(self, target_position='random', target_rotation='xyz', touch_get_obs='sensordata', reward_type='sparse'): - super(HandPenTouchSensorsEnv, self).__init__( + utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type) + ManipulateTouchSensorsEnv.__init__(self, model_path=MANIPULATE_PEN_XML, touch_get_obs=touch_get_obs, target_rotation=target_rotation,