You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
236 lines
9.1 KiB
236 lines
9.1 KiB
from copy import deepcopy
|
|
import time
|
|
|
|
import tyro
|
|
|
|
from decoupled_wbc.control.envs.g1.g1_env import G1Env
|
|
from decoupled_wbc.control.main.constants import (
|
|
CONTROL_GOAL_TOPIC,
|
|
DEFAULT_BASE_HEIGHT,
|
|
DEFAULT_NAV_CMD,
|
|
DEFAULT_WRIST_POSE,
|
|
JOINT_SAFETY_STATUS_TOPIC,
|
|
LOWER_BODY_POLICY_STATUS_TOPIC,
|
|
ROBOT_CONFIG_TOPIC,
|
|
STATE_TOPIC_NAME,
|
|
)
|
|
from decoupled_wbc.control.main.teleop.configs.configs import ControlLoopConfig
|
|
from decoupled_wbc.control.policy.wbc_policy_factory import get_wbc_policy
|
|
from decoupled_wbc.control.robot_model.instantiation.g1 import (
|
|
instantiate_g1_robot_model,
|
|
)
|
|
from decoupled_wbc.control.utils.keyboard_dispatcher import (
|
|
KeyboardDispatcher,
|
|
KeyboardEStop,
|
|
KeyboardListenerPublisher,
|
|
ROSKeyboardDispatcher,
|
|
)
|
|
from decoupled_wbc.control.utils.ros_utils import (
|
|
ROSManager,
|
|
ROSMsgPublisher,
|
|
ROSMsgSubscriber,
|
|
ROSServiceServer,
|
|
)
|
|
from decoupled_wbc.control.utils.telemetry import Telemetry
|
|
|
|
CONTROL_NODE_NAME = "ControlPolicy"
|
|
|
|
|
|
def main(config: ControlLoopConfig):
|
|
ros_manager = ROSManager(node_name=CONTROL_NODE_NAME)
|
|
node = ros_manager.node
|
|
|
|
# start the robot config server
|
|
ROSServiceServer(ROBOT_CONFIG_TOPIC, config.to_dict())
|
|
|
|
wbc_config = config.load_wbc_yaml()
|
|
|
|
data_exp_pub = ROSMsgPublisher(STATE_TOPIC_NAME)
|
|
lower_body_policy_status_pub = ROSMsgPublisher(LOWER_BODY_POLICY_STATUS_TOPIC)
|
|
joint_safety_status_pub = ROSMsgPublisher(JOINT_SAFETY_STATUS_TOPIC)
|
|
|
|
# Initialize telemetry
|
|
telemetry = Telemetry(window_size=100)
|
|
|
|
waist_location = "lower_and_upper_body" if config.enable_waist else "lower_body"
|
|
robot_model = instantiate_g1_robot_model(
|
|
waist_location=waist_location, high_elbow_pose=config.high_elbow_pose
|
|
)
|
|
|
|
env = G1Env(
|
|
env_name=config.env_name,
|
|
robot_model=robot_model,
|
|
config=wbc_config,
|
|
wbc_version=config.wbc_version,
|
|
)
|
|
if env.sim and not config.sim_sync_mode:
|
|
env.start_simulator()
|
|
|
|
wbc_policy = get_wbc_policy("g1", robot_model, wbc_config, config.upper_body_joint_speed)
|
|
|
|
keyboard_listener_pub = KeyboardListenerPublisher()
|
|
keyboard_estop = KeyboardEStop()
|
|
if config.keyboard_dispatcher_type == "raw":
|
|
dispatcher = KeyboardDispatcher()
|
|
elif config.keyboard_dispatcher_type == "ros":
|
|
dispatcher = ROSKeyboardDispatcher()
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid keyboard dispatcher: {config.keyboard_dispatcher_type}, please use 'raw' or 'ros'"
|
|
)
|
|
dispatcher.register(env)
|
|
dispatcher.register(wbc_policy)
|
|
dispatcher.register(keyboard_listener_pub)
|
|
dispatcher.register(keyboard_estop)
|
|
dispatcher.start()
|
|
|
|
rate = node.create_rate(config.control_frequency)
|
|
|
|
upper_body_policy_subscriber = ROSMsgSubscriber(CONTROL_GOAL_TOPIC)
|
|
|
|
last_teleop_cmd = None
|
|
try:
|
|
while ros_manager.ok():
|
|
t_start = time.monotonic()
|
|
with telemetry.timer("total_loop"):
|
|
# Step simulator if in sync mode
|
|
with telemetry.timer("step_simulator"):
|
|
if env.sim and config.sim_sync_mode:
|
|
env.step_simulator()
|
|
|
|
# Measure observation time
|
|
with telemetry.timer("observe"):
|
|
obs = env.observe()
|
|
wbc_policy.set_observation(obs)
|
|
|
|
# Measure policy setup time
|
|
with telemetry.timer("policy_setup"):
|
|
upper_body_cmd = upper_body_policy_subscriber.get_msg()
|
|
|
|
t_now = time.monotonic()
|
|
|
|
wbc_goal = {}
|
|
if upper_body_cmd:
|
|
wbc_goal = upper_body_cmd.copy()
|
|
last_teleop_cmd = upper_body_cmd.copy()
|
|
if config.ik_indicator:
|
|
env.set_ik_indicator(upper_body_cmd)
|
|
# Send goal to policy
|
|
if wbc_goal:
|
|
wbc_goal["interpolation_garbage_collection_time"] = t_now - 2 * (
|
|
1 / config.control_frequency
|
|
)
|
|
wbc_policy.set_goal(wbc_goal)
|
|
|
|
# Measure policy action calculation time
|
|
with telemetry.timer("policy_action"):
|
|
wbc_action = wbc_policy.get_action(time=t_now)
|
|
|
|
# Measure action queue time
|
|
with telemetry.timer("queue_action"):
|
|
env.queue_action(wbc_action)
|
|
|
|
# Publish status information for InteractiveModeController
|
|
with telemetry.timer("publish_status"):
|
|
# Get policy status - check if the lower body policy has use_policy_action enabled
|
|
policy_use_action = False
|
|
try:
|
|
# Access the lower body policy through the decoupled whole body policy
|
|
if hasattr(wbc_policy, "lower_body_policy"):
|
|
policy_use_action = getattr(
|
|
wbc_policy.lower_body_policy, "use_policy_action", False
|
|
)
|
|
except (AttributeError, TypeError):
|
|
policy_use_action = False
|
|
|
|
policy_status_msg = {"use_policy_action": policy_use_action, "timestamp": t_now}
|
|
lower_body_policy_status_pub.publish(policy_status_msg)
|
|
|
|
# Get joint safety status from G1Env (which already runs the safety monitor)
|
|
joint_safety_ok = env.get_joint_safety_status()
|
|
|
|
joint_safety_status_msg = {
|
|
"joint_safety_ok": joint_safety_ok,
|
|
"timestamp": t_now,
|
|
}
|
|
joint_safety_status_pub.publish(joint_safety_status_msg)
|
|
|
|
# Start or Stop data collection
|
|
if wbc_goal.get("toggle_data_collection", False):
|
|
dispatcher.handle_key("c")
|
|
|
|
# Abort the current episode
|
|
if wbc_goal.get("toggle_data_abort", False):
|
|
dispatcher.handle_key("x")
|
|
|
|
if env.use_sim and wbc_goal.get("reset_env_and_policy", False):
|
|
print("Resetting sim environment and policy")
|
|
# Reset teleop policy & sim env
|
|
dispatcher.handle_key("k")
|
|
|
|
# Clear upper body commands
|
|
upper_body_policy_subscriber._msg = None
|
|
upper_body_cmd = {
|
|
"target_upper_body_pose": obs["q"][
|
|
robot_model.get_joint_group_indices("upper_body")
|
|
],
|
|
"wrist_pose": DEFAULT_WRIST_POSE,
|
|
"base_height_command": DEFAULT_BASE_HEIGHT,
|
|
"navigate_cmd": DEFAULT_NAV_CMD,
|
|
}
|
|
last_teleop_cmd = upper_body_cmd.copy()
|
|
|
|
time.sleep(0.5)
|
|
|
|
msg = deepcopy(obs)
|
|
for key in obs.keys():
|
|
if key.endswith("_image"):
|
|
del msg[key]
|
|
|
|
# exporting data
|
|
if last_teleop_cmd:
|
|
msg.update(
|
|
{
|
|
"action": wbc_action["q"],
|
|
"action.eef": last_teleop_cmd.get("wrist_pose", DEFAULT_WRIST_POSE),
|
|
"base_height_command": last_teleop_cmd.get(
|
|
"base_height_command", DEFAULT_BASE_HEIGHT
|
|
),
|
|
"navigate_command": last_teleop_cmd.get(
|
|
"navigate_cmd", DEFAULT_NAV_CMD
|
|
),
|
|
"timestamps": {
|
|
"main_loop": time.time(),
|
|
"proprio": time.time(),
|
|
},
|
|
}
|
|
)
|
|
data_exp_pub.publish(msg)
|
|
end_time = time.monotonic()
|
|
|
|
if env.sim and (not env.sim.sim_thread or not env.sim.sim_thread.is_alive()):
|
|
raise RuntimeError("Simulator thread is not alive")
|
|
|
|
rate.sleep()
|
|
|
|
# Log timing information every 100 iterations (roughly every 2 seconds at 50Hz)
|
|
if config.verbose_timing:
|
|
# When verbose timing is enabled, always show timing
|
|
telemetry.log_timing_info(context="G1 Control Loop", threshold=0.0)
|
|
elif (end_time - t_start) > (1 / config.control_frequency) and not config.sim_sync_mode:
|
|
# Only show timing when loop is slow and verbose_timing is disabled
|
|
telemetry.log_timing_info(context="G1 Control Loop Missed", threshold=0.001)
|
|
|
|
except ros_manager.exceptions() as e:
|
|
print(f"ROSManager interrupted by user: {e}")
|
|
finally:
|
|
print("Cleaning up...")
|
|
# the order of the following is important
|
|
dispatcher.stop()
|
|
ros_manager.shutdown()
|
|
env.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
config = tyro.cli(ControlLoopConfig)
|
|
main(config)
|