You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
969 lines
33 KiB
969 lines
33 KiB
"""
|
|
Visualization utilities for VPlanner.
|
|
|
|
Creates prediction plots for training visualization and WandB logging.
|
|
"""
|
|
|
|
import io
|
|
import numpy as np
|
|
import torch
|
|
from typing import Dict, Any, Optional
|
|
from pathlib import Path
|
|
from loguru import logger
|
|
|
|
import matplotlib
|
|
|
|
matplotlib.use("Agg")
|
|
import matplotlib.pyplot as plt
|
|
from mpl_toolkits.mplot3d import Axes3D
|
|
from PIL import Image
|
|
|
|
from groot.rl.isaac_utils.rotations import quat_rotate
|
|
from groot.rl.trl.utils.fk_utils import FKHelper
|
|
|
|
# Lazy import cv2 - only used in SkeletonVisualizer for real-time display
|
|
cv2 = None
|
|
|
|
|
|
def _get_cv2():
|
|
global cv2
|
|
if cv2 is None:
|
|
import cv2 as _cv2
|
|
|
|
cv2 = _cv2
|
|
return cv2
|
|
|
|
|
|
class SkeletonVisualizer:
|
|
"""
|
|
Real-time skeleton visualizer for VPlanner evaluation.
|
|
|
|
Renders a root-centric 3D skeleton from DOF predictions using matplotlib,
|
|
then converts to OpenCV image for display.
|
|
"""
|
|
|
|
def __init__(self, motion_lib, img_size: int = 400):
|
|
"""
|
|
Initialize skeleton visualizer.
|
|
|
|
Args:
|
|
motion_lib: MotionLibRobot instance for FK
|
|
img_size: Output image size in pixels
|
|
"""
|
|
self.fk_helper = FKHelper(motion_lib)
|
|
self.img_size = img_size
|
|
self.device = motion_lib.mesh_parsers.dof_axis.device
|
|
|
|
# Quaternion for root-centric rendering
|
|
# The skeleton rest pose may be in Y-up, so rotate -90 deg around X to make Z-up
|
|
# Rotation of -90 deg around X axis: quat = [cos(-45°), sin(-45°), 0, 0] in wxyz
|
|
import math
|
|
|
|
angle = -math.pi / 2 # -90 degrees
|
|
self.upright_quat = torch.tensor(
|
|
[math.cos(angle / 2), math.sin(angle / 2), 0.0, 0.0], # w # x # y # z
|
|
device=self.device,
|
|
dtype=torch.float32,
|
|
)
|
|
|
|
# Create persistent figure for faster rendering
|
|
self.fig = plt.figure(figsize=(4, 4), dpi=100)
|
|
self.ax = self.fig.add_subplot(111, projection="3d")
|
|
|
|
logger.info(f"SkeletonVisualizer initialized: {img_size}x{img_size}")
|
|
|
|
def render(
|
|
self,
|
|
dof_pos: torch.Tensor,
|
|
title: str = "Predicted Pose",
|
|
) -> np.ndarray:
|
|
"""
|
|
Render a single DOF pose as a root-centric skeleton.
|
|
|
|
Args:
|
|
dof_pos: [29] DOF positions (single frame)
|
|
title: Title to display on the image
|
|
|
|
Returns:
|
|
OpenCV BGR image [img_size, img_size, 3]
|
|
"""
|
|
# Ensure batch dimension
|
|
if dof_pos.dim() == 1:
|
|
dof_pos = dof_pos.unsqueeze(0) # [1, 29]
|
|
|
|
dof_pos = dof_pos.to(self.device)
|
|
|
|
# Root at origin, identity rotation for FK
|
|
root_pos = torch.zeros(1, 3, device=self.device)
|
|
root_rot6d = torch.tensor([[1, 0, 0, 0, 1, 0]], device=self.device, dtype=torch.float)
|
|
identity_quat = torch.tensor([1.0, 0.0, 0.0, 0.0], device=self.device)
|
|
|
|
# Compute body positions via FK
|
|
try:
|
|
body_pos = self.fk_helper.dof_to_body_pos(
|
|
dof_pos, root_pos, root_rot6d, identity_quat
|
|
) # [1, num_keypoints, 3]
|
|
body_pos = body_pos[0].cpu().numpy() # [num_keypoints, 3]
|
|
|
|
# Rotate so pelvis-to-torso direction becomes +Z (upward)
|
|
# Pelvis is index 0, torso is index 7
|
|
pelvis = body_pos[0]
|
|
torso = body_pos[7]
|
|
up_vec = torso - pelvis
|
|
|
|
# Find which axis has the largest component in up_vec - that's the current "up"
|
|
up_axis = np.argmax(np.abs(up_vec))
|
|
|
|
# Swap axes so that axis becomes Z
|
|
if up_axis == 0: # X is up -> swap X and Z
|
|
body_pos = body_pos[:, [2, 1, 0]] # XYZ -> ZYX
|
|
if up_vec[0] < 0: # pointing in -X, flip Z
|
|
body_pos[:, 2] = -body_pos[:, 2]
|
|
elif up_axis == 1: # Y is up -> swap Y and Z
|
|
body_pos = body_pos[:, [0, 2, 1]] # XYZ -> XZY
|
|
if up_vec[1] < 0: # pointing in -Y, flip Z
|
|
body_pos[:, 2] = -body_pos[:, 2]
|
|
# else: Z is already up, check sign
|
|
elif up_vec[2] < 0: # Z is up but pointing down
|
|
body_pos[:, 2] = -body_pos[:, 2]
|
|
except Exception as e:
|
|
logger.warning(f"FK failed: {e}")
|
|
# Return blank image on failure
|
|
blank = np.zeros((self.img_size, self.img_size, 3), dtype=np.uint8)
|
|
_get_cv2().putText(
|
|
blank,
|
|
"FK Failed",
|
|
(10, self.img_size // 2),
|
|
_get_cv2().FONT_HERSHEY_SIMPLEX,
|
|
0.7,
|
|
(255, 255, 255),
|
|
2,
|
|
)
|
|
return blank
|
|
|
|
# Clear and redraw
|
|
self.ax.clear()
|
|
|
|
# Plot directly - upright_quat already rotated skeleton to Z-up
|
|
plot_x = body_pos[:, 0]
|
|
plot_y = body_pos[:, 1]
|
|
plot_z = body_pos[:, 2]
|
|
|
|
# Draw skeleton bones
|
|
for start, end in FKHelper.SKELETON_BONES:
|
|
self.ax.plot(
|
|
[plot_x[start], plot_x[end]],
|
|
[plot_y[start], plot_y[end]],
|
|
[plot_z[start], plot_z[end]],
|
|
c="cyan",
|
|
linewidth=2,
|
|
)
|
|
|
|
# Draw joints
|
|
regular_mask = np.ones(len(body_pos), dtype=bool)
|
|
regular_mask[FKHelper.FOOT_INDICES + FKHelper.HAND_INDICES] = False
|
|
|
|
# Regular joints
|
|
self.ax.scatter(
|
|
plot_x[regular_mask], plot_y[regular_mask], plot_z[regular_mask], c="white", s=30
|
|
)
|
|
|
|
# Feet (orange)
|
|
self.ax.scatter(
|
|
plot_x[FKHelper.FOOT_INDICES],
|
|
plot_y[FKHelper.FOOT_INDICES],
|
|
plot_z[FKHelper.FOOT_INDICES],
|
|
c="orange",
|
|
s=50,
|
|
marker="^",
|
|
)
|
|
|
|
# Hands (purple)
|
|
self.ax.scatter(
|
|
plot_x[FKHelper.HAND_INDICES],
|
|
plot_y[FKHelper.HAND_INDICES],
|
|
plot_z[FKHelper.HAND_INDICES],
|
|
c="magenta",
|
|
s=50,
|
|
marker="o",
|
|
)
|
|
|
|
# Auto-scale axes based on data
|
|
origin = body_pos.mean(axis=0)
|
|
radius = max(0.5 * (body_pos.max(axis=0) - body_pos.min(axis=0)).max(), 0.5)
|
|
|
|
self.ax.set_xlim([origin[0] - radius, origin[0] + radius])
|
|
self.ax.set_ylim([origin[1] - radius, origin[1] + radius])
|
|
self.ax.set_zlim([origin[2] - radius, origin[2] + radius])
|
|
|
|
# Set view angle: looking from front-right, slightly above
|
|
self.ax.view_init(elev=20, azim=-135)
|
|
|
|
# Style
|
|
self.ax.set_facecolor((0.1, 0.1, 0.1))
|
|
self.ax.set_xlabel("X", color="gray", fontsize=8)
|
|
self.ax.set_ylabel("Y", color="gray", fontsize=8)
|
|
self.ax.set_zlabel("Z (up)", color="gray", fontsize=8)
|
|
self.ax.set_title(title, color="white", fontsize=10)
|
|
try:
|
|
self.ax.set_box_aspect([1, 1, 1])
|
|
except AttributeError:
|
|
pass # Older matplotlib
|
|
self.ax.tick_params(colors="gray")
|
|
|
|
# Convert figure to OpenCV image
|
|
self.fig.tight_layout()
|
|
self.fig.canvas.draw()
|
|
|
|
# Get RGBA buffer (compatible with newer matplotlib)
|
|
w, h = self.fig.canvas.get_width_height()
|
|
buf = np.asarray(self.fig.canvas.buffer_rgba())
|
|
img = buf[:, :, :3] # Drop alpha channel, keep RGB
|
|
|
|
# Resize to target size
|
|
img = _get_cv2().resize(img, (self.img_size, self.img_size))
|
|
|
|
# Convert RGB to BGR for OpenCV
|
|
img = _get_cv2().cvtColor(img, _get_cv2().COLOR_RGB2BGR)
|
|
|
|
return img
|
|
|
|
def close(self):
|
|
"""Close the matplotlib figure."""
|
|
plt.close(self.fig)
|
|
|
|
|
|
class VPlannerVisualizer:
|
|
"""
|
|
Visualization utilities for VPlanner predictions.
|
|
|
|
Creates multi-panel figures showing:
|
|
- Input images
|
|
- BEV trajectory with heading arrows
|
|
- 3D skeleton trajectories (GT and Pred)
|
|
- Comparison plots
|
|
"""
|
|
|
|
def __init__(self, fk_helper: FKHelper):
|
|
"""
|
|
Initialize visualizer.
|
|
|
|
Args:
|
|
fk_helper: FKHelper instance for forward kinematics
|
|
"""
|
|
self.fk_helper = fk_helper
|
|
|
|
def create_prediction_plots(
|
|
self,
|
|
batch: Dict[str, Any],
|
|
predictions: Dict[str, torch.Tensor],
|
|
labels: Dict[str, torch.Tensor],
|
|
num_samples: int = 4,
|
|
) -> plt.Figure:
|
|
"""
|
|
Create a figure with prediction visualizations.
|
|
|
|
Shows:
|
|
- Input image
|
|
- Top-down trajectory with heading arrows
|
|
- 3D skeleton trajectory (GT)
|
|
- 3D skeleton trajectory (Pred)
|
|
- 3D comparison (GT + Pred together)
|
|
|
|
Args:
|
|
batch: Input batch with images and metadata
|
|
predictions: Model predictions dict
|
|
labels: Ground truth labels dict
|
|
num_samples: Number of samples to visualize
|
|
|
|
Returns:
|
|
matplotlib Figure
|
|
"""
|
|
num_cols = 5
|
|
fig = plt.figure(figsize=(5 * num_cols, 5 * num_samples))
|
|
|
|
for i in range(num_samples):
|
|
self._plot_sample(fig, i, num_samples, num_cols, batch, predictions, labels)
|
|
|
|
plt.tight_layout()
|
|
return fig
|
|
|
|
def _plot_sample(
|
|
self,
|
|
fig: plt.Figure,
|
|
i: int,
|
|
num_samples: int,
|
|
num_cols: int,
|
|
batch: Dict[str, Any],
|
|
predictions: Dict[str, torch.Tensor],
|
|
labels: Dict[str, torch.Tensor],
|
|
):
|
|
"""Plot visualizations for a single sample."""
|
|
# Get data for this sample
|
|
image = batch["image"][i]
|
|
seq_name = batch["seq_name"][i]
|
|
frame_idx = batch["frame_idx"][i].item()
|
|
bev_bounds = batch["bev_bounds"][i]
|
|
|
|
# Get predictions and labels
|
|
pred_pos = predictions["future_root_pos"][i]
|
|
gt_pos = labels["future_root_pos"][i]
|
|
pred_rot6d = predictions["future_root_rot6d"][i]
|
|
gt_rot6d = labels["future_root_rot6d"][i]
|
|
pred_dof = predictions["future_dof_pos"][i]
|
|
gt_dof = labels["future_dof_pos"][i]
|
|
|
|
# Current frame reference
|
|
current_root_pos = labels["current_root_pos"][i]
|
|
current_root_quat = labels["current_root_quat"][i]
|
|
|
|
# Transform root positions to world frame
|
|
gt_pos_world = (
|
|
self.fk_helper.transform_to_world(gt_pos, current_root_quat) + current_root_pos
|
|
)
|
|
pred_pos_world = (
|
|
self.fk_helper.transform_to_world(pred_pos, current_root_quat) + current_root_pos
|
|
)
|
|
|
|
# Compute body positions via FK (same path for GT and Pred)
|
|
try:
|
|
gt_body_world = self.fk_helper.dof_to_body_pos(
|
|
gt_dof, gt_pos, gt_rot6d, current_root_quat
|
|
)
|
|
pred_body_world = self.fk_helper.dof_to_body_pos(
|
|
pred_dof, pred_pos, pred_rot6d, current_root_quat
|
|
)
|
|
gt_body_pos = (gt_body_world + current_root_pos).cpu().numpy()
|
|
pred_body_pos = (pred_body_world + current_root_pos).cpu().numpy()
|
|
fk_success = True
|
|
except Exception as e:
|
|
logger.warning(f"FK failed for sample {i}: {e}")
|
|
fk_success = False
|
|
|
|
# Convert to numpy
|
|
gt_pos_np = gt_pos_world.cpu().numpy()
|
|
pred_pos_np = pred_pos_world.cpu().numpy()
|
|
gt_rot_np = gt_rot6d.cpu().numpy()
|
|
pred_rot_np = pred_rot6d.cpu().numpy()
|
|
current_pos_np = current_root_pos.cpu().numpy()
|
|
num_future = len(gt_pos_np)
|
|
|
|
# --- Plot 1: Input image ---
|
|
ax = fig.add_subplot(num_samples, num_cols, i * num_cols + 1)
|
|
self._plot_image(ax, image, seq_name, frame_idx)
|
|
|
|
# --- Plot 2: Top-down trajectory ---
|
|
ax = fig.add_subplot(num_samples, num_cols, i * num_cols + 2)
|
|
self._plot_bev_trajectory(
|
|
ax,
|
|
gt_pos_np,
|
|
pred_pos_np,
|
|
gt_rot_np,
|
|
pred_rot_np,
|
|
current_root_quat,
|
|
current_pos_np,
|
|
num_future,
|
|
bev_bounds,
|
|
)
|
|
|
|
# --- Plot 3: 3D skeleton GT ---
|
|
ax = fig.add_subplot(num_samples, num_cols, i * num_cols + 3, projection="3d")
|
|
if fk_success:
|
|
self._plot_skeleton_3d(ax, gt_body_pos, color="green", title="GT")
|
|
else:
|
|
ax.set_title("FK failed", fontsize=8)
|
|
|
|
# --- Plot 4: 3D skeleton Pred ---
|
|
ax = fig.add_subplot(num_samples, num_cols, i * num_cols + 4, projection="3d")
|
|
if fk_success:
|
|
self._plot_skeleton_3d(ax, pred_body_pos, color="red", title="Pred")
|
|
else:
|
|
ax.set_title("FK failed", fontsize=8)
|
|
|
|
# --- Plot 5: 3D comparison ---
|
|
ax = fig.add_subplot(num_samples, num_cols, i * num_cols + 5, projection="3d")
|
|
if fk_success:
|
|
self._plot_skeleton_comparison_3d(ax, gt_body_pos, pred_body_pos)
|
|
else:
|
|
ax.set_title("FK failed", fontsize=8)
|
|
|
|
def _plot_image(self, ax, image: torch.Tensor, seq_name: str, frame_idx: int):
|
|
"""Plot input image (oldest history on top, current on bottom)."""
|
|
# Denormalize from ImageNet normalization
|
|
mean = torch.tensor([0.485, 0.456, 0.406]).view(3, 1, 1)
|
|
std = torch.tensor([0.229, 0.224, 0.225]).view(3, 1, 1)
|
|
|
|
if image.dim() == 4: # [T, C, H, W] - multiple frames
|
|
if image.shape[0] == 0:
|
|
# No images - show placeholder
|
|
ax.text(
|
|
0.5,
|
|
0.5,
|
|
"No images\n(num_history_frames_img=0)",
|
|
ha="center",
|
|
va="center",
|
|
transform=ax.transAxes,
|
|
fontsize=10,
|
|
)
|
|
ax.set_xlim(0, 1)
|
|
ax.set_ylim(0, 1)
|
|
else:
|
|
oldest = (image[0].cpu() * std + mean).permute(1, 2, 0).numpy()
|
|
current = (image[-1].cpu() * std + mean).permute(1, 2, 0).numpy()
|
|
# Concatenate vertically: oldest on top, current on bottom
|
|
img = np.concatenate([oldest, current], axis=0)
|
|
ax.imshow(np.clip(img, 0, 1))
|
|
else: # [C, H, W] - single frame
|
|
img = (image.cpu() * std + mean).permute(1, 2, 0).numpy()
|
|
ax.imshow(np.clip(img, 0, 1))
|
|
ax.set_title(f"{seq_name}\nframe {frame_idx}", fontsize=8)
|
|
ax.axis("off")
|
|
|
|
def _plot_bev_trajectory(
|
|
self,
|
|
ax,
|
|
gt_pos: np.ndarray,
|
|
pred_pos: np.ndarray,
|
|
gt_rot: np.ndarray,
|
|
pred_rot: np.ndarray,
|
|
current_root_quat: torch.Tensor,
|
|
current_pos: np.ndarray,
|
|
num_future: int,
|
|
bev_bounds: Dict[str, float],
|
|
):
|
|
"""Plot top-down trajectory with heading arrows."""
|
|
# Plot trajectories
|
|
ax.plot(gt_pos[:, 0], gt_pos[:, 1], "g-", linewidth=2, label="GT", alpha=0.7)
|
|
ax.plot(pred_pos[:, 0], pred_pos[:, 1], "r--", linewidth=2, label="Pred", alpha=0.7)
|
|
ax.scatter(
|
|
[current_pos[0]],
|
|
[current_pos[1]],
|
|
c="blue",
|
|
s=100,
|
|
marker="s",
|
|
zorder=5,
|
|
label="Current",
|
|
)
|
|
|
|
# Heading arrows
|
|
gt_fwd_world = (
|
|
quat_rotate(
|
|
current_root_quat.unsqueeze(0).expand(num_future, -1),
|
|
torch.tensor(gt_rot[:, :3], device=current_root_quat.device, dtype=torch.float),
|
|
w_last=False,
|
|
)
|
|
.cpu()
|
|
.numpy()
|
|
)
|
|
pred_fwd_world = (
|
|
quat_rotate(
|
|
current_root_quat.unsqueeze(0).expand(num_future, -1),
|
|
torch.tensor(pred_rot[:, :3], device=current_root_quat.device, dtype=torch.float),
|
|
w_last=False,
|
|
)
|
|
.cpu()
|
|
.numpy()
|
|
)
|
|
|
|
# Axis limits from full motion bounds
|
|
x_center = (bev_bounds["x_min"] + bev_bounds["x_max"]) / 2
|
|
y_center = (bev_bounds["y_min"] + bev_bounds["y_max"]) / 2
|
|
extent = (
|
|
max(
|
|
bev_bounds["x_max"] - bev_bounds["x_min"],
|
|
bev_bounds["y_max"] - bev_bounds["y_min"],
|
|
0.5,
|
|
)
|
|
* 1.1
|
|
)
|
|
arrow_len = extent * 0.02
|
|
|
|
for t in range(num_future):
|
|
alpha = 1.0 - 0.7 * (t / max(num_future - 1, 1))
|
|
|
|
# GT arrow
|
|
gt_fwd = gt_fwd_world[t, :2]
|
|
if np.linalg.norm(gt_fwd) > 0.1:
|
|
gt_fwd = gt_fwd / np.linalg.norm(gt_fwd)
|
|
ax.arrow(
|
|
gt_pos[t, 0],
|
|
gt_pos[t, 1],
|
|
gt_fwd[0] * arrow_len,
|
|
gt_fwd[1] * arrow_len,
|
|
head_width=arrow_len * 0.4,
|
|
head_length=arrow_len * 0.3,
|
|
fc="green",
|
|
ec="green",
|
|
alpha=alpha,
|
|
zorder=4,
|
|
)
|
|
|
|
# Pred arrow
|
|
pred_fwd = pred_fwd_world[t, :2]
|
|
if np.linalg.norm(pred_fwd) > 0.1:
|
|
pred_fwd = pred_fwd / np.linalg.norm(pred_fwd)
|
|
ax.arrow(
|
|
pred_pos[t, 0],
|
|
pred_pos[t, 1],
|
|
pred_fwd[0] * arrow_len,
|
|
pred_fwd[1] * arrow_len,
|
|
head_width=arrow_len * 0.4,
|
|
head_length=arrow_len * 0.3,
|
|
fc="red",
|
|
ec="red",
|
|
alpha=alpha,
|
|
zorder=4,
|
|
)
|
|
|
|
ax.set_xlim(x_center - extent / 2, x_center + extent / 2)
|
|
ax.set_ylim(y_center + extent / 2, y_center - extent / 2) # Flipped: larger Y at bottom
|
|
ax.set_xlabel("X")
|
|
ax.set_ylabel("Y")
|
|
ax.set_title("BEV", fontsize=8)
|
|
ax.legend(fontsize=6, loc="upper left")
|
|
ax.set_aspect("equal")
|
|
ax.grid(True, alpha=0.3)
|
|
|
|
def _plot_skeleton_3d(self, ax, body_pos_seq: np.ndarray, color: str, title: str):
|
|
"""Plot 3D skeleton trajectory."""
|
|
num_frames = len(body_pos_seq)
|
|
frames_to_show = list(range(0, num_frames, 5))
|
|
if (num_frames - 1) not in frames_to_show:
|
|
frames_to_show.append(num_frames - 1)
|
|
|
|
for t in frames_to_show:
|
|
body_pos = body_pos_seq[t]
|
|
alpha = 1.0 - 0.7 * (t / max(num_frames - 1, 1))
|
|
|
|
# Joints
|
|
regular_mask = np.ones(len(body_pos), dtype=bool)
|
|
regular_mask[FKHelper.FOOT_INDICES + FKHelper.HAND_INDICES] = False
|
|
ax.scatter(
|
|
body_pos[regular_mask, 0],
|
|
body_pos[regular_mask, 1],
|
|
body_pos[regular_mask, 2],
|
|
c=color,
|
|
s=15,
|
|
alpha=alpha,
|
|
)
|
|
ax.scatter(
|
|
body_pos[FKHelper.FOOT_INDICES, 0],
|
|
body_pos[FKHelper.FOOT_INDICES, 1],
|
|
body_pos[FKHelper.FOOT_INDICES, 2],
|
|
c="orange",
|
|
s=25,
|
|
alpha=alpha,
|
|
marker="^",
|
|
)
|
|
ax.scatter(
|
|
body_pos[FKHelper.HAND_INDICES, 0],
|
|
body_pos[FKHelper.HAND_INDICES, 1],
|
|
body_pos[FKHelper.HAND_INDICES, 2],
|
|
c="purple",
|
|
s=25,
|
|
alpha=alpha,
|
|
marker="o",
|
|
)
|
|
|
|
# Bones
|
|
for start, end in FKHelper.SKELETON_BONES:
|
|
ax.plot(
|
|
[body_pos[start, 0], body_pos[end, 0]],
|
|
[body_pos[start, 1], body_pos[end, 1]],
|
|
[body_pos[start, 2], body_pos[end, 2]],
|
|
c=color,
|
|
linewidth=1.5,
|
|
alpha=alpha,
|
|
)
|
|
|
|
ax.set_xlabel("X")
|
|
ax.set_ylabel("Y")
|
|
ax.set_zlabel("Z")
|
|
ax.set_title(title, fontsize=8)
|
|
self._set_3d_axes_equal(ax, body_pos_seq.reshape(-1, 3))
|
|
|
|
def _plot_skeleton_comparison_3d(self, ax, gt_body_pos: np.ndarray, pred_body_pos: np.ndarray):
|
|
"""Plot GT and Pred skeletons together."""
|
|
num_frames = len(gt_body_pos)
|
|
frames_to_show = list(range(0, num_frames, 5))
|
|
if (num_frames - 1) not in frames_to_show:
|
|
frames_to_show.append(num_frames - 1)
|
|
|
|
for t in frames_to_show:
|
|
alpha = 1.0 - 0.7 * (t / max(num_frames - 1, 1))
|
|
|
|
for pos, color, style in [
|
|
(gt_body_pos[t], "green", "-"),
|
|
(pred_body_pos[t], "red", "--"),
|
|
]:
|
|
for start, end in FKHelper.SKELETON_BONES:
|
|
ax.plot(
|
|
[pos[start, 0], pos[end, 0]],
|
|
[pos[start, 1], pos[end, 1]],
|
|
[pos[start, 2], pos[end, 2]],
|
|
c=color,
|
|
linewidth=1.5,
|
|
alpha=alpha,
|
|
linestyle=style,
|
|
)
|
|
|
|
# Markers on GT only
|
|
ax.scatter(
|
|
gt_body_pos[t][FKHelper.FOOT_INDICES, 0],
|
|
gt_body_pos[t][FKHelper.FOOT_INDICES, 1],
|
|
gt_body_pos[t][FKHelper.FOOT_INDICES, 2],
|
|
c="orange",
|
|
s=20,
|
|
alpha=alpha,
|
|
marker="^",
|
|
)
|
|
ax.scatter(
|
|
gt_body_pos[t][FKHelper.HAND_INDICES, 0],
|
|
gt_body_pos[t][FKHelper.HAND_INDICES, 1],
|
|
gt_body_pos[t][FKHelper.HAND_INDICES, 2],
|
|
c="purple",
|
|
s=20,
|
|
alpha=alpha,
|
|
marker="o",
|
|
)
|
|
|
|
ax.set_xlabel("X")
|
|
ax.set_ylabel("Y")
|
|
ax.set_zlabel("Z")
|
|
ax.set_title("GT (green) vs Pred (red)", fontsize=8)
|
|
all_pos = np.concatenate([gt_body_pos.reshape(-1, 3), pred_body_pos.reshape(-1, 3)])
|
|
self._set_3d_axes_equal(ax, all_pos)
|
|
|
|
def _set_3d_axes_equal(self, ax, positions: np.ndarray):
|
|
"""Set 3D axes to equal aspect ratio with Y-axis flipped (larger Y at bottom)."""
|
|
origin = positions.mean(axis=0)
|
|
radius = 0.5 * max(positions.max(axis=0) - positions.min(axis=0))
|
|
radius = max(radius, 0.2)
|
|
ax.set_xlim3d([origin[0] - radius, origin[0] + radius])
|
|
ax.set_ylim3d([origin[1] + radius, origin[1] - radius]) # Flipped: larger Y at bottom
|
|
ax.set_zlim3d([origin[2] - radius, origin[2] + radius])
|
|
ax.set_box_aspect([1, 1, 1])
|
|
|
|
def create_terminal_prediction_plots(
|
|
self,
|
|
batch: Dict[str, Any],
|
|
predictions: Dict[str, torch.Tensor],
|
|
labels: Dict[str, torch.Tensor],
|
|
num_samples: int = 4,
|
|
) -> plt.Figure:
|
|
"""
|
|
Create a figure with terminal pose prediction visualizations.
|
|
|
|
Shows:
|
|
- Input image
|
|
- BEV with current position and terminal position (GT + Pred) with arrows
|
|
- 3D skeleton comparison (GT + Pred terminal pose)
|
|
|
|
Args:
|
|
batch: Input batch with images and metadata
|
|
predictions: Model predictions dict (terminal_*)
|
|
labels: Ground truth labels dict
|
|
num_samples: Number of samples to visualize
|
|
|
|
Returns:
|
|
matplotlib Figure
|
|
"""
|
|
num_cols = 3 # Image, BEV, 3D skeleton
|
|
fig = plt.figure(figsize=(5 * num_cols, 5 * num_samples))
|
|
|
|
for i in range(num_samples):
|
|
self._plot_terminal_sample(fig, i, num_samples, num_cols, batch, predictions, labels)
|
|
|
|
plt.tight_layout()
|
|
return fig
|
|
|
|
def _plot_terminal_sample(
|
|
self,
|
|
fig: plt.Figure,
|
|
i: int,
|
|
num_samples: int,
|
|
num_cols: int,
|
|
batch: Dict[str, Any],
|
|
predictions: Dict[str, torch.Tensor],
|
|
labels: Dict[str, torch.Tensor],
|
|
):
|
|
"""Plot visualizations for a single terminal prediction sample."""
|
|
# Get data for this sample
|
|
image = batch["image"][i]
|
|
seq_name = batch["seq_name"][i]
|
|
frame_idx = batch["frame_idx"][i].item()
|
|
bev_bounds = batch["bev_bounds"][i]
|
|
|
|
# Get predictions and labels (terminal = single frame, not trajectory)
|
|
pred_pos = predictions["terminal_root_pos"][i] # [3]
|
|
gt_pos = labels["terminal_root_pos"][i] # [3]
|
|
pred_rot6d = predictions["terminal_root_rot6d"][i] # [6]
|
|
gt_rot6d = labels["terminal_root_rot6d"][i] # [6]
|
|
pred_dof = predictions["terminal_dof_pos"][i] # [num_dofs]
|
|
gt_dof = labels["terminal_dof_pos"][i] # [num_dofs]
|
|
|
|
# Current frame reference
|
|
current_root_pos = labels["current_root_pos"][i]
|
|
current_root_quat = labels["current_root_quat"][i]
|
|
|
|
# Transform terminal positions to world frame
|
|
gt_pos_world = (
|
|
self.fk_helper.transform_to_world(gt_pos.unsqueeze(0), current_root_quat).squeeze(0)
|
|
+ current_root_pos
|
|
)
|
|
pred_pos_world = (
|
|
self.fk_helper.transform_to_world(pred_pos.unsqueeze(0), current_root_quat).squeeze(0)
|
|
+ current_root_pos
|
|
)
|
|
|
|
# Compute body positions via FK (add batch dim for FK)
|
|
try:
|
|
gt_body_world = self.fk_helper.dof_to_body_pos(
|
|
gt_dof.unsqueeze(0), gt_pos.unsqueeze(0), gt_rot6d.unsqueeze(0), current_root_quat
|
|
)
|
|
pred_body_world = self.fk_helper.dof_to_body_pos(
|
|
pred_dof.unsqueeze(0),
|
|
pred_pos.unsqueeze(0),
|
|
pred_rot6d.unsqueeze(0),
|
|
current_root_quat,
|
|
)
|
|
gt_body_pos = (gt_body_world[0] + current_root_pos).cpu().numpy() # [num_keypoints, 3]
|
|
pred_body_pos = (
|
|
(pred_body_world[0] + current_root_pos).cpu().numpy()
|
|
) # [num_keypoints, 3]
|
|
fk_success = True
|
|
except Exception as e:
|
|
logger.warning(f"FK failed for sample {i}: {e}")
|
|
fk_success = False
|
|
|
|
# Convert to numpy
|
|
gt_pos_np = gt_pos_world.cpu().numpy()
|
|
pred_pos_np = pred_pos_world.cpu().numpy()
|
|
gt_rot_np = gt_rot6d.cpu().numpy()
|
|
pred_rot_np = pred_rot6d.cpu().numpy()
|
|
current_pos_np = current_root_pos.cpu().numpy()
|
|
|
|
# --- Plot 1: Input image ---
|
|
ax = fig.add_subplot(num_samples, num_cols, i * num_cols + 1)
|
|
self._plot_image(ax, image, seq_name, frame_idx)
|
|
|
|
# --- Plot 2: BEV with terminal positions and arrows ---
|
|
ax = fig.add_subplot(num_samples, num_cols, i * num_cols + 2)
|
|
self._plot_terminal_bev(
|
|
ax,
|
|
gt_pos_np,
|
|
pred_pos_np,
|
|
gt_rot_np,
|
|
pred_rot_np,
|
|
current_root_quat,
|
|
current_pos_np,
|
|
bev_bounds,
|
|
)
|
|
|
|
# --- Plot 3: 3D skeleton comparison ---
|
|
ax = fig.add_subplot(num_samples, num_cols, i * num_cols + 3, projection="3d")
|
|
if fk_success:
|
|
self._plot_terminal_skeleton_3d(ax, gt_body_pos, pred_body_pos)
|
|
else:
|
|
ax.set_title("FK failed", fontsize=8)
|
|
|
|
def _plot_terminal_bev(
|
|
self,
|
|
ax,
|
|
gt_pos: np.ndarray,
|
|
pred_pos: np.ndarray,
|
|
gt_rot: np.ndarray,
|
|
pred_rot: np.ndarray,
|
|
current_root_quat: torch.Tensor,
|
|
current_pos: np.ndarray,
|
|
bev_bounds: Dict[str, float],
|
|
):
|
|
"""Plot BEV with current position, terminal GT and terminal Pred with heading arrows."""
|
|
# Plot current position
|
|
ax.scatter(
|
|
[current_pos[0]],
|
|
[current_pos[1]],
|
|
c="blue",
|
|
s=150,
|
|
marker="s",
|
|
zorder=5,
|
|
label="Current",
|
|
)
|
|
|
|
# Plot terminal positions
|
|
ax.scatter(
|
|
[gt_pos[0]], [gt_pos[1]], c="green", s=150, marker="*", zorder=5, label="GT Terminal"
|
|
)
|
|
ax.scatter(
|
|
[pred_pos[0]],
|
|
[pred_pos[1]],
|
|
c="red",
|
|
s=150,
|
|
marker="*",
|
|
zorder=5,
|
|
label="Pred Terminal",
|
|
)
|
|
|
|
# Draw lines from current to terminal
|
|
ax.plot(
|
|
[current_pos[0], gt_pos[0]], [current_pos[1], gt_pos[1]], "g--", linewidth=2, alpha=0.5
|
|
)
|
|
ax.plot(
|
|
[current_pos[0], pred_pos[0]],
|
|
[current_pos[1], pred_pos[1]],
|
|
"r--",
|
|
linewidth=2,
|
|
alpha=0.5,
|
|
)
|
|
|
|
# Axis limits from full motion bounds
|
|
x_center = (bev_bounds["x_min"] + bev_bounds["x_max"]) / 2
|
|
y_center = (bev_bounds["y_min"] + bev_bounds["y_max"]) / 2
|
|
extent = (
|
|
max(
|
|
bev_bounds["x_max"] - bev_bounds["x_min"],
|
|
bev_bounds["y_max"] - bev_bounds["y_min"],
|
|
0.5,
|
|
)
|
|
* 1.1
|
|
)
|
|
arrow_len = extent * 0.05
|
|
|
|
# Heading arrows for GT
|
|
gt_fwd_world = (
|
|
quat_rotate(
|
|
current_root_quat.unsqueeze(0),
|
|
torch.tensor(
|
|
gt_rot[:3], device=current_root_quat.device, dtype=torch.float
|
|
).unsqueeze(0),
|
|
w_last=False,
|
|
)
|
|
.cpu()
|
|
.numpy()[0]
|
|
)
|
|
gt_fwd = gt_fwd_world[:2]
|
|
if np.linalg.norm(gt_fwd) > 0.1:
|
|
gt_fwd = gt_fwd / np.linalg.norm(gt_fwd)
|
|
ax.arrow(
|
|
gt_pos[0],
|
|
gt_pos[1],
|
|
gt_fwd[0] * arrow_len,
|
|
gt_fwd[1] * arrow_len,
|
|
head_width=arrow_len * 0.4,
|
|
head_length=arrow_len * 0.3,
|
|
fc="green",
|
|
ec="green",
|
|
zorder=6,
|
|
linewidth=2,
|
|
)
|
|
|
|
# Heading arrows for Pred
|
|
pred_fwd_world = (
|
|
quat_rotate(
|
|
current_root_quat.unsqueeze(0),
|
|
torch.tensor(
|
|
pred_rot[:3], device=current_root_quat.device, dtype=torch.float
|
|
).unsqueeze(0),
|
|
w_last=False,
|
|
)
|
|
.cpu()
|
|
.numpy()[0]
|
|
)
|
|
pred_fwd = pred_fwd_world[:2]
|
|
if np.linalg.norm(pred_fwd) > 0.1:
|
|
pred_fwd = pred_fwd / np.linalg.norm(pred_fwd)
|
|
ax.arrow(
|
|
pred_pos[0],
|
|
pred_pos[1],
|
|
pred_fwd[0] * arrow_len,
|
|
pred_fwd[1] * arrow_len,
|
|
head_width=arrow_len * 0.4,
|
|
head_length=arrow_len * 0.3,
|
|
fc="red",
|
|
ec="red",
|
|
zorder=6,
|
|
linewidth=2,
|
|
)
|
|
|
|
ax.set_xlim(x_center - extent / 2, x_center + extent / 2)
|
|
ax.set_ylim(y_center + extent / 2, y_center - extent / 2) # Flipped: larger Y at bottom
|
|
ax.set_xlabel("X")
|
|
ax.set_ylabel("Y")
|
|
ax.set_title("Terminal BEV", fontsize=8)
|
|
ax.legend(fontsize=6, loc="upper left")
|
|
ax.set_aspect("equal")
|
|
ax.grid(True, alpha=0.3)
|
|
|
|
def _plot_terminal_skeleton_3d(self, ax, gt_body_pos: np.ndarray, pred_body_pos: np.ndarray):
|
|
"""Plot GT and Pred terminal skeletons together."""
|
|
for pos, color, label in [(gt_body_pos, "green", "GT"), (pred_body_pos, "red", "Pred")]:
|
|
# Bones
|
|
for start, end in FKHelper.SKELETON_BONES:
|
|
ax.plot(
|
|
[pos[start, 0], pos[end, 0]],
|
|
[pos[start, 1], pos[end, 1]],
|
|
[pos[start, 2], pos[end, 2]],
|
|
c=color,
|
|
linewidth=2,
|
|
alpha=0.8,
|
|
)
|
|
|
|
# Joints
|
|
regular_mask = np.ones(len(pos), dtype=bool)
|
|
regular_mask[FKHelper.FOOT_INDICES + FKHelper.HAND_INDICES] = False
|
|
ax.scatter(
|
|
pos[regular_mask, 0],
|
|
pos[regular_mask, 1],
|
|
pos[regular_mask, 2],
|
|
c=color,
|
|
s=20,
|
|
alpha=0.8,
|
|
)
|
|
|
|
# Special markers on GT
|
|
ax.scatter(
|
|
gt_body_pos[FKHelper.FOOT_INDICES, 0],
|
|
gt_body_pos[FKHelper.FOOT_INDICES, 1],
|
|
gt_body_pos[FKHelper.FOOT_INDICES, 2],
|
|
c="orange",
|
|
s=40,
|
|
marker="^",
|
|
label="Feet",
|
|
)
|
|
ax.scatter(
|
|
gt_body_pos[FKHelper.HAND_INDICES, 0],
|
|
gt_body_pos[FKHelper.HAND_INDICES, 1],
|
|
gt_body_pos[FKHelper.HAND_INDICES, 2],
|
|
c="purple",
|
|
s=40,
|
|
marker="o",
|
|
label="Hands",
|
|
)
|
|
|
|
ax.set_xlabel("X")
|
|
ax.set_ylabel("Y")
|
|
ax.set_zlabel("Z")
|
|
ax.set_title("GT (green) vs Pred (red)", fontsize=8)
|
|
all_pos = np.concatenate([gt_body_pos, pred_body_pos])
|
|
self._set_3d_axes_equal(ax, all_pos)
|
|
|
|
def save_figure(self, fig: plt.Figure, save_dir: Path, step: int, wandb_log: bool = False):
|
|
"""Save figure to disk and optionally log to wandb."""
|
|
import wandb
|
|
from groot.rl.trl.utils.common import wandb_run_exists
|
|
|
|
vis_dir = save_dir / "visualizations"
|
|
vis_dir.mkdir(exist_ok=True, parents=True)
|
|
fig.savefig(vis_dir / f"predictions_step_{step:06d}.png", dpi=100, bbox_inches="tight")
|
|
|
|
if wandb_log and wandb_run_exists():
|
|
buf = io.BytesIO()
|
|
fig.savefig(buf, format="png", dpi=100, bbox_inches="tight")
|
|
buf.seek(0)
|
|
wandb.log(
|
|
{"vis/predictions": wandb.Image(Image.open(buf), caption=f"Step {step}")}, step=step
|
|
)
|
|
buf.close()
|
|
|
|
plt.close(fig)
|
|
logger.info(f"Saved visualization for step {step}")
|