You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
469 lines
16 KiB
469 lines
16 KiB
import argparse
|
|
import os
|
|
from pathlib import Path
|
|
import signal
|
|
import subprocess
|
|
import threading
|
|
import time
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import rclpy
|
|
from scipy.spatial.transform import Rotation as R
|
|
from std_msgs.msg import String as RosStringMsg
|
|
|
|
from gr00t_wbc.control.main.constants import (
|
|
CONTROL_GOAL_TOPIC,
|
|
KEYBOARD_INPUT_TOPIC,
|
|
STATE_TOPIC_NAME,
|
|
)
|
|
from gr00t_wbc.control.utils.ros_utils import ROSMsgPublisher, ROSMsgSubscriber
|
|
from gr00t_wbc.control.utils.term_color_constants import GREEN_BOLD, RESET, YELLOW_BOLD
|
|
from gr00t_wbc.data.viz.rerun_viz import RerunViz
|
|
|
|
|
|
class KeyboardPublisher:
|
|
def __init__(self, topic_name: str = KEYBOARD_INPUT_TOPIC):
|
|
assert rclpy.ok(), "Expected ROS2 to be initialized in this process..."
|
|
executor = rclpy.get_global_executor()
|
|
self.node = executor.get_nodes()[0]
|
|
self.publisher = self.node.create_publisher(RosStringMsg, topic_name, 1)
|
|
|
|
def publish(self, key: str):
|
|
msg = RosStringMsg()
|
|
msg.data = key
|
|
self.publisher.publish(msg)
|
|
|
|
|
|
def is_robot_fallen_from_quat(mujoco_quat):
|
|
# Convert MuJoCo [w, x, y, z] → SciPy [x, y, z, w]
|
|
w, x, y, z = mujoco_quat
|
|
scipy_quat = [x, y, z, w]
|
|
|
|
r = R.from_quat(scipy_quat)
|
|
roll, pitch, _ = r.as_euler("xyz", degrees=False)
|
|
|
|
MAX_ROLL_PITCH = np.radians(60)
|
|
print(f"[Fall Check] roll={roll:.3f} rad, pitch={pitch:.3f} rad")
|
|
return abs(roll) > MAX_ROLL_PITCH or abs(pitch) > MAX_ROLL_PITCH
|
|
|
|
|
|
class LocomotionRunner:
|
|
def __init__(self, test_mode: str = "squat"):
|
|
self.test_mode = test_mode
|
|
if not rclpy.ok():
|
|
rclpy.init(args=None)
|
|
self.node = rclpy.create_node(f"EvalDriver_{test_mode}_{int(time.time())}")
|
|
|
|
# gracefully shutdown the spin thread when the test is done
|
|
self._stop_event = threading.Event()
|
|
|
|
self.spin_thread = threading.Thread(target=self._spin_loop, daemon=False)
|
|
self.spin_thread.start()
|
|
|
|
self.keyboard_event_publisher = KeyboardPublisher(KEYBOARD_INPUT_TOPIC)
|
|
self.control_publisher = ROSMsgPublisher(CONTROL_GOAL_TOPIC)
|
|
self.state_subscriber = ROSMsgSubscriber(STATE_TOPIC_NAME)
|
|
print(f"{test_mode} test initialized...")
|
|
|
|
def _spin_loop(self):
|
|
try:
|
|
while rclpy.ok() and not self._stop_event.is_set():
|
|
rclpy.spin_once(self.node)
|
|
except rclpy.executors.ExternalShutdownException:
|
|
print("[INFO] Spin thread exiting due to shutdown.")
|
|
finally:
|
|
print("spin loop stopped...")
|
|
|
|
def warm_up(self):
|
|
"""Stabilize and release the robot."""
|
|
print("waiting for 2 seconds...")
|
|
time.sleep(2)
|
|
print(f"running {self.test_mode} test...")
|
|
self.activate()
|
|
print("activated...")
|
|
time.sleep(1)
|
|
self.release()
|
|
print("released...")
|
|
time.sleep(5)
|
|
|
|
def _run_walk_test(self):
|
|
self.walk_forward() # speed up to 0.2 m/s
|
|
time.sleep(1)
|
|
self.walk_forward() # speed up to 0.4 m/s
|
|
|
|
rate = self.node.create_rate(0.5)
|
|
start_time = time.time()
|
|
while rclpy.ok() and (time.time() - start_time) < 10.0:
|
|
obs = self.state_subscriber.get_msg()
|
|
|
|
if is_robot_fallen_from_quat(obs["torso_quat"]):
|
|
print("robot fallen...")
|
|
return 0
|
|
elif self._check_success_condition(obs):
|
|
print(f"robot reaching target ({self.test_mode})...")
|
|
return 1, {}
|
|
else:
|
|
rate.sleep()
|
|
|
|
print("test timed out after 10 seconds...")
|
|
return 0, {}
|
|
|
|
def _run_squat_test(self):
|
|
rate = self.node.create_rate(0.5)
|
|
start_time = time.time()
|
|
while rclpy.ok() and (time.time() - start_time) < 10.0:
|
|
obs = self.state_subscriber.get_msg()
|
|
|
|
if is_robot_fallen_from_quat(obs["torso_quat"]):
|
|
print("robot fallen...")
|
|
return 0, {}
|
|
elif self._check_success_condition(obs):
|
|
print(f"robot reaching target ({self.test_mode})...")
|
|
return 1, {}
|
|
else:
|
|
self.go_down()
|
|
rate.sleep()
|
|
|
|
print("test timed out after 10 seconds...")
|
|
return 0, {}
|
|
|
|
def cmd_to_velocity(self, cmd_list):
|
|
cmd_to_velocity = {
|
|
"w": np.array([0.2, 0.0, 0.0]),
|
|
"s": np.array([-0.2, 0.0, 0.0]),
|
|
"q": np.array([0.0, 0.2, 0.0]),
|
|
"e": np.array([0.0, -0.2, 0.0]),
|
|
"z": np.array([0.0, 0.0, 0.0]),
|
|
}
|
|
|
|
accumulated_velocity = np.array([0.0, 0.0, 0.0])
|
|
velocity_list = []
|
|
for cmd in cmd_list:
|
|
if cmd == "z":
|
|
accumulated_velocity = [0.0, 0.0, 0.0]
|
|
elif cmd in ["CHECK", "SKIP"]:
|
|
accumulated_velocity = velocity_list[-1]
|
|
else:
|
|
accumulated_velocity += cmd_to_velocity[cmd]
|
|
velocity_list.append(accumulated_velocity.copy())
|
|
|
|
return velocity_list
|
|
|
|
def _run_stop_test(self):
|
|
base_vel_thres = 0.25
|
|
|
|
cmd_list = (
|
|
["w", "w", "w", "w", "s", "s", "s", "z", "SKIP", "CHECK"]
|
|
+ ["s", "s", "q", "w", "w", "w", "e", "s", "s", "z", "SKIP", "CHECK"]
|
|
+ ["q", "q", "w", "q", "e", "s", "s", "e", "w", "z", "SKIP", "CHECK"]
|
|
+ ["w", "w", "w", "w", "w", "s", "s", "s", "s", "z", "SKIP", "CHECK"]
|
|
)
|
|
|
|
success_flag = 1
|
|
|
|
statistics = {
|
|
"floating_base_pose": {"state": []},
|
|
"floating_base_vel": {"state": [], "cmd": []},
|
|
"timestamp": [],
|
|
}
|
|
for cmd in cmd_list:
|
|
self.keyboard_event_publisher.publish(cmd)
|
|
time.sleep(0.5)
|
|
obs = self.state_subscriber.get_msg()
|
|
statistics["floating_base_pose"]["state"].append(
|
|
np.linalg.norm(obs["floating_base_pose"])
|
|
)
|
|
statistics["floating_base_vel"]["state"].append(
|
|
np.linalg.norm(obs["floating_base_vel"])
|
|
)
|
|
statistics["timestamp"].append(time.time())
|
|
|
|
if cmd == "CHECK" and np.linalg.norm(obs["floating_base_vel"]) > base_vel_thres:
|
|
print(
|
|
f" [{YELLOW_BOLD}WARNING{RESET}] robot is not stopped fully. "
|
|
f"Current base velocity: {np.linalg.norm(obs['floating_base_vel']):.3f} > {base_vel_thres:.3f}"
|
|
)
|
|
# success_flag = 0 # robot is not stopped
|
|
|
|
time.sleep(0.5)
|
|
|
|
vel_cmd = self.cmd_to_velocity(cmd_list)
|
|
vel_cmd = [np.linalg.norm(v) for v in vel_cmd]
|
|
statistics["floating_base_vel"]["cmd"] = vel_cmd
|
|
return success_flag, statistics
|
|
|
|
def _run_eef_track_test(self):
|
|
from gr00t_wbc.control.policy.lerobot_replay_policy import LerobotReplayPolicy
|
|
|
|
parquet_path = (
|
|
Path(__file__).parent.parent.parent.parent / "replay_data" / "g1_pnpbottle.parquet"
|
|
)
|
|
replay_policy = LerobotReplayPolicy(parquet_path=str(parquet_path))
|
|
|
|
freq = 50
|
|
rate = self.node.create_rate(freq)
|
|
|
|
statistics = {
|
|
# "floating_base_pose": {"state": [], "cmd": []},
|
|
"eef_base_pose": {"state": [], "cmd": []},
|
|
"timestamp": [],
|
|
}
|
|
|
|
for ii in range(500):
|
|
action = replay_policy.get_action()
|
|
action = replay_policy.action_to_cmd(action)
|
|
action["timestamp"] = time.monotonic()
|
|
action["target_time"] = time.monotonic() + ii / freq
|
|
self.control_publisher.publish(action)
|
|
obs = self.state_subscriber.get_msg()
|
|
if obs is None:
|
|
print("no obs...")
|
|
continue
|
|
gt_obs = replay_policy.get_observation()
|
|
|
|
# statistics["floating_base_pose"]["state"].append(obs["floating_base_pose"])
|
|
# statistics["floating_base_pose"]["cmd"].append(np.zeros_like(obs["floating_base_pose"]))
|
|
statistics["eef_base_pose"]["state"].append(obs["wrist_pose"])
|
|
statistics["eef_base_pose"]["cmd"].append(gt_obs["wrist_pose"])
|
|
statistics["timestamp"].append(time.time())
|
|
|
|
pos_err = np.linalg.norm(obs["wrist_pose"][:3] - gt_obs["wrist_pose"][:3])
|
|
if pos_err > 1e-1:
|
|
print(
|
|
f" [{YELLOW_BOLD}WARNING{RESET}] robot failed to track the eef, "
|
|
f"error: {pos_err:.3f} ({self.test_mode})..."
|
|
)
|
|
return 0, statistics
|
|
|
|
if is_robot_fallen_from_quat(obs["torso_quat"]):
|
|
print("robot fallen...")
|
|
return 0, statistics
|
|
else:
|
|
rate.sleep()
|
|
|
|
return 1, statistics
|
|
|
|
def run(self):
|
|
self.warm_up()
|
|
|
|
test_mode_to_func = {
|
|
"squat": self._run_squat_test,
|
|
"walk": self._run_walk_test,
|
|
"stop": self._run_stop_test,
|
|
"eef_track": self._run_eef_track_test,
|
|
}
|
|
|
|
result, statistics = test_mode_to_func[self.test_mode]()
|
|
|
|
self.post_process(statistics)
|
|
return result
|
|
|
|
def _check_success_condition(self, obs):
|
|
if self.test_mode == "squat":
|
|
return obs["floating_base_pose"][2] < 0.4
|
|
elif self.test_mode == "walk":
|
|
return np.linalg.norm(obs["floating_base_pose"][0:2]) > 1.0
|
|
return False
|
|
|
|
def activate(self):
|
|
self.keyboard_event_publisher.publish("]")
|
|
|
|
def release(self):
|
|
self.keyboard_event_publisher.publish("9")
|
|
|
|
def go_down(self):
|
|
self.keyboard_event_publisher.publish("2")
|
|
|
|
def walk_forward(self):
|
|
self.keyboard_event_publisher.publish("w")
|
|
|
|
def walk_stop(self):
|
|
self.keyboard_event_publisher.publish("z")
|
|
|
|
def post_process(self, statistics):
|
|
if len(statistics) == 0:
|
|
return
|
|
|
|
# plot the statistics
|
|
plot_keys = [key for key in statistics.keys() if key != "timestamp"]
|
|
viz = RerunViz(
|
|
image_keys=[],
|
|
tensor_keys=plot_keys,
|
|
window_size=10.0,
|
|
app_name=f"{self.test_mode}_test",
|
|
)
|
|
|
|
for ii in range(len(statistics[plot_keys[0]]["state"])):
|
|
tensor_data = {}
|
|
for k in plot_keys:
|
|
if "state" in statistics[k] and "cmd" in statistics[k]:
|
|
tensor_data[k] = np.array(
|
|
(statistics[k]["state"][ii], statistics[k]["cmd"][ii])
|
|
).reshape(2, -1)
|
|
else:
|
|
tensor_data[k] = np.asarray(statistics[k]["state"][ii]).reshape(1, -1)
|
|
viz.plot_tensors(
|
|
tensor_data,
|
|
statistics["timestamp"][ii],
|
|
)
|
|
|
|
if self.test_mode == "stop":
|
|
base_velocity = statistics["floating_base_vel"]["state"]
|
|
base_velocity_cmd = statistics["floating_base_vel"]["cmd"]
|
|
|
|
base_velocity_tracking_err = []
|
|
for v_cmd, v in zip(base_velocity_cmd, base_velocity): # TODO: check if this is correct
|
|
if v_cmd.max() < 1e-4:
|
|
base_velocity_tracking_err.append(v)
|
|
print(
|
|
f" [{GREEN_BOLD}INFO{RESET}] Base velocity tracking when stopped: "
|
|
f"{np.mean(base_velocity_tracking_err):.3f}"
|
|
)
|
|
|
|
if self.test_mode == "eef_track":
|
|
eef_pose = statistics["eef_base_pose"]["state"]
|
|
eef_pose_cmd = statistics["eef_base_pose"]["cmd"]
|
|
eef_pose_tracking_err = []
|
|
for p_cmd, p in zip(eef_pose_cmd, eef_pose):
|
|
eef_pose_tracking_err.append(np.linalg.norm(p - p_cmd))
|
|
print(
|
|
f" [{GREEN_BOLD}INFO{RESET}] Eef pose tracking error: {np.mean(eef_pose_tracking_err):.3f}"
|
|
)
|
|
|
|
def shutdown(self):
|
|
self._stop_event.set()
|
|
self.spin_thread.join()
|
|
del self.state_subscriber
|
|
del self.keyboard_event_publisher
|
|
# Don't shutdown ROS between tests - let pytest handle it
|
|
|
|
|
|
def start_g1_control_loop():
|
|
proc = subprocess.Popen(
|
|
[
|
|
"python3",
|
|
"gr00t_wbc/control/main/teleop/run_g1_control_loop.py",
|
|
"--keyboard_dispatcher_type",
|
|
"ros",
|
|
"--enable-offscreen",
|
|
],
|
|
preexec_fn=os.setsid,
|
|
)
|
|
time.sleep(10)
|
|
return proc
|
|
|
|
|
|
def run_test(test_mode: str):
|
|
"""Run a single test with the specified mode."""
|
|
proc = start_g1_control_loop()
|
|
print(f"G1 control loop started for {test_mode} test...")
|
|
|
|
test = LocomotionRunner(test_mode)
|
|
result = test.run()
|
|
|
|
print("Shutting down...")
|
|
test.shutdown()
|
|
proc.send_signal(signal.SIGKILL)
|
|
proc.wait()
|
|
|
|
return result
|
|
|
|
|
|
def test_squat():
|
|
"""Pytest function for squat test."""
|
|
result = run_test("squat")
|
|
assert result == 1, "Squat test failed - robot either fell or didn't reach target height"
|
|
|
|
|
|
def test_walk():
|
|
"""Pytest function for walk test."""
|
|
result = run_test("walk")
|
|
assert result == 1, "Walk test failed - robot either fell or didn't reach target distance"
|
|
|
|
|
|
@pytest.mark.skip(reason="skipping test for now, cicd test always gets killed")
|
|
def test_stop():
|
|
"""Pytest function for walking to a nearby position and stop test."""
|
|
result = run_test("stop")
|
|
assert result == 1, "Stop test failed - robot either fell or didn't reach target distance"
|
|
|
|
|
|
@pytest.mark.skip(reason="skipping test for now, cicd test always gets killed")
|
|
def test_eef_track():
|
|
"""Pytest function for eef track test."""
|
|
result = run_test("eef_track")
|
|
assert result == 1, "Eef track test failed - robot either fell or didn't reach target distance"
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description="Run locomotion tests")
|
|
parser.add_argument("--squat", action="store_true", help="Run squat test only")
|
|
parser.add_argument("--walk", action="store_true", help="Run walk test only")
|
|
parser.add_argument("--stop", action="store_true", help="Run stop test only")
|
|
parser.add_argument("--eef_track", action="store_true", help="Run eef track test only")
|
|
|
|
args = parser.parse_args()
|
|
|
|
if args.squat and args.walk:
|
|
print("Error: Cannot specify both --squat and --walk")
|
|
return 1
|
|
|
|
if args.squat:
|
|
print("Running squat test only...")
|
|
result = run_test("squat")
|
|
if result == 1:
|
|
print("✓ Squat test PASSED")
|
|
return 0
|
|
else:
|
|
print("✗ Squat test FAILED")
|
|
return 1
|
|
|
|
elif args.walk:
|
|
print("Running walk test only...")
|
|
result = run_test("walk")
|
|
if result == 1:
|
|
print("✓ Walk test PASSED")
|
|
return 0
|
|
else:
|
|
print("✗ Walk test FAILED")
|
|
return 1
|
|
|
|
elif args.stop:
|
|
print("Running stop test only...")
|
|
result = run_test("stop")
|
|
if result == 1:
|
|
print("✓ Stop test PASSED")
|
|
return 0
|
|
else:
|
|
print("✗ Stop test FAILED")
|
|
return 1
|
|
|
|
elif args.eef_track:
|
|
print("Running eef track test only...")
|
|
result = run_test("eef_track")
|
|
if result == 1:
|
|
print("✓ Eef track test PASSED")
|
|
return 0
|
|
else:
|
|
print("✗ Eef track test FAILED")
|
|
return 1
|
|
|
|
else:
|
|
print("Running both tests...")
|
|
squat_result = run_test("squat")
|
|
walk_result = run_test("walk")
|
|
|
|
if squat_result == 1 and walk_result == 1:
|
|
print("✓ All tests PASSED")
|
|
return 0
|
|
else:
|
|
print(
|
|
f"✗ Test results: squat={'PASSED' if squat_result == 1 else 'FAILED'}, "
|
|
f"walk={'PASSED' if walk_result == 1 else 'FAILED'}"
|
|
)
|
|
return 1
|
|
|
|
|
|
if __name__ == "__main__":
|
|
exit(main())
|