You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
875 lines
36 KiB
875 lines
36 KiB
"""
|
|
MuJoCo Visualizer Class
|
|
|
|
A standalone visualizer for MuJoCo simulations that supports both interactive viewing
|
|
and offline rendering for video recording. Extracted and refactored from the
|
|
MetricNeuralRetarget callback.
|
|
|
|
Features:
|
|
- Interactive viewer with keyboard controls
|
|
- Offline rendering for video recording
|
|
- SMPL joints visualization as 3D spheres
|
|
- Side-by-side comparison of ground truth and predicted poses
|
|
- Headless rendering support (EGL/OSMesa)
|
|
- Configurable camera settings and rendering parameters
|
|
"""
|
|
|
|
import logging
|
|
import os
|
|
import tempfile
|
|
import threading
|
|
import time
|
|
from typing import Dict, List, Optional, Union
|
|
import xml.etree.ElementTree as ET
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
|
|
# Configure Mesa for headless rendering before any MuJoCo imports
|
|
def _configure_headless_rendering():
|
|
"""Configure environment for headless MuJoCo rendering"""
|
|
# Set MuJoCo to use EGL for hardware-accelerated offscreen rendering
|
|
if "MUJOCO_GL" not in os.environ:
|
|
os.environ["MUJOCO_GL"] = "egl"
|
|
|
|
# Set PyOpenGL platform for EGL
|
|
if "PYOPENGL_PLATFORM" not in os.environ:
|
|
os.environ["PYOPENGL_PLATFORM"] = "egl"
|
|
|
|
# Fallback to OSMesa if EGL is not available
|
|
if os.environ.get("MUJOCO_GL") == "osmesa":
|
|
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
|
os.environ["LIBGL_ALWAYS_SOFTWARE"] = "1"
|
|
|
|
|
|
# Configure headless rendering before importing MuJoCo
|
|
_configure_headless_rendering()
|
|
|
|
# MuJoCo imports for visualization
|
|
try:
|
|
import imageio
|
|
import mujoco
|
|
import mujoco.viewer
|
|
|
|
MUJOCO_AVAILABLE = True
|
|
logging.info(
|
|
f"MuJoCo available with rendering backend: {os.environ.get('MUJOCO_GL', 'default')}"
|
|
)
|
|
except ImportError as e:
|
|
MUJOCO_AVAILABLE = False
|
|
logging.warning(f"MuJoCo not available, visualization will be disabled: {e}")
|
|
except Exception as e:
|
|
MUJOCO_AVAILABLE = False
|
|
logging.warning(f"MuJoCo import failed, visualization will be disabled: {e}")
|
|
|
|
|
|
class MuJoCoVisualizer:
|
|
"""
|
|
Standalone MuJoCo visualizer supporting interactive viewing and offline rendering.
|
|
|
|
Features:
|
|
- Interactive viewer with keyboard controls:
|
|
- R: Reset to first frame
|
|
- Space: Pause/unpause animation
|
|
- N/P: Next/previous frame
|
|
- G: Toggle ground truth robot visibility
|
|
- T: Toggle predicted robot visibility
|
|
- S: Toggle SMPL joints visibility
|
|
- Offline rendering for video recording
|
|
- SMPL joints visualization as 3D spheres
|
|
- Side-by-side comparison support
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
xml_path: str,
|
|
enable_interactive: bool = True,
|
|
enable_video_recording: bool = False,
|
|
video_output_dir: str = "./videos",
|
|
video_width: int = 1280,
|
|
video_height: int = 720,
|
|
video_fps: int = 30,
|
|
smpl_sphere_radius: float = 0.02,
|
|
fps: int = 30,
|
|
realtime_mode: bool = False,
|
|
):
|
|
"""
|
|
Initialize MuJoCo visualizer.
|
|
|
|
Args:
|
|
xml_path: Path to MuJoCo XML model file
|
|
enable_interactive: Enable interactive viewer
|
|
enable_video_recording: Enable video recording
|
|
video_output_dir: Directory for video output
|
|
video_width: Video width in pixels
|
|
video_height: Video height in pixels
|
|
video_fps: Video frame rate
|
|
smpl_sphere_radius: Radius of SMPL joint spheres
|
|
fps: Simulation/animation frame rate
|
|
realtime_mode: If True, only visualize latest frame without buffering (default: False)
|
|
"""
|
|
self.xml_path = xml_path
|
|
self.enable_interactive = enable_interactive and MUJOCO_AVAILABLE
|
|
self.enable_video_recording = enable_video_recording and MUJOCO_AVAILABLE
|
|
self.realtime_mode = realtime_mode
|
|
|
|
# Video recording parameters
|
|
self.video_output_dir = video_output_dir
|
|
self.video_width = video_width
|
|
self.video_height = video_height
|
|
self.video_fps = video_fps
|
|
self.video_writer = None
|
|
self.offscreen_renderer = None
|
|
self.camera = None
|
|
|
|
# MuJoCo visualization state
|
|
self.mj_model = None
|
|
self.mj_data = None
|
|
self.viewer = None
|
|
self.viewer_thread = None
|
|
|
|
# Animation data buffers
|
|
if self.realtime_mode:
|
|
# Real-time mode: only store latest frame
|
|
self.latest_qpos_gt = None
|
|
self.latest_qpos_pred = None
|
|
self.latest_smpl_joints_gt = None
|
|
self.latest_smpl_joints_pred = None
|
|
logging.info("MuJoCo visualizer initialized in REAL-TIME mode (latest frame only)")
|
|
else:
|
|
# Buffered mode: store full trajectory
|
|
self.qpos_gt_buffer = (
|
|
[]
|
|
) # Ground truth qpos (translation + quaternion + joint positions)
|
|
self.qpos_pred_buffer = (
|
|
[]
|
|
) # Predicted qpos (translation + quaternion + joint positions)
|
|
self.smpl_joints_gt_buffer = [] # Ground truth SMPL joints (B x J x 3)
|
|
self.smpl_joints_pred_buffer = [] # Predicted SMPL joints (B x J x 3)
|
|
logging.info("MuJoCo visualizer initialized in BUFFER mode (full trajectory)")
|
|
|
|
# Animation control
|
|
self.current_frame = 0
|
|
self.paused = False
|
|
self.fps = fps
|
|
self.dt = 1.0 / self.fps
|
|
|
|
# Visibility toggles
|
|
self.show_gt = True # Show ground truth robot
|
|
self.show_pred = True # Show predicted robot
|
|
self.show_smpl_joints = True # Show SMPL joints as spheres
|
|
|
|
# SMPL visualization
|
|
self.sphere_radius = smpl_sphere_radius
|
|
self.smpl_sphere_sites = [] # List to store SMPL joint sphere site IDs
|
|
|
|
# Initialize MuJoCo model
|
|
self._init_mujoco_model()
|
|
|
|
def _create_xml_with_smpl_sites(self) -> str:
|
|
"""Create a modified XML file with SMPL joint sites"""
|
|
# Read the original XML file
|
|
tree = ET.parse(self.xml_path)
|
|
root = tree.getroot()
|
|
|
|
# Fix include paths to be absolute
|
|
xml_dir = os.path.dirname(os.path.abspath(self.xml_path))
|
|
for include_elem in root.findall("include"):
|
|
file_attr = include_elem.get("file")
|
|
if file_attr and not os.path.isabs(file_attr):
|
|
# Convert relative path to absolute path
|
|
abs_path = os.path.join(xml_dir, file_attr)
|
|
include_elem.set("file", abs_path)
|
|
|
|
# Find the worldbody element
|
|
worldbody = root.find("worldbody")
|
|
if worldbody is None:
|
|
return self.xml_path # Return original if no worldbody found
|
|
|
|
# Add SMPL joint sites for ground truth (blue spheres)
|
|
for j in range(24):
|
|
site = ET.SubElement(worldbody, "site")
|
|
site.set("name", f"smpl_gt_joint_{j}")
|
|
site.set("pos", "0 0 0") # Will be updated dynamically
|
|
site.set("size", str(self.sphere_radius))
|
|
site.set("rgba", "0 0 1 0.8") # Blue for GT
|
|
site.set("type", "sphere")
|
|
|
|
# Add SMPL joint sites for predictions (red spheres)
|
|
for j in range(24):
|
|
site = ET.SubElement(worldbody, "site")
|
|
site.set("name", f"smpl_pred_joint_{j}")
|
|
site.set("pos", "0 0 0") # Will be updated dynamically
|
|
site.set("size", str(self.sphere_radius))
|
|
site.set("rgba", "1 0 0 0.8") # Red for predictions
|
|
site.set("type", "sphere")
|
|
|
|
# Save the modified XML to a temporary file in the same directory as the original
|
|
temp_xml = tempfile.NamedTemporaryFile(mode="w", suffix=".xml", delete=False, dir=xml_dir)
|
|
tree.write(temp_xml.name, encoding="unicode", xml_declaration=True)
|
|
temp_xml.close()
|
|
|
|
return temp_xml.name
|
|
|
|
def _init_mujoco_model(self):
|
|
"""Initialize MuJoCo model and data"""
|
|
if not self.enable_interactive and not self.enable_video_recording:
|
|
return
|
|
|
|
# Log current rendering configuration
|
|
current_backend = os.environ.get("MUJOCO_GL", "default")
|
|
logging.info(f"Initializing MuJoCo model with rendering backend: {current_backend}")
|
|
|
|
try:
|
|
# Create XML with SMPL sites
|
|
xml_path = self._create_xml_with_smpl_sites()
|
|
logging.info(f"Created modified XML with SMPL sites: {xml_path}")
|
|
|
|
self.mj_model = mujoco.MjModel.from_xml_path(xml_path)
|
|
self.mj_data = mujoco.MjData(self.mj_model)
|
|
self.mj_model.opt.timestep = self.dt
|
|
logging.info("MuJoCo model loaded successfully with SMPL joint sites")
|
|
|
|
# Clean up temporary XML file if it's different from original
|
|
if xml_path != self.xml_path:
|
|
try:
|
|
os.unlink(xml_path)
|
|
logging.info(f"Cleaned up temporary XML file: {xml_path}")
|
|
except Exception as cleanup_e:
|
|
logging.warning(
|
|
f"Failed to clean up temporary XML file {xml_path}: {cleanup_e}"
|
|
)
|
|
|
|
# Initialize offline renderer for video recording
|
|
if self.enable_video_recording:
|
|
self._init_offscreen_renderer()
|
|
|
|
except Exception as e:
|
|
logging.error(f"Failed to load MuJoCo model with {current_backend}: {e}")
|
|
logging.info(f"Falling back to original XML file: {self.xml_path}")
|
|
|
|
# Try to load the original XML file as fallback
|
|
try:
|
|
self.mj_model = mujoco.MjModel.from_xml_path(self.xml_path)
|
|
self.mj_data = mujoco.MjData(self.mj_model)
|
|
self.mj_model.opt.timestep = self.dt
|
|
logging.info("Successfully loaded original MuJoCo model (without SMPL sites)")
|
|
|
|
# Disable SMPL joints visualization since sites weren't added
|
|
self.show_smpl_joints = False
|
|
|
|
if self.enable_video_recording:
|
|
self._init_offscreen_renderer()
|
|
|
|
except Exception as fallback_e:
|
|
logging.error(f"Failed to load original MuJoCo model as fallback: {fallback_e}")
|
|
|
|
# If we're in a headless environment, disable interactive visualization but keep video recording
|
|
if current_backend in ["egl", "osmesa"]:
|
|
logging.warning(
|
|
(
|
|
"Headless environment detected, disabling interactive "
|
|
"visualization but keeping video recording"
|
|
)
|
|
)
|
|
self.enable_interactive = False
|
|
# Try to keep video recording enabled if possible
|
|
if self.enable_video_recording:
|
|
try:
|
|
self._init_offscreen_renderer()
|
|
except Exception as video_e:
|
|
logging.error(f"Video recording also failed: {video_e}")
|
|
self.enable_video_recording = False
|
|
else:
|
|
self.enable_interactive = False
|
|
self.enable_video_recording = False
|
|
|
|
def _init_offscreen_renderer(self):
|
|
"""Initialize MuJoCo offscreen renderer for video recording"""
|
|
if not self.enable_video_recording or self.mj_model is None:
|
|
return
|
|
|
|
try:
|
|
# Ensure headless rendering is configured
|
|
current_backend = os.environ.get("MUJOCO_GL", "default")
|
|
logging.info(f"Initializing offscreen renderer with backend: {current_backend}")
|
|
|
|
# Create offscreen rendering context
|
|
self.offscreen_renderer = mujoco.Renderer(
|
|
self.mj_model, height=self.video_height, width=self.video_width
|
|
)
|
|
|
|
# Create camera for rendering
|
|
self.camera = mujoco.MjvCamera()
|
|
mujoco.mjv_defaultCamera(self.camera)
|
|
|
|
# Set camera parameters for side-by-side view
|
|
self.camera.distance = 3.5
|
|
self.camera.azimuth = 180.0
|
|
self.camera.elevation = -0.0
|
|
self.camera.lookat[:] = [0.0, 0.0, 0.5] # Look at center between robots
|
|
|
|
logging.info(
|
|
f"Offscreen renderer initialized successfully - "
|
|
f"Resolution: {self.video_width}x{self.video_height} @ "
|
|
f"{self.video_fps} FPS"
|
|
)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Failed to initialize offscreen renderer with {current_backend}: {e}")
|
|
|
|
# Try fallback to OSMesa if EGL failed
|
|
if current_backend == "egl":
|
|
logging.info("Attempting fallback to OSMesa for software rendering...")
|
|
try:
|
|
os.environ["MUJOCO_GL"] = "osmesa"
|
|
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
|
os.environ["LIBGL_ALWAYS_SOFTWARE"] = "1"
|
|
|
|
# Recreate renderer with OSMesa
|
|
self.offscreen_renderer = mujoco.Renderer(
|
|
self.mj_model, height=self.video_height, width=self.video_width
|
|
)
|
|
|
|
# Create camera for rendering
|
|
self.camera = mujoco.MjvCamera()
|
|
mujoco.mjv_defaultCamera(self.camera)
|
|
|
|
# Set camera parameters for side-by-side view
|
|
self.camera.distance = 3.5
|
|
self.camera.azimuth = 90.0
|
|
self.camera.elevation = -0.0
|
|
self.camera.lookat[:] = [0.0, 0.0, 0.5]
|
|
|
|
logging.info(
|
|
f"OSMesa fallback successful - Resolution: "
|
|
f"{self.video_width}x{self.video_height} @ "
|
|
f"{self.video_fps} FPS"
|
|
)
|
|
|
|
except Exception as fallback_e:
|
|
logging.error(f"OSMesa fallback also failed: {fallback_e}")
|
|
self.enable_video_recording = False
|
|
else:
|
|
self.enable_video_recording = False
|
|
|
|
def _key_callback(self, keycode):
|
|
"""Keyboard callback for MuJoCo viewer"""
|
|
if chr(keycode) == "R":
|
|
print("Reset")
|
|
self.current_frame = 0
|
|
elif chr(keycode) == " ":
|
|
print("Paused")
|
|
self.paused = not self.paused
|
|
elif chr(keycode) == "N":
|
|
print("Next frame")
|
|
max_frames = max(len(self.qpos_gt_buffer), len(self.qpos_pred_buffer))
|
|
if self.current_frame < max_frames - 1:
|
|
self.current_frame += 1
|
|
elif chr(keycode) == "P":
|
|
print("Previous frame")
|
|
if self.current_frame > 0:
|
|
self.current_frame -= 1
|
|
elif chr(keycode) == "G":
|
|
self.show_gt = not self.show_gt
|
|
print(f"Ground truth robot: {'ON' if self.show_gt else 'OFF'}")
|
|
elif chr(keycode) == "T":
|
|
self.show_pred = not self.show_pred
|
|
print(f"Predicted robot: {'ON' if self.show_pred else 'OFF'}")
|
|
elif chr(keycode) == "S":
|
|
self.show_smpl_joints = not self.show_smpl_joints
|
|
print(f"SMPL joints: {'ON' if self.show_smpl_joints else 'OFF'}")
|
|
else:
|
|
print(
|
|
(
|
|
"Controls: R=Reset, Space=Pause, N=Next frame, P=Previous frame, "
|
|
"G=Toggle GT robot, T=Toggle predicted robot, S=Toggle SMPL joints"
|
|
)
|
|
)
|
|
|
|
def _update_smpl_joints(self, frame_idx):
|
|
"""Update SMPL joint positions for current frame"""
|
|
if not self.show_smpl_joints or self.mj_data is None:
|
|
return
|
|
|
|
# Get SMPL joints for current frame
|
|
gt_joints = None
|
|
pred_joints = None
|
|
|
|
if self.realtime_mode:
|
|
# Real-time mode: use latest joints
|
|
gt_joints = self.latest_smpl_joints_gt
|
|
pred_joints = self.latest_smpl_joints_pred
|
|
else:
|
|
# Buffered mode: use frame index
|
|
if frame_idx < len(self.smpl_joints_gt_buffer):
|
|
gt_joints = self.smpl_joints_gt_buffer[frame_idx]
|
|
if frame_idx < len(self.smpl_joints_pred_buffer):
|
|
pred_joints = self.smpl_joints_pred_buffer[frame_idx]
|
|
|
|
# Update site positions for SMPL joints
|
|
try:
|
|
# Update GT SMPL joint sites (blue spheres)
|
|
if gt_joints is not None and self.show_gt:
|
|
for j in range(min(gt_joints.shape[0], 24)): # Ensure we don't exceed 24 joints
|
|
site_name = f"smpl_gt_joint_{j}"
|
|
site_id = mujoco.mj_name2id(self.mj_model, mujoco.mjtObj.mjOBJ_SITE, site_name)
|
|
if site_id >= 0:
|
|
pos = gt_joints[j].copy()
|
|
# Adjust position to match GT robot position (left side)
|
|
pos[0] -= 1.0 # Move to left side like GT robot
|
|
pos[2] += 0.793 # Height adjustment
|
|
self.mj_data.site_xpos[site_id] = pos
|
|
|
|
# Update predicted SMPL joint sites (red spheres)
|
|
if pred_joints is not None and self.show_pred:
|
|
for j in range(min(pred_joints.shape[0], 24)): # Ensure we don't exceed 24 joints
|
|
site_name = f"smpl_pred_joint_{j}"
|
|
site_id = mujoco.mj_name2id(self.mj_model, mujoco.mjtObj.mjOBJ_SITE, site_name)
|
|
if site_id >= 0:
|
|
pos = pred_joints[j].copy()
|
|
# Adjust position to match predicted robot position (right side)
|
|
pos[0] += 1.0 # Move to right side like predicted robot
|
|
pos[2] += 0.793 # Height adjustment
|
|
self.mj_data.site_xpos[site_id] = pos
|
|
|
|
except Exception:
|
|
# Silently handle any rendering errors to avoid crashing the viewer
|
|
pass
|
|
|
|
def _update_robot_poses(self, frame_idx):
|
|
"""Update robot poses for current frame"""
|
|
if self.mj_data is None:
|
|
return
|
|
|
|
if self.realtime_mode:
|
|
# Real-time mode: use latest frames
|
|
# Ground truth robot (left side) - first robot in the model
|
|
if self.show_gt and self.latest_qpos_gt is not None:
|
|
qpos_gt = self.latest_qpos_gt
|
|
|
|
# Set GT robot full qpos (translation + quaternion + joint positions)
|
|
if qpos_gt.shape[0] >= 36: # Full qpos: 3 (trans) + 4 (quat) + 29 (joints)
|
|
self.mj_data.qpos[0:36] = qpos_gt[:36] # GT robot full qpos
|
|
else:
|
|
self.mj_data.qpos[0 : qpos_gt.shape[0]] = qpos_gt
|
|
|
|
# Adjust GT robot position for side-by-side visualization
|
|
self.mj_data.qpos[0] = -1.0 # Move GT robot to left side
|
|
elif not self.show_gt:
|
|
# Hide GT robot by moving it far away
|
|
self.mj_data.qpos[0:3] = [-100, 0, -10]
|
|
|
|
# Predicted robot (right side) - second robot in the model
|
|
if self.show_pred and self.latest_qpos_pred is not None:
|
|
qpos_pred = self.latest_qpos_pred
|
|
|
|
# Set predicted robot full qpos (second robot in dual robot scene)
|
|
pred_start_idx = 36 # After GT robot's full qpos (36 DOFs)
|
|
if qpos_pred.shape[0] >= 36: # Full qpos: 3 (trans) + 4 (quat) + 29 (joints)
|
|
self.mj_data.qpos[pred_start_idx : pred_start_idx + 36] = qpos_pred[:36]
|
|
else:
|
|
self.mj_data.qpos[pred_start_idx : pred_start_idx + qpos_pred.shape[0]] = (
|
|
qpos_pred
|
|
)
|
|
|
|
# Adjust predicted robot position for side-by-side visualization
|
|
self.mj_data.qpos[pred_start_idx + 0] = 1.0 # Move predicted robot to right side
|
|
elif not self.show_pred:
|
|
# Hide predicted robot by moving it far away
|
|
pred_start_idx = 36
|
|
self.mj_data.qpos[pred_start_idx : pred_start_idx + 3] = [100, 0, -10]
|
|
else:
|
|
# Buffered mode: use frame index
|
|
max_frames = max(len(self.qpos_gt_buffer), len(self.qpos_pred_buffer))
|
|
if frame_idx >= max_frames:
|
|
return
|
|
|
|
# Ground truth robot (left side) - first robot in the model
|
|
if self.show_gt and frame_idx < len(self.qpos_gt_buffer):
|
|
qpos_gt = self.qpos_gt_buffer[frame_idx]
|
|
|
|
# Set GT robot full qpos (translation + quaternion + joint positions)
|
|
if qpos_gt.shape[0] >= 36: # Full qpos: 3 (trans) + 4 (quat) + 29 (joints)
|
|
self.mj_data.qpos[0:36] = qpos_gt[:36] # GT robot full qpos
|
|
else:
|
|
self.mj_data.qpos[0 : qpos_gt.shape[0]] = qpos_gt
|
|
|
|
# Adjust GT robot position for side-by-side visualization
|
|
self.mj_data.qpos[0] = -1.0 # Move GT robot to left side
|
|
elif not self.show_gt:
|
|
# Hide GT robot by moving it far away
|
|
self.mj_data.qpos[0:3] = [-100, 0, -10]
|
|
|
|
# Predicted robot (right side) - second robot in the model
|
|
if self.show_pred and frame_idx < len(self.qpos_pred_buffer):
|
|
qpos_pred = self.qpos_pred_buffer[frame_idx]
|
|
|
|
# Set predicted robot full qpos (second robot in dual robot scene)
|
|
pred_start_idx = 36 # After GT robot's full qpos (36 DOFs)
|
|
if qpos_pred.shape[0] >= 36: # Full qpos: 3 (trans) + 4 (quat) + 29 (joints)
|
|
self.mj_data.qpos[pred_start_idx : pred_start_idx + 36] = qpos_pred[:36]
|
|
else:
|
|
self.mj_data.qpos[pred_start_idx : pred_start_idx + qpos_pred.shape[0]] = (
|
|
qpos_pred
|
|
)
|
|
|
|
# Adjust predicted robot position for side-by-side visualization
|
|
self.mj_data.qpos[pred_start_idx + 0] = 1.0 # Move predicted robot to right side
|
|
elif not self.show_pred:
|
|
# Hide predicted robot by moving it far away
|
|
pred_start_idx = 36
|
|
self.mj_data.qpos[pred_start_idx : pred_start_idx + 3] = [100, 0, -10]
|
|
|
|
def _run_interactive_viewer(self):
|
|
"""Run MuJoCo viewer in a separate thread"""
|
|
if not self.enable_interactive or self.mj_model is None:
|
|
return
|
|
|
|
try:
|
|
with mujoco.viewer.launch_passive(
|
|
self.mj_model, self.mj_data, key_callback=self._key_callback
|
|
) as viewer:
|
|
self.viewer = viewer
|
|
# Set camera position
|
|
viewer.cam.distance = 15.0
|
|
viewer.cam.azimuth = 90.0
|
|
viewer.cam.elevation = -20.0
|
|
|
|
while viewer.is_running():
|
|
step_start = time.time()
|
|
|
|
if self.realtime_mode:
|
|
# Real-time mode: always show latest frame
|
|
if self.latest_qpos_gt is not None or self.latest_qpos_pred is not None:
|
|
# Update robot poses
|
|
self._update_robot_poses(0) # frame_idx not used in realtime mode
|
|
|
|
# Forward simulation to update visualization
|
|
mujoco.mj_forward(self.mj_model, self.mj_data)
|
|
|
|
# Update SMPL joints
|
|
self._update_smpl_joints(0) # frame_idx not used in realtime mode
|
|
|
|
viewer.sync()
|
|
else:
|
|
# Buffered mode: iterate through frames
|
|
if len(self.qpos_gt_buffer) > 0 or len(self.qpos_pred_buffer) > 0:
|
|
|
|
# Update robot poses
|
|
self._update_robot_poses(self.current_frame)
|
|
|
|
# Forward simulation to update visualization
|
|
mujoco.mj_forward(self.mj_model, self.mj_data)
|
|
|
|
# Update SMPL joints
|
|
self._update_smpl_joints(self.current_frame)
|
|
|
|
# Auto-advance frames if not paused
|
|
max_frames = max(len(self.qpos_gt_buffer), len(self.qpos_pred_buffer))
|
|
if not self.paused and max_frames > 1:
|
|
self.current_frame = (self.current_frame + 1) % max_frames
|
|
|
|
viewer.sync()
|
|
|
|
# Control frame rate
|
|
time_until_next_step = self.dt - (time.time() - step_start)
|
|
if time_until_next_step > 0:
|
|
time.sleep(time_until_next_step)
|
|
|
|
except Exception as e:
|
|
logging.error(f"Failed to launch MuJoCo viewer: {e}")
|
|
logging.info("Disabling interactive visualization, keeping video recording if enabled")
|
|
self.enable_interactive = False
|
|
|
|
def _render_offline_frame(self, frame_idx):
|
|
"""Render a single frame using offline renderer for video recording"""
|
|
if not self.enable_video_recording or self.offscreen_renderer is None:
|
|
return None
|
|
|
|
try:
|
|
# Update robot poses
|
|
self._update_robot_poses(frame_idx)
|
|
|
|
# Forward simulation to update visualization
|
|
mujoco.mj_forward(self.mj_model, self.mj_data)
|
|
|
|
# Update SMPL joint positions for offline rendering
|
|
self._update_smpl_joints(frame_idx)
|
|
|
|
# Update scene and render frame
|
|
self.offscreen_renderer.update_scene(self.mj_data, camera=self.camera)
|
|
frame = self.offscreen_renderer.render()
|
|
|
|
return frame
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error rendering offline frame {frame_idx}: {e}")
|
|
return None
|
|
|
|
def add_trajectory_data(
|
|
self,
|
|
qpos_gt: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
|
qpos_pred: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
|
smpl_joints_gt: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
|
smpl_joints_pred: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
|
):
|
|
"""
|
|
Add trajectory data to visualization buffers or update latest frame (realtime mode).
|
|
|
|
Args:
|
|
qpos_gt: Ground truth joint positions (B, T, DOF) or (T, DOF) or (DOF,)
|
|
qpos_pred: Predicted joint positions (B, T, DOF) or (T, DOF) or (DOF,)
|
|
smpl_joints_gt: Ground truth SMPL joints (B, T, 24, 3) or (T, 24, 3) or (24, 3)
|
|
smpl_joints_pred: Predicted SMPL joints (B, T, 24, 3) or (T, 24, 3) or (24, 3)
|
|
"""
|
|
# Convert tensors to numpy arrays
|
|
if qpos_gt is not None:
|
|
if torch.is_tensor(qpos_gt):
|
|
qpos_gt = qpos_gt.detach().cpu().numpy()
|
|
|
|
if self.realtime_mode:
|
|
# Real-time mode: just store the latest frame
|
|
if qpos_gt.ndim > 1:
|
|
self.latest_qpos_gt = (
|
|
qpos_gt[-1]
|
|
if qpos_gt.ndim == 2
|
|
else qpos_gt.reshape(-1, qpos_gt.shape[-1])[-1]
|
|
)
|
|
else:
|
|
self.latest_qpos_gt = qpos_gt
|
|
else:
|
|
# Buffered mode: add to buffer
|
|
self._add_qpos_data(qpos_gt, self.qpos_gt_buffer)
|
|
|
|
if qpos_pred is not None:
|
|
if torch.is_tensor(qpos_pred):
|
|
qpos_pred = qpos_pred.detach().cpu().numpy()
|
|
|
|
if self.realtime_mode:
|
|
# Real-time mode: just store the latest frame
|
|
if qpos_pred.ndim > 1:
|
|
self.latest_qpos_pred = (
|
|
qpos_pred[-1]
|
|
if qpos_pred.ndim == 2
|
|
else qpos_pred.reshape(-1, qpos_pred.shape[-1])[-1]
|
|
)
|
|
else:
|
|
self.latest_qpos_pred = qpos_pred
|
|
else:
|
|
# Buffered mode: add to buffer
|
|
self._add_qpos_data(qpos_pred, self.qpos_pred_buffer)
|
|
|
|
if smpl_joints_gt is not None:
|
|
if torch.is_tensor(smpl_joints_gt):
|
|
smpl_joints_gt = smpl_joints_gt.detach().cpu().numpy()
|
|
|
|
if self.realtime_mode:
|
|
# Real-time mode: just store the latest frame
|
|
if smpl_joints_gt.ndim == 2: # (24, 3)
|
|
self.latest_smpl_joints_gt = smpl_joints_gt
|
|
elif smpl_joints_gt.ndim == 3: # (T, 24, 3) or (B, 24, 3)
|
|
self.latest_smpl_joints_gt = smpl_joints_gt[-1]
|
|
elif smpl_joints_gt.ndim == 4: # (B, T, 24, 3)
|
|
self.latest_smpl_joints_gt = smpl_joints_gt.reshape(-1, 24, 3)[-1]
|
|
else:
|
|
# Buffered mode: add to buffer
|
|
self._add_smpl_data(smpl_joints_gt, self.smpl_joints_gt_buffer)
|
|
|
|
if smpl_joints_pred is not None:
|
|
if torch.is_tensor(smpl_joints_pred):
|
|
smpl_joints_pred = smpl_joints_pred.detach().cpu().numpy()
|
|
|
|
if self.realtime_mode:
|
|
# Real-time mode: just store the latest frame
|
|
if smpl_joints_pred.ndim == 2: # (24, 3)
|
|
self.latest_smpl_joints_pred = smpl_joints_pred
|
|
elif smpl_joints_pred.ndim == 3: # (T, 24, 3) or (B, 24, 3)
|
|
self.latest_smpl_joints_pred = smpl_joints_pred[-1]
|
|
elif smpl_joints_pred.ndim == 4: # (B, T, 24, 3)
|
|
self.latest_smpl_joints_pred = smpl_joints_pred.reshape(-1, 24, 3)[-1]
|
|
else:
|
|
# Buffered mode: add to buffer
|
|
self._add_smpl_data(smpl_joints_pred, self.smpl_joints_pred_buffer)
|
|
|
|
def _add_qpos_data(self, qpos_data: np.ndarray, buffer: List):
|
|
"""Add qpos data to buffer, handling different dimensions"""
|
|
if qpos_data.ndim == 3: # (batch, seq, qpos)
|
|
for b in range(qpos_data.shape[0]):
|
|
for t in range(qpos_data.shape[1]):
|
|
buffer.append(qpos_data[b, t])
|
|
elif qpos_data.ndim == 2: # (seq, qpos) or (batch, qpos)
|
|
if qpos_data.shape[1] > 50: # Assume (seq, qpos) if many DOFs
|
|
for t in range(qpos_data.shape[0]):
|
|
buffer.append(qpos_data[t])
|
|
else: # Assume (batch, qpos)
|
|
for b in range(qpos_data.shape[0]):
|
|
buffer.append(qpos_data[b])
|
|
else: # Single frame
|
|
buffer.append(qpos_data)
|
|
|
|
def _add_smpl_data(self, smpl_data: np.ndarray, buffer: List):
|
|
"""Add SMPL joints data to buffer, handling different dimensions"""
|
|
# Reshape to ensure proper format and center joints
|
|
if smpl_data.ndim == 4: # (batch, seq, joints, 3)
|
|
smpl_data = smpl_data.reshape(-1, 24, 3)
|
|
elif smpl_data.ndim == 3: # (seq, joints, 3) or (batch, joints, 3)
|
|
if smpl_data.shape[1] == 24: # (seq, 24, 3) or (batch, 24, 3)
|
|
smpl_data = smpl_data.reshape(-1, 24, 3)
|
|
elif smpl_data.ndim == 2: # (joints, 3)
|
|
smpl_data = smpl_data.reshape(1, 24, 3)
|
|
|
|
# Center joints relative to root joint (joint 0)
|
|
smpl_data = smpl_data - smpl_data[:, [0], :]
|
|
|
|
# Add to buffer
|
|
for i in range(smpl_data.shape[0]):
|
|
buffer.append(smpl_data[i])
|
|
|
|
def start_interactive_viewer(self):
|
|
"""Start interactive viewer in a separate thread"""
|
|
if self.enable_interactive and (
|
|
self.viewer_thread is None or not self.viewer_thread.is_alive()
|
|
):
|
|
self.viewer_thread = threading.Thread(target=self._run_interactive_viewer, daemon=True)
|
|
self.viewer_thread.start()
|
|
logging.info("Started MuJoCo interactive visualization thread")
|
|
|
|
def create_video(self, output_path: str, clear_buffers: bool = True) -> bool:
|
|
"""
|
|
Create video from stored trajectory data using offline rendering.
|
|
|
|
Args:
|
|
output_path: Path for output video file
|
|
clear_buffers: Whether to clear buffers after creating video
|
|
|
|
Returns:
|
|
True if video was created successfully, False otherwise
|
|
"""
|
|
if not self.enable_video_recording or (
|
|
len(self.qpos_gt_buffer) == 0 and len(self.qpos_pred_buffer) == 0
|
|
):
|
|
logging.warning("Video recording not enabled or no data available")
|
|
return False
|
|
|
|
logging.info(f"Creating video: {output_path}")
|
|
|
|
# Create output directory if it doesn't exist
|
|
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
|
|
|
# Initialize video writer
|
|
try:
|
|
video_writer = imageio.get_writer(
|
|
output_path, fps=self.video_fps, codec="libx264", quality=4, pixelformat="yuv420p"
|
|
)
|
|
except Exception as e:
|
|
logging.error(f"Failed to create video writer: {e}")
|
|
return False
|
|
|
|
max_frames = max(len(self.qpos_gt_buffer), len(self.qpos_pred_buffer))
|
|
|
|
try:
|
|
for frame_idx in range(max_frames):
|
|
frame = self._render_offline_frame(frame_idx)
|
|
if frame is not None:
|
|
video_writer.append_data(frame)
|
|
|
|
# Log progress every 10% of frames
|
|
if frame_idx % max(1, max_frames // 10) == 0:
|
|
progress = (frame_idx + 1) / max_frames * 100
|
|
logging.info(
|
|
f"Video rendering progress: {progress:.1f}% ({frame_idx + 1}/{max_frames})"
|
|
)
|
|
|
|
video_writer.close()
|
|
logging.info(f"Video saved successfully: {output_path}")
|
|
|
|
if clear_buffers:
|
|
self.clear_buffers()
|
|
|
|
return True
|
|
|
|
except Exception as e:
|
|
logging.error(f"Error creating video: {e}")
|
|
video_writer.close()
|
|
return False
|
|
|
|
def clear_buffers(self):
|
|
"""Clear all trajectory data buffers"""
|
|
if self.realtime_mode:
|
|
self.latest_qpos_gt = None
|
|
self.latest_qpos_pred = None
|
|
self.latest_smpl_joints_gt = None
|
|
self.latest_smpl_joints_pred = None
|
|
logging.info("Cleared latest frame data (realtime mode)")
|
|
else:
|
|
self.qpos_gt_buffer.clear()
|
|
self.qpos_pred_buffer.clear()
|
|
self.smpl_joints_gt_buffer.clear()
|
|
self.smpl_joints_pred_buffer.clear()
|
|
self.current_frame = 0
|
|
logging.info("Cleared all trajectory buffers")
|
|
|
|
def set_camera_params(
|
|
self,
|
|
distance: float = 3.5,
|
|
azimuth: float = 90.0,
|
|
elevation: float = 0.0,
|
|
lookat: List[float] = [0.0, 0.0, 0.5],
|
|
):
|
|
"""Set camera parameters for offline rendering"""
|
|
if self.camera is not None:
|
|
self.camera.distance = distance
|
|
self.camera.azimuth = azimuth
|
|
self.camera.elevation = elevation
|
|
self.camera.lookat[:] = lookat
|
|
logging.info(
|
|
f"Camera parameters updated: "
|
|
f"distance={distance}, "
|
|
f"azimuth={azimuth}, "
|
|
f"elevation={elevation}, "
|
|
f"lookat={lookat}"
|
|
)
|
|
|
|
def get_status(self) -> Dict:
|
|
"""Get current status of the visualizer"""
|
|
status = {
|
|
"mujoco_available": MUJOCO_AVAILABLE,
|
|
"interactive_enabled": self.enable_interactive,
|
|
"video_recording_enabled": self.enable_video_recording,
|
|
"realtime_mode": self.realtime_mode,
|
|
"model_loaded": self.mj_model is not None,
|
|
"paused": self.paused,
|
|
"show_gt": self.show_gt,
|
|
"show_pred": self.show_pred,
|
|
"show_smpl_joints": self.show_smpl_joints,
|
|
"viewer_running": self.viewer_thread is not None and self.viewer_thread.is_alive(),
|
|
}
|
|
|
|
if self.realtime_mode:
|
|
status.update(
|
|
{
|
|
"has_gt_data": self.latest_qpos_gt is not None,
|
|
"has_pred_data": self.latest_qpos_pred is not None,
|
|
"has_smpl_gt_data": self.latest_smpl_joints_gt is not None,
|
|
"has_smpl_pred_data": self.latest_smpl_joints_pred is not None,
|
|
}
|
|
)
|
|
else:
|
|
status.update(
|
|
{
|
|
"gt_frames": len(self.qpos_gt_buffer),
|
|
"pred_frames": len(self.qpos_pred_buffer),
|
|
"smpl_gt_frames": len(self.smpl_joints_gt_buffer),
|
|
"smpl_pred_frames": len(self.smpl_joints_pred_buffer),
|
|
"current_frame": self.current_frame,
|
|
}
|
|
)
|
|
|
|
return status
|
|
|
|
def __del__(self):
|
|
"""Cleanup resources"""
|
|
if hasattr(self, "video_writer") and self.video_writer is not None:
|
|
self.video_writer.close()
|