You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
346 lines
10 KiB
346 lines
10 KiB
"""Dataclass-based configuration for MuJoCo simulation, WBC deployment, and teleop.
|
|
|
|
BaseConfig holds all tuneable knobs (interface, frequency, safety limits, etc.)
|
|
and can load/override values from a WBC YAML file. SimLoopConfig extends it
|
|
with multiprocessing and image-publish settings for the sim loop entry point.
|
|
"""
|
|
|
|
from dataclasses import asdict, dataclass
|
|
import os
|
|
from pathlib import Path
|
|
import subprocess
|
|
from typing import Any, Literal, Optional
|
|
|
|
import yaml
|
|
|
|
from gear_sonic.utils.network.network_utils import resolve_interface
|
|
|
|
WBC_VERSIONS = ["sonic_model12"]
|
|
|
|
@dataclass
|
|
class ArgsConfigTemplate:
|
|
"""Args Config for running the data collection loop."""
|
|
|
|
def update(
|
|
self,
|
|
config_dict: dict,
|
|
strict: bool = False,
|
|
skip_keys: list[str] = [],
|
|
allowed_keys: list[str] | None = None,
|
|
):
|
|
for k, v in config_dict.items():
|
|
if k in skip_keys:
|
|
continue
|
|
if allowed_keys is not None and k not in allowed_keys:
|
|
continue
|
|
if strict and not hasattr(self, k):
|
|
raise ValueError(f"Config {k} not found in {self.__class__.__name__}")
|
|
if not strict and not hasattr(self, k):
|
|
continue
|
|
setattr(self, k, v)
|
|
|
|
@classmethod
|
|
def from_dict(
|
|
cls,
|
|
config_dict: dict,
|
|
strict: bool = False,
|
|
skip_keys: list[str] = [],
|
|
allowed_keys: list[str] | None = None,
|
|
):
|
|
instance = cls()
|
|
instance.update(
|
|
config_dict=config_dict, strict=strict, skip_keys=skip_keys, allowed_keys=allowed_keys
|
|
)
|
|
return instance
|
|
|
|
def to_dict(self):
|
|
return asdict(self)
|
|
|
|
def get(self, key: str, default: Any = None):
|
|
return getattr(self, key) if hasattr(self, key) else default
|
|
|
|
|
|
def override_wbc_config(
|
|
wbc_config: dict, config: "BaseConfig", missed_keys_only: bool = False
|
|
) -> dict:
|
|
"""Override WBC YAML values with dataclass values."""
|
|
key_to_value = {
|
|
"INTERFACE": config.interface,
|
|
"ENV_TYPE": config.env_type,
|
|
"VERSION": config.wbc_version,
|
|
"SIMULATOR": config.simulator,
|
|
"SIMULATE_DT": 1 / float(config.sim_frequency),
|
|
"ENABLE_OFFSCREEN": config.enable_offscreen,
|
|
"ENABLE_ONSCREEN": config.enable_onscreen,
|
|
"model_path": config.wbc_model_path,
|
|
"enable_waist": config.enable_waist,
|
|
"with_hands": config.with_hands,
|
|
"verbose": config.verbose,
|
|
"verbose_timing": config.verbose_timing,
|
|
"upper_body_max_joint_speed": config.upper_body_joint_speed,
|
|
"keyboard_dispatcher_type": config.keyboard_dispatcher_type,
|
|
"enable_gravity_compensation": config.enable_gravity_compensation,
|
|
"gravity_compensation_joints": config.gravity_compensation_joints,
|
|
"high_elbow_pose": config.high_elbow_pose,
|
|
"joint_safety_mode": config.joint_safety_mode,
|
|
"arm_velocity_limit": config.arm_velocity_limit,
|
|
"hand_velocity_limit": config.hand_velocity_limit,
|
|
"lower_body_velocity_limit": config.lower_body_velocity_limit,
|
|
"waist_pitch_limit": config.waist_pitch_limit,
|
|
"hand_torque_limit": config.hand_torque_limit,
|
|
"enable_natural_walk": config.enable_natural_walk,
|
|
}
|
|
|
|
if missed_keys_only:
|
|
for key in key_to_value:
|
|
if key not in wbc_config:
|
|
wbc_config[key] = key_to_value[key]
|
|
else:
|
|
for key in key_to_value:
|
|
wbc_config[key] = key_to_value[key]
|
|
|
|
# Sim-to-real KD gap: waist pitch (index 14) is over-damped in sim;
|
|
# reduce KD by 10 on the real robot to avoid sluggish response
|
|
if config.env_type == "real":
|
|
wbc_config["MOTOR_KD"][14] = wbc_config["MOTOR_KD"][14] - 10
|
|
|
|
return wbc_config
|
|
|
|
|
|
@dataclass
|
|
class BaseConfig(ArgsConfigTemplate):
|
|
"""Base config inherited by all G1 control loops"""
|
|
|
|
dataset_version: str = "sonic_model12"
|
|
|
|
# WBC Configuration
|
|
wbc_version: Literal[tuple(WBC_VERSIONS)] = "sonic_model12"
|
|
"""Version of the whole body controller."""
|
|
|
|
wbc_model_path: str = "policy/stand.onnx,policy/walk.onnx"
|
|
"""Path to WBC model file (relative to GearCheckpoints/wbc)"""
|
|
|
|
wbc_policy_class: str = "G1DecoupledWholeBodyPolicy"
|
|
"""Whole body policy class."""
|
|
|
|
# System Configuration
|
|
interface: str = "sim"
|
|
"""Interface to use for the control loop. [sim, real, lo, enxe8ea6a9c4e09]"""
|
|
|
|
simulator: str = "mujoco"
|
|
"""Simulator to use."""
|
|
|
|
sim_sync_mode: bool = False
|
|
"""Whether to run the control loop in sync mode."""
|
|
|
|
control_frequency: int = 50
|
|
"""Frequency of the control loop."""
|
|
|
|
sim_frequency: int = 200
|
|
"""Frequency of the simulation loop."""
|
|
|
|
# Robot Configuration
|
|
enable_waist: bool = True
|
|
"""Whether to include waist joints in IK."""
|
|
|
|
with_hands: bool = True
|
|
"""Enable hand functionality."""
|
|
|
|
high_elbow_pose: bool = False
|
|
"""Enable high elbow pose configuration."""
|
|
|
|
verbose: bool = True
|
|
"""Whether to print verbose output."""
|
|
|
|
enable_offscreen: bool = False
|
|
"""Whether to enable offscreen rendering."""
|
|
|
|
enable_onscreen: bool = True
|
|
"""Whether to enable onscreen rendering."""
|
|
|
|
enable_teleop_evaluator: bool = False
|
|
"""Whether to enable teleop evaluator."""
|
|
|
|
upper_body_joint_speed: float = 1000
|
|
"""Upper body joint speed."""
|
|
|
|
env_name: str = "default"
|
|
"""Environment name."""
|
|
|
|
ik_indicator: bool = False
|
|
"""Whether to draw IK indicators."""
|
|
|
|
verbose_timing: bool = False
|
|
"""Enable verbose timing output every iteration."""
|
|
|
|
keyboard_dispatcher_type: str = "raw"
|
|
"""Keyboard dispatcher to use. [raw, ros]"""
|
|
|
|
# Gravity Compensation Configuration
|
|
enable_gravity_compensation: bool = False
|
|
"""Enable gravity compensation using pinocchio dynamics."""
|
|
|
|
gravity_compensation_joints: Optional[list[str]] = None
|
|
"""Joint groups to apply gravity compensation to."""
|
|
|
|
# Joint Safety Configuration
|
|
joint_safety_mode: Literal["kill", "freeze"] = "kill"
|
|
"""Joint safety violation mode."""
|
|
|
|
arm_velocity_limit: float = 25.0
|
|
"""Arm joint velocity limit in rad/s."""
|
|
|
|
hand_velocity_limit: float = 1000.0
|
|
"""Hand/finger joint velocity limit in rad/s."""
|
|
|
|
lower_body_velocity_limit: float = 20.0
|
|
"""Lower body joint velocity limit in rad/s."""
|
|
|
|
waist_pitch_limit: float = 15.0
|
|
"""Waist pitch position limit in degrees."""
|
|
|
|
hand_torque_limit: float = 0.1
|
|
"""Hand torque limit in radians."""
|
|
|
|
enable_natural_walk: bool = False
|
|
"""Enable natural walk mode."""
|
|
|
|
# Teleop/Device Configuration
|
|
body_control_device: str = "dummy"
|
|
"""Device to use for body control."""
|
|
|
|
hand_control_device: Optional[str] = "dummy"
|
|
"""Device to use for hand control."""
|
|
|
|
body_streamer_ip: str = "10.112.210.229"
|
|
"""IP address for body streamer (vive only)."""
|
|
|
|
body_streamer_keyword: str = "knee"
|
|
"""Body streamer keyword (vive only)."""
|
|
|
|
enable_visualization: bool = False
|
|
"""Whether to enable visualization."""
|
|
|
|
enable_real_device: bool = True
|
|
"""Whether to enable real device."""
|
|
|
|
teleop_frequency: int = 20
|
|
"""Teleoperation frequency (Hz)."""
|
|
|
|
teleop_replay_path: Optional[str] = None
|
|
"""Path to teleop replay data."""
|
|
|
|
# Deployment/Camera Configuration
|
|
robot_ip: str = "192.168.123.164"
|
|
"""Robot IP address"""
|
|
|
|
data_collection: bool = True
|
|
"""Enable data collection"""
|
|
|
|
data_collection_frequency: int = 20
|
|
"""Data collection frequency (Hz)"""
|
|
|
|
root_output_dir: str = "outputs"
|
|
"""Root output directory"""
|
|
|
|
offline_dc: bool = False
|
|
"""Offline data collection."""
|
|
|
|
enable_upper_body_operation: bool = True
|
|
"""Enable upper body operation"""
|
|
|
|
upper_body_operation_mode: Literal["teleop", "inference"] = "teleop"
|
|
"""Upper body operation mode"""
|
|
|
|
inference_host: str = "localhost"
|
|
"""Inference server host"""
|
|
|
|
inference_port: int = 5555
|
|
"""Inference server port"""
|
|
|
|
inference_on_osmo: bool = False
|
|
"""Whether to run inference on osmo."""
|
|
|
|
inference_prompt: str = "Pick up apple from table to plate"
|
|
"""Inference prompt"""
|
|
|
|
inference_action_horizon: int = 16
|
|
"""Inference action horizon"""
|
|
|
|
inference_control_freq: int = 20
|
|
"""Inference control frequency (Hz)"""
|
|
|
|
inference_rate: float = 2.5
|
|
"""Inference rate (Hz)"""
|
|
|
|
inference_plot_rerun: bool = False
|
|
"""Enable inference plot rerun"""
|
|
|
|
inference_push_evals: bool = True
|
|
"""Whether to push evals."""
|
|
|
|
inference_publish_single_action: bool = False
|
|
"""Whether to publish only a single action."""
|
|
|
|
commit_id: str = ""
|
|
"""Commit ID for the current codebase"""
|
|
|
|
enable_mode_switch: bool = False
|
|
"""Enable operation mode switching via ROS topic."""
|
|
|
|
initial_mode: str = "idle"
|
|
"""Initial operation mode."""
|
|
|
|
def __post_init__(self):
|
|
# Resolve interface
|
|
self.interface, self.env_type = resolve_interface(self.interface)
|
|
|
|
# Set default gravity compensation joints if not specified
|
|
if self.gravity_compensation_joints is None:
|
|
self.gravity_compensation_joints = ["arms"]
|
|
|
|
try:
|
|
self.commit_id = (
|
|
subprocess.check_output(["git", "rev-parse", "HEAD"]).decode("utf-8").strip()
|
|
)
|
|
except Exception:
|
|
self.commit_id = ""
|
|
|
|
def load_wbc_yaml(self) -> dict:
|
|
"""Load and merge wbc yaml with dataclass overrides"""
|
|
import gear_sonic
|
|
|
|
gear_sonic_path = Path(os.path.dirname(gear_sonic.__file__))
|
|
configs_dir = gear_sonic_path / "utils" / "mujoco_sim" / "wbc_configs"
|
|
|
|
if self.wbc_version == "sonic_model12":
|
|
config_path = str(configs_dir / "g1_29dof_sonic_model12.yaml")
|
|
else:
|
|
raise ValueError(
|
|
f"Invalid wbc_version: {self.wbc_version}, please use one of: "
|
|
f"sonic_model12"
|
|
)
|
|
|
|
with open(config_path) as file:
|
|
wbc_config = yaml.load(file, Loader=yaml.FullLoader)
|
|
|
|
wbc_config = override_wbc_config(wbc_config, self)
|
|
|
|
return wbc_config
|
|
|
|
|
|
@dataclass
|
|
class SimLoopConfig(BaseConfig):
|
|
"""Config for running the simulation loop."""
|
|
|
|
mp_start_method: str = "spawn"
|
|
"""Multiprocessing start method"""
|
|
|
|
enable_image_publish: bool = False
|
|
"""Enable image publishing in simulation"""
|
|
|
|
camera_port: int = 5555
|
|
"""Camera port for image publishing"""
|
|
|
|
verbose: bool = False
|
|
"""Verbose output, override the base config verbose"""
|