You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
126 lines
4.6 KiB
126 lines
4.6 KiB
"""Minimal Linear Blend Skinning (LBS) for sparse sensor-point vertices on SMPLH."""
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from smplx.utils import Struct, to_np, to_tensor
|
|
|
|
from hmr4d.utils.smplx_utils import forward_kinematics_motion
|
|
from motiondiff.models.mdm.rotation_conversions import axis_angle_to_matrix
|
|
|
|
|
|
class MinimalLBS(nn.Module):
|
|
def __init__(self, sp_ids, bm_dir="models/smplh", num_betas=16, model_type="smplh", **kwargs):
|
|
super().__init__()
|
|
self.num_betas = num_betas
|
|
self.sensor_point_vid = torch.tensor(sp_ids)
|
|
|
|
# load struct data on predefined sensor-point
|
|
self.load_struct_on_sp(f"{bm_dir}/male/model.npz", prefix="male")
|
|
self.load_struct_on_sp(f"{bm_dir}/female/model.npz", prefix="female")
|
|
|
|
def load_struct_on_sp(self, bm_path, prefix="m"):
|
|
"""
|
|
Load 4 weights from body-model-struct.
|
|
Keep the sensor points only. Use prefix to label different bm.
|
|
"""
|
|
num_betas = self.num_betas
|
|
sp_vid = self.sensor_point_vid
|
|
# load data
|
|
data_struct = Struct(**np.load(bm_path, encoding="latin1"))
|
|
|
|
# v-template
|
|
v_template = to_tensor(to_np(data_struct.v_template)) # (V, 3)
|
|
v_template_sp = v_template[sp_vid] # (N, 3)
|
|
self.register_buffer(f"{prefix}_v_template_sp", v_template_sp, False)
|
|
|
|
# shapedirs
|
|
shapedirs = to_tensor(to_np(data_struct.shapedirs[:, :, :num_betas])) # (V, 3, NB)
|
|
shapedirs_sp = shapedirs[sp_vid]
|
|
self.register_buffer(f"{prefix}_shapedirs_sp", shapedirs_sp, False)
|
|
|
|
# posedirs
|
|
posedirs = to_tensor(to_np(data_struct.posedirs)) # (V, 3, 51*9)
|
|
posedirs_sp = posedirs[sp_vid]
|
|
posedirs_sp = posedirs_sp.reshape(len(sp_vid) * 3, -1).T # (51*9, N*3)
|
|
self.register_buffer(f"{prefix}_posedirs_sp", posedirs_sp, False)
|
|
|
|
# lbs_weights
|
|
lbs_weights = to_tensor(to_np(data_struct.weights)) # (V, J+1)
|
|
lbs_weights_sp = lbs_weights[sp_vid]
|
|
self.register_buffer(f"{prefix}_lbs_weights_sp", lbs_weights_sp, False)
|
|
|
|
def forward(
|
|
self,
|
|
root_orient=None,
|
|
pose_body=None,
|
|
trans=None,
|
|
betas=None,
|
|
A=None,
|
|
recompute_A=False,
|
|
genders=None,
|
|
joints_zero=None,
|
|
):
|
|
"""
|
|
Args:
|
|
root_orient, Optional: (B, T, 3)
|
|
pose_body: (B, T, J*3)
|
|
trans: (B, T, 3)
|
|
betas: (B, T, 16)
|
|
A, Optional: (B, T, J+1, 4, 4)
|
|
recompute_A: if True, root_orient should be given, otherwise use A
|
|
genders, List: ['male', 'female', ...]
|
|
joints_zero: (B, J+1, 3), required when recompute_A is True
|
|
Returns:
|
|
sensor_verts: (B, T, N, 3)
|
|
"""
|
|
B, T = pose_body.shape[:2]
|
|
|
|
v_template = torch.stack(
|
|
[getattr(self, f"{g}_v_template_sp") for g in genders]
|
|
) # (B, N, 3)
|
|
shapedirs = torch.stack(
|
|
[getattr(self, f"{g}_shapedirs_sp") for g in genders]
|
|
) # (B, N, 3, NB)
|
|
posedirs = torch.stack(
|
|
[getattr(self, f"{g}_posedirs_sp") for g in genders]
|
|
) # (B, 51*9, N*3)
|
|
lbs_weights = torch.stack(
|
|
[getattr(self, f"{g}_lbs_weights_sp") for g in genders]
|
|
) # (B, N, J+1)
|
|
|
|
# ===== LBS, handle T ===== #
|
|
# 2. Add shape contribution
|
|
if betas.shape[1] == 1:
|
|
betas = betas.expand(-1, T, -1)
|
|
blend_shape = torch.einsum("btl,bmkl->btmk", [betas, shapedirs])
|
|
v_shaped = v_template[:, None] + blend_shape
|
|
|
|
# 3. Add pose blend shapes
|
|
ident = torch.eye(3).to(pose_body)
|
|
aa = pose_body.reshape(B, T, -1, 3)
|
|
R = axis_angle_to_matrix(aa)
|
|
pose_feature = (R - ident).view(B, T, -1)
|
|
dim_pf = pose_feature.shape[-1]
|
|
# (B, T, P) @ (B, P, N*3) -> (B, T, N, 3)
|
|
pose_offsets = torch.matmul(pose_feature, posedirs[:, :dim_pf]).view(B, T, -1, 3)
|
|
v_posed = pose_offsets + v_shaped
|
|
|
|
# 4. Compute A
|
|
if recompute_A:
|
|
_, _, A = forward_kinematics_motion(root_orient, pose_body, trans, joints_zero)
|
|
|
|
# 5. Skinning
|
|
W = lbs_weights
|
|
# (B, 1, N, J+1)) @ (B, T, J+1, 16)
|
|
num_joints = A.shape[-3] # 22
|
|
Ts = torch.matmul(W[:, None, :, :num_joints], A.view(B, T, num_joints, 16))
|
|
Ts = Ts.view(B, T, -1, 4, 4) # (B, T, N, 4, 4)
|
|
v_posed_homo = F.pad(v_posed, (0, 1), value=1) # (B, T, N, 4)
|
|
v_homo = torch.matmul(Ts, torch.unsqueeze(v_posed_homo, dim=-1))
|
|
|
|
# 6. translate
|
|
sensor_verts = v_homo[:, :, :, :3, 0] + trans[:, :, None]
|
|
|
|
return sensor_verts
|