You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
589 lines
24 KiB
589 lines
24 KiB
import sys
|
|
from typing import Any, Dict, Tuple
|
|
|
|
import gymnasium as gym
|
|
from gymnasium.envs.registration import register
|
|
import numpy as np
|
|
from robocasa.environments.locomanipulation import REGISTERED_LOCOMANIPULATION_ENVS
|
|
from robocasa.models.robots import GR00T_LOCOMANIP_ENVS_ROBOTS
|
|
from robosuite.environments.robot_env import RobotEnv
|
|
from scipy.spatial.transform import Rotation as R
|
|
|
|
from gr00t_wbc.control.envs.g1.utils.joint_safety import JointSafetyMonitor
|
|
from gr00t_wbc.control.envs.robocasa.utils.controller_utils import (
|
|
update_robosuite_controller_configs,
|
|
)
|
|
from gr00t_wbc.control.envs.robocasa.utils.robocasa_env import ( # noqa: F401
|
|
ALLOWED_LANGUAGE_CHARSET,
|
|
Gr00tLocomanipRoboCasaEnv,
|
|
)
|
|
from gr00t_wbc.control.envs.robocasa.utils.sim_utils import change_simulation_timestep
|
|
from gr00t_wbc.control.robot_model.instantiation import get_robot_type_and_model
|
|
from gr00t_wbc.control.utils.n1_utils import (
|
|
prepare_gym_space_for_eval,
|
|
prepare_observation_for_eval,
|
|
)
|
|
from gr00t_wbc.data.constants import RS_VIEW_CAMERA_HEIGHT, RS_VIEW_CAMERA_WIDTH
|
|
|
|
|
|
class SyncEnv(gym.Env):
|
|
MAX_MUJOCO_STATE_LEN = 800
|
|
|
|
def __init__(self, env_name, robot_name, **kwargs):
|
|
self.env_name = env_name
|
|
self.robot_name = robot_name
|
|
self.onscreen = kwargs.get("onscreen", True)
|
|
self.enable_gravity_compensation = kwargs.pop("enable_gravity_compensation", False)
|
|
self.gravity_compensation_joints = kwargs.pop("gravity_compensation_joints", ["arms"])
|
|
_, self.robot_model = get_robot_type_and_model(
|
|
robot_name, enable_waist_ik=kwargs.pop("enable_waist", False)
|
|
)
|
|
|
|
change_simulation_timestep(kwargs.get("sim_freq", 1 / 0.005))
|
|
|
|
env_kwargs = {
|
|
"onscreen": kwargs.get("onscreen", True),
|
|
"offscreen": kwargs.get("offscreen", False),
|
|
"renderer": kwargs.get("renderer", "mjviewer"),
|
|
"render_camera": kwargs.get("render_camera", "frontview"),
|
|
"camera_names": kwargs.get("camera_names", ["frontview"]),
|
|
"camera_heights": kwargs.get("camera_heights", None),
|
|
"camera_widths": kwargs.get("camera_widths", None),
|
|
"controller_configs": kwargs["controller_configs"],
|
|
"control_freq": kwargs.get("control_freq", 50),
|
|
"translucent_robot": kwargs.get("translucent_robot", True),
|
|
"ik_indicator": kwargs.get("ik_indicator", False),
|
|
"randomize_cameras": kwargs.get("randomize_cameras", True),
|
|
}
|
|
self.env = Gr00tLocomanipRoboCasaEnv(
|
|
env_name, robot_name, robot_model=self.robot_model, **env_kwargs
|
|
)
|
|
self.init_cache()
|
|
|
|
self.reset()
|
|
|
|
@property
|
|
def base_env(self) -> RobotEnv:
|
|
return self.env.env
|
|
|
|
def overwrite_floating_base_action(self, navigate_cmd):
|
|
if self.base_env.robots[0].robot_model.default_base == "FloatingLeggedBase":
|
|
self.env.unwrapped.overridden_floating_base_action = navigate_cmd
|
|
|
|
def get_mujoco_state_info(self):
|
|
mujoco_state = self.base_env.sim.get_state().flatten()
|
|
assert len(mujoco_state) < SyncEnv.MAX_MUJOCO_STATE_LEN
|
|
padding_width = SyncEnv.MAX_MUJOCO_STATE_LEN - len(mujoco_state)
|
|
padded_mujoco_state = np.pad(
|
|
mujoco_state, (0, padding_width), mode="constant", constant_values=0
|
|
)
|
|
max_mujoco_state_len = SyncEnv.MAX_MUJOCO_STATE_LEN
|
|
mujoco_state_len = len(mujoco_state)
|
|
mujoco_state = padded_mujoco_state.copy()
|
|
return max_mujoco_state_len, mujoco_state_len, mujoco_state
|
|
|
|
def reset_to(self, state: Dict[str, Any]) -> Dict[str, Any] | None:
|
|
if hasattr(self.base_env, "reset_to"):
|
|
self.base_env.reset_to(state)
|
|
else:
|
|
# todo: maybe update robosuite to have reset_to()
|
|
env = self.base_env
|
|
if "model_file" in state:
|
|
xml = env.edit_model_xml(state["model_file"])
|
|
env.reset_from_xml_string(xml)
|
|
env.sim.reset()
|
|
if "states" in state:
|
|
env.sim.set_state_from_flattened(state["states"])
|
|
env.sim.forward()
|
|
|
|
obs = self.env.force_update_observation(timestep=0)
|
|
self.cache["obs"] = obs
|
|
return
|
|
|
|
def get_state(self) -> Dict[str, Any]:
|
|
return self.base_env.get_state()
|
|
|
|
def is_success(self):
|
|
"""
|
|
Check if the task condition(s) is reached. Should return a dictionary
|
|
{ str: bool } with at least a "task" key for the overall task success,
|
|
and additional optional keys corresponding to other task criteria.
|
|
"""
|
|
# First, try to use the base environment's is_success method if it exists
|
|
if hasattr(self.base_env, "is_success"):
|
|
return self.base_env.is_success()
|
|
|
|
# Fall back to using _check_success if available
|
|
elif hasattr(self.base_env, "_check_success"):
|
|
succ = self.base_env._check_success()
|
|
if isinstance(succ, dict):
|
|
assert "task" in succ
|
|
return succ
|
|
return {"task": succ}
|
|
|
|
# If neither method exists, return failure
|
|
else:
|
|
return {"task": False}
|
|
|
|
def init_cache(self):
|
|
self.cache = {
|
|
"obs": None,
|
|
"reward": None,
|
|
"terminated": None,
|
|
"truncated": None,
|
|
"info": None,
|
|
}
|
|
|
|
def reset(self, seed=None, options=None) -> Tuple[Dict[str, any], Dict[str, any]]:
|
|
self.init_cache()
|
|
obs, info = self.env.reset(seed=seed, options=options)
|
|
self.cache["obs"] = obs
|
|
self.cache["reward"] = 0
|
|
self.cache["terminated"] = False
|
|
self.cache["truncated"] = False
|
|
self.cache["info"] = info
|
|
return self.observe(), info
|
|
|
|
def observe(self) -> Dict[str, any]:
|
|
# Get observations from body and hands
|
|
assert (
|
|
self.cache["obs"] is not None
|
|
), "Observation cache is not initialized, please reset the environment first"
|
|
raw_obs = self.cache["obs"]
|
|
|
|
# Body and hand joint measurements come in actuator order, so we need to convert them to joint order
|
|
whole_q = self.robot_model.get_configuration_from_actuated_joints(
|
|
body_actuated_joint_values=raw_obs["body_q"],
|
|
left_hand_actuated_joint_values=raw_obs["left_hand_q"],
|
|
right_hand_actuated_joint_values=raw_obs["right_hand_q"],
|
|
)
|
|
whole_dq = self.robot_model.get_configuration_from_actuated_joints(
|
|
body_actuated_joint_values=raw_obs["body_dq"],
|
|
left_hand_actuated_joint_values=raw_obs["left_hand_dq"],
|
|
right_hand_actuated_joint_values=raw_obs["right_hand_dq"],
|
|
)
|
|
whole_ddq = self.robot_model.get_configuration_from_actuated_joints(
|
|
body_actuated_joint_values=raw_obs["body_ddq"],
|
|
left_hand_actuated_joint_values=raw_obs["left_hand_ddq"],
|
|
right_hand_actuated_joint_values=raw_obs["right_hand_ddq"],
|
|
)
|
|
whole_tau_est = self.robot_model.get_configuration_from_actuated_joints(
|
|
body_actuated_joint_values=raw_obs["body_tau_est"],
|
|
left_hand_actuated_joint_values=raw_obs["left_hand_tau_est"],
|
|
right_hand_actuated_joint_values=raw_obs["right_hand_tau_est"],
|
|
)
|
|
eef_obs = self.get_eef_obs(whole_q)
|
|
|
|
obs = {
|
|
"q": whole_q,
|
|
"dq": whole_dq,
|
|
"ddq": whole_ddq,
|
|
"tau_est": whole_tau_est,
|
|
"floating_base_pose": raw_obs["floating_base_pose"],
|
|
"floating_base_vel": raw_obs["floating_base_vel"],
|
|
"floating_base_acc": raw_obs["floating_base_acc"],
|
|
"wrist_pose": np.concatenate([eef_obs["left_wrist_pose"], eef_obs["right_wrist_pose"]]),
|
|
}
|
|
|
|
# Add state keys for model input
|
|
obs = prepare_observation_for_eval(self.robot_model, obs)
|
|
|
|
obs["annotation.human.task_description"] = raw_obs["language.language_instruction"]
|
|
|
|
if hasattr(self.base_env, "get_privileged_obs_keys"):
|
|
for key in self.base_env.get_privileged_obs_keys():
|
|
obs[key] = raw_obs[key]
|
|
|
|
for key in raw_obs.keys():
|
|
if key.endswith("_image"):
|
|
obs[key] = raw_obs[key]
|
|
# TODO: add video.key without _image suffix for evaluation, convert to uint8, remove later
|
|
obs[f"video.{key.replace('_image', '')}"] = raw_obs[key]
|
|
return obs
|
|
|
|
def step(
|
|
self, action: Dict[str, any]
|
|
) -> Tuple[Dict[str, any], float, bool, bool, Dict[str, any]]:
|
|
self.queue_action(action)
|
|
return self.get_step_info()
|
|
|
|
def get_observation(self):
|
|
return self.base_env._get_observations() # assumes base env is robosuite
|
|
|
|
def get_step_info(self) -> Tuple[Dict[str, any], float, bool, bool, Dict[str, any]]:
|
|
return (
|
|
self.observe(),
|
|
self.cache["reward"],
|
|
self.cache["terminated"],
|
|
self.cache["truncated"],
|
|
self.cache["info"],
|
|
)
|
|
|
|
def convert_q_to_actuated_joint_order(self, q: np.ndarray) -> np.ndarray:
|
|
body_q = self.robot_model.get_body_actuated_joints(q)
|
|
left_hand_q = self.robot_model.get_hand_actuated_joints(q, side="left")
|
|
right_hand_q = self.robot_model.get_hand_actuated_joints(q, side="right")
|
|
whole_q = np.zeros_like(q)
|
|
whole_q[self.robot_model.get_joint_group_indices("body")] = body_q
|
|
whole_q[self.robot_model.get_joint_group_indices("left_hand")] = left_hand_q
|
|
whole_q[self.robot_model.get_joint_group_indices("right_hand")] = right_hand_q
|
|
|
|
return whole_q
|
|
|
|
def set_ik_indicator(self, teleop_cmd):
|
|
"""Set the IK indicators for the simulator"""
|
|
if "left_wrist" in teleop_cmd and "right_wrist" in teleop_cmd:
|
|
left_wrist_input_pose = teleop_cmd["left_wrist"]
|
|
right_wrist_input_pose = teleop_cmd["right_wrist"]
|
|
ik_wrapper = self.base_env
|
|
ik_wrapper.set_target_poses_outside_env([left_wrist_input_pose, right_wrist_input_pose])
|
|
|
|
def render(self):
|
|
if self.base_env.viewer is not None:
|
|
self.base_env.viewer.update()
|
|
if self.onscreen:
|
|
self.base_env.render()
|
|
|
|
def queue_action(self, action: Dict[str, any]):
|
|
# action is in pinocchio joint order, we need to convert it to actuator order
|
|
action_q = self.convert_q_to_actuated_joint_order(action["q"])
|
|
|
|
# Compute gravity compensation torques if enabled
|
|
tau_q = np.zeros_like(action_q)
|
|
if self.enable_gravity_compensation and self.robot_model is not None:
|
|
try:
|
|
# Get current robot configuration from cache (more efficient than observe())
|
|
raw_obs = self.cache["obs"]
|
|
|
|
# Convert from actuated joint order to joint order for Pinocchio
|
|
current_q_joint_order = self.robot_model.get_configuration_from_actuated_joints(
|
|
body_actuated_joint_values=raw_obs["body_q"],
|
|
left_hand_actuated_joint_values=raw_obs["left_hand_q"],
|
|
right_hand_actuated_joint_values=raw_obs["right_hand_q"],
|
|
)
|
|
|
|
# Compute gravity compensation in joint order using current robot configuration
|
|
gravity_torques_joint_order = self.robot_model.compute_gravity_compensation_torques(
|
|
current_q_joint_order, joint_groups=self.gravity_compensation_joints
|
|
)
|
|
|
|
# Convert gravity torques to actuated joint order
|
|
gravity_torques_actuated = self.convert_q_to_actuated_joint_order(
|
|
gravity_torques_joint_order
|
|
)
|
|
|
|
# Add gravity compensation to torques
|
|
tau_q += gravity_torques_actuated
|
|
|
|
except Exception as e:
|
|
print(f"Error applying gravity compensation in sync_env: {e}")
|
|
|
|
obs, reward, terminated, truncated, info = self.env.step({"q": action_q, "tau": tau_q})
|
|
self.cache["obs"] = obs
|
|
self.cache["reward"] = reward
|
|
self.cache["terminated"] = terminated
|
|
self.cache["truncated"] = truncated
|
|
self.cache["info"] = info
|
|
|
|
def queue_state(self, state: Dict[str, any]):
|
|
# This function is for debugging or cross-playback between sim and real only.
|
|
state_q = self.convert_q_to_actuated_joint_order(state["q"])
|
|
obs, reward, terminated, truncated, info = self.env.unwrapped.step_only_kinematics(
|
|
{"q": state_q}
|
|
)
|
|
self.cache["obs"] = obs
|
|
self.cache["reward"] = reward
|
|
self.cache["terminated"] = terminated
|
|
self.cache["truncated"] = truncated
|
|
self.cache["info"] = info
|
|
|
|
@property
|
|
def observation_space(self) -> gym.Space:
|
|
# @todo: check if the low and high bounds are correct for body_obs.
|
|
q_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,))
|
|
dq_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,))
|
|
ddq_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,))
|
|
tau_est_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,))
|
|
floating_base_pose_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(7,))
|
|
floating_base_vel_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(6,))
|
|
floating_base_acc_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(6,))
|
|
wrist_pose_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(7 + 7,))
|
|
|
|
obs_space = gym.spaces.Dict(
|
|
{
|
|
"floating_base_pose": floating_base_pose_space,
|
|
"floating_base_vel": floating_base_vel_space,
|
|
"floating_base_acc": floating_base_acc_space,
|
|
"q": q_space,
|
|
"dq": dq_space,
|
|
"ddq": ddq_space,
|
|
"tau_est": tau_est_space,
|
|
"wrist_pose": wrist_pose_space,
|
|
}
|
|
)
|
|
|
|
obs_space = prepare_gym_space_for_eval(self.robot_model, obs_space)
|
|
|
|
obs_space["annotation.human.task_description"] = gym.spaces.Text(
|
|
max_length=256, charset=ALLOWED_LANGUAGE_CHARSET
|
|
)
|
|
|
|
if hasattr(self.base_env, "get_privileged_obs_keys"):
|
|
for key, shape in self.base_env.get_privileged_obs_keys().items():
|
|
space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=shape)
|
|
obs_space[key] = space
|
|
|
|
robocasa_obs_space = self.env.observation_space
|
|
for key in robocasa_obs_space.keys():
|
|
if key.endswith("_image"):
|
|
space = gym.spaces.Box(
|
|
low=-np.inf, high=np.inf, shape=robocasa_obs_space[key].shape
|
|
)
|
|
obs_space[key] = space
|
|
# TODO: add video.key without _image suffix for evaluation, remove later
|
|
space_uint = gym.spaces.Box(low=0, high=255, shape=space.shape, dtype=np.uint8)
|
|
obs_space[f"video.{key.replace('_image', '')}"] = space_uint
|
|
|
|
return obs_space
|
|
|
|
def reset_obj_pos(self):
|
|
# For Tairan's goal-reaching task, a hacky way to reset the object position is needed.
|
|
if hasattr(self.base_env, "reset_obj_pos"):
|
|
self.base_env.reset_obj_pos()
|
|
|
|
@property
|
|
def action_space(self) -> gym.Space:
|
|
return self.env.action_space
|
|
|
|
def close(self):
|
|
self.env.close()
|
|
|
|
def __repr__(self):
|
|
return (
|
|
f"SyncEnv(env_name={self.env_name}, \n"
|
|
f" observation_space={self.observation_space}, \n"
|
|
f" action_space={self.action_space})"
|
|
)
|
|
|
|
def get_joint_gains(self):
|
|
controller = self.base_env.robots[0].composite_controller
|
|
|
|
gains = {}
|
|
key_mapping = {
|
|
"left": "left_arm",
|
|
"right": "right_arm",
|
|
"legs": "legs",
|
|
"torso": "waist",
|
|
"head": "neck",
|
|
}
|
|
for k in controller.part_controllers.keys():
|
|
if hasattr(controller.part_controllers[k], "kp"):
|
|
if k in key_mapping:
|
|
gains[key_mapping[k]] = controller.part_controllers[k].kp
|
|
else:
|
|
gains[k] = controller.part_controllers[k].kp
|
|
gains.update(
|
|
{
|
|
"left_hand": self.base_env.sim.model.actuator_gainprm[
|
|
self.base_env.robots[0]._ref_actuators_indexes_dict["left_gripper"], 0
|
|
],
|
|
"right_hand": self.base_env.sim.model.actuator_gainprm[
|
|
self.base_env.robots[0]._ref_actuators_indexes_dict["right_gripper"], 0
|
|
],
|
|
}
|
|
)
|
|
joint_gains = np.zeros(self.robot_model.num_dofs)
|
|
for k in gains.keys():
|
|
joint_gains[self.robot_model.get_joint_group_indices(k)] = gains[k]
|
|
return joint_gains
|
|
|
|
def get_joint_damping(self):
|
|
controller = self.base_env.robots[0].composite_controller
|
|
damping = {}
|
|
key_mapping = {
|
|
"left": "left_arm",
|
|
"right": "right_arm",
|
|
"legs": "legs",
|
|
"torso": "waist",
|
|
"head": "neck",
|
|
}
|
|
for k in controller.part_controllers.keys():
|
|
if hasattr(controller.part_controllers[k], "kd"):
|
|
if k in key_mapping:
|
|
damping[key_mapping[k]] = controller.part_controllers[k].kd
|
|
else:
|
|
damping[k] = controller.part_controllers[k].kd
|
|
damping.update(
|
|
{
|
|
"left_hand": -self.base_env.sim.model.actuator_biasprm[
|
|
self.base_env.robots[0]._ref_actuators_indexes_dict["left_gripper"], 2
|
|
],
|
|
"right_hand": -self.base_env.sim.model.actuator_biasprm[
|
|
self.base_env.robots[0]._ref_actuators_indexes_dict["right_gripper"], 2
|
|
],
|
|
}
|
|
)
|
|
joint_damping = np.zeros(self.robot_model.num_dofs)
|
|
for k in damping.keys():
|
|
joint_damping[self.robot_model.get_joint_group_indices(k)] = damping[k]
|
|
return joint_damping
|
|
|
|
def get_eef_obs(self, q: np.ndarray) -> Dict[str, np.ndarray]:
|
|
self.robot_model.cache_forward_kinematics(q)
|
|
eef_obs = {}
|
|
for side in ["left", "right"]:
|
|
wrist_placement = self.robot_model.frame_placement(
|
|
self.robot_model.supplemental_info.hand_frame_names[side]
|
|
)
|
|
wrist_pos, wrist_quat = wrist_placement.translation[:3], R.from_matrix(
|
|
wrist_placement.rotation
|
|
).as_quat(scalar_first=True)
|
|
eef_obs[f"{side}_wrist_pose"] = np.concatenate([wrist_pos, wrist_quat])
|
|
|
|
return eef_obs
|
|
|
|
|
|
class G1SyncEnv(SyncEnv):
|
|
def __init__(
|
|
self,
|
|
env_name,
|
|
robot_name,
|
|
**kwargs,
|
|
):
|
|
renderer = kwargs.get("renderer", "mjviewer")
|
|
if renderer == "mjviewer":
|
|
default_render_camera = ["robot0_oak_egoview"]
|
|
elif renderer in ["mujoco", "rerun"]:
|
|
default_render_camera = [
|
|
"robot0_oak_egoview",
|
|
"robot0_oak_left_monoview",
|
|
"robot0_oak_right_monoview",
|
|
]
|
|
else:
|
|
raise NotImplementedError
|
|
default_camera_names = [
|
|
"robot0_oak_egoview",
|
|
"robot0_oak_left_monoview",
|
|
"robot0_oak_right_monoview",
|
|
]
|
|
default_camera_heights = [
|
|
RS_VIEW_CAMERA_HEIGHT,
|
|
RS_VIEW_CAMERA_HEIGHT,
|
|
RS_VIEW_CAMERA_HEIGHT,
|
|
]
|
|
default_camera_widths = [
|
|
RS_VIEW_CAMERA_WIDTH,
|
|
RS_VIEW_CAMERA_WIDTH,
|
|
RS_VIEW_CAMERA_WIDTH,
|
|
]
|
|
|
|
kwargs.update(
|
|
{
|
|
"onscreen": kwargs.get("onscreen", True),
|
|
"offscreen": kwargs.get("offscreen", False),
|
|
"render_camera": kwargs.get("render_camera", default_render_camera),
|
|
"camera_names": kwargs.get("camera_names", default_camera_names),
|
|
"camera_heights": kwargs.get("camera_heights", default_camera_heights),
|
|
"camera_widths": kwargs.get("camera_widths", default_camera_widths),
|
|
"translucent_robot": kwargs.get("translucent_robot", False),
|
|
}
|
|
)
|
|
super().__init__(env_name=env_name, robot_name=robot_name, **kwargs)
|
|
|
|
# Initialize safety monitor (visualization disabled) - G1 specific
|
|
self.safety_monitor = JointSafetyMonitor(
|
|
self.robot_model,
|
|
enable_viz=False,
|
|
env_type="sim", # G1SyncEnv is always simulation
|
|
)
|
|
self.safety_monitor.ramp_duration_steps = 0
|
|
self.safety_monitor.startup_complete = True
|
|
self.safety_monitor.LOWER_BODY_VELOCITY_LIMIT = (
|
|
1e10 # disable lower body velocity limits since it impacts WBC
|
|
)
|
|
self.last_safety_ok = True # Track last safety status from queue_action
|
|
|
|
@property
|
|
def observation_space(self):
|
|
obs_space = super().observation_space
|
|
obs_space["torso_quat"] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(4,))
|
|
obs_space["torso_ang_vel"] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(3,))
|
|
return obs_space
|
|
|
|
def observe(self):
|
|
obs = super().observe()
|
|
obs["torso_quat"] = self.cache["obs"]["secondary_imu_quat"]
|
|
obs["torso_ang_vel"] = self.cache["obs"]["secondary_imu_vel"][3:6]
|
|
return obs
|
|
|
|
def queue_action(self, action: Dict[str, any]):
|
|
# Safety check before queuing action
|
|
obs = self.observe()
|
|
safety_result = self.safety_monitor.handle_violations(obs, action)
|
|
action = safety_result["action"]
|
|
# Save safety status for efficient access
|
|
self.last_safety_ok = not safety_result.get("shutdown_required", False)
|
|
# Check if shutdown is required
|
|
if safety_result["shutdown_required"]:
|
|
self.safety_monitor.trigger_system_shutdown()
|
|
|
|
# Call parent queue_action with potentially modified action
|
|
super().queue_action(action)
|
|
|
|
def get_joint_safety_status(self) -> bool:
|
|
"""Get current joint safety status from the last queue_action safety check.
|
|
|
|
Returns:
|
|
bool: True if joints are safe (no shutdown required), False if unsafe
|
|
"""
|
|
return self.last_safety_ok
|
|
|
|
|
|
def create_gym_sync_env_class(env, robot, robot_alias, wbc_version):
|
|
class_name = f"{env}_{robot}_{wbc_version}"
|
|
id_name = f"gr00tlocomanip_{robot_alias}/{class_name}"
|
|
|
|
if robot_alias.startswith("g1"):
|
|
env_class_type = G1SyncEnv
|
|
elif robot_alias.startswith("gr1"):
|
|
env_class_type = globals().get("GR1SyncEnv", SyncEnv)
|
|
else:
|
|
env_class_type = SyncEnv
|
|
|
|
controller_configs = update_robosuite_controller_configs(
|
|
robot=robot,
|
|
wbc_version=wbc_version,
|
|
)
|
|
|
|
env_class_type = type(
|
|
class_name,
|
|
(env_class_type,),
|
|
{
|
|
"__init__": lambda self, **kwargs: super(self.__class__, self).__init__(
|
|
env_name=env,
|
|
robot_name=robot,
|
|
controller_configs=controller_configs,
|
|
**kwargs,
|
|
)
|
|
},
|
|
)
|
|
|
|
current_module = sys.modules["gr00t_wbc.control.envs.robocasa.sync_env"]
|
|
setattr(current_module, class_name, env_class_type)
|
|
register(
|
|
id=id_name, # Unique ID for the environment
|
|
entry_point=f"gr00t_wbc.control.envs.robocasa.sync_env:{class_name}",
|
|
)
|
|
|
|
|
|
WBC_VERSION = "gear_wbc"
|
|
|
|
for ENV in REGISTERED_LOCOMANIPULATION_ENVS:
|
|
for ROBOT, ROBOT_ALIAS in GR00T_LOCOMANIP_ENVS_ROBOTS.items():
|
|
create_gym_sync_env_class(ENV, ROBOT, ROBOT_ALIAS, WBC_VERSION)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
env = gym.make("gr00tlocomanip_g1_sim/PnPBottle_g1_gear_wbc")
|
|
print(env.observation_space)
|