You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
514 lines
20 KiB
514 lines
20 KiB
import copy
|
|
from dataclasses import dataclass
|
|
from functools import partial
|
|
import json
|
|
import os
|
|
from pathlib import Path
|
|
import shutil
|
|
from typing import Optional
|
|
|
|
import datasets
|
|
from datasets import load_dataset
|
|
from datasets.utils import disable_progress_bars
|
|
from huggingface_hub.errors import RepositoryNotFoundError
|
|
from lerobot.common.datasets.lerobot_dataset import (
|
|
LeRobotDataset,
|
|
LeRobotDatasetMetadata,
|
|
compute_episode_stats,
|
|
)
|
|
from lerobot.common.datasets.utils import (
|
|
check_timestamps_sync,
|
|
get_episode_data_index,
|
|
validate_episode_buffer,
|
|
validate_frame,
|
|
)
|
|
import numpy as np
|
|
from PIL import Image as PILImage
|
|
import torch
|
|
from torchvision import transforms
|
|
|
|
from gr00t_wbc.control.main.config_template import ArgsConfig
|
|
from gr00t_wbc.data.video_writer import VideoWriter
|
|
|
|
disable_progress_bars() # Disable HuggingFace progress bars
|
|
|
|
|
|
@dataclass
|
|
class DataCollectionInfo:
|
|
"""
|
|
This dataclass stores additional information that is relevant to the data collection process.
|
|
"""
|
|
|
|
lower_body_policy: Optional[str] = None
|
|
wbc_model_path: Optional[str] = None
|
|
teleoperator_username: Optional[str] = None
|
|
support_operator_username: Optional[str] = None
|
|
robot_type: Optional[str] = None
|
|
robot_id: Optional[str] = None
|
|
|
|
def to_dict(self) -> dict:
|
|
"""Convert the dataclass to a dictionary for JSON serialization."""
|
|
return {
|
|
"lower_body_policy": self.lower_body_policy,
|
|
"wbc_model_path": self.wbc_model_path,
|
|
"teleoperator_username": self.teleoperator_username,
|
|
"support_operator_username": self.support_operator_username,
|
|
"robot_type": self.robot_type,
|
|
"robot_id": self.robot_id,
|
|
}
|
|
|
|
@classmethod
|
|
def from_dict(cls, data: dict) -> "DataCollectionInfo":
|
|
"""Create a DataCollectionInfo instance from a dictionary."""
|
|
return cls(**data)
|
|
|
|
|
|
class Gr00tDatasetMetadata(LeRobotDatasetMetadata):
|
|
"""
|
|
Additional metadata on top of LeRobotDatasetMetadata:
|
|
- modality_config: Written to `meta/modality.json`
|
|
- discarded_episode_indices: List of episode indices that were discarded. Written to `meta/info.json`
|
|
"""
|
|
|
|
MODALITY_CONFIG_REL_PATH = Path("meta/modality.json")
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
with open(self.root / self.MODALITY_CONFIG_REL_PATH, "rb") as f:
|
|
self.modality_config = json.load(f)
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
modality_config: dict,
|
|
script_config: dict,
|
|
data_collection_info: DataCollectionInfo,
|
|
*args,
|
|
**kwargs,
|
|
):
|
|
cls.validate_modality_config(modality_config)
|
|
|
|
# Create base metadata object using parent class
|
|
obj = super().create(*args, **kwargs)
|
|
|
|
# we also need to initialize the discarded_episode_indices
|
|
obj.info["script_config"] = script_config
|
|
obj.info["discarded_episode_indices"] = []
|
|
obj.info["data_collection_info"] = data_collection_info.to_dict()
|
|
with open(obj.root / "meta" / "info.json", "w") as f:
|
|
json.dump(obj.info, f, indent=4)
|
|
|
|
obj.__class__ = cls
|
|
with open(obj.root / cls.MODALITY_CONFIG_REL_PATH, "w") as f:
|
|
json.dump(modality_config, f, indent=4)
|
|
obj.modality_config = modality_config
|
|
return obj
|
|
|
|
@staticmethod
|
|
def validate_modality_config(modality_config: dict) -> None:
|
|
# verify if it contains all state, action, video, annotation keys
|
|
valid_keys = ["state", "action", "video", "annotation"]
|
|
if not all(key in modality_config for key in valid_keys):
|
|
raise ValueError(
|
|
f"Modality config must contain all of the following keys: {valid_keys}"
|
|
)
|
|
|
|
# verify that each key has a modality_config dict
|
|
for key in valid_keys:
|
|
if key not in modality_config:
|
|
raise ValueError(f"Modality config must contain a '{key}' key")
|
|
|
|
|
|
class Gr00tDataExporter(LeRobotDataset):
|
|
"""
|
|
A class for exporting data collected for a single session to LeRobot Dataset.
|
|
|
|
Intended life cycle:
|
|
1. Create a Gr00tDataExporter object
|
|
2. Add frames using add_frame()
|
|
3. Save the episode using save_episode()
|
|
- This will flush the episode buffer to disk
|
|
- This will also close the video writers
|
|
- Create a new video writer and ep buffer to start new episode
|
|
|
|
If interrupted, here's the indented behavior:
|
|
- Interruption before save_episode() is called: loses the current episode
|
|
- Interruption after save_episode() is called: keeps completed episodes
|
|
"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
self.video_writers = self.create_video_writer()
|
|
|
|
@property
|
|
def repo_id(self):
|
|
return self.meta.repo_id
|
|
|
|
@property
|
|
def root(self):
|
|
return self.meta.root
|
|
|
|
@property
|
|
def local_files_only(self):
|
|
return self.meta.local_files_only
|
|
|
|
@property
|
|
def video_keys(self):
|
|
return self.meta.video_keys
|
|
|
|
@classmethod
|
|
def create(
|
|
cls,
|
|
save_root: str | Path,
|
|
fps: int,
|
|
features: dict,
|
|
modality_config: dict,
|
|
task: str,
|
|
script_config: ArgsConfig = ArgsConfig(),
|
|
data_collection_info: DataCollectionInfo = DataCollectionInfo(),
|
|
robot_type: str | None = None,
|
|
tolerance_s: float = 1e-4,
|
|
vcodec: str = "h264",
|
|
overwrite_existing: bool = False,
|
|
upload_bucket_path: str | None = None,
|
|
) -> "Gr00tDataExporter":
|
|
"""
|
|
Create a Gr00tDataExporter object.
|
|
|
|
Args:
|
|
save_root: The root directory to save the dataset.
|
|
fps: The frame rate of the dataset.
|
|
features: The features of the dataset.
|
|
modality_config: The modality config of the dataset.
|
|
task: The task performed during the data collection session.
|
|
data_collection_info: The data collection info.
|
|
If the dataset already exists, this argument will be ignored.
|
|
If data_collection_info is not provided, it will be set to an empty DataCollectionInfo object.
|
|
robot_type: The type of robot.
|
|
tolerance_s: The tolerance for the dataset.
|
|
image_writer_processes: The number of processes to use for the image writer.
|
|
image_writer_threads: The number of threads to use for the image writer.
|
|
vcodec: The codec to use for the video writer.
|
|
"""
|
|
|
|
obj = cls.__new__(cls)
|
|
repo_id = (
|
|
"tmp/tmp_dataset" # NOTE(fengyuanh): Not relevant since we are not pushing to the hub
|
|
)
|
|
if overwrite_existing and (Path(save_root)).exists():
|
|
print(
|
|
f"Found existing dataset at {save_root}",
|
|
"Cleaning up this directory since overwrite_existing is True.",
|
|
)
|
|
shutil.rmtree(save_root)
|
|
|
|
if (Path(save_root)).exists():
|
|
# Try to resume from existing dataset
|
|
try:
|
|
# Load the metadata
|
|
obj.meta = Gr00tDatasetMetadata(
|
|
repo_id=repo_id,
|
|
root=save_root,
|
|
)
|
|
|
|
except RepositoryNotFoundError as e:
|
|
raise ValueError(
|
|
f"Failed to resume from corrupted dataset. Please manually check the dataset at {save_root}"
|
|
) from e
|
|
else:
|
|
if not isinstance(script_config, dict):
|
|
script_config = script_config.to_dict()
|
|
obj.meta = Gr00tDatasetMetadata.create(
|
|
repo_id=repo_id,
|
|
fps=fps,
|
|
root=save_root,
|
|
# NOTE(fengyuanh): We use "robot_type" instead of this field which requires a Robot object
|
|
robot=None,
|
|
robot_type=robot_type,
|
|
features=features,
|
|
modality_config=modality_config,
|
|
script_config=script_config,
|
|
# NOTE(fengyuanh): Always use videos for exporting
|
|
use_videos=True,
|
|
data_collection_info=data_collection_info,
|
|
)
|
|
obj.tolerance_s = tolerance_s
|
|
obj.video_backend = (
|
|
"pyav" # NOTE(fengyuanh): Only used in training, not relevant for exporting
|
|
)
|
|
obj.vcodec = vcodec
|
|
obj.task = task
|
|
obj.image_writer = None
|
|
|
|
obj.episode_buffer = obj.create_episode_buffer()
|
|
|
|
obj.episodes = None
|
|
obj.hf_dataset = obj.create_hf_dataset()
|
|
obj.image_transforms = None
|
|
obj.delta_timestamps = None
|
|
obj.delta_indices = None
|
|
obj.episode_data_index = None
|
|
obj.upload_bucket_path = upload_bucket_path
|
|
obj.video_writers = obj.create_video_writer()
|
|
return obj
|
|
|
|
def create_video_writer(self) -> dict[str, VideoWriter]:
|
|
video_writers = {}
|
|
for key in self.meta.video_keys:
|
|
video_writers[key] = VideoWriter(
|
|
self.root
|
|
/ self.meta.get_video_file_path(self.episode_buffer["episode_index"], key),
|
|
self.meta.shapes[key][1],
|
|
self.meta.shapes[key][0],
|
|
self.fps,
|
|
self.vcodec,
|
|
)
|
|
return video_writers
|
|
|
|
# @note (k2): This function is copied from LeRobotDataset.add_frame.
|
|
# This is done because we want to bypass lerobot's
|
|
# image_writer and use our own VideoWriter class.
|
|
def add_frame(self, frame: dict) -> None:
|
|
"""
|
|
This function only adds the frame to the episode_buffer. Videos are handled by the video_writer,
|
|
which uses a stream writer to write to disk.
|
|
"""
|
|
frame = copy.deepcopy(frame)
|
|
frame["task"] = frame.get("task", self.task)
|
|
|
|
# Convert torch to numpy if needed
|
|
for name in frame:
|
|
if isinstance(frame[name], torch.Tensor):
|
|
frame[name] = frame[name].numpy()
|
|
|
|
validate_frame(frame, self.features)
|
|
|
|
if self.episode_buffer is None:
|
|
self.episode_buffer = self.create_episode_buffer()
|
|
|
|
# Automatically add frame_index and timestamp to episode buffer
|
|
frame_index = self.episode_buffer["size"]
|
|
timestamp = frame.pop("timestamp") if "timestamp" in frame else frame_index / self.fps
|
|
self.episode_buffer["frame_index"].append(frame_index)
|
|
self.episode_buffer["timestamp"].append(timestamp)
|
|
|
|
# Add frame features to episode_buffer
|
|
for key in frame:
|
|
if key == "task":
|
|
# Note: we associate the task in natural language to its task index during `save_episode`
|
|
self.episode_buffer["task"].append(frame["task"])
|
|
continue
|
|
|
|
if key not in self.features:
|
|
raise ValueError(
|
|
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
|
|
)
|
|
|
|
if self.features[key]["dtype"] in ["image", "video"]:
|
|
img_path = self._get_image_file_path(
|
|
episode_index=self.episode_buffer["episode_index"],
|
|
image_key=key,
|
|
frame_index=frame_index,
|
|
)
|
|
if frame_index == 0:
|
|
img_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
# @note (k2): using our own VideoWriter class, bypassing the image_writer
|
|
self.video_writers[key].add_frame(frame[key])
|
|
self.episode_buffer[key].append(str(img_path))
|
|
else:
|
|
self.episode_buffer[key].append(frame[key])
|
|
|
|
self.episode_buffer["size"] += 1
|
|
|
|
def stop_video_writers(self):
|
|
if not hasattr(self, "video_writers"):
|
|
raise RuntimeError(
|
|
"Can't stop video writers because they haven't been initialized. Call create() first."
|
|
)
|
|
for key in self.video_writers:
|
|
self.video_writers[key].stop()
|
|
|
|
def skip_and_start_new_episode(
|
|
self,
|
|
) -> None:
|
|
"""
|
|
Skip the current episode and start a new one.
|
|
"""
|
|
self.stop_video_writers()
|
|
self.episode_buffer = self.create_episode_buffer()
|
|
self.video_writers = self.create_video_writer()
|
|
|
|
# @note (k2): Code copied from LeRobotDataset.save_episode
|
|
# We override this function because we want to bypass lerobot's `compute_episode_stats` on video features
|
|
# since `compute_episode_stats` only works when images are written to disk.
|
|
def save_episode(self, episode_data: dict | None = None) -> None:
|
|
if not episode_data:
|
|
episode_buffer = self.episode_buffer
|
|
|
|
validate_episode_buffer(episode_buffer, self.meta.total_episodes, self.features)
|
|
|
|
# size and task are special cases that won't be added to hf_dataset
|
|
episode_length = episode_buffer.pop("size")
|
|
tasks = episode_buffer.pop("task")
|
|
episode_tasks = list(set(tasks))
|
|
episode_index = episode_buffer["episode_index"]
|
|
|
|
episode_buffer["index"] = np.arange(
|
|
self.meta.total_frames, self.meta.total_frames + episode_length
|
|
)
|
|
episode_buffer["episode_index"] = np.full((episode_length,), episode_index)
|
|
|
|
# Add new tasks to the tasks dictionary
|
|
for task in episode_tasks:
|
|
task_index = self.meta.get_task_index(task)
|
|
if task_index is None:
|
|
self.meta.add_task(task)
|
|
|
|
# Given tasks in natural language, find their corresponding task indices
|
|
episode_buffer["task_index"] = np.array([self.meta.get_task_index(task) for task in tasks])
|
|
|
|
for key, ft in self.features.items():
|
|
# index, episode_index, task_index are already processed above, and image and video
|
|
# are processed separately by storing image path and frame info as meta data
|
|
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
|
continue
|
|
episode_buffer[key] = np.stack(episode_buffer[key])
|
|
|
|
self._wait_image_writer()
|
|
self._save_episode_table(episode_buffer, episode_index)
|
|
|
|
# @note (k2): computing only non-video features stats
|
|
non_video_features = {k: v for k, v in self.features.items() if v["dtype"] not in ["video"]}
|
|
non_vid_ep_buffer = {
|
|
k: v for k, v in episode_buffer.items() if k in non_video_features.keys()
|
|
}
|
|
ep_stats = compute_episode_stats(non_vid_ep_buffer, non_video_features)
|
|
|
|
if len(self.meta.video_keys) > 0:
|
|
video_paths = self.encode_episode_videos(episode_index)
|
|
for key in self.meta.video_keys:
|
|
episode_buffer[key] = video_paths[key]
|
|
|
|
# `meta.save_episode` be executed after encoding the videos
|
|
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats)
|
|
|
|
ep_data_index = get_episode_data_index(self.meta.episodes, [episode_index])
|
|
ep_data_index_np = {k: t.numpy() for k, t in ep_data_index.items()}
|
|
check_timestamps_sync(
|
|
episode_buffer["timestamp"],
|
|
episode_buffer["episode_index"],
|
|
ep_data_index_np,
|
|
self.fps,
|
|
self.tolerance_s,
|
|
)
|
|
|
|
video_files = list(self.root.rglob("*.mp4"))
|
|
assert len(video_files) == self.num_episodes * len(self.meta.video_keys)
|
|
|
|
parquet_files = list(self.root.rglob("*.parquet"))
|
|
assert len(parquet_files) == self.num_episodes
|
|
|
|
# delete images
|
|
img_dir = self.root / "images"
|
|
if img_dir.is_dir():
|
|
shutil.rmtree(self.root / "images")
|
|
|
|
if not episode_data: # Reset the buffer and create new video writers
|
|
self.episode_buffer = self.create_episode_buffer()
|
|
self.video_writers = self.create_video_writer()
|
|
|
|
# check if all video and parquet files exist
|
|
for key in self.meta.video_keys:
|
|
video_path = os.path.join(self.root, self.meta.get_video_file_path(episode_index, key))
|
|
if not os.path.exists(video_path):
|
|
raise FileNotFoundError(
|
|
f"Video path: {video_path} does not exist for episode {episode_index}"
|
|
)
|
|
|
|
parquet_path = os.path.join(self.root, self.meta.get_data_file_path(episode_index))
|
|
if not os.path.exists(parquet_path):
|
|
raise FileNotFoundError(
|
|
f"Parquet path: {parquet_path} does not exist for episode {episode_index}"
|
|
)
|
|
|
|
# @note (k2): Overriding LeRobotDataset.encode_episode_videos to use our own VideoWriter class
|
|
def encode_episode_videos(self, episode_index: int) -> dict:
|
|
video_paths = {}
|
|
for key in self.meta.video_keys:
|
|
video_paths[key] = self.video_writers[key].stop()
|
|
return video_paths
|
|
|
|
def save_episode_as_discarded(self) -> None:
|
|
"""
|
|
Flag ongoing episode as discarded and save it to disk. Failed manipulations (grasp, manipulation) are
|
|
flagged as discarded. It will add the episode index to the discarded episode indices list in info.json.
|
|
"""
|
|
self.meta.info["discarded_episode_indices"] = self.meta.info.get(
|
|
"discarded_episode_indices", []
|
|
) + [self.episode_buffer["episode_index"]]
|
|
self.save_episode()
|
|
|
|
|
|
def hf_transform_to_torch_by_features(
|
|
features: datasets.Sequence, items_dict: dict[torch.Tensor | None]
|
|
):
|
|
"""Get a transform function that convert items from Hugging Face dataset (pyarrow)
|
|
to torch tensors. Importantly, images are converted from PIL, which corresponds to
|
|
a channel last representation (h w c) of uint8 type, to a torch image representation
|
|
with channel first (c h w) of float32 type in range [0,1].
|
|
"""
|
|
for key in items_dict:
|
|
first_item = items_dict[key][0]
|
|
if isinstance(first_item, PILImage.Image):
|
|
to_tensor = transforms.ToTensor()
|
|
items_dict[key] = [to_tensor(img) for img in items_dict[key]]
|
|
elif first_item is None:
|
|
pass
|
|
else:
|
|
if isinstance(features[key], datasets.Value):
|
|
dtype_str = features[key].dtype
|
|
elif isinstance(features[key], datasets.Sequence):
|
|
assert isinstance(features[key].feature, datasets.Value)
|
|
dtype_str = features[key].feature.dtype
|
|
else:
|
|
raise ValueError(f"Unsupported feature type for key '{key}': {features[key]}")
|
|
dtype_mapping = {
|
|
"float32": torch.float32,
|
|
"float64": torch.float64,
|
|
"int32": torch.int32,
|
|
"int64": torch.int64,
|
|
}
|
|
items_dict[key] = [
|
|
torch.tensor(x, dtype=dtype_mapping[dtype_str]) for x in items_dict[key]
|
|
]
|
|
return items_dict
|
|
|
|
|
|
# This is a subclass of LeRobotDataset that only fixes the data type when loading
|
|
# By default, LeRobotDataset will automatically convert float64 to float32
|
|
class TypedLeRobotDataset(LeRobotDataset):
|
|
def __init__(self, load_video=True, *args, **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
if not load_video:
|
|
video_keys = []
|
|
for key in self.meta.features.keys():
|
|
if self.meta.features[key]["dtype"] == "video":
|
|
video_keys.append(key)
|
|
for key in video_keys:
|
|
self.meta.features.pop(key)
|
|
|
|
def load_hf_dataset(self) -> datasets.Dataset:
|
|
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
|
if self.episodes is None:
|
|
path = str(self.root / "data")
|
|
hf_dataset = load_dataset("parquet", data_dir=path, split="train")
|
|
else:
|
|
files = [
|
|
str(self.root / self.meta.get_data_file_path(ep_idx)) for ep_idx in self.episodes
|
|
]
|
|
hf_dataset = load_dataset("parquet", data_files=files, split="train")
|
|
|
|
# TODO(aliberts): hf_dataset.set_format("torch")
|
|
hf_dataset.set_transform(partial(hf_transform_to_torch_by_features, hf_dataset.features))
|
|
return hf_dataset
|