You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

150 lines
5.2 KiB

from isaacgym import gymapi
from isaacgym import gymutil
import math
import numpy as np
import matplotlib.pyplot as plt
from replay_demo import Player
from pathlib import Path
import h5py
from tqdm import tqdm
import time
import yaml
import pickle
import torch
import cv2
from collections import deque
import argparse
import sys
sys.path.append("../")
from act.utils import parse_id
# from act.imitate_episodes import RECORD_DIR, DATA_DIR, LOG_DIR
from pathlib import Path
current_dir = Path(__file__).parent.resolve()
DATA_DIR = (current_dir.parent / 'data/').resolve()
RECORD_DIR = (DATA_DIR / 'recordings/').resolve()
LOG_DIR = (DATA_DIR / 'logs/').resolve()
# print(f"\nDATA dir: {DATA_DIR}")
def get_norm_stats(data_path):
# norm_stats = {
# "action_mean": np.array([]), "action_std": np.array([]),
# "qpos_mean": np.array([]), "qpos_std": np.array([]),
# }
with open(data_path, "rb") as f:
norm_stats = pickle.load(f)
return norm_stats
def load_policy(policy_path, device):
policy = torch.jit.load(policy_path, map_location=device)
return policy
def normalize_input(state, left_img, right_img, norm_stats, last_action_data=None):
# import ipdb; ipdb.set_trace()
# left_img = cv2.resize(left_img, (308, 224))
# right_img = cv2.resize(right_img, (308, 224))
image_data = torch.from_numpy(np.stack([left_img, right_img], axis=0)) / 255.0
qpos_data = (torch.from_numpy(state) - norm_stats["qpos_mean"]) / norm_stats["qpos_std"]
image_data = image_data.view((1, 2, 3, 480, 640)).to(device='cuda')
qpos_data = qpos_data.view((1, 26)).to(device='cuda')
if last_action_data is not None:
last_action_data = torch.from_numpy(last_action_data).to(device='cuda').view((1, -1)).to(torch.float)
qpos_data = torch.cat((qpos_data, last_action_data), dim=1)
return (qpos_data, image_data)
def merge_act(actions_for_curr_step, k = 0.01):
actions_populated = np.all(actions_for_curr_step != 0, axis=1)
actions_for_curr_step = actions_for_curr_step[actions_populated]
exp_weights = np.exp(-k * np.arange(actions_for_curr_step.shape[0]))
exp_weights = (exp_weights / exp_weights.sum()).reshape((-1, 1))
raw_action = (actions_for_curr_step * exp_weights).sum(axis=0)
return raw_action
if __name__ == '__main__':
parser = argparse.ArgumentParser('Set transformer detector', add_help=False)
parser.add_argument('--taskid', action='store', type=str, help='task id', required=True)
parser.add_argument('--exptid', action='store', type=str, help='experiment id', required=True)
parser.add_argument('--resume_ckpt', action='store', type=str, help='resume checkpoint', required=True)
args = vars(parser.parse_args())
episode_name = "processed_episode_0.hdf5"
task_dir, task_name = parse_id(RECORD_DIR, args['taskid'])
episode_path = (Path(task_dir) / 'processed' / episode_name).resolve()
exp_path, _ = parse_id((Path(LOG_DIR) / task_name).resolve(), args['exptid'])
norm_stat_path = Path(exp_path) / "dataset_stats.pkl"
policy_path = Path(exp_path) / f"traced_jit_{args['resume_ckpt']}.pt"
temporal_agg = True
action_dim = 28
chunk_size = 60
device = "cuda"
data = h5py.File(str(episode_path), 'r')
actions = np.array(data['qpos_action'])
left_imgs = np.array(data['observation.image.left'])
right_imgs = np.array(data['observation.image.right'])
states = np.array(data['observation.state'])
init_action = np.array(data.attrs['init_action'])
data.close()
timestamps = states.shape[0]
norm_stats = get_norm_stats(norm_stat_path)
policy = load_policy(policy_path, device)
policy.cuda()
policy.eval()
history_stack = 0
if history_stack > 0:
last_action_queue = deque(maxlen=history_stack)
for i in range(history_stack):
last_action_queue.append(actions[0])
else:
last_action_queue = None
last_action_data = None
player = Player(dt=1/30)
if temporal_agg:
all_time_actions = np.zeros([timestamps, timestamps+chunk_size, action_dim])
else:
num_actions_exe = chunk_size
try:
output = None
act_index = 0
for t in tqdm(range(timestamps)):
if history_stack > 0:
last_action_data = np.array(last_action_queue)
data = normalize_input(states[t], left_imgs[t], right_imgs[t], norm_stats, last_action_data)
if temporal_agg:
output = policy(*data)[0].detach().cpu().numpy() # (1,chuck_size,action_dim)
all_time_actions[[t], t:t+chunk_size] = output
act = merge_act(all_time_actions[:, t])
else:
if output is None or act_index == num_actions_exe-1:
print("Inference...")
output = policy(*data)[0].detach().cpu().numpy()
act_index = 0
act = output[act_index]
act_index += 1
# import ipdb; ipdb.set_trace()
if history_stack > 0:
last_action_queue.append(act)
act = act * norm_stats["action_std"] + norm_stats["action_mean"]
player.step(act, left_imgs[t], right_imgs[t])
except KeyboardInterrupt:
player.end()
exit()