You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
367 lines
14 KiB
367 lines
14 KiB
import torch
|
|
import numpy as np
|
|
import os
|
|
import pickle
|
|
import argparse
|
|
import matplotlib.pyplot as plt
|
|
from copy import deepcopy
|
|
from tqdm import tqdm
|
|
from einops import rearrange
|
|
|
|
# from .constants import DT
|
|
# from .constants import PUPPET_GRIPPER_JOINT_OPEN
|
|
from utils import load_data # data functions
|
|
from utils import compute_dict_mean, set_seed, detach_dict, parse_id, find_all_ckpt # helper functions
|
|
from policy import ACTPolicy, CNNMLPPolicy
|
|
# from .visualize_episodes import save_videos
|
|
import wandb
|
|
|
|
# from sim_env import BOX_POSE
|
|
# from constants import SIM_TASK_CONFIGS
|
|
import IPython
|
|
e = IPython.embed
|
|
import time
|
|
from itertools import repeat
|
|
|
|
def repeater(data_loader):
|
|
epoch = 0
|
|
for loader in repeat(data_loader):
|
|
for data in loader:
|
|
yield data
|
|
print(f'Epoch {epoch} done')
|
|
epoch += 1
|
|
|
|
from pathlib import Path
|
|
|
|
|
|
def main(args):
|
|
set_seed(1)
|
|
# command line parameters
|
|
is_eval = args['eval']
|
|
policy_class = args['policy_class']
|
|
onscreen_render = args['onscreen_render']
|
|
# task_name = args['task_name']
|
|
batch_size_train = args['batch_size']
|
|
batch_size_val = args['batch_size']
|
|
num_epochs = args['num_epochs']
|
|
|
|
# get task parameters
|
|
# is_sim = task_name[:4] == 'sim_'
|
|
task_dir, task_name = parse_id(RECORD_DIR, args['taskid'])
|
|
dataset_dir = (Path(task_dir) / 'processed').resolve()
|
|
ckpt_dir = (LOG_DIR / task_name / args['exptid']).resolve()
|
|
print("*"*20)
|
|
print(f"Task name: {task_name}")
|
|
print("*"*20)
|
|
|
|
# print(f"Checkpoint dir: {ckpt_dir}")
|
|
# task_config = SIM_TASK_CONFIGS[task_name]
|
|
# dataset_dir = task_config['dataset_dir']
|
|
# ckpt_dir = task_config['ckpt_dir']
|
|
# num_episodes = task_config['num_episodes']
|
|
# episode_len = task_config['episode_len']
|
|
camera_names = ['left', 'right']
|
|
|
|
# fixed parameters
|
|
state_dim = 26
|
|
action_dim = 28
|
|
lr_backbone = 1e-5
|
|
backbone = 'dino_v2'
|
|
if policy_class == 'ACT':
|
|
enc_layers = 4
|
|
dec_layers = 7
|
|
nheads = 8
|
|
policy_config = {'lr': args['lr'],
|
|
'num_queries': args['chunk_size'],
|
|
'kl_weight': args['kl_weight'],
|
|
'hidden_dim': args['hidden_dim'],
|
|
'dim_feedforward': args['dim_feedforward'],
|
|
'lr_backbone': lr_backbone,
|
|
'backbone': backbone,
|
|
'enc_layers': enc_layers,
|
|
'dec_layers': dec_layers,
|
|
'nheads': nheads,
|
|
'camera_names': camera_names,
|
|
'state_dim': state_dim,
|
|
'action_dim': action_dim,
|
|
'qpos_noise_std': args['qpos_noise_std'],
|
|
}
|
|
elif policy_class == 'CNNMLP':
|
|
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1,
|
|
'camera_names': camera_names,}
|
|
else:
|
|
raise NotImplementedError
|
|
|
|
config = {
|
|
'num_epochs': num_epochs,
|
|
'ckpt_dir': ckpt_dir,
|
|
# 'episode_len': episode_len,
|
|
'state_dim': state_dim,
|
|
'action_dim': action_dim,
|
|
'lr': args['lr'],
|
|
'policy_class': policy_class,
|
|
'onscreen_render': onscreen_render,
|
|
'policy_config': policy_config,
|
|
# 'task_name': task_name,
|
|
'seed': args['seed'],
|
|
'temporal_agg': args['temporal_agg'],
|
|
'camera_names': camera_names,
|
|
# 'real_robot': not is_sim
|
|
'resumeid': args['resumeid'],
|
|
'resume_ckpt': args['resume_ckpt'],
|
|
'task_name': task_name,
|
|
'exptid': args['exptid'],
|
|
}
|
|
mode = "disabled" if args["no_wandb"] or args["save_jit"] else "online"
|
|
wandb.init(project="television", name=args['exptid'], group=task_name, entity="cxx", mode=mode, dir="../data/logs")
|
|
wandb.config.update(config)
|
|
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, camera_names, batch_size_train, batch_size_val)
|
|
|
|
# save dataset stats
|
|
if not os.path.isdir(ckpt_dir):
|
|
os.makedirs(ckpt_dir)
|
|
stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl')
|
|
with open(stats_path, 'wb') as f:
|
|
pickle.dump(stats, f)
|
|
|
|
if args['save_jit']:
|
|
save_jit(config)
|
|
return
|
|
|
|
best_ckpt_info = train_bc(train_dataloader, val_dataloader, config)
|
|
best_epoch, min_val_loss, best_state_dict = best_ckpt_info
|
|
|
|
# save best checkpoint
|
|
ckpt_path = os.path.join(ckpt_dir, f'policy_best.ckpt')
|
|
torch.save(best_state_dict, ckpt_path)
|
|
print(f'Best ckpt, val loss {min_val_loss:.6f} @ epoch{best_epoch}')
|
|
wandb.finish()
|
|
|
|
def make_policy(policy_class, policy_config):
|
|
if policy_class == 'ACT':
|
|
policy = ACTPolicy(policy_config)
|
|
elif policy_class == 'CNNMLP':
|
|
policy = CNNMLPPolicy(policy_config)
|
|
else:
|
|
raise NotImplementedError
|
|
return policy
|
|
|
|
|
|
def make_optimizer(policy_class, policy):
|
|
if policy_class == 'ACT':
|
|
optimizer = policy.configure_optimizers()
|
|
elif policy_class == 'CNNMLP':
|
|
optimizer = policy.configure_optimizers()
|
|
else:
|
|
raise NotImplementedError
|
|
return optimizer
|
|
|
|
|
|
def get_image(ts, camera_names):
|
|
curr_images = []
|
|
for cam_name in camera_names:
|
|
curr_image = rearrange(ts.observation['images'][cam_name], 'h w c -> c h w')
|
|
curr_images.append(curr_image)
|
|
curr_image = np.stack(curr_images, axis=0)
|
|
curr_image = torch.from_numpy(curr_image / 255.0).float().cuda().unsqueeze(0)
|
|
return curr_image
|
|
|
|
|
|
def forward_pass(data, policy):
|
|
image_data, qpos_data, action_data, is_pad = data
|
|
image_data, qpos_data, action_data, is_pad = image_data.cuda(), qpos_data.cuda(), action_data.cuda(), is_pad.cuda()
|
|
return policy(qpos_data, image_data, action_data, is_pad) # TODO remove None
|
|
|
|
|
|
|
|
def train_bc(train_dataloader, val_dataloader, config):
|
|
num_epochs = config['num_epochs']
|
|
ckpt_dir = config['ckpt_dir']
|
|
seed = config['seed']
|
|
policy_class = config['policy_class']
|
|
policy_config = config['policy_config']
|
|
|
|
set_seed(seed)
|
|
|
|
policy = make_policy(policy_class, policy_config)
|
|
policy.cuda()
|
|
optimizer = make_optimizer(policy_class, policy)
|
|
|
|
if config['resumeid']:
|
|
exp_dir, exp_name = parse_id((LOG_DIR / config['task_name']).resolve(), config['resumeid'])
|
|
policy, _, _ = load_ckpt(policy, exp_dir, config['resume_ckpt'])
|
|
# if config['resume_ckpt']:
|
|
# ckpt_name = f"policy_epoch_{config['resume_ckpt']}_seed_0.ckpt"
|
|
# else:
|
|
# ckpt_name, _ = find_all_ckpt(exp_dir)#f"policy_last.ckpt"
|
|
# resume_path = (Path(exp_dir) / ckpt_name).resolve()
|
|
# print(f"Resuming from {resume_path}")
|
|
# checkpoint = torch.load(resume_path)
|
|
# policy.load_state_dict(checkpoint)
|
|
|
|
# train_history = []
|
|
# validation_history = []
|
|
min_val_loss = np.inf
|
|
best_ckpt_info = None
|
|
|
|
train_dataloader = repeater(train_dataloader)
|
|
for epoch in tqdm(range(num_epochs)):
|
|
print(f'\nEpoch {epoch}')
|
|
if epoch % 500 == 0:
|
|
# validation
|
|
with torch.inference_mode():
|
|
policy.eval()
|
|
validation_dicts = []
|
|
for batch_idx, data in enumerate(val_dataloader):
|
|
forward_dict = forward_pass(data, policy)
|
|
validation_dicts.append(forward_dict)
|
|
if batch_idx > 20:
|
|
break
|
|
|
|
validation_summary = compute_dict_mean(validation_dicts)
|
|
|
|
epoch_val_loss = validation_summary['loss']
|
|
if epoch_val_loss < min_val_loss:
|
|
min_val_loss = epoch_val_loss
|
|
best_ckpt_info = (epoch, min_val_loss, deepcopy(policy.state_dict()))
|
|
for k in list(validation_summary.keys()):
|
|
validation_summary[f'val/{k}'] = validation_summary.pop(k)
|
|
wandb.log(validation_summary, step=epoch)
|
|
print(f'Val loss: {epoch_val_loss:.5f}')
|
|
summary_string = ''
|
|
for k, v in validation_summary.items():
|
|
summary_string += f'{k}: {v.item():.3f} '
|
|
print(summary_string)
|
|
|
|
# training
|
|
policy.train()
|
|
optimizer.zero_grad()
|
|
|
|
data = next(train_dataloader)
|
|
forward_dict = forward_pass(data, policy)
|
|
# backward
|
|
loss = forward_dict['loss']
|
|
loss.backward()
|
|
optimizer.step()
|
|
optimizer.zero_grad()
|
|
|
|
epoch_summary = detach_dict(forward_dict)
|
|
|
|
# epoch_summary = compute_dict_mean(train_history[(batch_idx+1)*epoch:(batch_idx+1)*(epoch+1)])
|
|
epoch_train_loss = epoch_summary['loss']
|
|
print(f'Train loss: {epoch_train_loss:.5f}')
|
|
summary_string = ''
|
|
for k, v in epoch_summary.items():
|
|
summary_string += f'{k}: {v.item():.3f} '
|
|
print(summary_string)
|
|
wandb.log(epoch_summary, step=epoch)
|
|
|
|
if epoch % 1000 == 0 and epoch >= 1000:
|
|
ckpt_path = os.path.join(ckpt_dir, f'policy_epoch_{epoch}_seed_{seed}.ckpt')
|
|
torch.save(policy.state_dict(), ckpt_path)
|
|
# plot_history(train_history, validation_history, epoch, ckpt_dir, seed)
|
|
|
|
ckpt_path = os.path.join(ckpt_dir, f'policy_last.ckpt')
|
|
torch.save(policy.state_dict(), ckpt_path)
|
|
|
|
best_epoch, min_val_loss, best_state_dict = best_ckpt_info
|
|
ckpt_path = os.path.join(ckpt_dir, f'policy_epoch_{best_epoch}_seed_{seed}.ckpt')
|
|
torch.save(best_state_dict, ckpt_path)
|
|
print(f'Training finished:\nSeed {seed}, val loss {min_val_loss:.6f} at epoch {best_epoch}')
|
|
|
|
# save training curves
|
|
# plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed)
|
|
|
|
return best_ckpt_info
|
|
|
|
|
|
def plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed):
|
|
# save training curves
|
|
for key in train_history[0]:
|
|
plot_path = os.path.join(ckpt_dir, f'train_val_{key}_seed_{seed}.png')
|
|
plt.figure()
|
|
train_values = [summary[key].item() for summary in train_history]
|
|
val_values = [summary[key].item() for summary in validation_history]
|
|
plt.plot(np.linspace(0, num_epochs-1, len(train_history)), train_values, label='train')
|
|
plt.plot(np.linspace(0, num_epochs-1, len(validation_history)), val_values, label='validation')
|
|
# plt.ylim([-0.1, 1])
|
|
plt.tight_layout()
|
|
plt.legend()
|
|
plt.title(key)
|
|
plt.savefig(plot_path)
|
|
print(f'Saved plots to {ckpt_dir}')
|
|
|
|
def load_ckpt(policy, exp_dir, ckpt_name):
|
|
if ckpt_name:
|
|
epoch = ckpt_name
|
|
ckpt_name = f"policy_epoch_{ckpt_name}_seed_0.ckpt"
|
|
else:
|
|
ckpt_name, epoch = find_all_ckpt(exp_dir)#f"policy_last.ckpt"
|
|
resume_path = (Path(exp_dir) / ckpt_name).resolve()
|
|
print("*"*20)
|
|
print(f"Resuming from {resume_path}")
|
|
print("*"*20)
|
|
policy.load_state_dict(torch.load(resume_path))
|
|
return policy, ckpt_name, epoch
|
|
|
|
def save_jit(config):
|
|
# ckpt_dir = config['ckpt_dir']
|
|
policy_class = config['policy_class']
|
|
policy_config = config['policy_config']
|
|
|
|
exp_dir, exp_name = parse_id((LOG_DIR / config['task_name']).resolve(), config['exptid'])
|
|
|
|
policy = make_policy(policy_class, policy_config)
|
|
policy.cuda()
|
|
|
|
policy, ckpt_name, epoch = load_ckpt(policy, exp_dir, config['resume_ckpt'])
|
|
|
|
policy.eval()
|
|
image_data = torch.rand((1, 2, 3, 480, 640), device='cuda')
|
|
qpos_data = torch.rand((1, config['state_dim']), device='cuda')
|
|
input_data = (qpos_data, image_data)
|
|
|
|
traced_policy = torch.jit.trace(policy, input_data)
|
|
save_path = os.path.join(exp_dir, f"traced_jit_{epoch}.pt")
|
|
traced_policy.save(save_path)
|
|
print("Saved traced actor at ", save_path)
|
|
|
|
new_policy = torch.jit.load(save_path)
|
|
|
|
if __name__ == '__main__':
|
|
parser = argparse.ArgumentParser()
|
|
parser.add_argument('--eval', action='store_true')
|
|
parser.add_argument('--onscreen_render', action='store_true')
|
|
parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True)
|
|
# parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
|
|
parser.add_argument('--batch_size', action='store', type=int, help='batch_size', required=True)
|
|
parser.add_argument('--seed', action='store', type=int, help='seed', required=True)
|
|
parser.add_argument('--num_epochs', action='store', type=int, help='num_epochs', required=True)
|
|
parser.add_argument('--lr', action='store', type=float, help='lr', required=True)
|
|
parser.add_argument('--qpos_noise_std', action='store', default=0, type=float, help='lr', required=False)
|
|
|
|
# for ACT
|
|
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)
|
|
parser.add_argument('--chunk_size', action='store', type=int, help='chunk_size', required=False)
|
|
parser.add_argument('--hidden_dim', action='store', type=int, help='hidden_dim', required=False)
|
|
parser.add_argument('--dim_feedforward', action='store', type=int, help='dim_feedforward', required=False)
|
|
parser.add_argument('--temporal_agg', action='store_true')
|
|
parser.add_argument('--save_jit', action='store_true')
|
|
parser.add_argument('--no_wandb', action='store_true')
|
|
parser.add_argument('--resumeid', action='store', default="", type=str, help='resume id', required=False)
|
|
parser.add_argument('--resume_ckpt', action='store', default="", type=str, help='resume ckpt', required=False)
|
|
parser.add_argument('--taskid', action='store', type=str, help='task id', required=True)
|
|
parser.add_argument('--exptid', action='store', type=str, help='experiment id', required=True)
|
|
parser.add_argument('--source', choices=['self', 'ssd'], default='self')
|
|
args = vars(parser.parse_args())
|
|
|
|
if args['source'] == 'self':
|
|
current_dir = Path(__file__).parent.resolve()
|
|
else:
|
|
current_dir = Path("/media/cxx/Extreme Pro/human2robot/data/").resolve()
|
|
DATA_DIR = (current_dir.parent / 'data/').resolve()
|
|
RECORD_DIR = (DATA_DIR / 'recordings/').resolve()
|
|
LOG_DIR = (DATA_DIR / 'logs/').resolve()
|
|
# print(f"\nDATA dir: {DATA_DIR}")
|
|
|
|
main(args)
|