You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
403 lines
15 KiB
403 lines
15 KiB
import glob
|
|
import os
|
|
import tempfile
|
|
from unittest.mock import MagicMock, patch
|
|
|
|
import numpy as np
|
|
import pytest
|
|
|
|
try:
|
|
from gr00t_wbc.control.main.teleop.run_g1_data_exporter import Gr00tDataCollector
|
|
from gr00t_wbc.control.robot_model.instantiation.g1 import instantiate_g1_robot_model
|
|
from gr00t_wbc.data.constants import RS_VIEW_CAMERA_HEIGHT, RS_VIEW_CAMERA_WIDTH
|
|
from gr00t_wbc.data.exporter import Gr00tDataExporter
|
|
from gr00t_wbc.data.utils import get_dataset_features
|
|
except ModuleNotFoundError as e:
|
|
if "No module named 'rclpy'" in str(e):
|
|
pytestmark = pytest.mark.skip(reason="ROS (rclpy) is not installed")
|
|
else:
|
|
raise e
|
|
|
|
|
|
import json
|
|
|
|
# How does mocking ROS work?
|
|
#
|
|
# This test file uses mocking to simulate a ROS environment without requiring actual ROS hardware:
|
|
#
|
|
# 1. ros_ok_side_effect: Controls how long the ROS loop runs by returning a sequence of
|
|
# True/False values. [True, True, False] means "run for 2 iterations then stop"
|
|
#
|
|
# 2. MockROSMsgSubscriber: Simulates sensors (camera/state) by returning pre-defined data:
|
|
#
|
|
# 3. MockKeyboardListenerSubscriber: Simulates user input:
|
|
# - 'c' = start/stop recording
|
|
# - 'd' = discard episode
|
|
# - KeyboardInterrupt = simulate Ctrl+C
|
|
# - None = no input
|
|
#
|
|
# 4. MockROSEnvironment: A context manager that patches all ROS dependencies to use our mocks,
|
|
# allowing us to test ROS-dependent code without actual ROS running.
|
|
|
|
|
|
class MockROSMsgSubscriber:
|
|
def __init__(self, return_value: list[dict]):
|
|
self.return_value = return_value
|
|
self.counter = 0
|
|
|
|
def get_image(self):
|
|
if self.counter < len(self.return_value):
|
|
self.counter += 1
|
|
return self.return_value[self.counter - 1]
|
|
else:
|
|
return None
|
|
|
|
def get_msg(self):
|
|
if self.counter < len(self.return_value):
|
|
self.counter += 1
|
|
return self.return_value[self.counter - 1]
|
|
else:
|
|
return None
|
|
|
|
|
|
class MockKeyboardListenerSubscriber:
|
|
def __init__(self, return_value: list[str]):
|
|
self.return_value = return_value
|
|
self.counter = 0
|
|
|
|
def get_keyboard_input(self):
|
|
return self.return_value[self.counter]
|
|
|
|
def read_msg(self):
|
|
if self.counter < len(self.return_value):
|
|
result = self.return_value[self.counter]
|
|
if isinstance(result, KeyboardInterrupt):
|
|
raise result
|
|
self.counter += 1
|
|
return result
|
|
return None
|
|
|
|
|
|
class MockROSEnvironment:
|
|
"""Context manager for mocking ROS environment and subscribers."""
|
|
|
|
def __init__(self, ok_side_effect, keyboard_listener, img_subscriber, state_subscriber):
|
|
self.ok_side_effect = ok_side_effect
|
|
self.keyboard_listener = keyboard_listener
|
|
self.img_subscriber = img_subscriber
|
|
self.state_subscriber = state_subscriber
|
|
self.patches = []
|
|
|
|
def __enter__(self):
|
|
self.patches = [
|
|
patch("rclpy.init"),
|
|
patch("rclpy.create_node"),
|
|
patch("rclpy.spin"),
|
|
patch("rclpy.ok", side_effect=self.ok_side_effect),
|
|
patch("rclpy.shutdown"),
|
|
patch(
|
|
"gr00t_wbc.control.main.teleop.run_g1_data_exporter.KeyboardListenerSubscriber",
|
|
return_value=self.keyboard_listener,
|
|
),
|
|
patch(
|
|
"gr00t_wbc.control.main.teleop.run_g1_data_exporter.ROSImgMsgSubscriber",
|
|
return_value=self.img_subscriber,
|
|
),
|
|
patch(
|
|
"gr00t_wbc.control.main.teleop.run_g1_data_exporter.ROSMsgSubscriber",
|
|
return_value=self.state_subscriber,
|
|
),
|
|
]
|
|
|
|
for p in self.patches:
|
|
p.start()
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
for p in reversed(self.patches):
|
|
p.stop()
|
|
return False
|
|
|
|
|
|
def verify_parquet_files_exist(file_path: str, num_episodes: int):
|
|
parquet_files = glob.glob(os.path.join(file_path, "data/chunk-*/episode_*.parquet"))
|
|
assert (
|
|
len(parquet_files) == num_episodes
|
|
), f"Expected {num_episodes} parquet files, but found {len(parquet_files)}"
|
|
|
|
|
|
def verify_video_files_exist(file_path: str, observation_keys: list[str], num_episodes: int):
|
|
for observation_key in observation_keys:
|
|
video_files = glob.glob(
|
|
os.path.join(file_path, f"videos/chunk-*/{observation_key}/episode_*.mp4")
|
|
)
|
|
assert (
|
|
len(video_files) == num_episodes
|
|
), f"Expected {num_episodes} video files, but found {len(video_files)}"
|
|
|
|
|
|
def verify_metadata_files(file_path: str):
|
|
files_to_check = ["episodes.jsonl", "info.json", "tasks.jsonl", "modality.json"]
|
|
for file in files_to_check:
|
|
assert os.path.exists(os.path.join(file_path, "meta", file)), f"meta/{file} not created"
|
|
|
|
|
|
@pytest.fixture
|
|
def lerobot_features():
|
|
robot_model = instantiate_g1_robot_model()
|
|
return get_dataset_features(robot_model)
|
|
|
|
|
|
@pytest.fixture
|
|
def modality_config():
|
|
return {
|
|
"state": {"feature1": {"start": 0, "end": 4}, "feature2": {"start": 4, "end": 9}},
|
|
"action": {"feature1": {"start": 0, "end": 4}, "feature2": {"start": 4, "end": 9}},
|
|
"video": {"rs_view": {"original_key": "observation.images.ego_view"}},
|
|
"annotation": {"human.task_description": {"original_key": "task_index"}},
|
|
}
|
|
|
|
|
|
def _get_image_stream_data(episode_length: int, frame_rate: int, img_height: int, img_width: int):
|
|
return [
|
|
{
|
|
"image": np.zeros((img_height, img_width, 3), dtype=np.uint8),
|
|
"timestamp": (i * 1 / frame_rate),
|
|
}
|
|
for i in range(episode_length)
|
|
]
|
|
|
|
|
|
def _get_state_act_stream_data(
|
|
episode_length: int, frame_rate: int, state_dim: int, action_dim: int
|
|
):
|
|
return [
|
|
{
|
|
"q": np.zeros(state_dim),
|
|
"action": np.zeros(action_dim),
|
|
"timestamp": (i * 1 / frame_rate),
|
|
"navigate_command": np.zeros(3, dtype=np.float64),
|
|
"base_height_command": 0.0,
|
|
"wrist_pose": np.zeros(14, dtype=np.float64),
|
|
"action.eef": np.zeros(14, dtype=np.float64),
|
|
}
|
|
for i in range(episode_length)
|
|
]
|
|
|
|
|
|
def test_control_loop_happy_path_workflow(lerobot_features, modality_config):
|
|
"""
|
|
This test records a single episode and saves it to disk.
|
|
"""
|
|
episode_length = 10
|
|
frame_rate = 20
|
|
img_stream_data = _get_image_stream_data(
|
|
episode_length, frame_rate, RS_VIEW_CAMERA_HEIGHT, RS_VIEW_CAMERA_WIDTH
|
|
)
|
|
robot_model = instantiate_g1_robot_model()
|
|
state_act_stream_data = _get_state_act_stream_data(
|
|
episode_length, frame_rate, robot_model.num_joints, robot_model.num_joints
|
|
)
|
|
|
|
keyboard_sub_output = [None for _ in range(episode_length)]
|
|
keyboard_sub_output[0] = "c" # Start recording
|
|
keyboard_sub_output[-1] = "c" # Stop recording and save
|
|
|
|
# --------- Save the first episode ---------
|
|
mock_img_sub = MockROSMsgSubscriber(img_stream_data)
|
|
mock_state_sub = MockROSMsgSubscriber(state_act_stream_data)
|
|
mock_keyboard_listner = MockKeyboardListenerSubscriber(keyboard_sub_output)
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
dataset_dir = os.path.join(temp_dir, "dataset")
|
|
|
|
data_exporter = Gr00tDataExporter.create(
|
|
save_root=dataset_dir,
|
|
fps=frame_rate,
|
|
features=lerobot_features,
|
|
modality_config=modality_config,
|
|
task="test",
|
|
)
|
|
|
|
ros_ok_side_effect = [True] * (episode_length + 1) + [False]
|
|
with MockROSEnvironment(
|
|
ros_ok_side_effect, mock_keyboard_listner, mock_img_sub, mock_state_sub
|
|
):
|
|
data_collector = Gr00tDataCollector(
|
|
camera_topic_name="mock_camera_topic",
|
|
state_topic_name="mock_state_topic",
|
|
data_exporter=data_exporter,
|
|
frequency=frame_rate,
|
|
)
|
|
|
|
# mocking to avoid actual sleeping
|
|
data_collector.rate = MagicMock()
|
|
|
|
data_collector.run()
|
|
|
|
verify_parquet_files_exist(dataset_dir, 1)
|
|
verify_video_files_exist(dataset_dir, data_exporter.meta.video_keys, 1)
|
|
verify_metadata_files(dataset_dir)
|
|
|
|
# --------- Save the second episode ---------
|
|
# we reset the mock subscribers and re-run the control loop
|
|
# This immitates the case where the user starts recording a new episode on an existing dataset
|
|
mock_img_sub = MockROSMsgSubscriber(img_stream_data)
|
|
mock_state_sub = MockROSMsgSubscriber(state_act_stream_data)
|
|
ros_ok_side_effect = [True] * (episode_length + 1) + [False]
|
|
mock_keyboard_listner = MockKeyboardListenerSubscriber(keyboard_sub_output)
|
|
with MockROSEnvironment(
|
|
ros_ok_side_effect, mock_keyboard_listner, mock_img_sub, mock_state_sub
|
|
):
|
|
data_collector = Gr00tDataCollector(
|
|
camera_topic_name="mock_camera_topic",
|
|
state_topic_name="mock_state_topic",
|
|
data_exporter=data_exporter,
|
|
frequency=frame_rate,
|
|
)
|
|
|
|
# mocking to avoid actual sleeping
|
|
data_collector.rate = MagicMock()
|
|
|
|
data_collector.run()
|
|
|
|
# now there should be 2 episodes in the dataset
|
|
verify_parquet_files_exist(dataset_dir, 2)
|
|
verify_video_files_exist(dataset_dir, data_exporter.meta.video_keys, 2)
|
|
verify_metadata_files(dataset_dir)
|
|
|
|
|
|
def test_control_loop_keyboard_interrupt_workflow(lerobot_features, modality_config):
|
|
"""
|
|
This test simulates a keyboard interruption in the middle of recording.
|
|
Expected behavior:
|
|
- The episode is saved to disk
|
|
- The episode is marked as discarded
|
|
"""
|
|
episode_length = 15
|
|
frame_rate = 20
|
|
img_stream_data = _get_image_stream_data(
|
|
episode_length, frame_rate, RS_VIEW_CAMERA_HEIGHT, RS_VIEW_CAMERA_WIDTH
|
|
)
|
|
robot_model = instantiate_g1_robot_model()
|
|
state_act_stream_data = _get_state_act_stream_data(
|
|
episode_length, frame_rate, robot_model.num_joints, robot_model.num_joints
|
|
)
|
|
|
|
keyboard_sub_output = [None for _ in range(episode_length)]
|
|
keyboard_sub_output[0] = "c" # Start recording
|
|
keyboard_sub_output[5] = KeyboardInterrupt() # keyboard interruption in the middle of recording
|
|
|
|
mock_img_sub = MockROSMsgSubscriber(img_stream_data)
|
|
mock_state_sub = MockROSMsgSubscriber(state_act_stream_data)
|
|
mock_keyboard_listener = MockKeyboardListenerSubscriber(keyboard_sub_output)
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
dataset_dir = os.path.join(temp_dir, "dataset")
|
|
|
|
data_exporter = Gr00tDataExporter.create(
|
|
save_root=dataset_dir,
|
|
fps=frame_rate,
|
|
features=lerobot_features,
|
|
modality_config=modality_config,
|
|
task="test",
|
|
)
|
|
|
|
ros_ok_side_effect = [True] * episode_length + [False]
|
|
with MockROSEnvironment(
|
|
ros_ok_side_effect, mock_keyboard_listener, mock_img_sub, mock_state_sub
|
|
):
|
|
data_collector = Gr00tDataCollector(
|
|
camera_topic_name="mock_camera_topic",
|
|
state_topic_name="mock_state_topic",
|
|
data_exporter=data_exporter,
|
|
frequency=frame_rate,
|
|
)
|
|
|
|
data_collector.rate = MagicMock()
|
|
# try:
|
|
data_collector.run()
|
|
# except KeyboardInterrupt:
|
|
# pass
|
|
|
|
verify_parquet_files_exist(dataset_dir, 1)
|
|
verify_video_files_exist(dataset_dir, data_exporter.meta.video_keys, 1)
|
|
verify_metadata_files(dataset_dir)
|
|
|
|
# verify that the episode is marked as discarded
|
|
ep_info = json.load(open(os.path.join(dataset_dir, "meta", "info.json")))
|
|
assert ep_info["discarded_episode_indices"][0] == 0
|
|
assert ep_info["total_frames"] == 5
|
|
assert ep_info["total_episodes"] == 1
|
|
|
|
|
|
def test_discarded_episode_workflow(lerobot_features, modality_config):
|
|
"""
|
|
This test simulates a case where the user discards an episode in the middle of recording.
|
|
Expected behavior:
|
|
- Record 3 episodes, discard episode 0 and 2
|
|
- There should be 3 episodes saved to disk
|
|
- Episode 0 and 2 should be flagged as discarded
|
|
"""
|
|
episode_length = 17
|
|
frame_rate = 20
|
|
robot_model = instantiate_g1_robot_model()
|
|
state_dim = robot_model.num_joints
|
|
action_dim = robot_model.num_joints
|
|
img_stream_data = _get_image_stream_data(
|
|
episode_length, frame_rate, RS_VIEW_CAMERA_HEIGHT, RS_VIEW_CAMERA_WIDTH
|
|
)
|
|
state_act_stream_data = _get_state_act_stream_data(
|
|
episode_length, frame_rate, state_dim, action_dim
|
|
)
|
|
|
|
keyboard_sub_output = [None for _ in range(episode_length)]
|
|
keyboard_sub_output[0] = "c" # Start recording episode index 0
|
|
keyboard_sub_output[5] = "x" # Discard episode index 0
|
|
keyboard_sub_output[7] = "c" # Start recording episode index 1
|
|
keyboard_sub_output[10] = "c" # stop recording and save episode index 1
|
|
keyboard_sub_output[12] = "c" # start recording episode index 2
|
|
keyboard_sub_output[15] = "x" # discard episode index 2
|
|
|
|
mock_img_sub = MockROSMsgSubscriber(img_stream_data)
|
|
mock_state_sub = MockROSMsgSubscriber(state_act_stream_data)
|
|
mock_keyboard_listener = MockKeyboardListenerSubscriber(keyboard_sub_output)
|
|
|
|
with tempfile.TemporaryDirectory() as temp_dir:
|
|
dataset_dir = os.path.join(temp_dir, "dataset")
|
|
|
|
data_exporter = Gr00tDataExporter.create(
|
|
save_root=dataset_dir,
|
|
fps=frame_rate,
|
|
features=lerobot_features,
|
|
modality_config=modality_config,
|
|
task="test",
|
|
)
|
|
|
|
ros_ok_side_effect = [True] * episode_length + [False]
|
|
with MockROSEnvironment(
|
|
ros_ok_side_effect, mock_keyboard_listener, mock_img_sub, mock_state_sub
|
|
):
|
|
data_collector = Gr00tDataCollector(
|
|
camera_topic_name="mock_camera_topic",
|
|
state_topic_name="mock_state_topic",
|
|
data_exporter=data_exporter,
|
|
frequency=frame_rate,
|
|
)
|
|
|
|
data_collector.rate = MagicMock()
|
|
try:
|
|
data_collector.run()
|
|
except Exception:
|
|
pass
|
|
|
|
# vrify if the episode is marked as discarded
|
|
ep_info = json.load(open(os.path.join(dataset_dir, "meta", "info.json")))
|
|
assert len(ep_info["discarded_episode_indices"]) == 2
|
|
assert ep_info["discarded_episode_indices"][0] == 0
|
|
assert ep_info["discarded_episode_indices"][1] == 2
|
|
|
|
# verify that all episodes are saved regardless of being discarded
|
|
verify_parquet_files_exist(dataset_dir, 3)
|
|
verify_video_files_exist(dataset_dir, data_exporter.meta.video_keys, 3)
|
|
verify_metadata_files(dataset_dir)
|