You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

236 lines
9.1 KiB

from copy import deepcopy
import time
import tyro
from decoupled_wbc.control.envs.g1.g1_env import G1Env
from decoupled_wbc.control.main.constants import (
CONTROL_GOAL_TOPIC,
DEFAULT_BASE_HEIGHT,
DEFAULT_NAV_CMD,
DEFAULT_WRIST_POSE,
JOINT_SAFETY_STATUS_TOPIC,
LOWER_BODY_POLICY_STATUS_TOPIC,
ROBOT_CONFIG_TOPIC,
STATE_TOPIC_NAME,
)
from decoupled_wbc.control.main.teleop.configs.configs import ControlLoopConfig
from decoupled_wbc.control.policy.wbc_policy_factory import get_wbc_policy
from decoupled_wbc.control.robot_model.instantiation.g1 import (
instantiate_g1_robot_model,
)
from decoupled_wbc.control.utils.keyboard_dispatcher import (
KeyboardDispatcher,
KeyboardEStop,
KeyboardListenerPublisher,
ROSKeyboardDispatcher,
)
from decoupled_wbc.control.utils.ros_utils import (
ROSManager,
ROSMsgPublisher,
ROSMsgSubscriber,
ROSServiceServer,
)
from decoupled_wbc.control.utils.telemetry import Telemetry
CONTROL_NODE_NAME = "ControlPolicy"
def main(config: ControlLoopConfig):
ros_manager = ROSManager(node_name=CONTROL_NODE_NAME)
node = ros_manager.node
# start the robot config server
ROSServiceServer(ROBOT_CONFIG_TOPIC, config.to_dict())
wbc_config = config.load_wbc_yaml()
data_exp_pub = ROSMsgPublisher(STATE_TOPIC_NAME)
lower_body_policy_status_pub = ROSMsgPublisher(LOWER_BODY_POLICY_STATUS_TOPIC)
joint_safety_status_pub = ROSMsgPublisher(JOINT_SAFETY_STATUS_TOPIC)
# Initialize telemetry
telemetry = Telemetry(window_size=100)
waist_location = "lower_and_upper_body" if config.enable_waist else "lower_body"
robot_model = instantiate_g1_robot_model(
waist_location=waist_location, high_elbow_pose=config.high_elbow_pose
)
env = G1Env(
env_name=config.env_name,
robot_model=robot_model,
config=wbc_config,
wbc_version=config.wbc_version,
)
if env.sim and not config.sim_sync_mode:
env.start_simulator()
wbc_policy = get_wbc_policy("g1", robot_model, wbc_config, config.upper_body_joint_speed)
keyboard_listener_pub = KeyboardListenerPublisher()
keyboard_estop = KeyboardEStop()
if config.keyboard_dispatcher_type == "raw":
dispatcher = KeyboardDispatcher()
elif config.keyboard_dispatcher_type == "ros":
dispatcher = ROSKeyboardDispatcher()
else:
raise ValueError(
f"Invalid keyboard dispatcher: {config.keyboard_dispatcher_type}, please use 'raw' or 'ros'"
)
dispatcher.register(env)
dispatcher.register(wbc_policy)
dispatcher.register(keyboard_listener_pub)
dispatcher.register(keyboard_estop)
dispatcher.start()
rate = node.create_rate(config.control_frequency)
upper_body_policy_subscriber = ROSMsgSubscriber(CONTROL_GOAL_TOPIC)
last_teleop_cmd = None
try:
while ros_manager.ok():
t_start = time.monotonic()
with telemetry.timer("total_loop"):
# Step simulator if in sync mode
with telemetry.timer("step_simulator"):
if env.sim and config.sim_sync_mode:
env.step_simulator()
# Measure observation time
with telemetry.timer("observe"):
obs = env.observe()
wbc_policy.set_observation(obs)
# Measure policy setup time
with telemetry.timer("policy_setup"):
upper_body_cmd = upper_body_policy_subscriber.get_msg()
t_now = time.monotonic()
wbc_goal = {}
if upper_body_cmd:
wbc_goal = upper_body_cmd.copy()
last_teleop_cmd = upper_body_cmd.copy()
if config.ik_indicator:
env.set_ik_indicator(upper_body_cmd)
# Send goal to policy
if wbc_goal:
wbc_goal["interpolation_garbage_collection_time"] = t_now - 2 * (
1 / config.control_frequency
)
wbc_policy.set_goal(wbc_goal)
# Measure policy action calculation time
with telemetry.timer("policy_action"):
wbc_action = wbc_policy.get_action(time=t_now)
# Measure action queue time
with telemetry.timer("queue_action"):
env.queue_action(wbc_action)
# Publish status information for InteractiveModeController
with telemetry.timer("publish_status"):
# Get policy status - check if the lower body policy has use_policy_action enabled
policy_use_action = False
try:
# Access the lower body policy through the decoupled whole body policy
if hasattr(wbc_policy, "lower_body_policy"):
policy_use_action = getattr(
wbc_policy.lower_body_policy, "use_policy_action", False
)
except (AttributeError, TypeError):
policy_use_action = False
policy_status_msg = {"use_policy_action": policy_use_action, "timestamp": t_now}
lower_body_policy_status_pub.publish(policy_status_msg)
# Get joint safety status from G1Env (which already runs the safety monitor)
joint_safety_ok = env.get_joint_safety_status()
joint_safety_status_msg = {
"joint_safety_ok": joint_safety_ok,
"timestamp": t_now,
}
joint_safety_status_pub.publish(joint_safety_status_msg)
# Start or Stop data collection
if wbc_goal.get("toggle_data_collection", False):
dispatcher.handle_key("c")
# Abort the current episode
if wbc_goal.get("toggle_data_abort", False):
dispatcher.handle_key("x")
if env.use_sim and wbc_goal.get("reset_env_and_policy", False):
print("Resetting sim environment and policy")
# Reset teleop policy & sim env
dispatcher.handle_key("k")
# Clear upper body commands
upper_body_policy_subscriber._msg = None
upper_body_cmd = {
"target_upper_body_pose": obs["q"][
robot_model.get_joint_group_indices("upper_body")
],
"wrist_pose": DEFAULT_WRIST_POSE,
"base_height_command": DEFAULT_BASE_HEIGHT,
"navigate_cmd": DEFAULT_NAV_CMD,
}
last_teleop_cmd = upper_body_cmd.copy()
time.sleep(0.5)
msg = deepcopy(obs)
for key in obs.keys():
if key.endswith("_image"):
del msg[key]
# exporting data
if last_teleop_cmd:
msg.update(
{
"action": wbc_action["q"],
"action.eef": last_teleop_cmd.get("wrist_pose", DEFAULT_WRIST_POSE),
"base_height_command": last_teleop_cmd.get(
"base_height_command", DEFAULT_BASE_HEIGHT
),
"navigate_command": last_teleop_cmd.get(
"navigate_cmd", DEFAULT_NAV_CMD
),
"timestamps": {
"main_loop": time.time(),
"proprio": time.time(),
},
}
)
data_exp_pub.publish(msg)
end_time = time.monotonic()
if env.sim and (not env.sim.sim_thread or not env.sim.sim_thread.is_alive()):
raise RuntimeError("Simulator thread is not alive")
rate.sleep()
# Log timing information every 100 iterations (roughly every 2 seconds at 50Hz)
if config.verbose_timing:
# When verbose timing is enabled, always show timing
telemetry.log_timing_info(context="G1 Control Loop", threshold=0.0)
elif (end_time - t_start) > (1 / config.control_frequency) and not config.sim_sync_mode:
# Only show timing when loop is slow and verbose_timing is disabled
telemetry.log_timing_info(context="G1 Control Loop Missed", threshold=0.001)
except ros_manager.exceptions() as e:
print(f"ROSManager interrupted by user: {e}")
finally:
print("Cleaning up...")
# the order of the following is important
dispatcher.stop()
ros_manager.shutdown()
env.close()
if __name__ == "__main__":
config = tyro.cli(ControlLoopConfig)
main(config)