mirror of
https://github.com/Farama-Foundation/Gymnasium.git
synced 2025-08-30 09:55:39 +00:00
fix hand serialization tests (#1647)
* fix serialization by moving EzPickle inheritance to the leaf classes and passing correct arguments to EzPickle * fix touch sensor serialization too
This commit is contained in:
@@ -25,10 +25,10 @@ MANIPULATE_EGG_XML = os.path.join('hand', 'manipulate_egg.xml')
|
|||||||
MANIPULATE_PEN_XML = os.path.join('hand', 'manipulate_pen.xml')
|
MANIPULATE_PEN_XML = os.path.join('hand', 'manipulate_pen.xml')
|
||||||
|
|
||||||
|
|
||||||
class ManipulateEnv(hand_env.HandEnv, utils.EzPickle):
|
class ManipulateEnv(hand_env.HandEnv):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_path, target_position, target_rotation,
|
self, model_path, target_position, target_rotation,
|
||||||
target_position_range, reward_type, initial_qpos={},
|
target_position_range, reward_type, initial_qpos=None,
|
||||||
randomize_initial_position=True, randomize_initial_rotation=True,
|
randomize_initial_position=True, randomize_initial_rotation=True,
|
||||||
distance_threshold=0.01, rotation_threshold=0.1, n_substeps=20, relative_control=False,
|
distance_threshold=0.01, rotation_threshold=0.1, n_substeps=20, relative_control=False,
|
||||||
ignore_z_target_rotation=False,
|
ignore_z_target_rotation=False,
|
||||||
@@ -71,11 +71,12 @@ class ManipulateEnv(hand_env.HandEnv, utils.EzPickle):
|
|||||||
|
|
||||||
assert self.target_position in ['ignore', 'fixed', 'random']
|
assert self.target_position in ['ignore', 'fixed', 'random']
|
||||||
assert self.target_rotation in ['ignore', 'fixed', 'xyz', 'z', 'parallel']
|
assert self.target_rotation in ['ignore', 'fixed', 'xyz', 'z', 'parallel']
|
||||||
|
initial_qpos = initial_qpos or {}
|
||||||
|
|
||||||
hand_env.HandEnv.__init__(
|
hand_env.HandEnv.__init__(
|
||||||
self, model_path, n_substeps=n_substeps, initial_qpos=initial_qpos,
|
self, model_path, n_substeps=n_substeps, initial_qpos=initial_qpos,
|
||||||
relative_control=relative_control)
|
relative_control=relative_control)
|
||||||
utils.EzPickle.__init__(self)
|
|
||||||
|
|
||||||
def _get_achieved_goal(self):
|
def _get_achieved_goal(self):
|
||||||
# Object position and rotation.
|
# Object position and rotation.
|
||||||
@@ -271,27 +272,30 @@ class ManipulateEnv(hand_env.HandEnv, utils.EzPickle):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class HandBlockEnv(ManipulateEnv):
|
class HandBlockEnv(ManipulateEnv, utils.EzPickle):
|
||||||
def __init__(self, target_position='random', target_rotation='xyz', reward_type='sparse'):
|
def __init__(self, target_position='random', target_rotation='xyz', reward_type='sparse'):
|
||||||
super(HandBlockEnv, self).__init__(
|
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
|
||||||
|
ManipulateEnv.__init__(self,
|
||||||
model_path=MANIPULATE_BLOCK_XML, target_position=target_position,
|
model_path=MANIPULATE_BLOCK_XML, target_position=target_position,
|
||||||
target_rotation=target_rotation,
|
target_rotation=target_rotation,
|
||||||
target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]),
|
target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]),
|
||||||
reward_type=reward_type)
|
reward_type=reward_type)
|
||||||
|
|
||||||
|
|
||||||
class HandEggEnv(ManipulateEnv):
|
class HandEggEnv(ManipulateEnv, utils.EzPickle):
|
||||||
def __init__(self, target_position='random', target_rotation='xyz', reward_type='sparse'):
|
def __init__(self, target_position='random', target_rotation='xyz', reward_type='sparse'):
|
||||||
super(HandEggEnv, self).__init__(
|
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
|
||||||
|
ManipulateEnv.__init__(self,
|
||||||
model_path=MANIPULATE_EGG_XML, target_position=target_position,
|
model_path=MANIPULATE_EGG_XML, target_position=target_position,
|
||||||
target_rotation=target_rotation,
|
target_rotation=target_rotation,
|
||||||
target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]),
|
target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]),
|
||||||
reward_type=reward_type)
|
reward_type=reward_type)
|
||||||
|
|
||||||
|
|
||||||
class HandPenEnv(ManipulateEnv):
|
class HandPenEnv(ManipulateEnv, utils.EzPickle):
|
||||||
def __init__(self, target_position='random', target_rotation='xyz', reward_type='sparse'):
|
def __init__(self, target_position='random', target_rotation='xyz', reward_type='sparse'):
|
||||||
super(HandPenEnv, self).__init__(
|
utils.EzPickle.__init__(self, target_position, target_rotation, reward_type)
|
||||||
|
ManipulateEnv.__init__(self,
|
||||||
model_path=MANIPULATE_PEN_XML, target_position=target_position,
|
model_path=MANIPULATE_PEN_XML, target_position=target_position,
|
||||||
target_rotation=target_rotation,
|
target_rotation=target_rotation,
|
||||||
target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]),
|
target_position_range=np.array([(-0.04, 0.04), (-0.06, 0.02), (0.0, 0.06)]),
|
||||||
|
@@ -10,7 +10,7 @@ MANIPULATE_EGG_XML = os.path.join('hand', 'manipulate_egg_touch_sensors.xml')
|
|||||||
MANIPULATE_PEN_XML = os.path.join('hand', 'manipulate_pen_touch_sensors.xml')
|
MANIPULATE_PEN_XML = os.path.join('hand', 'manipulate_pen_touch_sensors.xml')
|
||||||
|
|
||||||
|
|
||||||
class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv, utils.EzPickle):
|
class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, model_path, target_position, target_rotation,
|
self, model_path, target_position, target_rotation,
|
||||||
target_position_range, reward_type, initial_qpos={},
|
target_position_range, reward_type, initial_qpos={},
|
||||||
@@ -46,7 +46,6 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv, utils.EzPickle):
|
|||||||
distance_threshold=distance_threshold, rotation_threshold=rotation_threshold, n_substeps=n_substeps, relative_control=relative_control,
|
distance_threshold=distance_threshold, rotation_threshold=rotation_threshold, n_substeps=n_substeps, relative_control=relative_control,
|
||||||
ignore_z_target_rotation=ignore_z_target_rotation,
|
ignore_z_target_rotation=ignore_z_target_rotation,
|
||||||
)
|
)
|
||||||
utils.EzPickle.__init__(self)
|
|
||||||
|
|
||||||
for k, v in self.sim.model._sensor_name2id.items(): # get touch sensor site names and their ids
|
for k, v in self.sim.model._sensor_name2id.items(): # get touch sensor site names and their ids
|
||||||
if 'robot0:TS_' in k:
|
if 'robot0:TS_' in k:
|
||||||
@@ -95,9 +94,10 @@ class ManipulateTouchSensorsEnv(manipulate.ManipulateEnv, utils.EzPickle):
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class HandBlockTouchSensorsEnv(ManipulateTouchSensorsEnv):
|
class HandBlockTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
|
||||||
def __init__(self, target_position='random', target_rotation='xyz', touch_get_obs='sensordata', reward_type='sparse'):
|
def __init__(self, target_position='random', target_rotation='xyz', touch_get_obs='sensordata', reward_type='sparse'):
|
||||||
super(HandBlockTouchSensorsEnv, self).__init__(
|
utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
|
||||||
|
ManipulateTouchSensorsEnv.__init__(self,
|
||||||
model_path=MANIPULATE_BLOCK_XML,
|
model_path=MANIPULATE_BLOCK_XML,
|
||||||
touch_get_obs=touch_get_obs,
|
touch_get_obs=touch_get_obs,
|
||||||
target_rotation=target_rotation,
|
target_rotation=target_rotation,
|
||||||
@@ -106,9 +106,10 @@ class HandBlockTouchSensorsEnv(ManipulateTouchSensorsEnv):
|
|||||||
reward_type=reward_type)
|
reward_type=reward_type)
|
||||||
|
|
||||||
|
|
||||||
class HandEggTouchSensorsEnv(ManipulateTouchSensorsEnv):
|
class HandEggTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
|
||||||
def __init__(self, target_position='random', target_rotation='xyz', touch_get_obs='sensordata', reward_type='sparse'):
|
def __init__(self, target_position='random', target_rotation='xyz', touch_get_obs='sensordata', reward_type='sparse'):
|
||||||
super(HandEggTouchSensorsEnv, self).__init__(
|
utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
|
||||||
|
ManipulateTouchSensorsEnv.__init__(self,
|
||||||
model_path=MANIPULATE_EGG_XML,
|
model_path=MANIPULATE_EGG_XML,
|
||||||
touch_get_obs=touch_get_obs,
|
touch_get_obs=touch_get_obs,
|
||||||
target_rotation=target_rotation,
|
target_rotation=target_rotation,
|
||||||
@@ -117,9 +118,10 @@ class HandEggTouchSensorsEnv(ManipulateTouchSensorsEnv):
|
|||||||
reward_type=reward_type)
|
reward_type=reward_type)
|
||||||
|
|
||||||
|
|
||||||
class HandPenTouchSensorsEnv(ManipulateTouchSensorsEnv):
|
class HandPenTouchSensorsEnv(ManipulateTouchSensorsEnv, utils.EzPickle):
|
||||||
def __init__(self, target_position='random', target_rotation='xyz', touch_get_obs='sensordata', reward_type='sparse'):
|
def __init__(self, target_position='random', target_rotation='xyz', touch_get_obs='sensordata', reward_type='sparse'):
|
||||||
super(HandPenTouchSensorsEnv, self).__init__(
|
utils.EzPickle.__init__(self, target_position, target_rotation, touch_get_obs, reward_type)
|
||||||
|
ManipulateTouchSensorsEnv.__init__(self,
|
||||||
model_path=MANIPULATE_PEN_XML,
|
model_path=MANIPULATE_PEN_XML,
|
||||||
touch_get_obs=touch_get_obs,
|
touch_get_obs=touch_get_obs,
|
||||||
target_rotation=target_rotation,
|
target_rotation=target_rotation,
|
||||||
|
Reference in New Issue
Block a user