You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

469 lines
16 KiB

import argparse
import os
from pathlib import Path
import signal
import subprocess
import threading
import time
import numpy as np
import pytest
import rclpy
from scipy.spatial.transform import Rotation as R
from std_msgs.msg import String as RosStringMsg
from gr00t_wbc.control.main.constants import (
CONTROL_GOAL_TOPIC,
KEYBOARD_INPUT_TOPIC,
STATE_TOPIC_NAME,
)
from gr00t_wbc.control.utils.ros_utils import ROSMsgPublisher, ROSMsgSubscriber
from gr00t_wbc.control.utils.term_color_constants import GREEN_BOLD, RESET, YELLOW_BOLD
from gr00t_wbc.data.viz.rerun_viz import RerunViz
class KeyboardPublisher:
def __init__(self, topic_name: str = KEYBOARD_INPUT_TOPIC):
assert rclpy.ok(), "Expected ROS2 to be initialized in this process..."
executor = rclpy.get_global_executor()
self.node = executor.get_nodes()[0]
self.publisher = self.node.create_publisher(RosStringMsg, topic_name, 1)
def publish(self, key: str):
msg = RosStringMsg()
msg.data = key
self.publisher.publish(msg)
def is_robot_fallen_from_quat(mujoco_quat):
# Convert MuJoCo [w, x, y, z] → SciPy [x, y, z, w]
w, x, y, z = mujoco_quat
scipy_quat = [x, y, z, w]
r = R.from_quat(scipy_quat)
roll, pitch, _ = r.as_euler("xyz", degrees=False)
MAX_ROLL_PITCH = np.radians(60)
print(f"[Fall Check] roll={roll:.3f} rad, pitch={pitch:.3f} rad")
return abs(roll) > MAX_ROLL_PITCH or abs(pitch) > MAX_ROLL_PITCH
class LocomotionRunner:
def __init__(self, test_mode: str = "squat"):
self.test_mode = test_mode
if not rclpy.ok():
rclpy.init(args=None)
self.node = rclpy.create_node(f"EvalDriver_{test_mode}_{int(time.time())}")
# gracefully shutdown the spin thread when the test is done
self._stop_event = threading.Event()
self.spin_thread = threading.Thread(target=self._spin_loop, daemon=False)
self.spin_thread.start()
self.keyboard_event_publisher = KeyboardPublisher(KEYBOARD_INPUT_TOPIC)
self.control_publisher = ROSMsgPublisher(CONTROL_GOAL_TOPIC)
self.state_subscriber = ROSMsgSubscriber(STATE_TOPIC_NAME)
print(f"{test_mode} test initialized...")
def _spin_loop(self):
try:
while rclpy.ok() and not self._stop_event.is_set():
rclpy.spin_once(self.node)
except rclpy.executors.ExternalShutdownException:
print("[INFO] Spin thread exiting due to shutdown.")
finally:
print("spin loop stopped...")
def warm_up(self):
"""Stabilize and release the robot."""
print("waiting for 2 seconds...")
time.sleep(2)
print(f"running {self.test_mode} test...")
self.activate()
print("activated...")
time.sleep(1)
self.release()
print("released...")
time.sleep(5)
def _run_walk_test(self):
self.walk_forward() # speed up to 0.2 m/s
time.sleep(1)
self.walk_forward() # speed up to 0.4 m/s
rate = self.node.create_rate(0.5)
start_time = time.time()
while rclpy.ok() and (time.time() - start_time) < 10.0:
obs = self.state_subscriber.get_msg()
if is_robot_fallen_from_quat(obs["torso_quat"]):
print("robot fallen...")
return 0
elif self._check_success_condition(obs):
print(f"robot reaching target ({self.test_mode})...")
return 1, {}
else:
rate.sleep()
print("test timed out after 10 seconds...")
return 0, {}
def _run_squat_test(self):
rate = self.node.create_rate(0.5)
start_time = time.time()
while rclpy.ok() and (time.time() - start_time) < 10.0:
obs = self.state_subscriber.get_msg()
if is_robot_fallen_from_quat(obs["torso_quat"]):
print("robot fallen...")
return 0, {}
elif self._check_success_condition(obs):
print(f"robot reaching target ({self.test_mode})...")
return 1, {}
else:
self.go_down()
rate.sleep()
print("test timed out after 10 seconds...")
return 0, {}
def cmd_to_velocity(self, cmd_list):
cmd_to_velocity = {
"w": np.array([0.2, 0.0, 0.0]),
"s": np.array([-0.2, 0.0, 0.0]),
"q": np.array([0.0, 0.2, 0.0]),
"e": np.array([0.0, -0.2, 0.0]),
"z": np.array([0.0, 0.0, 0.0]),
}
accumulated_velocity = np.array([0.0, 0.0, 0.0])
velocity_list = []
for cmd in cmd_list:
if cmd == "z":
accumulated_velocity = [0.0, 0.0, 0.0]
elif cmd in ["CHECK", "SKIP"]:
accumulated_velocity = velocity_list[-1]
else:
accumulated_velocity += cmd_to_velocity[cmd]
velocity_list.append(accumulated_velocity.copy())
return velocity_list
def _run_stop_test(self):
base_vel_thres = 0.25
cmd_list = (
["w", "w", "w", "w", "s", "s", "s", "z", "SKIP", "CHECK"]
+ ["s", "s", "q", "w", "w", "w", "e", "s", "s", "z", "SKIP", "CHECK"]
+ ["q", "q", "w", "q", "e", "s", "s", "e", "w", "z", "SKIP", "CHECK"]
+ ["w", "w", "w", "w", "w", "s", "s", "s", "s", "z", "SKIP", "CHECK"]
)
success_flag = 1
statistics = {
"floating_base_pose": {"state": []},
"floating_base_vel": {"state": [], "cmd": []},
"timestamp": [],
}
for cmd in cmd_list:
self.keyboard_event_publisher.publish(cmd)
time.sleep(0.5)
obs = self.state_subscriber.get_msg()
statistics["floating_base_pose"]["state"].append(
np.linalg.norm(obs["floating_base_pose"])
)
statistics["floating_base_vel"]["state"].append(
np.linalg.norm(obs["floating_base_vel"])
)
statistics["timestamp"].append(time.time())
if cmd == "CHECK" and np.linalg.norm(obs["floating_base_vel"]) > base_vel_thres:
print(
f" [{YELLOW_BOLD}WARNING{RESET}] robot is not stopped fully. "
f"Current base velocity: {np.linalg.norm(obs['floating_base_vel']):.3f} > {base_vel_thres:.3f}"
)
# success_flag = 0 # robot is not stopped
time.sleep(0.5)
vel_cmd = self.cmd_to_velocity(cmd_list)
vel_cmd = [np.linalg.norm(v) for v in vel_cmd]
statistics["floating_base_vel"]["cmd"] = vel_cmd
return success_flag, statistics
def _run_eef_track_test(self):
from gr00t_wbc.control.policy.lerobot_replay_policy import LerobotReplayPolicy
parquet_path = (
Path(__file__).parent.parent.parent.parent / "replay_data" / "g1_pnpbottle.parquet"
)
replay_policy = LerobotReplayPolicy(parquet_path=str(parquet_path))
freq = 50
rate = self.node.create_rate(freq)
statistics = {
# "floating_base_pose": {"state": [], "cmd": []},
"eef_base_pose": {"state": [], "cmd": []},
"timestamp": [],
}
for ii in range(500):
action = replay_policy.get_action()
action = replay_policy.action_to_cmd(action)
action["timestamp"] = time.monotonic()
action["target_time"] = time.monotonic() + ii / freq
self.control_publisher.publish(action)
obs = self.state_subscriber.get_msg()
if obs is None:
print("no obs...")
continue
gt_obs = replay_policy.get_observation()
# statistics["floating_base_pose"]["state"].append(obs["floating_base_pose"])
# statistics["floating_base_pose"]["cmd"].append(np.zeros_like(obs["floating_base_pose"]))
statistics["eef_base_pose"]["state"].append(obs["wrist_pose"])
statistics["eef_base_pose"]["cmd"].append(gt_obs["wrist_pose"])
statistics["timestamp"].append(time.time())
pos_err = np.linalg.norm(obs["wrist_pose"][:3] - gt_obs["wrist_pose"][:3])
if pos_err > 1e-1:
print(
f" [{YELLOW_BOLD}WARNING{RESET}] robot failed to track the eef, "
f"error: {pos_err:.3f} ({self.test_mode})..."
)
return 0, statistics
if is_robot_fallen_from_quat(obs["torso_quat"]):
print("robot fallen...")
return 0, statistics
else:
rate.sleep()
return 1, statistics
def run(self):
self.warm_up()
test_mode_to_func = {
"squat": self._run_squat_test,
"walk": self._run_walk_test,
"stop": self._run_stop_test,
"eef_track": self._run_eef_track_test,
}
result, statistics = test_mode_to_func[self.test_mode]()
self.post_process(statistics)
return result
def _check_success_condition(self, obs):
if self.test_mode == "squat":
return obs["floating_base_pose"][2] < 0.4
elif self.test_mode == "walk":
return np.linalg.norm(obs["floating_base_pose"][0:2]) > 1.0
return False
def activate(self):
self.keyboard_event_publisher.publish("]")
def release(self):
self.keyboard_event_publisher.publish("9")
def go_down(self):
self.keyboard_event_publisher.publish("2")
def walk_forward(self):
self.keyboard_event_publisher.publish("w")
def walk_stop(self):
self.keyboard_event_publisher.publish("z")
def post_process(self, statistics):
if len(statistics) == 0:
return
# plot the statistics
plot_keys = [key for key in statistics.keys() if key != "timestamp"]
viz = RerunViz(
image_keys=[],
tensor_keys=plot_keys,
window_size=10.0,
app_name=f"{self.test_mode}_test",
)
for ii in range(len(statistics[plot_keys[0]]["state"])):
tensor_data = {}
for k in plot_keys:
if "state" in statistics[k] and "cmd" in statistics[k]:
tensor_data[k] = np.array(
(statistics[k]["state"][ii], statistics[k]["cmd"][ii])
).reshape(2, -1)
else:
tensor_data[k] = np.asarray(statistics[k]["state"][ii]).reshape(1, -1)
viz.plot_tensors(
tensor_data,
statistics["timestamp"][ii],
)
if self.test_mode == "stop":
base_velocity = statistics["floating_base_vel"]["state"]
base_velocity_cmd = statistics["floating_base_vel"]["cmd"]
base_velocity_tracking_err = []
for v_cmd, v in zip(base_velocity_cmd, base_velocity): # TODO: check if this is correct
if v_cmd.max() < 1e-4:
base_velocity_tracking_err.append(v)
print(
f" [{GREEN_BOLD}INFO{RESET}] Base velocity tracking when stopped: "
f"{np.mean(base_velocity_tracking_err):.3f}"
)
if self.test_mode == "eef_track":
eef_pose = statistics["eef_base_pose"]["state"]
eef_pose_cmd = statistics["eef_base_pose"]["cmd"]
eef_pose_tracking_err = []
for p_cmd, p in zip(eef_pose_cmd, eef_pose):
eef_pose_tracking_err.append(np.linalg.norm(p - p_cmd))
print(
f" [{GREEN_BOLD}INFO{RESET}] Eef pose tracking error: {np.mean(eef_pose_tracking_err):.3f}"
)
def shutdown(self):
self._stop_event.set()
self.spin_thread.join()
del self.state_subscriber
del self.keyboard_event_publisher
# Don't shutdown ROS between tests - let pytest handle it
def start_g1_control_loop():
proc = subprocess.Popen(
[
"python3",
"gr00t_wbc/control/main/teleop/run_g1_control_loop.py",
"--keyboard_dispatcher_type",
"ros",
"--enable-offscreen",
],
preexec_fn=os.setsid,
)
time.sleep(10)
return proc
def run_test(test_mode: str):
"""Run a single test with the specified mode."""
proc = start_g1_control_loop()
print(f"G1 control loop started for {test_mode} test...")
test = LocomotionRunner(test_mode)
result = test.run()
print("Shutting down...")
test.shutdown()
proc.send_signal(signal.SIGKILL)
proc.wait()
return result
def test_squat():
"""Pytest function for squat test."""
result = run_test("squat")
assert result == 1, "Squat test failed - robot either fell or didn't reach target height"
def test_walk():
"""Pytest function for walk test."""
result = run_test("walk")
assert result == 1, "Walk test failed - robot either fell or didn't reach target distance"
@pytest.mark.skip(reason="skipping test for now, cicd test always gets killed")
def test_stop():
"""Pytest function for walking to a nearby position and stop test."""
result = run_test("stop")
assert result == 1, "Stop test failed - robot either fell or didn't reach target distance"
@pytest.mark.skip(reason="skipping test for now, cicd test always gets killed")
def test_eef_track():
"""Pytest function for eef track test."""
result = run_test("eef_track")
assert result == 1, "Eef track test failed - robot either fell or didn't reach target distance"
def main():
parser = argparse.ArgumentParser(description="Run locomotion tests")
parser.add_argument("--squat", action="store_true", help="Run squat test only")
parser.add_argument("--walk", action="store_true", help="Run walk test only")
parser.add_argument("--stop", action="store_true", help="Run stop test only")
parser.add_argument("--eef_track", action="store_true", help="Run eef track test only")
args = parser.parse_args()
if args.squat and args.walk:
print("Error: Cannot specify both --squat and --walk")
return 1
if args.squat:
print("Running squat test only...")
result = run_test("squat")
if result == 1:
print("✓ Squat test PASSED")
return 0
else:
print("✗ Squat test FAILED")
return 1
elif args.walk:
print("Running walk test only...")
result = run_test("walk")
if result == 1:
print("✓ Walk test PASSED")
return 0
else:
print("✗ Walk test FAILED")
return 1
elif args.stop:
print("Running stop test only...")
result = run_test("stop")
if result == 1:
print("✓ Stop test PASSED")
return 0
else:
print("✗ Stop test FAILED")
return 1
elif args.eef_track:
print("Running eef track test only...")
result = run_test("eef_track")
if result == 1:
print("✓ Eef track test PASSED")
return 0
else:
print("✗ Eef track test FAILED")
return 1
else:
print("Running both tests...")
squat_result = run_test("squat")
walk_result = run_test("walk")
if squat_result == 1 and walk_result == 1:
print("✓ All tests PASSED")
return 0
else:
print(
f"✗ Test results: squat={'PASSED' if squat_result == 1 else 'FAILED'}, "
f"walk={'PASSED' if walk_result == 1 else 'FAILED'}"
)
return 1
if __name__ == "__main__":
exit(main())