|
|
@ -1,21 +1,16 @@ |
|
|
import sys |
|
|
|
|
|
import time |
|
|
|
|
|
import collections |
|
|
import collections |
|
|
import yaml |
|
|
|
|
|
import torch |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
import sys |
|
|
|
|
|
|
|
|
import os |
|
|
|
|
|
import threading |
|
|
import time |
|
|
import time |
|
|
import collections |
|
|
|
|
|
import yaml |
|
|
|
|
|
import torch |
|
|
|
|
|
import numpy as np |
|
|
|
|
|
|
|
|
|
|
|
import mujoco |
|
|
import mujoco |
|
|
import mujoco.viewer |
|
|
import mujoco.viewer |
|
|
|
|
|
import numpy as np |
|
|
import onnxruntime as ort |
|
|
import onnxruntime as ort |
|
|
import threading |
|
|
|
|
|
from pynput import keyboard as pkb |
|
|
from pynput import keyboard as pkb |
|
|
import os |
|
|
|
|
|
|
|
|
import torch |
|
|
|
|
|
import yaml |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class GearWbcController: |
|
|
class GearWbcController: |
|
|
def __init__(self, config_path): |
|
|
def __init__(self, config_path): |
|
|
@ -24,21 +19,21 @@ class GearWbcController: |
|
|
self.config = self.load_config(os.path.join(self.CONFIG_PATH, "g1_gear_wbc.yaml")) |
|
|
self.config = self.load_config(os.path.join(self.CONFIG_PATH, "g1_gear_wbc.yaml")) |
|
|
|
|
|
|
|
|
self.control_dict = { |
|
|
self.control_dict = { |
|
|
'loco_cmd': self.config['cmd_init'], |
|
|
|
|
|
"height_cmd": self.config['height_cmd'], |
|
|
|
|
|
"rpy_cmd": self.config.get('rpy_cmd', [0.0, 0.0, 0.0]), |
|
|
|
|
|
"freq_cmd": self.config.get('freq_cmd', 1.5) |
|
|
|
|
|
|
|
|
"loco_cmd": self.config["cmd_init"], |
|
|
|
|
|
"height_cmd": self.config["height_cmd"], |
|
|
|
|
|
"rpy_cmd": self.config.get("rpy_cmd", [0.0, 0.0, 0.0]), |
|
|
|
|
|
"freq_cmd": self.config.get("freq_cmd", 1.5), |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
self.model = mujoco.MjModel.from_xml_path(self.config['xml_path']) |
|
|
|
|
|
|
|
|
self.model = mujoco.MjModel.from_xml_path(self.config["xml_path"]) |
|
|
self.data = mujoco.MjData(self.model) |
|
|
self.data = mujoco.MjData(self.model) |
|
|
self.model.opt.timestep = self.config['simulation_dt'] |
|
|
|
|
|
|
|
|
self.model.opt.timestep = self.config["simulation_dt"] |
|
|
self.n_joints = self.data.qpos.shape[0] - 7 |
|
|
self.n_joints = self.data.qpos.shape[0] - 7 |
|
|
self.torso_index = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, "torso_link") |
|
|
self.torso_index = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, "torso_link") |
|
|
self.base_index = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, "pelvis") |
|
|
self.base_index = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, "pelvis") |
|
|
self.action = np.zeros(self.config['num_actions'], dtype=np.float32) |
|
|
|
|
|
self.target_dof_pos = self.config['default_angles'].copy() |
|
|
|
|
|
self.policy = self.load_onnx_policy(self.config['policy_path']) |
|
|
|
|
|
|
|
|
self.action = np.zeros(self.config["num_actions"], dtype=np.float32) |
|
|
|
|
|
self.target_dof_pos = self.config["default_angles"].copy() |
|
|
|
|
|
self.policy = self.load_onnx_policy(self.config["policy_path"]) |
|
|
self.gait_indices = torch.zeros((1), dtype=torch.float32) |
|
|
self.gait_indices = torch.zeros((1), dtype=torch.float32) |
|
|
self.counter = 0 |
|
|
self.counter = 0 |
|
|
self.just_started = 0.0 |
|
|
self.just_started = 0.0 |
|
|
@ -49,15 +44,15 @@ class GearWbcController: |
|
|
self.data, self.config, self.action, self.control_dict, self.n_joints |
|
|
self.data, self.config, self.action, self.control_dict, self.n_joints |
|
|
) |
|
|
) |
|
|
self.obs_history = collections.deque( |
|
|
self.obs_history = collections.deque( |
|
|
[np.zeros(self.single_obs_dim, dtype=np.float32)] * self.config['obs_history_len'], |
|
|
|
|
|
maxlen=self.config['obs_history_len'] |
|
|
|
|
|
|
|
|
[np.zeros(self.single_obs_dim, dtype=np.float32)] * self.config["obs_history_len"], |
|
|
|
|
|
maxlen=self.config["obs_history_len"], |
|
|
) |
|
|
) |
|
|
self.obs = np.zeros(self.config['num_obs'], dtype=np.float32) |
|
|
|
|
|
|
|
|
self.obs = np.zeros(self.config["num_obs"], dtype=np.float32) |
|
|
self.keyboard_listener(self.control_dict, self.config) |
|
|
self.keyboard_listener(self.control_dict, self.config) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def keyboard_listener(self, control_dict, config): |
|
|
def keyboard_listener(self, control_dict, config): |
|
|
"""Listen to key press events and update cmd and height_cmd""" |
|
|
"""Listen to key press events and update cmd and height_cmd""" |
|
|
|
|
|
|
|
|
def on_press(key): |
|
|
def on_press(key): |
|
|
try: |
|
|
try: |
|
|
k = key.char |
|
|
k = key.char |
|
|
@ -65,58 +60,60 @@ class GearWbcController: |
|
|
return # Special keys ignored |
|
|
return # Special keys ignored |
|
|
|
|
|
|
|
|
with self.cmd_lock: |
|
|
with self.cmd_lock: |
|
|
if k == 'w': |
|
|
|
|
|
control_dict['loco_cmd'][0] += 0.2 |
|
|
|
|
|
elif k == 's': |
|
|
|
|
|
control_dict['loco_cmd'][0] -= 0.2 |
|
|
|
|
|
elif k == 'a': |
|
|
|
|
|
control_dict['loco_cmd'][1] += 0.5 |
|
|
|
|
|
elif k == 'd': |
|
|
|
|
|
control_dict['loco_cmd'][1] -= 0.5 |
|
|
|
|
|
elif k == 'q': |
|
|
|
|
|
control_dict['loco_cmd'][2] += 0.5 |
|
|
|
|
|
elif k == 'e': |
|
|
|
|
|
control_dict['loco_cmd'][2] -= 0.5 |
|
|
|
|
|
elif k == 'z': |
|
|
|
|
|
control_dict['loco_cmd'][:] = config['cmd_init'] |
|
|
|
|
|
control_dict["height_cmd"] = config['height_cmd'] |
|
|
|
|
|
control_dict['rpy_cmd'][:] = config['rpy_cmd'] |
|
|
|
|
|
control_dict['freq_cmd'] = config['freq_cmd'] |
|
|
|
|
|
elif k == '1': |
|
|
|
|
|
|
|
|
if k == "w": |
|
|
|
|
|
control_dict["loco_cmd"][0] += 0.2 |
|
|
|
|
|
elif k == "s": |
|
|
|
|
|
control_dict["loco_cmd"][0] -= 0.2 |
|
|
|
|
|
elif k == "a": |
|
|
|
|
|
control_dict["loco_cmd"][1] += 0.5 |
|
|
|
|
|
elif k == "d": |
|
|
|
|
|
control_dict["loco_cmd"][1] -= 0.5 |
|
|
|
|
|
elif k == "q": |
|
|
|
|
|
control_dict["loco_cmd"][2] += 0.5 |
|
|
|
|
|
elif k == "e": |
|
|
|
|
|
control_dict["loco_cmd"][2] -= 0.5 |
|
|
|
|
|
elif k == "z": |
|
|
|
|
|
control_dict["loco_cmd"][:] = config["cmd_init"] |
|
|
|
|
|
control_dict["height_cmd"] = config["height_cmd"] |
|
|
|
|
|
control_dict["rpy_cmd"][:] = config["rpy_cmd"] |
|
|
|
|
|
control_dict["freq_cmd"] = config["freq_cmd"] |
|
|
|
|
|
elif k == "1": |
|
|
control_dict["height_cmd"] += 0.05 |
|
|
control_dict["height_cmd"] += 0.05 |
|
|
elif k == '2': |
|
|
|
|
|
|
|
|
elif k == "2": |
|
|
control_dict["height_cmd"] -= 0.05 |
|
|
control_dict["height_cmd"] -= 0.05 |
|
|
elif k == '3': |
|
|
|
|
|
control_dict['rpy_cmd'][0] += 0.2 |
|
|
|
|
|
elif k == '4': |
|
|
|
|
|
control_dict['rpy_cmd'][0] -= 0.2 |
|
|
|
|
|
elif k == '5': |
|
|
|
|
|
control_dict['rpy_cmd'][1] += 0.2 |
|
|
|
|
|
elif k == '6': |
|
|
|
|
|
control_dict['rpy_cmd'][1] -= 0.2 |
|
|
|
|
|
elif k == '7': |
|
|
|
|
|
control_dict['rpy_cmd'][2] += 0.2 |
|
|
|
|
|
elif k == '8': |
|
|
|
|
|
control_dict['rpy_cmd'][2] -= 0.2 |
|
|
|
|
|
elif k == 'm': |
|
|
|
|
|
control_dict['freq_cmd'] += 0.1 |
|
|
|
|
|
elif k == 'n': |
|
|
|
|
|
control_dict['freq_cmd'] -= 0.1 |
|
|
|
|
|
|
|
|
|
|
|
print(f"Current Commands: loco_cmd = {control_dict['loco_cmd']}, height_cmd = {control_dict['height_cmd']}, rpy_cmd = {control_dict['rpy_cmd']}, freq_cmd = {control_dict['freq_cmd']}") |
|
|
|
|
|
|
|
|
elif k == "3": |
|
|
|
|
|
control_dict["rpy_cmd"][0] += 0.2 |
|
|
|
|
|
elif k == "4": |
|
|
|
|
|
control_dict["rpy_cmd"][0] -= 0.2 |
|
|
|
|
|
elif k == "5": |
|
|
|
|
|
control_dict["rpy_cmd"][1] += 0.2 |
|
|
|
|
|
elif k == "6": |
|
|
|
|
|
control_dict["rpy_cmd"][1] -= 0.2 |
|
|
|
|
|
elif k == "7": |
|
|
|
|
|
control_dict["rpy_cmd"][2] += 0.2 |
|
|
|
|
|
elif k == "8": |
|
|
|
|
|
control_dict["rpy_cmd"][2] -= 0.2 |
|
|
|
|
|
elif k == "m": |
|
|
|
|
|
control_dict["freq_cmd"] += 0.1 |
|
|
|
|
|
elif k == "n": |
|
|
|
|
|
control_dict["freq_cmd"] -= 0.1 |
|
|
|
|
|
|
|
|
|
|
|
print( |
|
|
|
|
|
f"Current Commands: loco_cmd = {control_dict['loco_cmd']}, height_cmd = {control_dict['height_cmd']}, rpy_cmd = {control_dict['rpy_cmd']}, freq_cmd = {control_dict['freq_cmd']}" |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
listener = pkb.Listener(on_press=on_press) |
|
|
listener = pkb.Listener(on_press=on_press) |
|
|
listener.daemon = True |
|
|
listener.daemon = True |
|
|
listener.start() |
|
|
listener.start() |
|
|
|
|
|
|
|
|
def load_config(self, config_path): |
|
|
def load_config(self, config_path): |
|
|
with open(config_path, 'r') as f: |
|
|
|
|
|
|
|
|
with open(config_path, "r") as f: |
|
|
config = yaml.safe_load(f) |
|
|
config = yaml.safe_load(f) |
|
|
|
|
|
|
|
|
for path_key in ['policy_path', 'xml_path']: |
|
|
|
|
|
|
|
|
for path_key in ["policy_path", "xml_path"]: |
|
|
config[path_key] = os.path.join(CONFIG_PATH, config[path_key]) |
|
|
config[path_key] = os.path.join(CONFIG_PATH, config[path_key]) |
|
|
|
|
|
|
|
|
array_keys = ['kps', 'kds', 'default_angles', 'cmd_scale', 'cmd_init'] |
|
|
|
|
|
|
|
|
array_keys = ["kps", "kds", "default_angles", "cmd_scale", "cmd_init"] |
|
|
for key in array_keys: |
|
|
for key in array_keys: |
|
|
config[key] = np.array(config[key], dtype=np.float32) |
|
|
config[key] = np.array(config[key], dtype=np.float32) |
|
|
|
|
|
|
|
|
@ -128,17 +125,19 @@ class GearWbcController: |
|
|
def quat_rotate_inverse(self, q, v): |
|
|
def quat_rotate_inverse(self, q, v): |
|
|
w, x, y, z = q |
|
|
w, x, y, z = q |
|
|
q_conj = np.array([w, -x, -y, -z]) |
|
|
q_conj = np.array([w, -x, -y, -z]) |
|
|
return np.array([ |
|
|
|
|
|
v[0] * (q_conj[0]**2 + q_conj[1]**2 - q_conj[2]**2 - q_conj[3]**2) + |
|
|
|
|
|
v[1] * 2 * (q_conj[1]*q_conj[2] - q_conj[0]*q_conj[3]) + |
|
|
|
|
|
v[2] * 2 * (q_conj[1]*q_conj[3] + q_conj[0]*q_conj[2]), |
|
|
|
|
|
v[0] * 2 * (q_conj[1]*q_conj[2] + q_conj[0]*q_conj[3]) + |
|
|
|
|
|
v[1] * (q_conj[0]**2 - q_conj[1]**2 + q_conj[2]**2 - q_conj[3]**2) + |
|
|
|
|
|
v[2] * 2 * (q_conj[2]*q_conj[3] - q_conj[0]*q_conj[1]), |
|
|
|
|
|
v[0] * 2 * (q_conj[1]*q_conj[3] - q_conj[0]*q_conj[2]) + |
|
|
|
|
|
v[1] * 2 * (q_conj[2]*q_conj[3] + q_conj[0]*q_conj[1]) + |
|
|
|
|
|
v[2] * (q_conj[0]**2 - q_conj[1]**2 - q_conj[2]**2 + q_conj[3]**2) |
|
|
|
|
|
]) |
|
|
|
|
|
|
|
|
return np.array( |
|
|
|
|
|
[ |
|
|
|
|
|
v[0] * (q_conj[0] ** 2 + q_conj[1] ** 2 - q_conj[2] ** 2 - q_conj[3] ** 2) |
|
|
|
|
|
+ v[1] * 2 * (q_conj[1] * q_conj[2] - q_conj[0] * q_conj[3]) |
|
|
|
|
|
+ v[2] * 2 * (q_conj[1] * q_conj[3] + q_conj[0] * q_conj[2]), |
|
|
|
|
|
v[0] * 2 * (q_conj[1] * q_conj[2] + q_conj[0] * q_conj[3]) |
|
|
|
|
|
+ v[1] * (q_conj[0] ** 2 - q_conj[1] ** 2 + q_conj[2] ** 2 - q_conj[3] ** 2) |
|
|
|
|
|
+ v[2] * 2 * (q_conj[2] * q_conj[3] - q_conj[0] * q_conj[1]), |
|
|
|
|
|
v[0] * 2 * (q_conj[1] * q_conj[3] - q_conj[0] * q_conj[2]) |
|
|
|
|
|
+ v[1] * 2 * (q_conj[2] * q_conj[3] + q_conj[0] * q_conj[1]) |
|
|
|
|
|
+ v[2] * (q_conj[0] ** 2 - q_conj[1] ** 2 - q_conj[2] ** 2 + q_conj[3] ** 2), |
|
|
|
|
|
] |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
def get_gravity_orientation(self, quat): |
|
|
def get_gravity_orientation(self, quat): |
|
|
gravity_vec = np.array([0.0, 0.0, -1.0]) |
|
|
gravity_vec = np.array([0.0, 0.0, -1.0]) |
|
|
@ -146,10 +145,10 @@ class GearWbcController: |
|
|
|
|
|
|
|
|
def compute_observation(self, d, config, action, control_dict, n_joints): |
|
|
def compute_observation(self, d, config, action, control_dict, n_joints): |
|
|
command = np.zeros(8, dtype=np.float32) |
|
|
command = np.zeros(8, dtype=np.float32) |
|
|
command[:3] = control_dict['loco_cmd'][:3] * config['cmd_scale'] |
|
|
|
|
|
command[3] = control_dict['height_cmd'] |
|
|
|
|
|
command[4] = control_dict['freq_cmd'] |
|
|
|
|
|
command[5:8] = control_dict['rpy_cmd'] |
|
|
|
|
|
|
|
|
command[:3] = control_dict["loco_cmd"][:3] * config["cmd_scale"] |
|
|
|
|
|
command[3] = control_dict["height_cmd"] |
|
|
|
|
|
command[4] = control_dict["freq_cmd"] |
|
|
|
|
|
command[5:8] = control_dict["rpy_cmd"] |
|
|
|
|
|
|
|
|
# gait indice |
|
|
# gait indice |
|
|
is_static = np.linalg.norm(command[:3]) < 0.1 |
|
|
is_static = np.linalg.norm(command[:3]) < 0.1 |
|
|
@ -191,10 +190,7 @@ class GearWbcController: |
|
|
# Clock signal |
|
|
# Clock signal |
|
|
clock = [torch.sin(2 * np.pi * fi) for fi in gait_pair] |
|
|
clock = [torch.sin(2 * np.pi * fi) for fi in gait_pair] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for i, (clk, frozen_mask_attr) in enumerate( |
|
|
|
|
|
zip(clock, ['frozen_FL', 'frozen_FR']) |
|
|
|
|
|
): |
|
|
|
|
|
|
|
|
for i, (clk, frozen_mask_attr) in enumerate(zip(clock, ["frozen_FL", "frozen_FR"])): |
|
|
frozen_mask = getattr(self, frozen_mask_attr) |
|
|
frozen_mask = getattr(self, frozen_mask_attr) |
|
|
# Freeze condition: static and at sin peak |
|
|
# Freeze condition: static and at sin peak |
|
|
if is_static and (not frozen_mask) and clk.item() > 0.98: |
|
|
if is_static and (not frozen_mask) and clk.item() > 0.98: |
|
|
@ -205,23 +201,25 @@ class GearWbcController: |
|
|
clock[i] = clk |
|
|
clock[i] = clk |
|
|
|
|
|
|
|
|
self.clock_inputs = torch.stack(clock).unsqueeze(0) |
|
|
self.clock_inputs = torch.stack(clock).unsqueeze(0) |
|
|
qj = d.qpos[7:7+n_joints].copy() |
|
|
|
|
|
dqj = d.qvel[6:6+n_joints].copy() |
|
|
|
|
|
|
|
|
qj = d.qpos[7 : 7 + n_joints].copy() |
|
|
|
|
|
dqj = d.qvel[6 : 6 + n_joints].copy() |
|
|
quat = d.qpos[3:7].copy() |
|
|
quat = d.qpos[3:7].copy() |
|
|
omega = d.qvel[3:6].copy() |
|
|
omega = d.qvel[3:6].copy() |
|
|
# omega = self.data.xmat[self.base_index].reshape(3, 3).T @ self.data.cvel[self.base_index][3:6] |
|
|
# omega = self.data.xmat[self.base_index].reshape(3, 3).T @ self.data.cvel[self.base_index][3:6] |
|
|
padded_defaults = np.zeros(n_joints, dtype=np.float32) |
|
|
padded_defaults = np.zeros(n_joints, dtype=np.float32) |
|
|
L = min(len(config['default_angles']), n_joints) |
|
|
|
|
|
padded_defaults[:L] = config['default_angles'][:L] |
|
|
|
|
|
|
|
|
L = min(len(config["default_angles"]), n_joints) |
|
|
|
|
|
padded_defaults[:L] = config["default_angles"][:L] |
|
|
|
|
|
|
|
|
qj_scaled = (qj - padded_defaults) * config['dof_pos_scale'] |
|
|
|
|
|
dqj_scaled = dqj * config['dof_vel_scale'] |
|
|
|
|
|
|
|
|
qj_scaled = (qj - padded_defaults) * config["dof_pos_scale"] |
|
|
|
|
|
dqj_scaled = dqj * config["dof_vel_scale"] |
|
|
gravity_orientation = self.get_gravity_orientation(quat) |
|
|
gravity_orientation = self.get_gravity_orientation(quat) |
|
|
omega_scaled = omega * config['ang_vel_scale'] |
|
|
|
|
|
|
|
|
omega_scaled = omega * config["ang_vel_scale"] |
|
|
|
|
|
|
|
|
torso_quat = self.data.xquat[self.torso_index] |
|
|
torso_quat = self.data.xquat[self.torso_index] |
|
|
torso_omega = self.data.xmat[self.torso_index].reshape(3, 3).T @ self.data.cvel[self.torso_index][3:6] |
|
|
|
|
|
torso_omega_scaled = torso_omega * config['ang_vel_scale'] |
|
|
|
|
|
|
|
|
torso_omega = ( |
|
|
|
|
|
self.data.xmat[self.torso_index].reshape(3, 3).T @ self.data.cvel[self.torso_index][3:6] |
|
|
|
|
|
) |
|
|
|
|
|
torso_omega_scaled = torso_omega * config["ang_vel_scale"] |
|
|
torso_gravity_orientation = self.get_gravity_orientation(torso_quat) |
|
|
torso_gravity_orientation = self.get_gravity_orientation(torso_quat) |
|
|
|
|
|
|
|
|
single_obs_dim = 95 |
|
|
single_obs_dim = 95 |
|
|
@ -229,21 +227,25 @@ class GearWbcController: |
|
|
single_obs[0:8] = command[:8] |
|
|
single_obs[0:8] = command[:8] |
|
|
single_obs[8:11] = omega_scaled |
|
|
single_obs[8:11] = omega_scaled |
|
|
single_obs[11:14] = gravity_orientation |
|
|
single_obs[11:14] = gravity_orientation |
|
|
single_obs[14:17] = 0.#torso_omega_scaled |
|
|
|
|
|
single_obs[17:20] = 0.#torso_gravity_orientation |
|
|
|
|
|
single_obs[20:20+n_joints] = qj_scaled |
|
|
|
|
|
single_obs[20+n_joints:20+2*n_joints] = dqj_scaled |
|
|
|
|
|
single_obs[20+2*n_joints:20+2*n_joints+15] = action |
|
|
|
|
|
single_obs[20+2*n_joints+15:20+2*n_joints+15+2] = self.clock_inputs.cpu().numpy().reshape(2) |
|
|
|
|
|
|
|
|
single_obs[14:17] = 0.0 # torso_omega_scaled |
|
|
|
|
|
single_obs[17:20] = 0.0 # torso_gravity_orientation |
|
|
|
|
|
single_obs[20 : 20 + n_joints] = qj_scaled |
|
|
|
|
|
single_obs[20 + n_joints : 20 + 2 * n_joints] = dqj_scaled |
|
|
|
|
|
single_obs[20 + 2 * n_joints : 20 + 2 * n_joints + 15] = action |
|
|
|
|
|
single_obs[20 + 2 * n_joints + 15 : 20 + 2 * n_joints + 15 + 2] = ( |
|
|
|
|
|
self.clock_inputs.cpu().numpy().reshape(2) |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
return single_obs, single_obs_dim |
|
|
return single_obs, single_obs_dim |
|
|
|
|
|
|
|
|
def load_onnx_policy(self, path): |
|
|
def load_onnx_policy(self, path): |
|
|
model = ort.InferenceSession(path) |
|
|
model = ort.InferenceSession(path) |
|
|
|
|
|
|
|
|
def run_inference(input_tensor): |
|
|
def run_inference(input_tensor): |
|
|
ort_inputs = {model.get_inputs()[0].name: input_tensor.cpu().numpy()} |
|
|
ort_inputs = {model.get_inputs()[0].name: input_tensor.cpu().numpy()} |
|
|
ort_outs = model.run(None, ort_inputs) |
|
|
ort_outs = model.run(None, ort_inputs) |
|
|
return torch.tensor(ort_outs[0], device="cuda:0") |
|
|
return torch.tensor(ort_outs[0], device="cuda:0") |
|
|
|
|
|
|
|
|
return run_inference |
|
|
return run_inference |
|
|
|
|
|
|
|
|
def run(self): |
|
|
def run(self): |
|
|
@ -252,52 +254,58 @@ class GearWbcController: |
|
|
|
|
|
|
|
|
with mujoco.viewer.launch_passive(self.model, self.data) as viewer: |
|
|
with mujoco.viewer.launch_passive(self.model, self.data) as viewer: |
|
|
start = time.time() |
|
|
start = time.time() |
|
|
while viewer.is_running() and time.time() - start < self.config['simulation_duration']: |
|
|
|
|
|
|
|
|
while viewer.is_running() and time.time() - start < self.config["simulation_duration"]: |
|
|
step_start = time.time() |
|
|
step_start = time.time() |
|
|
|
|
|
|
|
|
leg_tau = self.pd_control( |
|
|
leg_tau = self.pd_control( |
|
|
self.target_dof_pos, |
|
|
self.target_dof_pos, |
|
|
self.data.qpos[7:7+self.config['num_actions']], |
|
|
|
|
|
self.config['kps'], |
|
|
|
|
|
np.zeros_like(self.config['kps']), |
|
|
|
|
|
self.data.qvel[6:6+self.config['num_actions']], |
|
|
|
|
|
self.config['kds'] |
|
|
|
|
|
|
|
|
self.data.qpos[7 : 7 + self.config["num_actions"]], |
|
|
|
|
|
self.config["kps"], |
|
|
|
|
|
np.zeros_like(self.config["kps"]), |
|
|
|
|
|
self.data.qvel[6 : 6 + self.config["num_actions"]], |
|
|
|
|
|
self.config["kds"], |
|
|
) |
|
|
) |
|
|
self.data.ctrl[:self.config['num_actions']] = leg_tau |
|
|
|
|
|
|
|
|
self.data.ctrl[: self.config["num_actions"]] = leg_tau |
|
|
|
|
|
|
|
|
if self.n_joints > self.config['num_actions']: |
|
|
|
|
|
|
|
|
if self.n_joints > self.config["num_actions"]: |
|
|
arm_tau = self.pd_control( |
|
|
arm_tau = self.pd_control( |
|
|
np.zeros(self.n_joints - self.config['num_actions'], dtype=np.float32), |
|
|
|
|
|
self.data.qpos[7+self.config['num_actions']:7+self.n_joints], |
|
|
|
|
|
np.full(self.n_joints - self.config['num_actions'], 100.0), |
|
|
|
|
|
np.zeros(self.n_joints - self.config['num_actions']), |
|
|
|
|
|
self.data.qvel[6+self.config['num_actions']:6+self.n_joints], |
|
|
|
|
|
np.full(self.n_joints - self.config['num_actions'], 0.5) |
|
|
|
|
|
|
|
|
np.zeros(self.n_joints - self.config["num_actions"], dtype=np.float32), |
|
|
|
|
|
self.data.qpos[7 + self.config["num_actions"] : 7 + self.n_joints], |
|
|
|
|
|
np.full(self.n_joints - self.config["num_actions"], 100.0), |
|
|
|
|
|
np.zeros(self.n_joints - self.config["num_actions"]), |
|
|
|
|
|
self.data.qvel[6 + self.config["num_actions"] : 6 + self.n_joints], |
|
|
|
|
|
np.full(self.n_joints - self.config["num_actions"], 0.5), |
|
|
) |
|
|
) |
|
|
self.data.ctrl[self.config['num_actions']:] = arm_tau |
|
|
|
|
|
|
|
|
self.data.ctrl[self.config["num_actions"] :] = arm_tau |
|
|
|
|
|
|
|
|
mujoco.mj_step(self.model, self.data) |
|
|
mujoco.mj_step(self.model, self.data) |
|
|
|
|
|
|
|
|
self.counter += 1 |
|
|
self.counter += 1 |
|
|
if self.counter % self.config['control_decimation'] == 0: |
|
|
|
|
|
|
|
|
if self.counter % self.config["control_decimation"] == 0: |
|
|
with self.cmd_lock: |
|
|
with self.cmd_lock: |
|
|
current_cmd = self.control_dict |
|
|
current_cmd = self.control_dict |
|
|
|
|
|
|
|
|
single_obs, _ = self.compute_observation(self.data, self.config, self.action, current_cmd, self.n_joints) |
|
|
|
|
|
|
|
|
single_obs, _ = self.compute_observation( |
|
|
|
|
|
self.data, self.config, self.action, current_cmd, self.n_joints |
|
|
|
|
|
) |
|
|
self.obs_history.append(single_obs) |
|
|
self.obs_history.append(single_obs) |
|
|
|
|
|
|
|
|
for i, hist_obs in enumerate(self.obs_history): |
|
|
for i, hist_obs in enumerate(self.obs_history): |
|
|
self.obs[i * self.single_obs_dim:(i + 1) * self.single_obs_dim] = hist_obs |
|
|
|
|
|
|
|
|
self.obs[i * self.single_obs_dim : (i + 1) * self.single_obs_dim] = hist_obs |
|
|
|
|
|
|
|
|
obs_tensor = torch.from_numpy(self.obs).unsqueeze(0) |
|
|
obs_tensor = torch.from_numpy(self.obs).unsqueeze(0) |
|
|
self.action = self.policy(obs_tensor).cpu().detach().numpy().squeeze() |
|
|
self.action = self.policy(obs_tensor).cpu().detach().numpy().squeeze() |
|
|
self.target_dof_pos = self.action * self.config['action_scale'] + self.config['default_angles'] |
|
|
|
|
|
|
|
|
self.target_dof_pos = ( |
|
|
|
|
|
self.action * self.config["action_scale"] + self.config["default_angles"] |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
viewer.sync() |
|
|
viewer.sync() |
|
|
# time.sleep(max(0, self.model.opt.timestep - (time.time() - step_start))) |
|
|
# time.sleep(max(0, self.model.opt.timestep - (time.time() - step_start))) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
if __name__ == "__main__": |
|
|
CONFIG_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "resources", "robots", "g1") |
|
|
|
|
|
|
|
|
CONFIG_PATH = os.path.join( |
|
|
|
|
|
os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "resources", "robots", "g1" |
|
|
|
|
|
) |
|
|
controller = GearWbcController(CONFIG_PATH) |
|
|
controller = GearWbcController(CONFIG_PATH) |
|
|
controller.run() |
|
|
controller.run() |