mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 09:55:39 +00:00
fix hand serialization tests (#1647)
* fix serialization by moving EzPickle inheritance to the leaf classes and passing correct arguments to EzPickle * fix touch sensor serialization too
This commit is contained in:
@@ -25,10 +25,10 @@ MANIPULATE_EGG_XML = os.path.join('hand', 'manipulate_egg.xml')
|
||||
MANIPULATE_PEN_XML = os.path.join('hand', 'manipulate_pen.xml')
|
||||
|
||||
|
||||
class ManipulateEnv(hand_env.HandEnv, utils.EzPickle):
|
||||
class ManipulateEnv(hand_env.HandEnv):
|
||||
def __init__(
|
||||
self, model_path, target_position, target_rotation,
|
||||
target_position_range, reward_type, initial_qpos={},
|
||||
target_position_range, reward_type, initial_qpos=None,
|
||||
randomize_initial_position=True, randomize_initial_rotation=True,
|
||||
distance_threshold=0.01, rotation_threshold=0.1, n_substeps=20, relative_control=False,
|
||||
ignore_z_target_rotation=False,
|
||||
@@ -71,11 +71,12 @@ class ManipulateEnv(hand_env.HandEnv, utils.EzPickle):
|
||||
|
||||
assert self.target_position in ['ignore', 'fixed', 'random']
|
||||
assert self.target_rotation in ['ignore', 'fixed', 'xyz', 'z', 'parallel']
|
||||
initial_qpos = initial_qpos or {}
|
||||
|
||||
hand_env.HandEnv.__init__(
|
||||
self, model_path, n_substeps=n_substeps, initial_qpos=initial_qpos,
|
||||
relative_control=relative_control)
|
||||
utils.EzPickle.__init__(self)
|
||||
|
||||
|
||||
def _get_achieved_goal(self):
|
||||
# Object position and rotation.
|
||||
@@ -271,27 +272,30 @@ class ManipulateEnv(hand_env.HandEnv, utils.EzPickle):
|
||||
}
|
||||
|
||||
|
||||
class HandBlockEnv(ManipulateEnv):
|
||||
class HandBlockEnv(ManipulateEnv, utils.EzPickle):
|
||||
def __init__(self, target_position='random', target_rotation='xyz', reward_type='sparse'):
|
||||
super(HandBlockEnv, self).__init__(
|
||||
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
|
||||
ManipulateEnv.__init__(self,
|
||||
model_path=MANIPULATE_BLOCK_XML, target_position=target_position,
|
||||
target_rotation=target_rotation,
|
||||
target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]),
|
||||
reward_type=reward_type)
|
||||
|
||||
|
||||
class HandEggEnv(ManipulateEnv):
|
||||
class HandEggEnv(ManipulateEnv, utils.EzPickle):
|
||||
def __init__(self, target_position='random', target_rotation='xyz', reward_type='sparse'):
|
||||
super(HandEggEnv, self).__init__(
|
||||
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
|
||||
ManipulateEnv.__init__(self,
|
||||
model_path=MANIPULATE_EGG_XML, target_position=target_position,
|
||||
target_rotation=target_rotation,
|
||||
target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]),
|
||||
reward_type=reward_type)
|
||||
|
||||
|
||||
class HandPenEnv(ManipulateEnv):
|
||||
class HandPenEnv(ManipulateEnv, utils.EzPickle):
|
||||
def __init__(self, target_position='random', target_rotation='xyz', reward_type='sparse'):
|
||||
super(HandPenEnv, self).__init__(
|
||||
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
|
||||
ManipulateEnv.__init__(self,
|
||||
model_path=MANIPULATE_PEN_XML, target_position=target_position,
|
||||
target_rotation=target_rotation,
|
||||
target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]),
|
||||
|
@@ -10,7 +10,7 @@ MANIPULATE_EGG_XML = os.path.join('hand', 'manipulate_egg_touch_sensors.xml')
|
||||
MANIPULATE_PEN_XML = os.path.join('hand', 'manipulate_pen_touch_sensors.xml')
|
||||
|
||||
|
||||
class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv, utils.EzPickle):
|
||||
class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
|
||||
def __init__(
|
||||
self, model_path, target_position, target_rotation,
|
||||
target_position_range, reward_type, initial_qpos={},
|
||||
@@ -46,7 +46,6 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv, utils.EzPickle):
|
||||
distance_threshold=distance_threshold, rotation_threshold=rotation_threshold, n_substeps=n_substeps, relative_control=relative_control,
|
||||
ignore_z_target_rotation=ignore_z_target_rotation,
|
||||
)
|
||||
utils.EzPickle.__init__(self)
|
||||
|
||||
for k, v in self.sim.model._sensor_name2id.items(): # get touch sensor site names and their ids
|
||||
if 'robot0:TS_' in k:
|
||||
@@ -95,9 +94,10 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv, utils.EzPickle):
|
||||
}
|
||||
|
||||
|
||||
class HandBlockTouchSensorsEnv(ManipulateTouchSensorsEnv):
|
||||
class HandBlockTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
|
||||
def __init__(self, target_position='random', target_rotation='xyz', touch_get_obs='sensordata', reward_type='sparse'):
|
||||
super(HandBlockTouchSensorsEnv, self).__init__(
|
||||
utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
|
||||
ManipulateTouchSensorsEnv.__init__(self,
|
||||
model_path=MANIPULATE_BLOCK_XML,
|
||||
touch_get_obs=touch_get_obs,
|
||||
target_rotation=target_rotation,
|
||||
@@ -106,9 +106,10 @@ class HandBlockTouchSensorsEnv(ManipulateTouchSensorsEnv):
|
||||
reward_type=reward_type)
|
||||
|
||||
|
||||
class HandEggTouchSensorsEnv(ManipulateTouchSensorsEnv):
|
||||
class HandEggTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
|
||||
def __init__(self, target_position='random', target_rotation='xyz', touch_get_obs='sensordata', reward_type='sparse'):
|
||||
super(HandEggTouchSensorsEnv, self).__init__(
|
||||
utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
|
||||
ManipulateTouchSensorsEnv.__init__(self,
|
||||
model_path=MANIPULATE_EGG_XML,
|
||||
touch_get_obs=touch_get_obs,
|
||||
target_rotation=target_rotation,
|
||||
@@ -117,9 +118,10 @@ class HandEggTouchSensorsEnv(ManipulateTouchSensorsEnv):
|
||||
reward_type=reward_type)
|
||||
|
||||
|
||||
class HandPenTouchSensorsEnv(ManipulateTouchSensorsEnv):
|
||||
class HandPenTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
|
||||
def __init__(self, target_position='random', target_rotation='xyz', touch_get_obs='sensordata', reward_type='sparse'):
|
||||
super(HandPenTouchSensorsEnv, self).__init__(
|
||||
utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
|
||||
ManipulateTouchSensorsEnv.__init__(self,
|
||||
model_path=MANIPULATE_PEN_XML,
|
||||
touch_get_obs=touch_get_obs,
|
||||
target_rotation=target_rotation,
|
||||
|
Reference in New Issue
Block a user