You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

301 lines
13 KiB

from dataclasses import dataclass
from typing import Any, Dict, List, Tuple
import numpy as np
from robocasa.models.robots import remove_mimic_joints
from robosuite.models.robots import RobotModel as RobosuiteRobotModel
from decoupled_wbc.control.robot_model import RobotModel
class Gr00tJointInfo:
"""
Mapping from decoupled_wbc actuated joint names to robocasa joint names.
"""
def __init__(self, robot_model: RobosuiteRobotModel):
self.robocasa_body_prefix = "robot0_"
self.robocasa_gripper_prefix = "gripper0_"
self.robot_model: RobotModel = robot_model
self.body_actuated_joint_names: List[str] = (
self.robot_model.supplemental_info.body_actuated_joints
)
self.left_hand_actuated_joint_names: List[str] = (
self.robot_model.supplemental_info.left_hand_actuated_joints
)
self.right_hand_actuated_joint_names: List[str] = (
self.robot_model.supplemental_info.right_hand_actuated_joints
)
self.actuated_joint_names: List[str] = self._get_gr00t_actuated_joint_names()
self.body_actuated_joint_to_index: Dict[str, int] = (
self._get_gr00t_body_actuated_joint_name_to_index()
)
self.gripper_actuated_joint_to_index: Tuple[Dict[str, int], Dict[str, int]] = (
self._get_gr00t_gripper_actuated_joint_name_to_index()
)
self.actuated_joint_name_to_index: Dict[str, int] = (
self._get_gr00t_actuated_joint_name_to_index()
)
def _get_gr00t_actuated_joint_names(self) -> List[str]:
"""Get list of gr00t actuated joint names ordered by their indices."""
if self.robot_model.supplemental_info is None:
raise ValueError("Robot model must have supplemental_info")
# Get joint names and indices
body_names = self.robot_model.supplemental_info.body_actuated_joints
left_hand_names = self.robot_model.supplemental_info.left_hand_actuated_joints
right_hand_names = self.robot_model.supplemental_info.right_hand_actuated_joints
body_indices = self.robot_model.get_joint_group_indices("body")
left_hand_indices = self.robot_model.get_joint_group_indices("left_hand")
right_hand_indices = self.robot_model.get_joint_group_indices("right_hand")
# Create a dictionary mapping index to name
index_to_name = {}
for name, idx in zip(body_names, body_indices):
index_to_name[idx] = self.robocasa_body_prefix + name
for name, idx in zip(left_hand_names, left_hand_indices):
index_to_name[idx] = self.robocasa_gripper_prefix + "left_" + name
for name, idx in zip(right_hand_names, right_hand_indices):
index_to_name[idx] = self.robocasa_gripper_prefix + "right_" + name
sorted_indices = sorted(index_to_name.keys())
all_actuated_joint_names = [index_to_name[idx] for idx in sorted_indices]
return all_actuated_joint_names
def _get_gr00t_body_actuated_joint_name_to_index(self) -> Dict[str, int]:
"""Get dictionary mapping gr00t actuated joint names to indices."""
if self.robot_model.supplemental_info is None:
raise ValueError("Robot model must have supplemental_info")
body_names = self.robot_model.supplemental_info.body_actuated_joints
body_indices = self.robot_model.get_joint_group_indices("body")
sorted_indices = np.argsort(body_indices)
sorted_names = [body_names[i] for i in sorted_indices]
return {self.robocasa_body_prefix + name: ii for ii, name in enumerate(sorted_names)}
def _get_gr00t_gripper_actuated_joint_name_to_index(
self,
) -> Tuple[Dict[str, int], Dict[str, int]]:
"""Get dictionary mapping gr00t actuated joint names to indices."""
if self.robot_model.supplemental_info is None:
raise ValueError("Robot model must have supplemental_info")
left_hand_names = self.robot_model.supplemental_info.left_hand_actuated_joints
right_hand_names = self.robot_model.supplemental_info.right_hand_actuated_joints
left_hand_indices = self.robot_model.get_joint_group_indices("left_hand")
right_hand_indices = self.robot_model.get_joint_group_indices("right_hand")
sorted_left_hand_indices = np.argsort(left_hand_indices)
sorted_right_hand_indices = np.argsort(right_hand_indices)
sorted_left_hand_names = [left_hand_names[i] for i in sorted_left_hand_indices]
sorted_right_hand_names = [right_hand_names[i] for i in sorted_right_hand_indices]
return (
{
self.robocasa_gripper_prefix + "left_" + name: ii
for ii, name in enumerate(sorted_left_hand_names)
},
{
self.robocasa_gripper_prefix + "right_" + name: ii
for ii, name in enumerate(sorted_right_hand_names)
},
)
def _get_gr00t_actuated_joint_name_to_index(self) -> Dict[str, int]:
"""Get dictionary mapping gr00t actuated joint names to indices."""
return {name: ii for ii, name in enumerate(self.actuated_joint_names)}
@dataclass
class Gr00tObsActionConverter:
"""
Converter to align simulation environment joint action space with real environment joint action space.
Handles joint order and range conversion.
"""
robot_model: RobotModel
robosuite_robot_model: RobosuiteRobotModel
robocasa_body_prefix: str = "robot0_"
robocasa_gripper_prefix: str = "gripper0_"
def __post_init__(self):
"""Initialize converter with robot configuration."""
self.robot_key = self.robot_model.supplemental_info.name
self.gr00t_joint_info = Gr00tJointInfo(self.robot_model)
self.robocasa_joint_names_for_each_part: Dict[str, List[str]] = (
self._get_robocasa_joint_names_for_each_part()
)
self.robocasa_actuator_names_for_each_part: Dict[str, List[str]] = (
self._get_robotcasa_actuator_names_for_each_part()
)
# Store mappings directly as class attributes
self.gr00t_joint_name_to_index = self.gr00t_joint_info.actuated_joint_name_to_index
self.gr00t_body_joint_name_to_index = self.gr00t_joint_info.body_actuated_joint_to_index
self.gr00t_gripper_joint_name_to_index = {
"left": self.gr00t_joint_info.gripper_actuated_joint_to_index[0],
"right": self.gr00t_joint_info.gripper_actuated_joint_to_index[1],
}
self.gr00t_to_robocasa_actuator_indices = self._get_actuator_mapping()
if self.robot_key == "GR1_Fourier":
self.joint_multiplier = (
lambda x: np.array([-1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1]) * x
)
self.actuator_multiplier = (
lambda x: np.array([-1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1]) * x
)
else:
self.joint_multiplier = lambda x: x
self.actuator_multiplier = lambda x: x
# Store DOF counts directly
self.body_dof = len(self.gr00t_joint_info.body_actuated_joint_names)
self.gripper_dof = len(self.gr00t_joint_info.left_hand_actuated_joint_names) + len(
self.gr00t_joint_info.right_hand_actuated_joint_names
)
self.whole_dof = self.body_dof + self.gripper_dof
self.body_nu = len(self.gr00t_joint_info.body_actuated_joint_names)
self.gripper_nu = len(self.gr00t_joint_info.left_hand_actuated_joint_names) + len(
self.gr00t_joint_info.right_hand_actuated_joint_names
)
self.whole_nu = self.body_nu + self.gripper_nu
def _get_robocasa_joint_names_for_each_part(self) -> Dict[str, List[str]]:
part_names = self.robosuite_robot_model._ref_joints_indexes_dict.keys()
robocasa_joint_names_for_each_part = {}
for part_name in part_names:
joint_indices = self.robosuite_robot_model._ref_joints_indexes_dict[part_name]
joint_names = [
self.robosuite_robot_model.sim.model.joint_id2name(j) for j in joint_indices
]
robocasa_joint_names_for_each_part[part_name] = joint_names
return robocasa_joint_names_for_each_part
def _get_robotcasa_actuator_names_for_each_part(self) -> Dict[str, List[str]]:
part_names = self.robosuite_robot_model._ref_actuators_indexes_dict.keys()
robocasa_actuator_names_for_each_part = {}
for part_name in part_names:
if part_name == "base":
continue
actuator_indices = self.robosuite_robot_model._ref_actuators_indexes_dict[part_name]
actuator_names = [
self.robosuite_robot_model.sim.model.actuator_id2name(j) for j in actuator_indices
]
robocasa_actuator_names_for_each_part[part_name] = actuator_names
return robocasa_actuator_names_for_each_part
def _get_actuator_mapping(self) -> Dict[str, List[int]]:
"""Get mapping from decoupled_wbc actuatored joint order to robocasa actuatored joint order for whole body."""
return {
part_name: [
self.gr00t_joint_info.actuated_joint_name_to_index[j]
for j in self.robocasa_actuator_names_for_each_part[part_name]
]
for part_name in self.robocasa_actuator_names_for_each_part.keys()
}
def check_action_dim_match(self, vec_dim: int) -> bool:
"""
Check if input vector dimension matches expected dimension.
Args:
vec_dim: Dimension of input vector
Returns:
bool: True if dimensions match
"""
return vec_dim == self.whole_dof
def gr00t_to_robocasa_action_dict(self, action_vec: np.ndarray) -> Dict[str, Any]:
"""
Convert gr00t flat action vector to robocasa dictionary mapping part names to actions.
Args:
robot: Robocasa robot model instance
action_vec: Full action vector array in gr00t actuated joint order
Returns:
dict: Mapping from part names to action vectors for robocasa
"""
if not self.check_action_dim_match(len(action_vec)):
raise ValueError(
f"Action vector dimension mismatch: {len(action_vec)} != {self.whole_dof}"
)
action_dict = {}
cc = self.robosuite_robot_model.composite_controller
for part_name, controller in cc.part_controllers.items():
if "gripper" in part_name:
robocasa_action = action_vec[self.gr00t_to_robocasa_actuator_indices[part_name]]
if self.actuator_multiplier is not None:
robocasa_action = self.actuator_multiplier(robocasa_action)
action_dict[part_name] = remove_mimic_joints(
cc.grippers[part_name], robocasa_action
)
elif "base" in part_name:
assert (
len(self.gr00t_to_robocasa_actuator_indices.get(part_name, [])) == 0
or self.robosuite_robot_model.default_base == "FloatingLeggedBase"
)
else:
action_dict[part_name] = action_vec[
self.gr00t_to_robocasa_actuator_indices[part_name]
]
return action_dict
def robocasa_to_gr00t_actuated_order(
self, joint_names: List[str], q: np.ndarray, obs_type: str = "body"
) -> np.ndarray:
"""
Convert observation from robocasa joint order to gr00t actuated joint order.
Args:
joint_names: List of joint names in robocasa order (with prefixes)
q: Joint positions corresponding to joint_names
obs_type: Type of observation ("body", "left_gripper", "right_gripper", or "whole")
Returns:
Joint positions in gr00t actuated joint order
"""
assert len(joint_names) == len(q), "Joint names and q must have the same length"
if obs_type == "body":
actuated_q = np.zeros(self.body_dof)
for i, jn in enumerate(joint_names):
actuated_q[self.gr00t_body_joint_name_to_index[jn]] = q[i]
elif obs_type == "left_gripper":
actuated_q = np.zeros(self.gripper_dof // 2)
for i, jn in enumerate(joint_names):
actuated_q[self.gr00t_gripper_joint_name_to_index["left"][jn]] = q[i]
elif obs_type == "right_gripper":
actuated_q = np.zeros(self.gripper_dof // 2)
for i, jn in enumerate(joint_names):
actuated_q[self.gr00t_gripper_joint_name_to_index["right"][jn]] = q[i]
elif obs_type == "whole":
actuated_q = np.zeros(self.whole_dof)
for i, jn in enumerate(joint_names):
actuated_q[self.gr00t_joint_name_to_index[jn]] = q[i]
else:
raise ValueError(f"Unknown observation type: {obs_type}")
return actuated_q
def gr00t_to_robocasa_joint_order(
self, joint_names: List[str], q_in_actuated_order: np.ndarray
) -> np.ndarray:
"""
Convert gr00t actuated joint order to robocasa joint order.
Args:
joint_names: List of joint names in robocasa order (with prefixes)
q_in_actuated_order: Joint positions corresponding to joint_names in gr00t actuated joint order
Returns:
Joint positions in robocasa joint order
"""
q = np.zeros(len(joint_names))
for i, jn in enumerate(joint_names):
q[i] = q_in_actuated_order[self.gr00t_joint_name_to_index[jn]]
return q