You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
324 lines
13 KiB
324 lines
13 KiB
from copy import deepcopy
|
|
from typing import Dict
|
|
|
|
import gymnasium as gym
|
|
import numpy as np
|
|
from scipy.spatial.transform import Rotation as R
|
|
|
|
from decoupled_wbc.control.base.humanoid_env import Hands, HumanoidEnv
|
|
from decoupled_wbc.control.envs.g1.g1_body import G1Body
|
|
from decoupled_wbc.control.envs.g1.g1_hand import G1ThreeFingerHand
|
|
from decoupled_wbc.control.envs.g1.sim.simulator_factory import SimulatorFactory, init_channel
|
|
from decoupled_wbc.control.envs.g1.utils.joint_safety import JointSafetyMonitor
|
|
from decoupled_wbc.control.robot_model.instantiation.g1 import instantiate_g1_robot_model
|
|
from decoupled_wbc.control.robot_model.robot_model import RobotModel
|
|
from decoupled_wbc.control.utils.ros_utils import ROSManager
|
|
|
|
|
|
class G1Env(HumanoidEnv):
|
|
def __init__(
|
|
self,
|
|
env_name: str = "default",
|
|
robot_model: RobotModel = None,
|
|
wbc_version: str = "v2",
|
|
config: Dict[str, any] = None,
|
|
**kwargs,
|
|
):
|
|
super().__init__()
|
|
self.robot_model = deepcopy(robot_model) # need to cache FK results
|
|
self.config = config
|
|
|
|
# Initialize safety monitor (visualization disabled)
|
|
self.safety_monitor = JointSafetyMonitor(
|
|
robot_model, enable_viz=False, env_type=self.config.get("ENV_TYPE", "real")
|
|
)
|
|
self.last_obs = None
|
|
self.last_safety_ok = True # Track last safety status from queue_action
|
|
|
|
init_channel(config=self.config)
|
|
|
|
# Initialize body and hands
|
|
self._body = G1Body(config=self.config)
|
|
|
|
self.with_hands = config.get("with_hands", False)
|
|
|
|
# Gravity compensation settings
|
|
self.enable_gravity_compensation = config.get("enable_gravity_compensation", False)
|
|
self.gravity_compensation_joints = config.get("gravity_compensation_joints", ["arms"])
|
|
|
|
if self.enable_gravity_compensation:
|
|
print(
|
|
f"Gravity compensation enabled for joint groups: {self.gravity_compensation_joints}"
|
|
)
|
|
if self.with_hands:
|
|
self._hands = Hands()
|
|
self._hands.left = G1ThreeFingerHand(is_left=True)
|
|
self._hands.right = G1ThreeFingerHand(is_left=False)
|
|
|
|
# Initialize simulator if in simulation mode
|
|
self.use_sim = self.config.get("ENV_TYPE") == "sim"
|
|
|
|
if self.use_sim:
|
|
# Create simulator using factory
|
|
|
|
kwargs.update(
|
|
{
|
|
"onscreen": self.config.get("ENABLE_ONSCREEN", True),
|
|
"offscreen": self.config.get("ENABLE_OFFSCREEN", False),
|
|
}
|
|
)
|
|
self.sim = SimulatorFactory.create_simulator(
|
|
config=self.config,
|
|
env_name=env_name,
|
|
wbc_version=wbc_version,
|
|
body_ik_solver_settings_type=kwargs.get("body_ik_solver_settings_type", "default"),
|
|
**kwargs,
|
|
)
|
|
else:
|
|
self.sim = None
|
|
|
|
# using the real robot
|
|
self.calibrate_hands()
|
|
|
|
# Initialize ROS 2 node
|
|
self.ros_manager = ROSManager(node_name="g1_env")
|
|
self.ros_node = self.ros_manager.node
|
|
|
|
self.delay_list = []
|
|
self.visualize_delay = False
|
|
self.print_delay_interval = 100
|
|
self.cnt = 0
|
|
|
|
def start_simulator(self):
|
|
# imag epublish disabled since the sim is running in a sub-thread
|
|
SimulatorFactory.start_simulator(self.sim, as_thread=True, enable_image_publish=False)
|
|
|
|
def step_simulator(self):
|
|
sim_num_steps = int(self.config["REWARD_DT"] / self.config["SIMULATE_DT"])
|
|
for _ in range(sim_num_steps):
|
|
self.sim.sim_env.sim_step()
|
|
self.sim.sim_env.update_viewer()
|
|
|
|
def body(self) -> G1Body:
|
|
return self._body
|
|
|
|
def hands(self) -> Hands:
|
|
if not self.with_hands:
|
|
raise RuntimeError(
|
|
"Hands not initialized. Use --with_hands True to enable hand functionality."
|
|
)
|
|
return self._hands
|
|
|
|
def observe(self) -> Dict[str, any]:
|
|
# Get observations from body and hands
|
|
body_obs = self.body().observe()
|
|
|
|
body_q = body_obs["body_q"]
|
|
body_dq = body_obs["body_dq"]
|
|
body_ddq = body_obs["body_ddq"]
|
|
body_tau_est = body_obs["body_tau_est"]
|
|
|
|
if self.with_hands:
|
|
left_hand_obs = self.hands().left.observe()
|
|
right_hand_obs = self.hands().right.observe()
|
|
left_hand_q = left_hand_obs["hand_q"]
|
|
right_hand_q = right_hand_obs["hand_q"]
|
|
left_hand_dq = left_hand_obs["hand_dq"]
|
|
right_hand_dq = right_hand_obs["hand_dq"]
|
|
left_hand_ddq = left_hand_obs["hand_ddq"]
|
|
right_hand_ddq = right_hand_obs["hand_ddq"]
|
|
left_hand_tau_est = left_hand_obs["hand_tau_est"]
|
|
right_hand_tau_est = right_hand_obs["hand_tau_est"]
|
|
|
|
# Body and hand joint measurements come in actuator order, so we need to convert them to joint order
|
|
whole_q = self.robot_model.get_configuration_from_actuated_joints(
|
|
body_actuated_joint_values=body_q,
|
|
left_hand_actuated_joint_values=left_hand_q,
|
|
right_hand_actuated_joint_values=right_hand_q,
|
|
)
|
|
whole_dq = self.robot_model.get_configuration_from_actuated_joints(
|
|
body_actuated_joint_values=body_dq,
|
|
left_hand_actuated_joint_values=left_hand_dq,
|
|
right_hand_actuated_joint_values=right_hand_dq,
|
|
)
|
|
whole_ddq = self.robot_model.get_configuration_from_actuated_joints(
|
|
body_actuated_joint_values=body_ddq,
|
|
left_hand_actuated_joint_values=left_hand_ddq,
|
|
right_hand_actuated_joint_values=right_hand_ddq,
|
|
)
|
|
whole_tau_est = self.robot_model.get_configuration_from_actuated_joints(
|
|
body_actuated_joint_values=body_tau_est,
|
|
left_hand_actuated_joint_values=left_hand_tau_est,
|
|
right_hand_actuated_joint_values=right_hand_tau_est,
|
|
)
|
|
else:
|
|
# Body and hand joint measurements come in actuator order, so we need to convert them to joint order
|
|
whole_q = self.robot_model.get_configuration_from_actuated_joints(
|
|
body_actuated_joint_values=body_q,
|
|
)
|
|
whole_dq = self.robot_model.get_configuration_from_actuated_joints(
|
|
body_actuated_joint_values=body_dq,
|
|
)
|
|
whole_ddq = self.robot_model.get_configuration_from_actuated_joints(
|
|
body_actuated_joint_values=body_ddq,
|
|
)
|
|
whole_tau_est = self.robot_model.get_configuration_from_actuated_joints(
|
|
body_actuated_joint_values=body_tau_est,
|
|
)
|
|
|
|
eef_obs = self.get_eef_obs(whole_q)
|
|
|
|
obs = {
|
|
"q": whole_q,
|
|
"dq": whole_dq,
|
|
"ddq": whole_ddq,
|
|
"tau_est": whole_tau_est,
|
|
"floating_base_pose": body_obs["floating_base_pose"],
|
|
"floating_base_vel": body_obs["floating_base_vel"],
|
|
"floating_base_acc": body_obs["floating_base_acc"],
|
|
"wrist_pose": np.concatenate([eef_obs["left_wrist_pose"], eef_obs["right_wrist_pose"]]),
|
|
"torso_quat": body_obs["torso_quat"],
|
|
"torso_ang_vel": body_obs["torso_ang_vel"],
|
|
}
|
|
|
|
if self.use_sim and self.sim:
|
|
obs.update(self.sim.get_privileged_obs())
|
|
|
|
# Store last observation for safety checking
|
|
self.last_obs = obs
|
|
|
|
return obs
|
|
|
|
@property
|
|
def observation_space(self) -> gym.Space:
|
|
# @todo: check if the low and high bounds are correct for body_obs.
|
|
q_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,))
|
|
dq_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,))
|
|
ddq_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,))
|
|
tau_est_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,))
|
|
floating_base_pose_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(7,))
|
|
floating_base_vel_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(6,))
|
|
floating_base_acc_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(6,))
|
|
wrist_pose_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(7 + 7,))
|
|
return gym.spaces.Dict(
|
|
{
|
|
"floating_base_pose": floating_base_pose_space,
|
|
"floating_base_vel": floating_base_vel_space,
|
|
"floating_base_acc": floating_base_acc_space,
|
|
"q": q_space,
|
|
"dq": dq_space,
|
|
"ddq": ddq_space,
|
|
"tau_est": tau_est_space,
|
|
"wrist_pose": wrist_pose_space,
|
|
}
|
|
)
|
|
|
|
def queue_action(self, action: Dict[str, any]):
|
|
# Safety check
|
|
if self.last_obs is not None:
|
|
safety_result = self.safety_monitor.handle_violations(self.last_obs, action)
|
|
action = safety_result["action"]
|
|
|
|
# Map action from joint order to actuator order
|
|
body_actuator_q = self.robot_model.get_body_actuated_joints(action["q"])
|
|
|
|
self.body().queue_action(
|
|
{
|
|
"body_q": body_actuator_q,
|
|
"body_dq": np.zeros_like(body_actuator_q),
|
|
"body_tau": np.zeros_like(body_actuator_q),
|
|
}
|
|
)
|
|
|
|
if self.with_hands:
|
|
left_hand_actuator_q = self.robot_model.get_hand_actuated_joints(
|
|
action["q"], side="left"
|
|
)
|
|
right_hand_actuator_q = self.robot_model.get_hand_actuated_joints(
|
|
action["q"], side="right"
|
|
)
|
|
self.hands().left.queue_action({"hand_q": left_hand_actuator_q})
|
|
self.hands().right.queue_action({"hand_q": right_hand_actuator_q})
|
|
|
|
def action_space(self) -> gym.Space:
|
|
return gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,))
|
|
|
|
def calibrate_hands(self):
|
|
"""Calibrate the hand joint qpos if real robot"""
|
|
if self.with_hands:
|
|
print("calibrating left hand")
|
|
self.hands().left.calibrate_hand()
|
|
print("calibrating right hand")
|
|
self.hands().right.calibrate_hand()
|
|
else:
|
|
print("Skipping hand calibration - hands disabled")
|
|
|
|
def set_ik_indicator(self, teleop_cmd):
|
|
"""Set the IK indicators for the simulator"""
|
|
if self.config["SIMULATOR"] == "robocasa":
|
|
if "left_wrist" in teleop_cmd and "right_wrist" in teleop_cmd:
|
|
left_wrist_input_pose = teleop_cmd["left_wrist"]
|
|
right_wrist_input_pose = teleop_cmd["right_wrist"]
|
|
ik_wrapper = self.sim.env.env.unwrapped.env
|
|
ik_wrapper.set_target_poses_outside_env(
|
|
[left_wrist_input_pose, right_wrist_input_pose]
|
|
)
|
|
else:
|
|
raise NotImplementedError("IK indicators are only implemented for robocasa simulator")
|
|
|
|
def set_sync_mode(self, sync_mode: bool, steps_per_action: int = 4):
|
|
"""When set to True, the simulator will wait for the action to be sent to it"""
|
|
if self.config["SIMULATOR"] == "robocasa":
|
|
self.sim.set_sync_mode(sync_mode, steps_per_action)
|
|
|
|
def reset(self):
|
|
if self.sim:
|
|
self.sim.reset()
|
|
|
|
def close(self):
|
|
if self.sim:
|
|
self.sim.close()
|
|
|
|
def robot_model(self) -> RobotModel:
|
|
return self.robot_model
|
|
|
|
def get_reward(self):
|
|
if self.sim:
|
|
return self.sim.get_reward()
|
|
|
|
def reset_obj_pos(self):
|
|
if hasattr(self.sim, "base_env") and hasattr(self.sim.base_env, "reset_obj_pos"):
|
|
self.sim.base_env.reset_obj_pos()
|
|
|
|
def get_eef_obs(self, q: np.ndarray) -> Dict[str, np.ndarray]:
|
|
self.robot_model.cache_forward_kinematics(q)
|
|
eef_obs = {}
|
|
for side in ["left", "right"]:
|
|
wrist_placement = self.robot_model.frame_placement(
|
|
self.robot_model.supplemental_info.hand_frame_names[side]
|
|
)
|
|
wrist_pos, wrist_quat = wrist_placement.translation[:3], R.from_matrix(
|
|
wrist_placement.rotation
|
|
).as_quat(scalar_first=True)
|
|
eef_obs[f"{side}_wrist_pose"] = np.concatenate([wrist_pos, wrist_quat])
|
|
|
|
return eef_obs
|
|
|
|
def get_joint_safety_status(self) -> bool:
|
|
"""Get current joint safety status from the last queue_action safety check.
|
|
|
|
Returns:
|
|
bool: True if joints are safe (no shutdown required), False if unsafe
|
|
"""
|
|
return self.last_safety_ok
|
|
|
|
def handle_keyboard_button(self, key):
|
|
# Only handles keyboard buttons for the mujoco simulator for now.
|
|
if self.use_sim and self.config.get("SIMULATOR", "mujoco") == "mujoco":
|
|
self.sim.handle_keyboard_button(key)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
env = G1Env(robot_model=instantiate_g1_robot_model(), wbc_version="gear_wbc")
|
|
while True:
|
|
print(env.observe())
|