You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 

115 lines
4.7 KiB

import torch.nn as nn
from torch.nn import functional as F
from torchvision.transforms import v2
import torch
from detr.main import build_ACT_model_and_optimizer, build_CNNMLP_model_and_optimizer
import IPython
e = IPython.embed
class ACTPolicy(nn.Module):
def __init__(self, args_override):
super().__init__()
model, optimizer = build_ACT_model_and_optimizer(args_override)
self.model = model # CVAE decoder
self.optimizer = optimizer
self.kl_weight = args_override['kl_weight']
self.qpos_noise_std = args_override['qpos_noise_std']
print(f'KL Weight {self.kl_weight}')
def __call__(self, qpos, image, actions=None, is_pad=None):
env_state = None
# normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225])
patch_h = 16
patch_w = 22
if actions is not None: # training time
# transform = v2.Compose([
# v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
# v2.RandomPerspective(distortion_scale=0.5),
# v2.RandomAffine(degrees=10, translate=(0.1,0.1), scale=(0.9,1.1)),
# v2.GaussianBlur(kernel_size=(9,9), sigma=(0.1,2.0)),
# v2.Normalize(
# mean=[0.485, 0.456, 0.406],
# std=[0.229, 0.224, 0.225])
# ])
transform = v2.Compose([
v2.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5, hue=0.5),
v2.RandomPerspective(distortion_scale=0.5),
v2.RandomAffine(degrees=10, translate=(0.1,0.1), scale=(0.9,1.1)),
v2.GaussianBlur(kernel_size=(9,9), sigma=(0.1,2.0)),
v2.Resize((patch_h * 14, patch_w * 14)),
# v2.CenterCrop((patch_h * 14, patch_w * 14)),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
])
qpos += (self.qpos_noise_std**0.5)*torch.randn_like(qpos)
else: # inference time
transform = v2.Compose([
v2.Resize((patch_h * 14, patch_w * 14)),
# v2.CenterCrop((patch_h * 14, patch_w * 14)),
v2.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225))
])
image = transform(image)
if actions is not None: # training time
actions = actions[:, :self.model.num_queries]
is_pad = is_pad[:, :self.model.num_queries]
a_hat, is_pad_hat, (mu, logvar) = self.model(qpos, image, env_state, actions, is_pad)
total_kld, dim_wise_kld, mean_kld = kl_divergence(mu, logvar)
loss_dict = dict()
all_l1 = F.l1_loss(actions, a_hat, reduction='none')
l1 = (all_l1 * ~is_pad.unsqueeze(-1)).mean()
loss_dict['l1'] = l1
loss_dict['kl'] = total_kld[0]
loss_dict['loss'] = loss_dict['l1'] + loss_dict['kl'] * self.kl_weight
return loss_dict
else: # inference time
a_hat, _, (_, _) = self.model(qpos, image, env_state) # no action, sample from prior
return a_hat
def configure_optimizers(self):
return self.optimizer
class CNNMLPPolicy(nn.Module):
def __init__(self, args_override):
super().__init__()
model, optimizer = build_CNNMLP_model_and_optimizer(args_override)
self.model = model # decoder
self.optimizer = optimizer
def __call__(self, qpos, image, actions=None, is_pad=None):
env_state = None # TODO
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
image = normalize(image)
if actions is not None: # training time
actions = actions[:, 0]
a_hat = self.model(qpos, image, env_state, actions)
mse = F.mse_loss(actions, a_hat)
loss_dict = dict()
loss_dict['mse'] = mse
loss_dict['loss'] = loss_dict['mse']
return loss_dict
else: # inference time
a_hat = self.model(qpos, image, env_state) # no action, sample from prior
return a_hat
def configure_optimizers(self):
return self.optimizer
def kl_divergence(mu, logvar):
batch_size = mu.size(0)
assert batch_size != 0
if mu.data.ndimension() == 4:
mu = mu.view(mu.size(0), mu.size(1))
if logvar.data.ndimension() == 4:
logvar = logvar.view(logvar.size(0), logvar.size(1))
klds = -0.5 * (1 + logvar - mu.pow(2) - logvar.exp())
total_kld = klds.sum(1).mean(0, True)
dimension_wise_kld = klds.mean(0)
mean_kld = klds.mean(1).mean(0, True)
return total_kld, dimension_wise_kld, mean_kld