You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

307 lines
13 KiB

import sys
import time
import collections
import yaml
import torch
import numpy as np
import sys
import time
import collections
import yaml
import torch
import numpy as np
import mujoco
import mujoco.viewer
import onnxruntime as ort
import threading
from pynput import keyboard as pkb
import os
class GearWbcController:
def __init__(self, config_path):
self.CONFIG_PATH = config_path
self.cmd_lock = threading.Lock()
self.config = self.load_config(os.path.join(self.CONFIG_PATH, "g1_gear_wbc.yaml"))
self.control_dict = {
'loco_cmd': self.config['cmd_init'],
"height_cmd": self.config['height_cmd'],
"rpy_cmd": self.config.get('rpy_cmd', [0.0, 0.0, 0.0]),
"freq_cmd": self.config.get('freq_cmd', 1.5)
}
self.model = mujoco.MjModel.from_xml_path(self.config['xml_path'])
self.data = mujoco.MjData(self.model)
self.model.opt.timestep = self.config['simulation_dt']
self.n_joints = self.data.qpos.shape[0] - 7
self.torso_index = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, "torso_link")
self.base_index = mujoco.mj_name2id(self.model, mujoco.mjtObj.mjOBJ_BODY, "pelvis")
self.action = np.zeros(self.config['num_actions'], dtype=np.float32)
self.target_dof_pos = self.config['default_angles'].copy()
self.policy = self.load_onnx_policy(self.config['policy_path'])
self.walk_policy = self.load_onnx_policy(self.config['walk_policy_path'])
self.gait_indices = torch.zeros((1), dtype=torch.float32)
self.counter = 0
self.just_started = 0.0
self.walking_mask = False
self.frozen_FL = False
self.frozen_FR = False
self.single_obs, self.single_obs_dim = self.compute_observation(
self.data, self.config, self.action, self.control_dict, self.n_joints
)
self.obs_history = collections.deque(
[np.zeros(self.single_obs_dim, dtype=np.float32)] * self.config['obs_history_len'],
maxlen=self.config['obs_history_len']
)
self.obs = np.zeros(self.config['num_obs'], dtype=np.float32)
self.keyboard_listener(self.control_dict, self.config)
def keyboard_listener(self, control_dict, config):
"""Listen to key press events and update cmd and height_cmd"""
def on_press(key):
try:
k = key.char
except AttributeError:
return # Special keys ignored
with self.cmd_lock:
if k == 'w':
control_dict['loco_cmd'][0] += 0.1
elif k == 's':
control_dict['loco_cmd'][0] -= 0.1
elif k == 'a':
control_dict['loco_cmd'][1] += 0.1
elif k == 'd':
control_dict['loco_cmd'][1] -= 0.1
elif k == 'q':
control_dict['loco_cmd'][2] += 0.1
elif k == 'e':
control_dict['loco_cmd'][2] -= 0.1
elif k == 'z':
control_dict['loco_cmd'][:] = config['cmd_init']
control_dict["height_cmd"] = config['height_cmd']
control_dict['rpy_cmd'][:] = config['rpy_cmd']
control_dict['freq_cmd'] = config['freq_cmd']
elif k == '1':
control_dict["height_cmd"] += 0.05
elif k == '2':
control_dict["height_cmd"] -= 0.05
elif k == '3':
control_dict['rpy_cmd'][0] += 0.2
elif k == '4':
control_dict['rpy_cmd'][0] -= 0.2
elif k == '5':
control_dict['rpy_cmd'][1] += 0.2
elif k == '6':
control_dict['rpy_cmd'][1] -= 0.2
elif k == '7':
control_dict['rpy_cmd'][2] += 0.2
elif k == '8':
control_dict['rpy_cmd'][2] -= 0.2
elif k == 'm':
control_dict['freq_cmd'] += 0.1
elif k == 'n':
control_dict['freq_cmd'] -= 0.1
print(f"Current Commands: loco_cmd = {control_dict['loco_cmd']}, height_cmd = {control_dict['height_cmd']}, rpy_cmd = {control_dict['rpy_cmd']}, freq_cmd = {control_dict['freq_cmd']}")
listener = pkb.Listener(on_press=on_press)
listener.daemon = True
listener.start()
def load_config(self, config_path):
with open(config_path, 'r') as f:
config = yaml.safe_load(f)
for path_key in ['policy_path', 'xml_path', 'walk_policy_path']:
config[path_key] = os.path.join(CONFIG_PATH, config[path_key])
array_keys = ['kps', 'kds', 'default_angles', 'cmd_scale', 'cmd_init']
for key in array_keys:
config[key] = np.array(config[key], dtype=np.float32)
return config
def pd_control(self, target_q, q, kp, target_dq, dq, kd):
return (target_q - q) * kp + (target_dq - dq) * kd
def quat_rotate_inverse(self, q, v):
w, x, y, z = q
q_conj = np.array([w, -x, -y, -z])
return np.array([
v[0] * (q_conj[0]**2 + q_conj[1]**2 - q_conj[2]**2 - q_conj[3]**2) +
v[1] * 2 * (q_conj[1]*q_conj[2] - q_conj[0]*q_conj[3]) +
v[2] * 2 * (q_conj[1]*q_conj[3] + q_conj[0]*q_conj[2]),
v[0] * 2 * (q_conj[1]*q_conj[2] + q_conj[0]*q_conj[3]) +
v[1] * (q_conj[0]**2 - q_conj[1]**2 + q_conj[2]**2 - q_conj[3]**2) +
v[2] * 2 * (q_conj[2]*q_conj[3] - q_conj[0]*q_conj[1]),
v[0] * 2 * (q_conj[1]*q_conj[3] - q_conj[0]*q_conj[2]) +
v[1] * 2 * (q_conj[2]*q_conj[3] + q_conj[0]*q_conj[1]) +
v[2] * (q_conj[0]**2 - q_conj[1]**2 - q_conj[2]**2 + q_conj[3]**2)
])
def get_gravity_orientation(self, quat):
gravity_vec = np.array([0.0, 0.0, -1.0])
return self.quat_rotate_inverse(quat, gravity_vec)
def compute_observation(self, d, config, action, control_dict, n_joints):
command = np.zeros(7, dtype=np.float32)
command[:3] = control_dict['loco_cmd'][:3] * config['cmd_scale']
command[3] = control_dict['height_cmd']
# command[4] = control_dict['freq_cmd']
command[4:7] = control_dict['rpy_cmd']
# # gait indice
# is_static = np.linalg.norm(command[:3]) < 0.1
# just_entered_walk = (not is_static) and (not self.walking_mask)
# self.walking_mask = not is_static
# if just_entered_walk:
# self.just_started = 0.0
# self.gait_indices = torch.tensor([-0.25])
# if not is_static:
# self.just_started += 0.02
# else:
# self.just_started = 0.0
# if not is_static:
# self.frozen_FL = False
# self.frozen_FR = False
# self.gait_indices = torch.remainder(self.gait_indices + 0.02 * command[4], 1.0)
# # Parameters
# duration = 0.5
# phase = 0.5
# # Gait indices
# gait_FR = self.gait_indices.clone()
# gait_FL = torch.remainder(gait_FR + phase, 1.0)
# if self.just_started < (0.5 / command[4]):
# gait_FR = torch.tensor([0.25])
# gait_pair = [gait_FL.clone(), gait_FR.clone()]
# for i, fi in enumerate(gait_pair):
# if fi.item() < duration:
# gait_pair[i] = fi * (0.5 / duration)
# else:
# gait_pair[i] = 0.5 + (fi - duration) * (0.5 / (1 - duration))
# # Clock signal
# clock = [torch.sin(2 * np.pi * fi) for fi in gait_pair]
# for i, (clk, frozen_mask_attr) in enumerate(
# zip(clock, ['frozen_FL', 'frozen_FR'])
# ):
# frozen_mask = getattr(self, frozen_mask_attr)
# # Freeze condition: static and at sin peak
# if is_static and (not frozen_mask) and clk.item() > 0.98:
# setattr(self, frozen_mask_attr, True)
# clk = torch.tensor([1.0])
# if getattr(self, frozen_mask_attr):
# clk = torch.tensor([1.0])
# clock[i] = clk
# self.clock_inputs = torch.stack(clock).unsqueeze(0)
qj = d.qpos[7:7+n_joints].copy()
dqj = d.qvel[6:6+n_joints].copy()
quat = d.qpos[3:7].copy()
omega = d.qvel[3:6].copy()
# omega = self.data.xmat[self.base_index].reshape(3, 3).T @ self.data.cvel[self.base_index][3:6]
padded_defaults = np.zeros(n_joints, dtype=np.float32)
L = min(len(config['default_angles']), n_joints)
padded_defaults[:L] = config['default_angles'][:L]
qj_scaled = (qj - padded_defaults) * config['dof_pos_scale']
dqj_scaled = dqj * config['dof_vel_scale']
gravity_orientation = self.get_gravity_orientation(quat)
omega_scaled = omega * config['ang_vel_scale']
torso_quat = self.data.xquat[self.torso_index]
torso_omega = self.data.xmat[self.torso_index].reshape(3, 3).T @ self.data.cvel[self.torso_index][3:6]
torso_omega_scaled = torso_omega * config['ang_vel_scale']
torso_gravity_orientation = self.get_gravity_orientation(torso_quat)
single_obs_dim = 86
single_obs = np.zeros(single_obs_dim, dtype=np.float32)
single_obs[0:7] = command[:7]
single_obs[7:10] = omega_scaled
single_obs[10:13] = gravity_orientation
# single_obs[14:17] = 0.#torso_omega_scaled
# single_obs[17:20] = 0.#torso_gravity_orientation
single_obs[13:13+n_joints] = qj_scaled
single_obs[13+n_joints:13+2*n_joints] = dqj_scaled
single_obs[13+2*n_joints:13+2*n_joints+15] = action
# single_obs[20+2*n_joints+15:20+2*n_joints+15+2] = self.clock_inputs.cpu().numpy().reshape(2)
return single_obs, single_obs_dim
def load_onnx_policy(self, path):
model = ort.InferenceSession(path)
def run_inference(input_tensor):
ort_inputs = {model.get_inputs()[0].name: input_tensor.cpu().numpy()}
ort_outs = model.run(None, ort_inputs)
return torch.tensor(ort_outs[0], device="cuda:0")
return run_inference
def run(self):
self.counter = 0
with mujoco.viewer.launch_passive(self.model, self.data) as viewer:
start = time.time()
while viewer.is_running() and time.time() - start < self.config['simulation_duration']:
step_start = time.time()
leg_tau = self.pd_control(
self.target_dof_pos,
self.data.qpos[7:7+self.config['num_actions']],
self.config['kps'],
np.zeros_like(self.config['kps']),
self.data.qvel[6:6+self.config['num_actions']],
self.config['kds']
)
self.data.ctrl[:self.config['num_actions']] = leg_tau
if self.n_joints > self.config['num_actions']:
arm_tau = self.pd_control(
np.zeros(self.n_joints - self.config['num_actions'], dtype=np.float32),
self.data.qpos[7+self.config['num_actions']:7+self.n_joints],
np.full(self.n_joints - self.config['num_actions'], 100.0),
np.zeros(self.n_joints - self.config['num_actions']),
self.data.qvel[6+self.config['num_actions']:6+self.n_joints],
np.full(self.n_joints - self.config['num_actions'], 0.5)
)
self.data.ctrl[self.config['num_actions']:] = arm_tau
mujoco.mj_step(self.model, self.data)
self.counter += 1
if self.counter % self.config['control_decimation'] == 0:
with self.cmd_lock:
current_cmd = self.control_dict
single_obs, _ = self.compute_observation(self.data, self.config, self.action, current_cmd, self.n_joints)
self.obs_history.append(single_obs)
for i, hist_obs in enumerate(self.obs_history):
self.obs[i * self.single_obs_dim:(i + 1) * self.single_obs_dim] = hist_obs
obs_tensor = torch.from_numpy(self.obs).unsqueeze(0)
if (np.linalg.norm(np.array(current_cmd['loco_cmd']))) <= 0.05:
self.action = self.policy(obs_tensor).cpu().detach().numpy().squeeze()
else:
self.action = self.walk_policy(obs_tensor).cpu().detach().numpy().squeeze()
self.target_dof_pos = self.action * self.config['action_scale'] + self.config['default_angles']
viewer.sync()
# time.sleep(max(0, self.model.opt.timestep - (time.time() - step_start)))
if __name__ == "__main__":
CONFIG_PATH = os.path.join(os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "resources", "robots", "g1")
controller = GearWbcController(CONFIG_PATH)
controller.run()