You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
305 lines
11 KiB
305 lines
11 KiB
from abc import abstractmethod
|
|
import threading
|
|
import time
|
|
from typing import Any, Dict, Tuple
|
|
|
|
import mujoco
|
|
import numpy as np
|
|
import rclpy
|
|
|
|
from gr00t_wbc.control.envs.g1.sim.image_publish_utils import ImagePublishProcess
|
|
from gr00t_wbc.control.envs.robocasa.utils.robocasa_env import (
|
|
Gr00tLocomanipRoboCasaEnv,
|
|
) # noqa: F401
|
|
from gr00t_wbc.control.robot_model.robot_model import RobotModel
|
|
from gr00t_wbc.control.utils.keyboard_dispatcher import KeyboardListenerSubscriber
|
|
|
|
|
|
class RoboCasaEnvServer:
|
|
"""
|
|
This class is responsible for running the simulation environment loop in a separate thread.
|
|
It communicates with the main thread via the `publish_obs` and `get_action` methods through `channel_bridge`.
|
|
It will also handle the viewer sync when `onscreen` is True.
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
env_name: str,
|
|
robot_name: str,
|
|
robot_model: RobotModel,
|
|
env_kwargs: Dict[str, Any],
|
|
**kwargs,
|
|
):
|
|
# initialize environment
|
|
if env_kwargs.get("onscreen", False):
|
|
env_kwargs["onscreen"] = False
|
|
self.onscreen = True # onscreen render in the main thread
|
|
self.render_camera = env_kwargs.get("render_camera", None)
|
|
else:
|
|
self.onscreen = False
|
|
self.env_name = env_name
|
|
self.env = Gr00tLocomanipRoboCasaEnv(env_name, robot_name, robot_model, **env_kwargs)
|
|
self.init_caches()
|
|
self.cache_lock = threading.Lock()
|
|
|
|
# initialize channel
|
|
self.init_channel()
|
|
|
|
# initialize ROS2 node
|
|
if not rclpy.ok():
|
|
rclpy.init()
|
|
self.node = rclpy.create_node("sim_robocasa")
|
|
self.thread = threading.Thread(target=rclpy.spin, args=(self.node,), daemon=True)
|
|
self.thread.start()
|
|
else:
|
|
self.thread = None
|
|
executor = rclpy.get_global_executor()
|
|
self.node = executor.get_nodes()[0] # will only take the first node
|
|
|
|
self.control_freq = env_kwargs.get("control_freq", 1 / 0.02)
|
|
self.sim_freq = kwargs.get("sim_freq", 1 / 0.005)
|
|
self.control_rate = self.node.create_rate(self.control_freq)
|
|
|
|
self.running = False
|
|
self.sim_thread = None
|
|
self.sync_lock = threading.Lock()
|
|
|
|
self.sync_mode = kwargs.get("sync_mode", False)
|
|
self.steps_per_action = kwargs.get("steps_per_action", 1)
|
|
|
|
self.image_dt = kwargs.get("image_dt", 0.04)
|
|
self.image_publish_process = None
|
|
self.viewer_freq = kwargs.get("viewer_freq", 1 / 0.02)
|
|
self.viewer = None
|
|
|
|
self.verbose = kwargs.get("verbose", True)
|
|
|
|
# Initialize keyboard listener for env reset
|
|
self.keyboard_listener = KeyboardListenerSubscriber()
|
|
|
|
self.reset()
|
|
|
|
@property
|
|
def base_env(self):
|
|
return self.env.env
|
|
|
|
def start_image_publish_subprocess(self, start_method: str = "spawn", camera_port: int = 5555):
|
|
"""Initialize image publishing subprocess if cameras are configured"""
|
|
if len(self.env.camera_names) == 0:
|
|
print(
|
|
"Warning: No camera configs provided, image publishing subprocess will not be started"
|
|
)
|
|
return
|
|
|
|
# Build camera configs from env camera settings
|
|
camera_configs = {}
|
|
for env_cam_name in self.env.camera_names:
|
|
camera_config = self.env.camera_key_mapper.get_camera_config(env_cam_name)
|
|
mapped_cam_name, cam_width, cam_height = camera_config
|
|
camera_configs[mapped_cam_name] = {"height": cam_height, "width": cam_width}
|
|
|
|
self.image_publish_process = ImagePublishProcess(
|
|
camera_configs=camera_configs,
|
|
image_dt=self.image_dt,
|
|
zmq_port=camera_port,
|
|
start_method=start_method,
|
|
verbose=self.verbose,
|
|
)
|
|
|
|
self.image_publish_process.start_process()
|
|
|
|
def update_render_caches(self, obs: Dict[str, Any]):
|
|
"""Update render cache and shared memory for subprocess"""
|
|
if self.image_publish_process is None:
|
|
return
|
|
|
|
# Extract image observations from obs dict
|
|
render_caches = {
|
|
k: v for k, v in obs.items() if k.endswith("_image") and isinstance(v, np.ndarray)
|
|
}
|
|
|
|
# Update shared memory if image publishing process is available
|
|
if render_caches:
|
|
self.image_publish_process.update_shared_memory(render_caches)
|
|
|
|
def init_caches(self):
|
|
self.caches = {
|
|
"obs": None,
|
|
"reward": None,
|
|
"terminated": None,
|
|
"truncated": None,
|
|
"info": None,
|
|
}
|
|
|
|
def reset(self, **kwargs):
|
|
if self.viewer is not None:
|
|
self.viewer.close()
|
|
|
|
obs, info = self.env.reset(**kwargs)
|
|
self.caches["obs"] = obs
|
|
self.caches["reward"] = 0
|
|
self.caches["terminated"] = False
|
|
self.caches["truncated"] = False
|
|
self.caches["info"] = info
|
|
|
|
# initialize viewer
|
|
if self.onscreen:
|
|
self.viewer = mujoco.viewer.launch_passive(
|
|
self.base_env.sim.model._model,
|
|
self.base_env.sim.data._data,
|
|
show_left_ui=False,
|
|
show_right_ui=False,
|
|
)
|
|
self.viewer.opt.geomgroup[0] = 0 # disable collision visualization
|
|
if self.render_camera is not None:
|
|
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED
|
|
self.viewer.cam.fixedcamid = self.base_env.sim.model._model.cam(
|
|
self.render_camera
|
|
).id
|
|
|
|
# self.episode_state.reset_state()
|
|
return obs, info
|
|
|
|
@abstractmethod
|
|
def init_channel(self):
|
|
raise NotImplementedError("init_channel must be implemented by the subclass")
|
|
|
|
@abstractmethod
|
|
def publish_obs(self):
|
|
raise NotImplementedError("publish_obs must be implemented by the subclass")
|
|
|
|
@abstractmethod
|
|
def get_action(self) -> Tuple[Dict[str, Any], bool, bool]:
|
|
raise NotImplementedError("get_action must be implemented by the subclass")
|
|
|
|
def start_as_thread(self):
|
|
"""Start the simulation thread"""
|
|
if self.sim_thread is not None and self.sim_thread.is_alive():
|
|
return
|
|
|
|
self.sim_thread = threading.Thread(target=self.start)
|
|
self.sim_thread.daemon = True
|
|
self.sim_thread.start()
|
|
|
|
def set_sync_mode(self, sync_mode: bool, steps_per_action: int = 4):
|
|
"""Set the sync mode of the environment server"""
|
|
with self.sync_lock:
|
|
self.sync_mode = sync_mode
|
|
self.steps_per_action = steps_per_action
|
|
|
|
def _check_keyboard_input(self):
|
|
"""Check for keyboard input and handle state transitions"""
|
|
key = self.keyboard_listener.read_msg()
|
|
if key == "k":
|
|
print("\033[1;32m[Sim env]\033[0m Resetting sim environment")
|
|
self.reset()
|
|
|
|
def start(self):
|
|
"""Function executed by the simulation thread"""
|
|
iter_idx = 0
|
|
steps_per_cur_action = 0
|
|
t_start = time.monotonic()
|
|
|
|
self.running = True
|
|
|
|
while self.running:
|
|
# Check keyboard input for state transitions
|
|
self._check_keyboard_input()
|
|
|
|
# Publish observations and get new action
|
|
self.publish_obs()
|
|
action, ready, is_new_action = self.get_action()
|
|
# ready is True if the action is received from the control loop
|
|
# is_new_action is True if the action is new (not the same as the previous action)
|
|
with self.sync_lock:
|
|
sync_mode = self.sync_mode
|
|
max_steps_per_action = self.steps_per_action
|
|
|
|
# Process action if ready and within step limits
|
|
action_should_apply = ready and (
|
|
(not sync_mode) or steps_per_cur_action < max_steps_per_action
|
|
)
|
|
if action_should_apply:
|
|
obs, reward, terminated, truncated, info = self.env.step(action)
|
|
with self.cache_lock:
|
|
self.caches["obs"] = obs
|
|
self.caches["reward"] = reward
|
|
self.caches["terminated"] = terminated
|
|
self.caches["truncated"] = truncated
|
|
self.caches["info"] = info
|
|
|
|
if reward == 1.0 and iter_idx % 50 == 0:
|
|
print("\033[92mTask successful. Can save data now.\033[0m")
|
|
|
|
iter_idx += 1
|
|
steps_per_cur_action += 1
|
|
if self.verbose and sync_mode:
|
|
print("steps_per_cur_action: ", steps_per_cur_action)
|
|
|
|
# Update render caches at image publishing rate
|
|
if action_should_apply and iter_idx % int(self.image_dt * self.control_freq) == 0:
|
|
with self.cache_lock:
|
|
obs_copy = self.caches["obs"].copy()
|
|
self.update_render_caches(obs_copy)
|
|
|
|
# Reset step counter for new actions
|
|
if is_new_action:
|
|
steps_per_cur_action = 0
|
|
|
|
# Update viewer at specified frequency
|
|
if self.onscreen and iter_idx % (self.control_freq / self.viewer_freq) == 0:
|
|
self.viewer.sync()
|
|
|
|
# Check if we're meeting the desired control frequency
|
|
if iter_idx % 100 == 0:
|
|
end_time = time.monotonic()
|
|
if self.verbose:
|
|
print(
|
|
f"sim FPS: {100.0 / (end_time - t_start) * (self.sim_freq / self.control_freq)}"
|
|
)
|
|
if (end_time - t_start) > ((110.0 / self.control_freq)): # for tolerance
|
|
print(
|
|
f"Warning: Sim runs at "
|
|
"{100.0/(end_time - t_start) * (self.sim_freq / self.control_freq):.1f}Hz, "
|
|
f"but should run at {self.sim_freq:.1f}Hz"
|
|
)
|
|
t_start = end_time
|
|
|
|
# reset obj pos every 200 steps
|
|
if iter_idx % 200 == 0:
|
|
if hasattr(self.base_env, "reset_obj_pos"):
|
|
self.base_env.reset_obj_pos()
|
|
|
|
self.control_rate.sleep()
|
|
|
|
def get_privileged_obs(self):
|
|
"""Get privileged observation. Should be implemented by subclasses."""
|
|
obs = {}
|
|
with self.cache_lock:
|
|
if hasattr(self.base_env, "get_privileged_obs_keys"):
|
|
for key in self.base_env.get_privileged_obs_keys():
|
|
obs[key] = self.caches["obs"][key]
|
|
|
|
for key in self.caches["obs"].keys():
|
|
if key.endswith("_image"):
|
|
obs[key] = self.caches["obs"][key]
|
|
|
|
return obs
|
|
|
|
def stop(self):
|
|
"""Stop the simulation thread"""
|
|
self.running = False
|
|
if self.sim_thread is not None:
|
|
self.sim_thread.join(timeout=1.0) # Wait for thread to finish with timeout
|
|
self.sim_thread = None
|
|
|
|
def close(self):
|
|
self.stop()
|
|
if self.image_publish_process is not None:
|
|
self.image_publish_process.stop()
|
|
if self.onscreen:
|
|
self.viewer.close()
|
|
self.env.close()
|
|
|
|
def get_reward(self):
|
|
return self.base_env.reward()
|