You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
150 lines
5.2 KiB
150 lines
5.2 KiB
from isaacgym import gymapi
|
|
from isaacgym import gymutil
|
|
|
|
import math
|
|
import numpy as np
|
|
import matplotlib.pyplot as plt
|
|
|
|
from replay_demo import Player
|
|
|
|
from pathlib import Path
|
|
import h5py
|
|
from tqdm import tqdm
|
|
import time
|
|
import yaml
|
|
import pickle
|
|
import torch
|
|
import cv2
|
|
from collections import deque
|
|
import argparse
|
|
import sys
|
|
sys.path.append("../")
|
|
from act.utils import parse_id
|
|
# from act.imitate_episodes import RECORD_DIR, DATA_DIR, LOG_DIR
|
|
|
|
from pathlib import Path
|
|
current_dir = Path(__file__).parent.resolve()
|
|
DATA_DIR = (current_dir.parent / 'data/').resolve()
|
|
RECORD_DIR = (DATA_DIR / 'recordings/').resolve()
|
|
LOG_DIR = (DATA_DIR / 'logs/').resolve()
|
|
# print(f"\nDATA dir: {DATA_DIR}")
|
|
|
|
def get_norm_stats(data_path):
|
|
# norm_stats = {
|
|
# "action_mean": np.array([]), "action_std": np.array([]),
|
|
# "qpos_mean": np.array([]), "qpos_std": np.array([]),
|
|
# }
|
|
with open(data_path, "rb") as f:
|
|
norm_stats = pickle.load(f)
|
|
return norm_stats
|
|
|
|
|
|
def load_policy(policy_path, device):
|
|
policy = torch.jit.load(policy_path, map_location=device)
|
|
return policy
|
|
|
|
|
|
def normalize_input(state, left_img, right_img, norm_stats, last_action_data=None):
|
|
# import ipdb; ipdb.set_trace()
|
|
# left_img = cv2.resize(left_img, (308, 224))
|
|
# right_img = cv2.resize(right_img, (308, 224))
|
|
image_data = torch.from_numpy(np.stack([left_img, right_img], axis=0)) / 255.0
|
|
qpos_data = (torch.from_numpy(state) - norm_stats["qpos_mean"]) / norm_stats["qpos_std"]
|
|
image_data = image_data.view((1, 2, 3, 480, 640)).to(device='cuda')
|
|
qpos_data = qpos_data.view((1, 26)).to(device='cuda')
|
|
|
|
if last_action_data is not None:
|
|
last_action_data = torch.from_numpy(last_action_data).to(device='cuda').view((1, -1)).to(torch.float)
|
|
qpos_data = torch.cat((qpos_data, last_action_data), dim=1)
|
|
return (qpos_data, image_data)
|
|
|
|
|
|
def merge_act(actions_for_curr_step, k = 0.01):
|
|
actions_populated = np.all(actions_for_curr_step != 0, axis=1)
|
|
actions_for_curr_step = actions_for_curr_step[actions_populated]
|
|
|
|
exp_weights = np.exp(-k * np.arange(actions_for_curr_step.shape[0]))
|
|
exp_weights = (exp_weights / exp_weights.sum()).reshape((-1, 1))
|
|
raw_action = (actions_for_curr_step * exp_weights).sum(axis=0)
|
|
|
|
return raw_action
|
|
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
|
|
parser.add_argument('--taskid', action='store', type=str, help='task id', required=True)
|
|
parser.add_argument('--exptid', action='store', type=str, help='experiment id', required=True)
|
|
parser.add_argument('--resume_ckpt', action='store', type=str, help='resume checkpoint', required=True)
|
|
args = vars(parser.parse_args())
|
|
|
|
episode_name = "processed_episode_0.hdf5"
|
|
task_dir, task_name = parse_id(RECORD_DIR, args['taskid'])
|
|
episode_path = (Path(task_dir) / 'processed' / episode_name).resolve()
|
|
exp_path, _ = parse_id((Path(LOG_DIR) / task_name).resolve(), args['exptid'])
|
|
|
|
norm_stat_path = Path(exp_path) / "dataset_stats.pkl"
|
|
policy_path = Path(exp_path) / f"traced_jit_{args['resume_ckpt']}.pt"
|
|
|
|
temporal_agg = True
|
|
action_dim = 28
|
|
|
|
chunk_size = 60
|
|
device = "cuda"
|
|
|
|
data = h5py.File(str(episode_path), 'r')
|
|
actions = np.array(data['qpos_action'])
|
|
left_imgs = np.array(data['observation.image.left'])
|
|
right_imgs = np.array(data['observation.image.right'])
|
|
states = np.array(data['observation.state'])
|
|
init_action = np.array(data.attrs['init_action'])
|
|
data.close()
|
|
timestamps = states.shape[0]
|
|
|
|
norm_stats = get_norm_stats(norm_stat_path)
|
|
policy = load_policy(policy_path, device)
|
|
policy.cuda()
|
|
policy.eval()
|
|
|
|
history_stack = 0
|
|
if history_stack > 0:
|
|
last_action_queue = deque(maxlen=history_stack)
|
|
for i in range(history_stack):
|
|
last_action_queue.append(actions[0])
|
|
else:
|
|
last_action_queue = None
|
|
last_action_data = None
|
|
player = Player(dt=1/30)
|
|
|
|
if temporal_agg:
|
|
all_time_actions = np.zeros([timestamps, timestamps+chunk_size, action_dim])
|
|
else:
|
|
num_actions_exe = chunk_size
|
|
|
|
try:
|
|
output = None
|
|
act_index = 0
|
|
for t in tqdm(range(timestamps)):
|
|
if history_stack > 0:
|
|
last_action_data = np.array(last_action_queue)
|
|
|
|
data = normalize_input(states[t], left_imgs[t], right_imgs[t], norm_stats, last_action_data)
|
|
|
|
if temporal_agg:
|
|
output = policy(*data)[0].detach().cpu().numpy() # (1,chuck_size,action_dim)
|
|
all_time_actions[[t], t:t+chunk_size] = output
|
|
act = merge_act(all_time_actions[:, t])
|
|
else:
|
|
if output is None or act_index == num_actions_exe-1:
|
|
print("Inference...")
|
|
output = policy(*data)[0].detach().cpu().numpy()
|
|
act_index = 0
|
|
act = output[act_index]
|
|
act_index += 1
|
|
# import ipdb; ipdb.set_trace()
|
|
if history_stack > 0:
|
|
last_action_queue.append(act)
|
|
act = act * norm_stats["action_std"] + norm_stats["action_mean"]
|
|
player.step(act, left_imgs[t], right_imgs[t])
|
|
except KeyboardInterrupt:
|
|
player.end()
|
|
exit()
|