2018-02-26 17:35:07 +01:00
|
|
|
import numpy as np
|
|
|
|
|
|
|
|
from gym.envs.robotics import rotations, robot_env, utils
|
|
|
|
|
|
|
|
|
|
|
|
def goal_distance(goal_a, goal_b):
|
|
|
|
assert goal_a.shape == goal_b.shape
|
|
|
|
return np.linalg.norm(goal_a - goal_b, axis=-1)
|
|
|
|
|
|
|
|
|
|
|
|
class FetchEnv(robot_env.RobotEnv):
|
2021-07-29 02:26:34 +02:00
|
|
|
"""Superclass for all Fetch environments."""
|
2018-02-26 17:35:07 +01:00
|
|
|
|
|
|
|
def __init__(
|
2021-07-29 02:26:34 +02:00
|
|
|
self,
|
|
|
|
model_path,
|
|
|
|
n_substeps,
|
|
|
|
gripper_extra_height,
|
|
|
|
block_gripper,
|
|
|
|
has_object,
|
|
|
|
target_in_the_air,
|
|
|
|
target_offset,
|
|
|
|
obj_range,
|
|
|
|
target_range,
|
|
|
|
distance_threshold,
|
|
|
|
initial_qpos,
|
|
|
|
reward_type,
|
2018-02-26 17:35:07 +01:00
|
|
|
):
|
|
|
|
"""Initializes a new Fetch environment.
|
|
|
|
|
|
|
|
Args:
|
|
|
|
model_path (string): path to the environments XML file
|
|
|
|
n_substeps (int): number of substeps the simulation runs on every call to step
|
|
|
|
gripper_extra_height (float): additional height above the table when positioning the gripper
|
|
|
|
block_gripper (boolean): whether or not the gripper is blocked (i.e. not movable) or not
|
|
|
|
has_object (boolean): whether or not the environment has an object
|
|
|
|
target_in_the_air (boolean): whether or not the target should be in the air above the table or on the table surface
|
|
|
|
target_offset (float or array with 3 elements): offset of the target
|
|
|
|
obj_range (float): range of a uniform distribution for sampling initial object positions
|
|
|
|
target_range (float): range of a uniform distribution for sampling a target
|
|
|
|
distance_threshold (float): the threshold after which a goal is considered achieved
|
|
|
|
initial_qpos (dict): a dictionary of joint names and values that define the initial configuration
|
|
|
|
reward_type ('sparse' or 'dense'): the reward type, i.e. sparse or dense
|
|
|
|
"""
|
|
|
|
self.gripper_extra_height = gripper_extra_height
|
|
|
|
self.block_gripper = block_gripper
|
|
|
|
self.has_object = has_object
|
|
|
|
self.target_in_the_air = target_in_the_air
|
|
|
|
self.target_offset = target_offset
|
|
|
|
self.obj_range = obj_range
|
|
|
|
self.target_range = target_range
|
|
|
|
self.distance_threshold = distance_threshold
|
|
|
|
self.reward_type = reward_type
|
|
|
|
|
|
|
|
super(FetchEnv, self).__init__(
|
2021-07-29 02:26:34 +02:00
|
|
|
model_path=model_path,
|
|
|
|
n_substeps=n_substeps,
|
|
|
|
n_actions=4,
|
|
|
|
initial_qpos=initial_qpos,
|
|
|
|
)
|
2018-02-26 17:35:07 +01:00
|
|
|
|
|
|
|
# GoalEnv methods
|
|
|
|
# ----------------------------
|
|
|
|
|
|
|
|
def compute_reward(self, achieved_goal, goal, info):
|
|
|
|
# Compute distance between goal and the achieved goal.
|
|
|
|
d = goal_distance(achieved_goal, goal)
|
2021-07-29 02:26:34 +02:00
|
|
|
if self.reward_type == "sparse":
|
2018-02-26 17:35:07 +01:00
|
|
|
return -(d > self.distance_threshold).astype(np.float32)
|
|
|
|
else:
|
|
|
|
return -d
|
|
|
|
|
|
|
|
# RobotEnv methods
|
|
|
|
# ----------------------------
|
|
|
|
|
|
|
|
def _step_callback(self):
|
|
|
|
if self.block_gripper:
|
2021-07-29 02:26:34 +02:00
|
|
|
self.sim.data.set_joint_qpos("robot0:l_gripper_finger_joint", 0.0)
|
|
|
|
self.sim.data.set_joint_qpos("robot0:r_gripper_finger_joint", 0.0)
|
2018-02-26 17:35:07 +01:00
|
|
|
self.sim.forward()
|
|
|
|
|
|
|
|
def _set_action(self, action):
|
|
|
|
assert action.shape == (4,)
|
2021-07-29 12:42:48 -04:00
|
|
|
action = action.copy() # ensure that we don't change the action outside of this scope
|
2018-02-26 17:35:07 +01:00
|
|
|
pos_ctrl, gripper_ctrl = action[:3], action[3]
|
|
|
|
|
|
|
|
pos_ctrl *= 0.05 # limit maximum change in position
|
2021-07-29 02:26:34 +02:00
|
|
|
rot_ctrl = [
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
1.0,
|
|
|
|
0.0,
|
|
|
|
] # fixed rotation of the end effector, expressed as a quaternion
|
2018-02-26 17:35:07 +01:00
|
|
|
gripper_ctrl = np.array([gripper_ctrl, gripper_ctrl])
|
|
|
|
assert gripper_ctrl.shape == (2,)
|
|
|
|
if self.block_gripper:
|
|
|
|
gripper_ctrl = np.zeros_like(gripper_ctrl)
|
|
|
|
action = np.concatenate([pos_ctrl, rot_ctrl, gripper_ctrl])
|
|
|
|
|
|
|
|
# Apply action to simulation.
|
|
|
|
utils.ctrl_set_action(self.sim, action)
|
|
|
|
utils.mocap_set_action(self.sim, action)
|
|
|
|
|
|
|
|
def _get_obs(self):
|
|
|
|
# positions
|
2021-07-29 02:26:34 +02:00
|
|
|
grip_pos = self.sim.data.get_site_xpos("robot0:grip")
|
2018-02-26 17:35:07 +01:00
|
|
|
dt = self.sim.nsubsteps * self.sim.model.opt.timestep
|
2021-07-29 02:26:34 +02:00
|
|
|
grip_velp = self.sim.data.get_site_xvelp("robot0:grip") * dt
|
2018-02-26 17:35:07 +01:00
|
|
|
robot_qpos, robot_qvel = utils.robot_get_obs(self.sim)
|
|
|
|
if self.has_object:
|
2021-07-29 02:26:34 +02:00
|
|
|
object_pos = self.sim.data.get_site_xpos("object0")
|
2018-02-26 17:35:07 +01:00
|
|
|
# rotations
|
2021-07-29 02:26:34 +02:00
|
|
|
object_rot = rotations.mat2euler(self.sim.data.get_site_xmat("object0"))
|
2018-02-26 17:35:07 +01:00
|
|
|
# velocities
|
2021-07-29 02:26:34 +02:00
|
|
|
object_velp = self.sim.data.get_site_xvelp("object0") * dt
|
|
|
|
object_velr = self.sim.data.get_site_xvelr("object0") * dt
|
2018-02-26 17:35:07 +01:00
|
|
|
# gripper state
|
|
|
|
object_rel_pos = object_pos - grip_pos
|
|
|
|
object_velp -= grip_velp
|
|
|
|
else:
|
2021-07-29 12:42:48 -04:00
|
|
|
object_pos = object_rot = object_velp = object_velr = object_rel_pos = np.zeros(0)
|
2018-02-26 17:35:07 +01:00
|
|
|
gripper_state = robot_qpos[-2:]
|
2021-07-29 12:42:48 -04:00
|
|
|
gripper_vel = robot_qvel[-2:] * dt # change to a scalar if the gripper is made symmetric
|
2018-02-26 17:35:07 +01:00
|
|
|
|
|
|
|
if not self.has_object:
|
|
|
|
achieved_goal = grip_pos.copy()
|
|
|
|
else:
|
|
|
|
achieved_goal = np.squeeze(object_pos.copy())
|
2021-07-29 02:26:34 +02:00
|
|
|
obs = np.concatenate(
|
|
|
|
[
|
|
|
|
grip_pos,
|
|
|
|
object_pos.ravel(),
|
|
|
|
object_rel_pos.ravel(),
|
|
|
|
gripper_state,
|
|
|
|
object_rot.ravel(),
|
|
|
|
object_velp.ravel(),
|
|
|
|
object_velr.ravel(),
|
|
|
|
grip_velp,
|
|
|
|
gripper_vel,
|
|
|
|
]
|
|
|
|
)
|
2018-02-26 17:35:07 +01:00
|
|
|
|
|
|
|
return {
|
2021-07-29 02:26:34 +02:00
|
|
|
"observation": obs.copy(),
|
|
|
|
"achieved_goal": achieved_goal.copy(),
|
|
|
|
"desired_goal": self.goal.copy(),
|
2018-02-26 17:35:07 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
def _viewer_setup(self):
|
2021-07-29 02:26:34 +02:00
|
|
|
body_id = self.sim.model.body_name2id("robot0:gripper_link")
|
2018-02-26 17:35:07 +01:00
|
|
|
lookat = self.sim.data.body_xpos[body_id]
|
|
|
|
for idx, value in enumerate(lookat):
|
|
|
|
self.viewer.cam.lookat[idx] = value
|
|
|
|
self.viewer.cam.distance = 2.5
|
2021-07-29 02:26:34 +02:00
|
|
|
self.viewer.cam.azimuth = 132.0
|
|
|
|
self.viewer.cam.elevation = -14.0
|
2018-02-26 17:35:07 +01:00
|
|
|
|
|
|
|
def _render_callback(self):
|
|
|
|
# Visualize target.
|
|
|
|
sites_offset = (self.sim.data.site_xpos - self.sim.model.site_pos).copy()
|
2021-07-29 02:26:34 +02:00
|
|
|
site_id = self.sim.model.site_name2id("target0")
|
2018-02-26 17:35:07 +01:00
|
|
|
self.sim.model.site_pos[site_id] = self.goal - sites_offset[0]
|
|
|
|
self.sim.forward()
|
|
|
|
|
|
|
|
def _reset_sim(self):
|
|
|
|
self.sim.set_state(self.initial_state)
|
|
|
|
|
|
|
|
# Randomize start position of object.
|
|
|
|
if self.has_object:
|
|
|
|
object_xpos = self.initial_gripper_xpos[:2]
|
|
|
|
while np.linalg.norm(object_xpos - self.initial_gripper_xpos[:2]) < 0.1:
|
2021-07-29 12:42:48 -04:00
|
|
|
object_xpos = self.initial_gripper_xpos[:2] + self.np_random.uniform(-self.obj_range, self.obj_range, size=2)
|
2021-07-29 02:26:34 +02:00
|
|
|
object_qpos = self.sim.data.get_joint_qpos("object0:joint")
|
2018-02-26 17:35:07 +01:00
|
|
|
assert object_qpos.shape == (7,)
|
|
|
|
object_qpos[:2] = object_xpos
|
2021-07-29 02:26:34 +02:00
|
|
|
self.sim.data.set_joint_qpos("object0:joint", object_qpos)
|
2018-02-26 17:35:07 +01:00
|
|
|
|
|
|
|
self.sim.forward()
|
|
|
|
return True
|
|
|
|
|
|
|
|
def _sample_goal(self):
|
|
|
|
if self.has_object:
|
2021-07-29 12:42:48 -04:00
|
|
|
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(-self.target_range, self.target_range, size=3)
|
2018-02-26 17:35:07 +01:00
|
|
|
goal += self.target_offset
|
|
|
|
goal[2] = self.height_offset
|
|
|
|
if self.target_in_the_air and self.np_random.uniform() < 0.5:
|
|
|
|
goal[2] += self.np_random.uniform(0, 0.45)
|
|
|
|
else:
|
2021-07-29 12:42:48 -04:00
|
|
|
goal = self.initial_gripper_xpos[:3] + self.np_random.uniform(-self.target_range, self.target_range, size=3)
|
2018-02-26 17:35:07 +01:00
|
|
|
return goal.copy()
|
|
|
|
|
|
|
|
def _is_success(self, achieved_goal, desired_goal):
|
|
|
|
d = goal_distance(achieved_goal, desired_goal)
|
|
|
|
return (d < self.distance_threshold).astype(np.float32)
|
|
|
|
|
|
|
|
def _env_setup(self, initial_qpos):
|
|
|
|
for name, value in initial_qpos.items():
|
|
|
|
self.sim.data.set_joint_qpos(name, value)
|
|
|
|
utils.reset_mocap_welds(self.sim)
|
|
|
|
self.sim.forward()
|
|
|
|
|
|
|
|
# Move end effector into position.
|
2021-07-29 12:42:48 -04:00
|
|
|
gripper_target = np.array([-0.498, 0.005, -0.431 + self.gripper_extra_height]) + self.sim.data.get_site_xpos(
|
|
|
|
"robot0:grip"
|
|
|
|
)
|
2021-07-29 02:26:34 +02:00
|
|
|
gripper_rotation = np.array([1.0, 0.0, 1.0, 0.0])
|
|
|
|
self.sim.data.set_mocap_pos("robot0:mocap", gripper_target)
|
|
|
|
self.sim.data.set_mocap_quat("robot0:mocap", gripper_rotation)
|
2018-02-26 17:35:07 +01:00
|
|
|
for _ in range(10):
|
|
|
|
self.sim.step()
|
|
|
|
|
|
|
|
# Extract information for sampling goals.
|
2021-07-29 02:26:34 +02:00
|
|
|
self.initial_gripper_xpos = self.sim.data.get_site_xpos("robot0:grip").copy()
|
2018-02-26 17:35:07 +01:00
|
|
|
if self.has_object:
|
2021-07-29 02:26:34 +02:00
|
|
|
self.height_offset = self.sim.data.get_site_xpos("object0")[2]
|
2019-02-15 15:55:51 -08:00
|
|
|
|
2021-07-29 02:26:34 +02:00
|
|
|
def render(self, mode="human", width=500, height=500):
|
2019-02-15 15:55:51 -08:00
|
|
|
return super(FetchEnv, self).render(mode, width, height)
|