You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

772 lines
31 KiB

import argparse
import pathlib
from pathlib import Path
import threading
from threading import Lock, Thread
from typing import Dict
import mujoco
import mujoco.viewer
import numpy as np
import rclpy
from unitree_sdk2py.core.channel import ChannelFactoryInitialize
import yaml
from decoupled_wbc.control.envs.g1.sim.image_publish_utils import ImagePublishProcess
from decoupled_wbc.control.envs.g1.sim.metric_utils import check_contact, check_height
from decoupled_wbc.control.envs.g1.sim.sim_utilts import get_subtree_body_names
from decoupled_wbc.control.envs.g1.sim.unitree_sdk2py_bridge import ElasticBand, UnitreeSdk2Bridge
DECOUPLED_WBC_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent.parent
class DefaultEnv:
"""Base environment class that handles simulation environment setup and step"""
def __init__(
self,
config: Dict[str, any],
env_name: str = "default",
camera_configs: Dict[str, any] = {},
onscreen: bool = False,
offscreen: bool = False,
enable_image_publish: bool = False,
):
# global_view is only set up for this specifc scene for now.
if config["ROBOT_SCENE"] == "decoupled_wbc/control/robot_model/model_data/g1/scene_29dof.xml":
camera_configs["global_view"] = {
"height": 400,
"width": 400,
}
self.config = config
self.env_name = env_name
self.num_body_dof = self.config["NUM_JOINTS"]
self.num_hand_dof = self.config["NUM_HAND_JOINTS"]
self.sim_dt = self.config["SIMULATE_DT"]
self.obs = None
self.torques = np.zeros(self.num_body_dof + self.num_hand_dof * 2)
self.torque_limit = np.array(self.config["motor_effort_limit_list"])
self.camera_configs = camera_configs
# Thread safety lock
self.reward_lock = Lock()
# Unitree bridge will be initialized by the simulator
self.unitree_bridge = None
# Store display mode
self.onscreen = onscreen
# Initialize scene (defined in subclasses)
self.init_scene()
self.last_reward = 0
# Setup offscreen rendering if needed
self.offscreen = offscreen
if self.offscreen:
self.init_renderers()
self.image_dt = self.config.get("IMAGE_DT", 0.033333)
self.image_publish_process = None
def start_image_publish_subprocess(self, start_method: str = "spawn", camera_port: int = 5555):
# Use spawn method for better GIL isolation, or configured method
if len(self.camera_configs) == 0:
print(
"Warning: No camera configs provided, image publishing subprocess will not be started"
)
return
start_method = self.config.get("MP_START_METHOD", "spawn")
self.image_publish_process = ImagePublishProcess(
camera_configs=self.camera_configs,
image_dt=self.image_dt,
zmq_port=camera_port,
start_method=start_method,
verbose=self.config.get("verbose", False),
)
self.image_publish_process.start_process()
def init_scene(self):
"""Initialize the default robot scene"""
self.mj_model = mujoco.MjModel.from_xml_path(
str(pathlib.Path(DECOUPLED_WBC_ROOT) / self.config["ROBOT_SCENE"])
)
self.mj_data = mujoco.MjData(self.mj_model)
self.mj_model.opt.timestep = self.sim_dt
self.torso_index = mujoco.mj_name2id(self.mj_model, mujoco.mjtObj.mjOBJ_BODY, "torso_link")
self.root_body = "pelvis"
# Enable the elastic band
if self.config["ENABLE_ELASTIC_BAND"]:
self.elastic_band = ElasticBand()
if "g1" in self.config["ROBOT_TYPE"]:
if self.config["enable_waist"]:
self.band_attached_link = self.mj_model.body("pelvis").id
else:
self.band_attached_link = self.mj_model.body("torso_link").id
elif "h1" in self.config["ROBOT_TYPE"]:
self.band_attached_link = self.mj_model.body("torso_link").id
else:
self.band_attached_link = self.mj_model.body("base_link").id
if self.onscreen:
self.viewer = mujoco.viewer.launch_passive(
self.mj_model,
self.mj_data,
key_callback=self.elastic_band.MujuocoKeyCallback,
show_left_ui=False,
show_right_ui=False,
)
else:
mujoco.mj_forward(self.mj_model, self.mj_data)
self.viewer = None
else:
if self.onscreen:
self.viewer = mujoco.viewer.launch_passive(
self.mj_model, self.mj_data, show_left_ui=False, show_right_ui=False
)
else:
mujoco.mj_forward(self.mj_model, self.mj_data)
self.viewer = None
if self.viewer:
# viewer camera
self.viewer.cam.azimuth = 120 # Horizontal rotation in degrees
self.viewer.cam.elevation = -30 # Vertical tilt in degrees
self.viewer.cam.distance = 2.0 # Distance from camera to target
self.viewer.cam.lookat = np.array([0, 0, 0.5]) # Point the camera is looking at
# Note that the actuator order is the same as the joint order in the mujoco model.
self.body_joint_index = []
self.left_hand_index = []
self.right_hand_index = []
for i in range(self.mj_model.njnt):
name = self.mj_model.joint(i).name
if any(
[
part_name in name
for part_name in ["hip", "knee", "ankle", "waist", "shoulder", "elbow", "wrist"]
]
):
self.body_joint_index.append(i)
elif "left_hand" in name:
self.left_hand_index.append(i)
elif "right_hand" in name:
self.right_hand_index.append(i)
assert len(self.body_joint_index) == self.config["NUM_JOINTS"]
assert len(self.left_hand_index) == self.config["NUM_HAND_JOINTS"]
assert len(self.right_hand_index) == self.config["NUM_HAND_JOINTS"]
self.body_joint_index = np.array(self.body_joint_index)
self.left_hand_index = np.array(self.left_hand_index)
self.right_hand_index = np.array(self.right_hand_index)
def init_renderers(self):
# Initialize camera renderers
self.renderers = {}
for camera_name, camera_config in self.camera_configs.items():
renderer = mujoco.Renderer(
self.mj_model, height=camera_config["height"], width=camera_config["width"]
)
self.renderers[camera_name] = renderer
def compute_body_torques(self) -> np.ndarray:
"""Compute body torques based on the current robot state"""
body_torques = np.zeros(self.num_body_dof)
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd:
for i in range(self.unitree_bridge.num_body_motor):
if self.unitree_bridge.use_sensor:
body_torques[i] = (
self.unitree_bridge.low_cmd.motor_cmd[i].tau
+ self.unitree_bridge.low_cmd.motor_cmd[i].kp
* (self.unitree_bridge.low_cmd.motor_cmd[i].q - self.mj_data.sensordata[i])
+ self.unitree_bridge.low_cmd.motor_cmd[i].kd
* (
self.unitree_bridge.low_cmd.motor_cmd[i].dq
- self.mj_data.sensordata[i + self.unitree_bridge.num_body_motor]
)
)
else:
body_torques[i] = (
self.unitree_bridge.low_cmd.motor_cmd[i].tau
+ self.unitree_bridge.low_cmd.motor_cmd[i].kp
* (
self.unitree_bridge.low_cmd.motor_cmd[i].q
- self.mj_data.qpos[self.body_joint_index[i] + 7 - 1]
)
+ self.unitree_bridge.low_cmd.motor_cmd[i].kd
* (
self.unitree_bridge.low_cmd.motor_cmd[i].dq
- self.mj_data.qvel[self.body_joint_index[i] + 6 - 1]
)
)
return body_torques
def compute_hand_torques(self) -> np.ndarray:
"""Compute hand torques based on the current robot state"""
left_hand_torques = np.zeros(self.num_hand_dof)
right_hand_torques = np.zeros(self.num_hand_dof)
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd:
for i in range(self.unitree_bridge.num_hand_motor):
left_hand_torques[i] = (
self.unitree_bridge.left_hand_cmd.motor_cmd[i].tau
+ self.unitree_bridge.left_hand_cmd.motor_cmd[i].kp
* (
self.unitree_bridge.left_hand_cmd.motor_cmd[i].q
- self.mj_data.qpos[self.left_hand_index[i] + 7 - 1]
)
+ self.unitree_bridge.left_hand_cmd.motor_cmd[i].kd
* (
self.unitree_bridge.left_hand_cmd.motor_cmd[i].dq
- self.mj_data.qvel[self.left_hand_index[i] + 6 - 1]
)
)
right_hand_torques[i] = (
self.unitree_bridge.right_hand_cmd.motor_cmd[i].tau
+ self.unitree_bridge.right_hand_cmd.motor_cmd[i].kp
* (
self.unitree_bridge.right_hand_cmd.motor_cmd[i].q
- self.mj_data.qpos[self.right_hand_index[i] + 7 - 1]
)
+ self.unitree_bridge.right_hand_cmd.motor_cmd[i].kd
* (
self.unitree_bridge.right_hand_cmd.motor_cmd[i].dq
- self.mj_data.qvel[self.right_hand_index[i] + 6 - 1]
)
)
return np.concatenate((left_hand_torques, right_hand_torques))
def compute_body_qpos(self) -> np.ndarray:
"""Compute body joint positions based on the current command"""
body_qpos = np.zeros(self.num_body_dof)
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd:
for i in range(self.unitree_bridge.num_body_motor):
body_qpos[i] = self.unitree_bridge.low_cmd.motor_cmd[i].q
return body_qpos
def compute_hand_qpos(self) -> np.ndarray:
"""Compute hand joint positions based on the current command"""
hand_qpos = np.zeros(self.num_hand_dof * 2)
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd:
for i in range(self.unitree_bridge.num_hand_motor):
hand_qpos[i] = self.unitree_bridge.left_hand_cmd.motor_cmd[i].q
hand_qpos[i + self.num_hand_dof] = self.unitree_bridge.right_hand_cmd.motor_cmd[i].q
return hand_qpos
def prepare_obs(self) -> Dict[str, any]:
"""Prepare observation dictionary from the current robot state"""
obs = {}
obs["floating_base_pose"] = self.mj_data.qpos[:7]
obs["floating_base_vel"] = self.mj_data.qvel[:6]
obs["floating_base_acc"] = self.mj_data.qacc[:6]
obs["secondary_imu_quat"] = self.mj_data.xquat[self.torso_index]
obs["secondary_imu_vel"] = self.mj_data.cvel[self.torso_index]
obs["body_q"] = self.mj_data.qpos[self.body_joint_index + 7 - 1]
obs["body_dq"] = self.mj_data.qvel[self.body_joint_index + 6 - 1]
obs["body_ddq"] = self.mj_data.qacc[self.body_joint_index + 6 - 1]
obs["body_tau_est"] = self.mj_data.actuator_force[self.body_joint_index - 1]
if self.num_hand_dof > 0:
obs["left_hand_q"] = self.mj_data.qpos[self.left_hand_index + 7 - 1]
obs["left_hand_dq"] = self.mj_data.qvel[self.left_hand_index + 6 - 1]
obs["left_hand_ddq"] = self.mj_data.qacc[self.left_hand_index + 6 - 1]
obs["left_hand_tau_est"] = self.mj_data.actuator_force[self.left_hand_index - 1]
obs["right_hand_q"] = self.mj_data.qpos[self.right_hand_index + 7 - 1]
obs["right_hand_dq"] = self.mj_data.qvel[self.right_hand_index + 6 - 1]
obs["right_hand_ddq"] = self.mj_data.qacc[self.right_hand_index + 6 - 1]
obs["right_hand_tau_est"] = self.mj_data.actuator_force[self.right_hand_index - 1]
obs["time"] = self.mj_data.time
return obs
def sim_step(self):
self.obs = self.prepare_obs()
self.unitree_bridge.PublishLowState(self.obs)
if self.unitree_bridge.joystick:
self.unitree_bridge.PublishWirelessController()
if self.config["ENABLE_ELASTIC_BAND"]:
if self.elastic_band.enable:
# Get Cartesian pose and velocity of the band_attached_link
pose = np.concatenate(
[
self.mj_data.xpos[self.band_attached_link], # link position in world
self.mj_data.xquat[
self.band_attached_link
], # link quaternion in world [w,x,y,z]
np.zeros(6), # placeholder for velocity
]
)
# Get velocity in world frame
mujoco.mj_objectVelocity(
self.mj_model,
self.mj_data,
mujoco.mjtObj.mjOBJ_BODY,
self.band_attached_link,
pose[7:13],
0, # 0 for world frame
)
# Reorder velocity from [ang, lin] to [lin, ang]
pose[7:10], pose[10:13] = pose[10:13], pose[7:10].copy()
self.mj_data.xfrc_applied[self.band_attached_link] = self.elastic_band.Advance(pose)
else:
# explicitly resetting the force when the band is not enabled
self.mj_data.xfrc_applied[self.band_attached_link] = np.zeros(6)
body_torques = self.compute_body_torques()
hand_torques = self.compute_hand_torques()
self.torques[self.body_joint_index - 1] = body_torques
if self.num_hand_dof > 0:
self.torques[self.left_hand_index - 1] = hand_torques[: self.num_hand_dof]
self.torques[self.right_hand_index - 1] = hand_torques[self.num_hand_dof :]
self.torques = np.clip(self.torques, -self.torque_limit, self.torque_limit)
if self.config["FREE_BASE"]:
self.mj_data.ctrl = np.concatenate((np.zeros(6), self.torques))
else:
self.mj_data.ctrl = self.torques
mujoco.mj_step(self.mj_model, self.mj_data)
# self.check_self_collision()
def kinematics_step(self):
"""
Run kinematics only: compute the qpos of the robot and directly set the qpos.
For debugging purposes.
"""
if self.unitree_bridge is not None:
self.unitree_bridge.PublishLowState(self.prepare_obs())
if self.unitree_bridge.joystick:
self.unitree_bridge.PublishWirelessController()
if self.config["ENABLE_ELASTIC_BAND"]:
if self.elastic_band.enable:
# Get Cartesian pose and velocity of the band_attached_link
pose = np.concatenate(
[
self.mj_data.xpos[self.band_attached_link], # link position in world
self.mj_data.xquat[
self.band_attached_link
], # link quaternion in world [w,x,y,z]
np.zeros(6), # placeholder for velocity
]
)
# Get velocity in world frame
mujoco.mj_objectVelocity(
self.mj_model,
self.mj_data,
mujoco.mjtObj.mjOBJ_BODY,
self.band_attached_link,
pose[7:13],
0, # 0 for world frame
)
# Reorder velocity from [ang, lin] to [lin, ang]
pose[7:10], pose[10:13] = pose[10:13], pose[7:10].copy()
self.mj_data.xfrc_applied[self.band_attached_link] = self.elastic_band.Advance(pose)
else:
# explicitly resetting the force when the band is not enabled
self.mj_data.xfrc_applied[self.band_attached_link] = np.zeros(6)
body_qpos = self.compute_body_qpos() # (num_body_dof,)
hand_qpos = self.compute_hand_qpos() # (num_hand_dof * 2,)
self.mj_data.qpos[self.body_joint_index + 7 - 1] = body_qpos
self.mj_data.qpos[self.left_hand_index + 7 - 1] = hand_qpos[: self.num_hand_dof]
self.mj_data.qpos[self.right_hand_index + 7 - 1] = hand_qpos[self.num_hand_dof :]
mujoco.mj_kinematics(self.mj_model, self.mj_data)
mujoco.mj_comPos(self.mj_model, self.mj_data)
def apply_perturbation(self, key):
"""Apply perturbation to the robot"""
# Add velocity perturbations in body frame
perturbation_x_body = 0.0 # forward/backward in body frame
perturbation_y_body = 0.0 # left/right in body frame
if key == "up":
perturbation_x_body = 1.0 # forward
elif key == "down":
perturbation_x_body = -1.0 # backward
elif key == "left":
perturbation_y_body = 1.0 # left
elif key == "right":
perturbation_y_body = -1.0 # right
# Transform body frame velocity to world frame using MuJoCo's rotation
vel_body = np.array([perturbation_x_body, perturbation_y_body, 0.0])
vel_world = np.zeros(3)
base_quat = self.mj_data.qpos[3:7] # [w, x, y, z] quaternion
# Use MuJoCo's robust quaternion rotation (handles invalid quaternions automatically)
mujoco.mju_rotVecQuat(vel_world, vel_body, base_quat)
# Apply to base linear velocity in world frame
self.mj_data.qvel[0] += vel_world[0] # world X velocity
self.mj_data.qvel[1] += vel_world[1] # world Y velocity
# Update dynamics after velocity change
mujoco.mj_forward(self.mj_model, self.mj_data)
def update_viewer(self):
if self.viewer is not None:
self.viewer.sync()
def update_viewer_camera(self):
if self.viewer is not None:
if self.viewer.cam.type == mujoco.mjtCamera.mjCAMERA_TRACKING:
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FREE
else:
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_TRACKING
def update_reward(self):
"""Calculate reward. Should be implemented by subclasses."""
with self.reward_lock:
self.last_reward = 0
def get_reward(self):
"""Thread-safe way to get the last calculated reward."""
with self.reward_lock:
return self.last_reward
def set_unitree_bridge(self, unitree_bridge):
"""Set the unitree bridge from the simulator"""
self.unitree_bridge = unitree_bridge
def get_privileged_obs(self):
"""Get privileged observation. Should be implemented by subclasses."""
return {}
def update_render_caches(self):
"""Update render cache and shared memory for subprocess."""
render_caches = {}
for camera_name, camera_config in self.camera_configs.items():
renderer = self.renderers[camera_name]
if "params" in camera_config:
renderer.update_scene(self.mj_data, camera=camera_config["params"])
else:
renderer.update_scene(self.mj_data, camera=camera_name)
render_caches[camera_name + "_image"] = renderer.render()
# Update shared memory if image publishing process is available
if self.image_publish_process is not None:
self.image_publish_process.update_shared_memory(render_caches)
return render_caches
def handle_keyboard_button(self, key):
if self.elastic_band is not None:
self.elastic_band.handle_keyboard_button(key)
if key == "backspace":
self.reset()
if key == "v":
self.update_viewer_camera()
if key in ["up", "down", "left", "right"]:
self.apply_perturbation(key)
def check_fall(self):
"""Check if the robot has fallen"""
self.fall = False
if self.mj_data.qpos[2] < 0.2:
self.fall = True
print(f"Warning: Robot has fallen, height: {self.mj_data.qpos[2]:.3f} m")
if self.fall:
self.reset()
def check_self_collision(self):
"""Check for self-collision of the robot"""
robot_bodies = get_subtree_body_names(self.mj_model, self.mj_model.body(self.root_body).id)
self_collision, contact_bodies = check_contact(
self.mj_model, self.mj_data, robot_bodies, robot_bodies, return_all_contact_bodies=True
)
if self_collision:
print(f"Warning: Self-collision detected: {contact_bodies}")
return self_collision
def reset(self):
mujoco.mj_resetData(self.mj_model, self.mj_data)
class CubeEnv(DefaultEnv):
"""Environment with a cube object for pick and place tasks"""
def __init__(
self,
config: Dict[str, any],
onscreen: bool = False,
offscreen: bool = False,
enable_image_publish: bool = False,
):
# Override the robot scene
config = config.copy() # Create a copy to avoid modifying the original
config["ROBOT_SCENE"] = "decoupled_wbc/control/robot_model/model_data/g1/pnp_cube_43dof.xml"
super().__init__(config, "cube", {}, onscreen, offscreen, enable_image_publish)
def update_reward(self):
"""Calculate reward based on gripper contact with cube and cube height"""
right_hand_body = [
"right_hand_thumb_2_link",
"right_hand_middle_1_link",
"right_hand_index_1_link",
]
gripper_cube_contact = check_contact(
self.mj_model, self.mj_data, right_hand_body, "cube_body"
)
cube_lifted = check_height(self.mj_model, self.mj_data, "cube", 0.85, 2.0)
with self.reward_lock:
self.last_reward = gripper_cube_contact & cube_lifted
class BoxEnv(DefaultEnv):
"""Environment with a box object for manipulation tasks"""
def __init__(
self,
config: Dict[str, any],
onscreen: bool = False,
offscreen: bool = False,
enable_image_publish: bool = False,
):
# Override the robot scene
config = config.copy() # Create a copy to avoid modifying the original
config["ROBOT_SCENE"] = "decoupled_wbc/control/robot_model/model_data/g1/lift_box_43dof.xml"
super().__init__(config, "box", {}, onscreen, offscreen, enable_image_publish)
def reward(self):
"""Calculate reward based on gripper contact with cube and cube height"""
left_hand_body = [
"left_hand_thumb_2_link",
"left_hand_middle_1_link",
"left_hand_index_1_link",
]
right_hand_body = [
"right_hand_thumb_2_link",
"right_hand_middle_1_link",
"right_hand_index_1_link",
]
gripper_box_contact = check_contact(self.mj_model, self.mj_data, left_hand_body, "box_body")
gripper_box_contact &= check_contact(
self.mj_model, self.mj_data, right_hand_body, "box_body"
)
box_lifted = check_height(self.mj_model, self.mj_data, "box", 0.92, 2.0)
print("gripper_box_contact: ", gripper_box_contact, "box_lifted: ", box_lifted)
with self.reward_lock:
self.last_reward = gripper_box_contact & box_lifted
return self.last_reward
class BottleEnv(DefaultEnv):
"""Environment with a cylinder object for manipulation tasks"""
def __init__(
self,
config: Dict[str, any],
onscreen: bool = False,
offscreen: bool = False,
enable_image_publish: bool = False,
):
# Override the robot scene
config = config.copy() # Create a copy to avoid modifying the original
config["ROBOT_SCENE"] = "decoupled_wbc/control/robot_model/model_data/g1/pnp_bottle_43dof.xml"
camera_configs = {
"egoview": {
"height": 400,
"width": 400,
},
}
super().__init__(
config, "cylinder", camera_configs, onscreen, offscreen, enable_image_publish
)
self.bottle_body = self.mj_model.body("bottle_body")
self.bottle_geom = self.mj_model.geom("bottle")
if self.viewer is not None:
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED
self.viewer.cam.fixedcamid = self.mj_model.camera("egoview").id
def update_reward(self):
"""Calculate reward based on gripper contact with cylinder and cylinder height"""
pass
def get_privileged_obs(self):
obs_pos = self.mj_data.xpos[self.bottle_body.id]
obs_quat = self.mj_data.xquat[self.bottle_body.id]
return {"bottle_pos": obs_pos, "bottle_quat": obs_quat}
class BaseSimulator:
"""Base simulator class that handles initialization and running of simulations"""
def __init__(self, config: Dict[str, any], env_name: str = "default", **kwargs):
self.config = config
self.env_name = env_name
# Initialize ROS 2 node
if not rclpy.ok():
rclpy.init()
self.node = rclpy.create_node("sim_mujoco")
self.thread = threading.Thread(target=rclpy.spin, args=(self.node,), daemon=True)
self.thread.start()
else:
self.thread = None
executor = rclpy.get_global_executor()
self.node = executor.get_nodes()[0] # will only take the first node
# Create rate objects for different update frequencies
self.sim_dt = self.config["SIMULATE_DT"]
self.reward_dt = self.config.get("REWARD_DT", 0.02)
self.image_dt = self.config.get("IMAGE_DT", 0.033333)
self.viewer_dt = self.config.get("VIEWER_DT", 0.02)
self.rate = self.node.create_rate(1 / self.sim_dt)
# Create the appropriate environment based on name
if env_name == "default":
self.sim_env = DefaultEnv(config, env_name, **kwargs)
elif env_name == "pnp_cube":
self.sim_env = CubeEnv(config, **kwargs)
elif env_name == "lift_box":
self.sim_env = BoxEnv(config, **kwargs)
elif env_name == "pnp_bottle":
self.sim_env = BottleEnv(config, **kwargs)
else:
raise ValueError(f"Invalid environment name: {env_name}")
# Initialize the DDS communication layer - should be safe to call multiple times
try:
if self.config.get("INTERFACE", None):
ChannelFactoryInitialize(self.config["DOMAIN_ID"], self.config["INTERFACE"])
else:
ChannelFactoryInitialize(self.config["DOMAIN_ID"])
except Exception as e:
# If it fails because it's already initialized, that's okay
print(f"Note: Channel factory initialization attempt: {e}")
# Initialize the unitree bridge and pass it to the environment
self.init_unitree_bridge()
self.sim_env.set_unitree_bridge(self.unitree_bridge)
# Initialize additional components
self.init_subscriber()
self.init_publisher()
self.sim_thread = None
def start_as_thread(self):
# Create simulation thread
self.sim_thread = Thread(target=self.start)
self.sim_thread.start()
def start_image_publish_subprocess(self, start_method: str = "spawn", camera_port: int = 5555):
"""Start the image publish subprocess"""
self.sim_env.start_image_publish_subprocess(start_method, camera_port)
def init_subscriber(self):
"""Initialize subscribers. Can be overridden by subclasses."""
pass
def init_publisher(self):
"""Initialize publishers. Can be overridden by subclasses."""
pass
def init_unitree_bridge(self):
"""Initialize the unitree SDK bridge"""
self.unitree_bridge = UnitreeSdk2Bridge(self.config)
if self.config["USE_JOYSTICK"]:
self.unitree_bridge.SetupJoystick(
device_id=self.config["JOYSTICK_DEVICE"], js_type=self.config["JOYSTICK_TYPE"]
)
def start(self):
"""Main simulation loop"""
sim_cnt = 0
try:
while (
self.sim_env.viewer and self.sim_env.viewer.is_running()
) or self.sim_env.viewer is None:
# Run simulation step
self.sim_env.sim_step()
# Update viewer at viewer rate
if sim_cnt % int(self.viewer_dt / self.sim_dt) == 0:
self.sim_env.update_viewer()
# Calculate reward at reward rate
if sim_cnt % int(self.reward_dt / self.sim_dt) == 0:
self.sim_env.update_reward()
# Update render caches at image rate
if sim_cnt % int(self.image_dt / self.sim_dt) == 0:
self.sim_env.update_render_caches()
# Sleep to maintain correct rate
self.rate.sleep()
sim_cnt += 1
except rclpy.exceptions.ROSInterruptException:
# This is expected when ROS shuts down - exit cleanly
pass
except Exception:
self.close()
def __del__(self):
"""Clean up resources when simulator is deleted"""
self.close()
def reset(self):
"""Reset the simulation. Can be overridden by subclasses."""
self.sim_env.reset()
def close(self):
"""Close the simulation. Can be overridden by subclasses."""
try:
# Stop image publishing subprocess
if self.sim_env.image_publish_process is not None:
self.sim_env.image_publish_process.stop()
# Close viewer
if hasattr(self.sim_env, "viewer") and self.sim_env.viewer is not None:
self.sim_env.viewer.close()
# Shutdown ROS
if rclpy.ok():
rclpy.shutdown()
except Exception as e:
print(f"Warning during close: {e}")
def get_privileged_obs(self):
obs = self.sim_env.get_privileged_obs()
# TODO: add ros2 topic to get privileged obs
return obs
def handle_keyboard_button(self, key):
# Only handles keyboard buttons for default env.
if self.env_name == "default":
self.sim_env.handle_keyboard_button(key)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Robot")
parser.add_argument(
"--config",
type=str,
default="./decoupled_wbc/control/main/teleop/configs/g1_29dof_gear_wbc.yaml",
help="config file",
)
args = parser.parse_args()
with open(args.config, "r") as file:
config = yaml.load(file, Loader=yaml.FullLoader)
if config.get("INTERFACE", None):
ChannelFactoryInitialize(config["DOMAIN_ID"], config["INTERFACE"])
else:
ChannelFactoryInitialize(config["DOMAIN_ID"])
simulation = BaseSimulator(config)
simulation.start_as_thread()