You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

233 lines
9.8 KiB

import os
import cv2
import json
import datetime
import numpy as np
import time
from .rerun_visualizer import RerunLogger
from queue import Queue, Empty
from threading import Thread
import logging_mp
logger_mp = logging_mp.getLogger(__name__)
class EpisodeWriter():
def __init__(self, task_dir, task_goal=None, task_desc = None, task_steps = None, frequency=30, image_size=[640, 480], rerun_log = True):
"""
image_size: [width, height]
"""
logger_mp.info("==> EpisodeWriter initializing...\n")
self.task_dir = task_dir
self.text = {
"goal": "Pick up the red cup on the table.",
"desc": "task description",
"steps":"step1: do this; step2: do that; ...",
}
if task_goal is not None:
self.text['goal'] = task_goal
if task_desc is not None:
self.text['desc'] = task_desc
if task_steps is not None:
self.text['steps'] = task_steps
self.frequency = frequency
self.image_size = image_size
self.rerun_log = rerun_log
if self.rerun_log:
logger_mp.info("==> RerunLogger initializing...\n")
self.rerun_logger = RerunLogger(prefix="online/", IdxRangeBoundary = 60, memory_limit = "300MB")
logger_mp.info("==> RerunLogger initializing ok.\n")
self.item_id = -1
self.episode_id = -1
if os.path.exists(self.task_dir):
episode_dirs = [episode_dir for episode_dir in os.listdir(self.task_dir) if 'episode_' in episode_dir and not episode_dir.endswith('.zip')]
episode_last = sorted(episode_dirs)[-1] if len(episode_dirs) > 0 else None
self.episode_id = 0 if episode_last is None else int(episode_last.split('_')[-1])
logger_mp.info(f"==> task_dir directory already exist, now self.episode_id is:{self.episode_id}\n")
else:
os.makedirs(self.task_dir)
logger_mp.info(f"==> episode directory does not exist, now create one.\n")
self.data_info()
self.is_available = True # Indicates whether the class is available for new operations
# Initialize the queue and worker thread
self.item_data_queue = Queue(-1)
self.stop_worker = False
self.need_save = False # Flag to indicate when save_episode is triggered
self.worker_thread = Thread(target=self.process_queue)
self.worker_thread.start()
logger_mp.info("==> EpisodeWriter initialized successfully.\n")
def is_ready(self):
return self.is_available
def data_info(self, version='1.0.0', date=None, author=None):
self.info = {
"version": "1.0.0" if version is None else version,
"date": datetime.date.today().strftime('%Y-%m-%d') if date is None else date,
"author": "unitree" if author is None else author,
"image": {"width":self.image_size[0], "height":self.image_size[1], "fps":self.frequency},
"depth": {"width":self.image_size[0], "height":self.image_size[1], "fps":self.frequency},
"audio": {"sample_rate": 16000, "channels": 1, "format":"PCM", "bits":16}, # PCM_S16
"joint_names":{
"left_arm": [],
"left_ee": [],
"right_arm": [],
"right_ee": [],
"body": [],
},
"tactile_names": {
"left_ee": [],
"right_ee": [],
},
"sim_state": ""
}
def create_episode(self):
"""
Create a new episode.
Returns:
bool: True if the episode is successfully created, False otherwise.
Note:
Once successfully created, this function will only be available again after save_episode complete its save task.
"""
if not self.is_available:
logger_mp.info("==> The class is currently unavailable for new operations. Please wait until ongoing tasks are completed.")
return False # Return False if the class is unavailable
# Reset episode-related data and create necessary directories
self.item_id = -1
self.episode_id = self.episode_id + 1
self.episode_dir = os.path.join(self.task_dir, f"episode_{str(self.episode_id).zfill(4)}")
self.color_dir = os.path.join(self.episode_dir, 'colors')
self.depth_dir = os.path.join(self.episode_dir, 'depths')
self.audio_dir = os.path.join(self.episode_dir, 'audios')
self.json_path = os.path.join(self.episode_dir, 'data.json')
os.makedirs(self.episode_dir, exist_ok=True)
os.makedirs(self.color_dir, exist_ok=True)
os.makedirs(self.depth_dir, exist_ok=True)
os.makedirs(self.audio_dir, exist_ok=True)
with open(self.json_path, "w", encoding="utf-8") as f:
f.write('{\n')
f.write('"info": ' + json.dumps(self.info, ensure_ascii=False, indent=4) + ',\n')
f.write('"text": ' + json.dumps(self.text, ensure_ascii=False, indent=4) + ',\n')
f.write('"data": [\n')
self.first_item = True # Flag to handle commas in JSON array
if self.rerun_log:
self.online_logger = RerunLogger(prefix="online/", IdxRangeBoundary = 60, memory_limit="300MB")
self.is_available = False # After the episode is created, the class is marked as unavailable until the episode is successfully saved
logger_mp.info(f"==> New episode created: {self.episode_dir}")
return True # Return True if the episode is successfully created
def add_item(self, colors, depths=None, states=None, actions=None, tactiles=None, audios=None, sim_state=None):
# Increment the item ID
self.item_id += 1
# Create the item data dictionary
item_data = {
'idx': self.item_id,
'colors': colors,
'depths': depths,
'states': states,
'actions': actions,
'tactiles': tactiles,
'audios': audios,
'sim_state': sim_state,
}
# Enqueue the item data
self.item_data_queue.put(item_data)
def process_queue(self):
while not self.stop_worker or not self.item_data_queue.empty():
# Process items in the queue
try:
item_data = self.item_data_queue.get(timeout=1)
try:
self._process_item_data(item_data)
except Exception as e:
logger_mp.info(f"Error processing item_data (idx={item_data['idx']}): {e}")
self.item_data_queue.task_done()
except Empty:
pass
# Check if save_episode was triggered
if self.need_save and self.item_data_queue.empty():
self._save_episode()
def _process_item_data(self, item_data):
idx = item_data['idx']
colors = item_data.get('colors', {})
depths = item_data.get('depths', {})
audios = item_data.get('audios', {})
# Save images
if colors:
for idx_color, (color_key, color) in enumerate(colors.items()):
color_name = f'{str(idx).zfill(6)}_{color_key}.jpg'
if not cv2.imwrite(os.path.join(self.color_dir, color_name), color):
logger_mp.info(f"Failed to save color image.")
item_data['colors'][color_key] = os.path.join('colors', color_name)
# Save depths
if depths:
for idx_depth, (depth_key, depth) in enumerate(depths.items()):
depth_name = f'{str(idx).zfill(6)}_{depth_key}.jpg'
if not cv2.imwrite(os.path.join(self.depth_dir, depth_name), depth):
logger_mp.info(f"Failed to save depth image.")
item_data['depths'][depth_key] = os.path.join('depths', depth_name)
# Save audios
if audios:
for mic, audio in audios.items():
audio_name = f'audio_{str(idx).zfill(6)}_{mic}.npy'
np.save(os.path.join(self.audio_dir, audio_name), audio.astype(np.int16))
item_data['audios'][mic] = os.path.join('audios', audio_name)
# Update episode data
with open(self.json_path, "a", encoding="utf-8") as f:
if not self.first_item:
f.write(",\n")
f.write(json.dumps(item_data, ensure_ascii=False, indent=4))
self.first_item = False
# Log data if necessary
if self.rerun_log:
curent_record_time = time.time()
logger_mp.info(f"==> episode_id:{self.episode_id} item_id:{idx} current_time:{curent_record_time}")
self.rerun_logger.log_item_data(item_data)
def save_episode(self):
"""
Trigger the save operation. This sets the save flag, and the process_queue thread will handle it.
"""
self.need_save = True # Set the save flag
logger_mp.info(f"==> Episode saved start...")
def _save_episode(self):
"""
Save the episode data to a JSON file.
"""
with open(self.json_path, "a", encoding="utf-8") as f:
f.write("\n]\n}") # Close the JSON array and object
self.need_save = False # Reset the save flag
self.is_available = True # Mark the class as available after saving
logger_mp.info(f"==> Episode saved successfully to {self.json_path}.")
def close(self):
"""
Stop the worker thread and ensure all tasks are completed.
"""
self.item_data_queue.join()
if not self.is_available: # If self.is_available is False, it means there is still data not saved.
self.save_episode()
while not self.is_available:
time.sleep(0.01)
self.stop_worker = True
self.worker_thread.join()