You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
111 lines
3.8 KiB
111 lines
3.8 KiB
import time
|
|
|
|
import pandas as pd
|
|
|
|
from gr00t_wbc.control.base.policy import Policy
|
|
from gr00t_wbc.control.main.constants import (
|
|
DEFAULT_BASE_HEIGHT,
|
|
DEFAULT_NAV_CMD,
|
|
DEFAULT_WRIST_POSE,
|
|
)
|
|
from gr00t_wbc.control.robot_model.robot_model import RobotModel
|
|
from gr00t_wbc.data.viz.rerun_viz import RerunViz
|
|
|
|
|
|
class LerobotReplayPolicy(Policy):
|
|
"""Replay policy for Lerobot dataset, so we can replay the dataset
|
|
and just use the action from the dataset.
|
|
|
|
Args:
|
|
parquet_path: Path to the parquet file containing the dataset.
|
|
"""
|
|
|
|
is_active = True # by default, the replay policy is active
|
|
|
|
def __init__(self, robot_model: RobotModel, parquet_path: str, use_viz: bool = False):
|
|
# self.dataset = LerobotDataset(dataset_path)
|
|
self.parquet_path = parquet_path
|
|
self._ctr = 0
|
|
# read the parquet file
|
|
self.df = pd.read_parquet(self.parquet_path)
|
|
self._max_ctr = len(self.df)
|
|
# get the action from the dataframe
|
|
self.action = self.df.iloc[self._ctr]["action"]
|
|
self.use_viz = use_viz
|
|
if self.use_viz:
|
|
self.viz = RerunViz(
|
|
image_keys=["egoview_image"],
|
|
tensor_keys=[
|
|
"left_arm_qpos",
|
|
"left_hand_qpos",
|
|
"right_arm_qpos",
|
|
"right_hand_qpos",
|
|
],
|
|
window_size=5.0,
|
|
)
|
|
self.robot_model = robot_model
|
|
self.upper_body_joint_indices = self.robot_model.get_joint_group_indices("upper_body")
|
|
|
|
def get_action(self) -> dict[str, any]:
|
|
# get the action from the dataframe
|
|
action = self.df.iloc[self._ctr]["action"]
|
|
wrist_pose = self.df.iloc[self._ctr]["action.eef"]
|
|
navigate_cmd = self.df.iloc[self._ctr].get("teleop.navigate_command", DEFAULT_NAV_CMD)
|
|
base_height_cmd = self.df.iloc[self._ctr].get(
|
|
"teleop.base_height_command", DEFAULT_BASE_HEIGHT
|
|
)
|
|
|
|
self._ctr += 1
|
|
if self._ctr >= self._max_ctr:
|
|
self._ctr = 0
|
|
# print(f"Replay {self._ctr} / {self._max_ctr}")
|
|
if self.use_viz:
|
|
self.viz.plot_tensors(
|
|
{
|
|
"left_arm_qpos": action[self.robot_model.get_joint_group_indices("left_arm")]
|
|
+ 15,
|
|
"left_hand_qpos": action[self.robot_model.get_joint_group_indices("left_hand")]
|
|
+ 15,
|
|
"right_arm_qpos": action[self.robot_model.get_joint_group_indices("right_arm")]
|
|
+ 15,
|
|
"right_hand_qpos": action[
|
|
self.robot_model.get_joint_group_indices("right_hand")
|
|
]
|
|
+ 15,
|
|
},
|
|
time.monotonic(),
|
|
)
|
|
|
|
return {
|
|
"target_upper_body_pose": action[self.upper_body_joint_indices],
|
|
"wrist_pose": wrist_pose,
|
|
"navigate_cmd": navigate_cmd,
|
|
"base_height_cmd": base_height_cmd,
|
|
"timestamp": time.time(),
|
|
}
|
|
|
|
def action_to_cmd(self, action: dict[str, any]) -> dict[str, any]:
|
|
action["target_upper_body_pose"] = action["q"][
|
|
self.robot_model.get_joint_group_indices("upper_body")
|
|
]
|
|
del action["q"]
|
|
return action
|
|
|
|
def set_observation(self, observation: dict[str, any]):
|
|
pass
|
|
|
|
def get_observation(self) -> dict[str, any]:
|
|
return {
|
|
"wrist_pose": self.df.iloc[self._ctr - 1].get(
|
|
"observation.eef_state", DEFAULT_WRIST_POSE
|
|
),
|
|
"timestamp": time.time(),
|
|
}
|
|
|
|
|
|
if __name__ == "__main__":
|
|
policy = LerobotReplayPolicy(
|
|
parquet_path="outputs/g1-open-hands-may7/data/chunk-000/episode_000000.parquet"
|
|
)
|
|
action = policy.get_action()
|
|
print(action)
|