2161 changed files with 291889 additions and 21435 deletions
-
2.gitattributes
-
74.github/workflows/docs.yml
-
21.gitignore
-
93CITATION.cff
-
190LICENSE
-
320README.md
-
0decoupled_wbc/__init__.py
-
0decoupled_wbc/control/__init__.py
-
0decoupled_wbc/control/base/__init__.py
-
0decoupled_wbc/control/base/env.py
-
60decoupled_wbc/control/base/humanoid_env.py
-
0decoupled_wbc/control/base/policy.py
-
0decoupled_wbc/control/base/sensor.py
-
0decoupled_wbc/control/envs/__init__.py
-
0decoupled_wbc/control/envs/g1/__init__.py
-
67decoupled_wbc/control/envs/g1/g1_body.py
-
324decoupled_wbc/control/envs/g1/g1_env.py
-
89decoupled_wbc/control/envs/g1/g1_hand.py
-
0decoupled_wbc/control/envs/g1/sim/__init__.py
-
772decoupled_wbc/control/envs/g1/sim/base_sim.py
-
256decoupled_wbc/control/envs/g1/sim/image_publish_utils.py
-
71decoupled_wbc/control/envs/g1/sim/metric_utils.py
-
63decoupled_wbc/control/envs/g1/sim/robocasa_sim.py
-
0decoupled_wbc/control/envs/g1/sim/sim_utilts.py
-
144decoupled_wbc/control/envs/g1/sim/simulator_factory.py
-
0decoupled_wbc/control/envs/g1/sim/unitree_sdk2py_bridge.py
-
0decoupled_wbc/control/envs/g1/utils/__init__.py
-
0decoupled_wbc/control/envs/g1/utils/command_sender.py
-
534decoupled_wbc/control/envs/g1/utils/joint_safety.py
-
0decoupled_wbc/control/envs/g1/utils/state_processor.py
-
0decoupled_wbc/control/envs/robocasa/__init__.py
-
305decoupled_wbc/control/envs/robocasa/async_env_server.py
-
586decoupled_wbc/control/envs/robocasa/sync_env.py
-
0decoupled_wbc/control/envs/robocasa/utils/__init__.py
-
73decoupled_wbc/control/envs/robocasa/utils/cam_key_converter.py
-
0decoupled_wbc/control/envs/robocasa/utils/controller_utils.py
-
443decoupled_wbc/control/envs/robocasa/utils/robocasa_env.py
-
301decoupled_wbc/control/envs/robocasa/utils/robot_key_converter.py
-
0decoupled_wbc/control/envs/robocasa/utils/sim_utils.py
-
0decoupled_wbc/control/main/__init__.py
-
0decoupled_wbc/control/main/config_template.py
-
0decoupled_wbc/control/main/constants.py
-
0decoupled_wbc/control/main/teleop/__init__.py
-
483decoupled_wbc/control/main/teleop/configs/configs.py
-
421decoupled_wbc/control/main/teleop/configs/g1_29dof_gear_wbc.yaml
-
0decoupled_wbc/control/main/teleop/configs/g1_gear_wbc.yaml
-
0decoupled_wbc/control/main/teleop/configs/identifiers.py
-
627decoupled_wbc/control/main/teleop/playback_sync_sim_data.py
-
249decoupled_wbc/control/main/teleop/run_camera_viewer.py
-
236decoupled_wbc/control/main/teleop/run_g1_control_loop.py
-
364decoupled_wbc/control/main/teleop/run_g1_data_exporter.py
-
68decoupled_wbc/control/main/teleop/run_navigation_policy_loop.py
-
61decoupled_wbc/control/main/teleop/run_sim_loop.py
-
213decoupled_wbc/control/main/teleop/run_sync_sim_data_collection.py
-
110decoupled_wbc/control/main/teleop/run_teleop_policy_loop.py
-
0decoupled_wbc/control/policy/__init__.py
-
157decoupled_wbc/control/policy/g1_decoupled_whole_body_policy.py
-
295decoupled_wbc/control/policy/g1_gear_wbc_policy.py
-
25decoupled_wbc/control/policy/identity_policy.py
-
297decoupled_wbc/control/policy/interpolation_policy.py
-
87decoupled_wbc/control/policy/keyboard_navigation_policy.py
-
111decoupled_wbc/control/policy/lerobot_replay_policy.py
-
207decoupled_wbc/control/policy/teleop_policy.py
-
65decoupled_wbc/control/policy/wbc_policy_factory.py
-
0decoupled_wbc/control/robot_model/__init__.py
-
0decoupled_wbc/control/robot_model/instantiation/__init__.py
-
62decoupled_wbc/control/robot_model/instantiation/g1.py
-
0decoupled_wbc/control/robot_model/model_data/g1/g1_29dof.urdf
-
0decoupled_wbc/control/robot_model/model_data/g1/g1_29dof_old.xml
-
0decoupled_wbc/control/robot_model/model_data/g1/g1_29dof_with_hand.urdf
-
0decoupled_wbc/control/robot_model/model_data/g1/g1_29dof_with_hand.xml
-
0decoupled_wbc/control/robot_model/model_data/g1/g1_29dof_with_hand_rev_1_0_activatedfinger.xml
-
0decoupled_wbc/control/robot_model/model_data/g1/lift_box_43dof.xml
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/head_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_ankle_pitch_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_ankle_roll_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_elbow_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_elbow_link_merge.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_hand_index_0_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_hand_index_1_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_hand_middle_0_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_hand_middle_1_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_hand_palm_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_hand_thumb_0_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_hand_thumb_1_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_hand_thumb_2_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_hip_pitch_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_hip_roll_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_hip_yaw_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_knee_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_rubber_hand.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_shoulder_pitch_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_shoulder_roll_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_shoulder_yaw_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_wrist_pitch_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_wrist_roll_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_wrist_roll_rubber_hand.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/left_wrist_yaw_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/logo_link.STL
-
0decoupled_wbc/control/robot_model/model_data/g1/meshes/pelvis.STL
@ -0,0 +1,74 @@ |
|||
name: Build and Deploy Documentation |
|||
|
|||
on: |
|||
push: |
|||
branches: |
|||
- main |
|||
- gear-sonic |
|||
paths: |
|||
- "docs/**" |
|||
- ".github/workflows/docs.yml" |
|||
- ".gitattributes" |
|||
workflow_dispatch: |
|||
|
|||
# Allow only one concurrent deployment; cancel in-flight runs |
|||
concurrency: |
|||
group: "pages" |
|||
cancel-in-progress: true |
|||
|
|||
jobs: |
|||
build: |
|||
name: Build Sphinx Docs |
|||
runs-on: ubuntu-latest |
|||
steps: |
|||
- name: Checkout repository |
|||
uses: actions/checkout@v4 |
|||
with: |
|||
lfs: false |
|||
|
|||
- name: Restore docs static assets (bypass git-lfs smudge) |
|||
run: | |
|||
# git-lfs on the runner rewrites files tracked by *.png/*.gif |
|||
# even when our .gitattributes override removes filter=lfs. |
|||
# Use git cat-file to write real binary content directly from |
|||
# the object store, bypassing all smudge filters. |
|||
git ls-tree -r HEAD -- docs/source/_static \ |
|||
| awk '{print $3, $4}' \ |
|||
| while IFS=" " read -r hash path; do |
|||
git cat-file blob "$hash" > "$path" |
|||
done |
|||
|
|||
- name: Set up Python |
|||
uses: actions/setup-python@v5 |
|||
with: |
|||
python-version: "3.10" |
|||
cache: "pip" |
|||
cache-dependency-path: "docs/requirements.txt" |
|||
|
|||
- name: Install documentation dependencies |
|||
run: pip install -r docs/requirements.txt |
|||
|
|||
- name: Build HTML documentation |
|||
run: sphinx-build -b html docs/source docs/build/html |
|||
|
|||
- name: Upload Pages artifact |
|||
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/gear-sonic' |
|||
uses: actions/upload-pages-artifact@v3 |
|||
with: |
|||
path: docs/build/html |
|||
|
|||
deploy: |
|||
name: Deploy to GitHub Pages |
|||
needs: build |
|||
if: github.ref == 'refs/heads/main' || github.ref == 'refs/heads/gear-sonic' |
|||
runs-on: ubuntu-latest |
|||
permissions: |
|||
pages: write |
|||
id-token: write |
|||
environment: |
|||
name: github-pages |
|||
url: ${{ steps.deployment.outputs.page_url }} |
|||
steps: |
|||
- name: Deploy to GitHub Pages |
|||
id: deployment |
|||
uses: actions/deploy-pages@v4 |
|||
@ -0,0 +1,93 @@ |
|||
cff-version: 1.2.0 |
|||
message: "If you use this software, please cite it as below." |
|||
title: "GR00T Whole-Body Control" |
|||
authors: |
|||
- family-names: "Luo" |
|||
given-names: "Zhengyi" |
|||
- family-names: "Yuan" |
|||
given-names: "Ye" |
|||
- family-names: "Wang" |
|||
given-names: "Tingwu" |
|||
- family-names: "Li" |
|||
given-names: "Chenran" |
|||
- family-names: "Chen" |
|||
given-names: "Sirui" |
|||
- family-names: "Castañeda" |
|||
given-names: "Fernando" |
|||
- family-names: "Cao" |
|||
given-names: "Zi-Ang" |
|||
- family-names: "Li" |
|||
given-names: "Jiefeng" |
|||
- family-names: "Zhu" |
|||
given-names: "Yuke" |
|||
url: "https://github.com/NVlabs/GR00T-WholeBodyControl" |
|||
repository-code: "https://github.com/NVlabs/GR00T-WholeBodyControl" |
|||
type: software |
|||
keywords: |
|||
- humanoid-robotics |
|||
- reinforcement-learning |
|||
- whole-body-control |
|||
- motion-tracking |
|||
- teleoperation |
|||
- robotics |
|||
- pytorch |
|||
license: Apache-2.0 |
|||
preferred-citation: |
|||
type: article |
|||
title: "SONIC: Supersizing Motion Tracking for Natural Humanoid Whole-Body Control" |
|||
authors: |
|||
- family-names: "Luo" |
|||
given-names: "Zhengyi" |
|||
- family-names: "Yuan" |
|||
given-names: "Ye" |
|||
- family-names: "Wang" |
|||
given-names: "Tingwu" |
|||
- family-names: "Li" |
|||
given-names: "Chenran" |
|||
- family-names: "Chen" |
|||
given-names: "Sirui" |
|||
- family-names: "Castañeda" |
|||
given-names: "Fernando" |
|||
- family-names: "Cao" |
|||
given-names: "Zi-Ang" |
|||
- family-names: "Li" |
|||
given-names: "Jiefeng" |
|||
- family-names: "Minor" |
|||
given-names: "David" |
|||
- family-names: "Ben" |
|||
given-names: "Qingwei" |
|||
- family-names: "Da" |
|||
given-names: "Xingye" |
|||
- family-names: "Ding" |
|||
given-names: "Runyu" |
|||
- family-names: "Hogg" |
|||
given-names: "Cyrus" |
|||
- family-names: "Song" |
|||
given-names: "Lina" |
|||
- family-names: "Lim" |
|||
given-names: "Edy" |
|||
- family-names: "Jeong" |
|||
given-names: "Eugene" |
|||
- family-names: "He" |
|||
given-names: "Tairan" |
|||
- family-names: "Xue" |
|||
given-names: "Haoru" |
|||
- family-names: "Xiao" |
|||
given-names: "Wenli" |
|||
- family-names: "Wang" |
|||
given-names: "Zi" |
|||
- family-names: "Yuen" |
|||
given-names: "Simon" |
|||
- family-names: "Kautz" |
|||
given-names: "Jan" |
|||
- family-names: "Chang" |
|||
given-names: "Yan" |
|||
- family-names: "Iqbal" |
|||
given-names: "Umar" |
|||
- family-names: "Fan" |
|||
given-names: "Linxi" |
|||
- family-names: "Zhu" |
|||
given-names: "Yuke" |
|||
journal: "arXiv preprint" |
|||
year: 2025 |
|||
url: "https://arxiv.org/abs/2511.07820" |
|||
@ -1,38 +1,186 @@ |
|||
NVIDIA License |
|||
================================================================================ |
|||
DUAL LICENSE NOTICE |
|||
================================================================================ |
|||
|
|||
1. Definitions |
|||
This repository is dual-licensed. Different components are under different terms: |
|||
|
|||
“Licensor” means any person or entity that distributes its Work. |
|||
“Work” means (a) the original work of authorship made available under this license, which may include software, documentation, or other files, and (b) any additions to or derivative works thereof that are made available under this license. |
|||
The terms “reproduce,” “reproduction,” “derivative works,” and “distribution” have the meaning as provided under U.S. copyright law; provided, however, that for the purposes of this license, derivative works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work. |
|||
Works are “made available” under this license by including in or with the Work either (a) a copyright notice referencing the applicability of this license to the Work, or (b) a copy of this license. |
|||
1. SOURCE CODE - Apache License 2.0 |
|||
All source code, scripts, and software components |
|||
|
|||
2. License Grant |
|||
2. MODEL WEIGHTS - NVIDIA Open Model License |
|||
All trained model checkpoints and weights |
|||
|
|||
2.1 Copyright Grant. Subject to the terms and conditions of this license, each Licensor grants to you a perpetual, worldwide, non-exclusive, royalty-free, copyright license to use, reproduce, prepare derivative works of, publicly display, publicly perform, sublicense and distribute its Work and any resulting derivative works in any form. |
|||
See below for the full text of each license. |
|||
|
|||
3. Limitations |
|||
================================================================================ |
|||
PART 1: SOURCE CODE LICENSE (Apache License 2.0) |
|||
================================================================================ |
|||
|
|||
3.1 Redistribution. You may reproduce or distribute the Work only if (a) you do so under this license, (b) you include a complete copy of this license with your distribution, and (c) you retain without modification any copyright, patent, trademark, or attribution notices that are present in the Work. |
|||
Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. |
|||
|
|||
3.2 Derivative Works. You may specify that additional or different terms apply to the use, reproduction, and distribution of your derivative works of the Work (“Your Terms”) only if (a) Your Terms provide that the use limitation in Section 3.3 applies to your derivative works; (b) you comply with Other Licenses, and (c) you identify the specific derivative works that are subject to Your Terms and Other Licenses, as applicable. Notwithstanding Your Terms, this license (including the redistribution requirements in Section 3.1) will continue to apply to the Work itself. |
|||
Licensed under the Apache License, Version 2.0 (the "License"); |
|||
you may not use this file except in compliance with the License. |
|||
You may obtain a copy of the License at |
|||
|
|||
3.3 Use Limitation. The Work and any derivative works thereof only may be used or intended for use non-commercially. As used herein, “non-commercially” means for non-commercial research purposes only, and excludes any military, surveillance, service of nuclear technology or biometric processing purposes. |
|||
http://www.apache.org/licenses/LICENSE-2.0 |
|||
|
|||
3.4 Patent Claims. If you bring or threaten to bring a patent claim against any Licensor (including any claim, cross-claim or counterclaim in a lawsuit) to enforce any patents that you allege are infringed by any Work, then your rights under this license from such Licensor (including the grant in Section 2.1) will terminate immediately. |
|||
Unless required by applicable law or agreed to in writing, software |
|||
distributed under the License is distributed on an "AS IS" BASIS, |
|||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
|||
See the License for the specific language governing permissions and |
|||
limitations under the License. |
|||
|
|||
3.5 Trademarks. This license does not grant any rights to use any Licensor’s or its affiliates’ names, logos, or trademarks, except as necessary to reproduce the notices described in this license. |
|||
|
|||
3.6 Termination. If you violate any term of this license, then your rights under this license (including the grant in Section 2.1) will terminate immediately. |
|||
================================================================================ |
|||
PART 2: MODEL WEIGHTS LICENSE (NVIDIA Open Model License) |
|||
================================================================================ |
|||
|
|||
3.7 Components Under Other Licenses. The Work may include or be distributed with components provided with separate legal notices or terms that accompany the components, such as open source software licenses and other license terms, including but not limited to the Meta OPT-IML 175B License Agreement (“Other Licenses”). The components are subject to the applicable Other Licenses, including any proprietary notices, disclaimers, requirements and extended use rights; except that this Agreement will prevail regarding the use of third-party software, unless a third-party software license requires it license terms to prevail. |
|||
NVIDIA OPEN MODEL LICENSE AGREEMENT |
|||
|
|||
4. Disclaimer of Warranty. |
|||
Last Modified: October 24, 2025 |
|||
|
|||
THE WORK IS PROVIDED “AS IS” WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF |
|||
MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER THIS LICENSE. |
|||
NVIDIA Corporation and its affiliates ("NVIDIA") grants permission to use machine |
|||
learning models under specific conditions. Key permissions include creating |
|||
derivative models and distributing them, with NVIDIA retaining no ownership claims |
|||
over outputs generated by users. |
|||
|
|||
5. Limitation of Liability. |
|||
SECTION 1: DEFINITIONS |
|||
|
|||
EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION, LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF THE POSSIBILITY OF SUCH DAMAGES. |
|||
1.1 "Derivative Model" means any modification of, or works based on or derived |
|||
from, the Model, excluding outputs. |
|||
|
|||
1.2 "Legal Entity" means the union of the acting entity and all other entities |
|||
that control, are controlled by, or are under common control with that entity. |
|||
|
|||
1.3 "Model" means the machine learning model, software, and any checkpoints, |
|||
weights, algorithms, parameters, configuration files, and documentation that NVIDIA |
|||
makes available under this Agreement. |
|||
|
|||
1.4 "NVIDIA Cosmos Model" means a multimodal Model that is covered by this Agreement. |
|||
|
|||
1.5 "Special-Purpose Model" means a Model that is limited to narrow, |
|||
purpose-specific tasks. |
|||
|
|||
1.6 "You" or "Your" means an individual or Legal Entity exercising permissions |
|||
granted by this Agreement. |
|||
|
|||
SECTION 2: CONDITIONS FOR USE, LICENSE GRANT, AI ETHICS AND IP OWNERSHIP |
|||
|
|||
2.1 Conditions for Use. You must comply with all terms and conditions of this |
|||
Agreement. If You initiate copyright or patent litigation against any entity |
|||
(including a cross-claim or counterclaim in a lawsuit) alleging that the Model |
|||
constitutes direct or contributory infringement, then Your licenses under this |
|||
Agreement shall terminate. If You circumvent any safety guardrails or safety |
|||
measures built in to the Model without providing comparable alternatives, Your |
|||
rights under this Agreement shall terminate. NVIDIA may update this Agreement at |
|||
any time to comply with applicable law; Your continued use constitutes Your |
|||
acceptance of the updated terms. |
|||
|
|||
2.2 License Grant. Subject to the terms and conditions of this Agreement, NVIDIA |
|||
hereby grants You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, |
|||
revocable license to publicly perform, publicly display, reproduce, use, create |
|||
derivative works of, make, have made, sell, offer for sale, distribute and import |
|||
the Model. |
|||
|
|||
2.3 AI Ethics. Your use of the Model must be in accordance with NVIDIA's |
|||
Trustworthy AI terms, which can be found at |
|||
https://www.nvidia.com/en-us/agreements/trustworthy-ai/terms/. |
|||
|
|||
2.4 IP Ownership. NVIDIA owns the original Model and NVIDIA's Derivative Models. |
|||
You own Your Derivative Models. NVIDIA makes no claim of ownership to outputs. You |
|||
are responsible for outputs and their subsequent uses. |
|||
|
|||
SECTION 3: REDISTRIBUTION |
|||
|
|||
You may reproduce and distribute copies of the Model or Derivative Models thereof, |
|||
with or without modifications, provided that You meet the following conditions: |
|||
|
|||
a. You must include a copy of this Agreement. |
|||
|
|||
b. You must include the following attribution notice, which can appear in the same |
|||
location as other third-party notices or license information: "Licensed by NVIDIA |
|||
Corporation under the NVIDIA Open Model License." |
|||
|
|||
c. If You are distributing a NVIDIA Cosmos Model, You must also include the phrase |
|||
"Built on NVIDIA Cosmos" on the applicable website, in the user interface, in a |
|||
blog, in an "about" page, or in product documentation. |
|||
|
|||
d. You may add Your own copyright statement to Your modifications and may provide |
|||
additional or different license terms and conditions for use, reproduction, or |
|||
distribution of Your modifications or for any Derivative Models as a whole, |
|||
provided Your use, reproduction, and distribution otherwise complies with this |
|||
Agreement. |
|||
|
|||
SECTION 4: SEPARATE COMPONENTS |
|||
|
|||
The Model may contain components that are subject to separate legal notices or |
|||
governed by separate licenses (including Open Source Software Licenses), as may be |
|||
described in any files made available with the Model. Your use of those separate |
|||
components is subject to the applicable license. This Agreement shall control over |
|||
the separate licenses for third-party Open Source Software to the extent that the |
|||
separate license imposes additional restrictions. "Open Source Software License" |
|||
means any software license approved by the Open Source Initiative, Free Software |
|||
Foundation, or similar recognized organization, or a license identified by SPDX. |
|||
|
|||
SECTION 5: TRADEMARKS |
|||
|
|||
This Agreement does not grant permission to use the trade names, trademarks, |
|||
service marks, or product names of NVIDIA, except as required for reasonable and |
|||
customary use in describing the origin of the Model and reproducing the content of |
|||
the notice. |
|||
|
|||
SECTION 6: DISCLAIMER OF WARRANTY |
|||
|
|||
NVIDIA provides the Model on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF |
|||
ANY KIND, either express or implied, including, without limitation, any warranties |
|||
or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A |
|||
PARTICULAR PURPOSE. You are solely responsible for reviewing the documentation |
|||
accompanying the Model and determining the appropriateness of using the Model, and |
|||
You understand that Special-Purpose Models are limited to narrow, purpose-specific |
|||
tasks and must not be deployed for uses that are beyond such tasks. |
|||
|
|||
SECTION 7: LIMITATION OF LIABILITY |
|||
|
|||
In no event and under no legal theory, whether in tort (including negligence), |
|||
contract, or otherwise, unless required by applicable law (such as deliberate and |
|||
grossly negligent acts) or agreed to in writing, will NVIDIA be liable to You for |
|||
damages, including any direct, indirect, special, incidental, or consequential |
|||
damages of any character arising as a result of this Agreement or out of the use |
|||
or inability to use the Model or Derivative Models or outputs (including but not |
|||
limited to damages for loss of goodwill, work stoppage, computer failure or |
|||
malfunction, or any and all other commercial damages or losses), even if NVIDIA has |
|||
been advised of the possibility of such damages. |
|||
|
|||
SECTION 8: INDEMNITY |
|||
|
|||
You will defend, indemnify and hold harmless NVIDIA and its affiliates, and their |
|||
respective employees, contractors, directors, officers and agents, from and against |
|||
any and all claims, damages, obligations, losses, liabilities, costs or debt, and |
|||
expenses (including but not limited to attorney's fees) arising from Your use or |
|||
distribution of the Model or Derivative Models or outputs. |
|||
|
|||
SECTION 9: FEEDBACK |
|||
|
|||
NVIDIA may use feedback You provide without restriction and without any |
|||
compensation to You. |
|||
|
|||
SECTION 10: GOVERNING LAW |
|||
|
|||
This Agreement will be governed in all respects by the laws of the United States |
|||
and of the State of Delaware, without regard to conflict of laws provisions. The |
|||
federal and state courts residing in Santa Clara County, California shall have |
|||
exclusive jurisdiction over any dispute arising out of this Agreement, and You |
|||
hereby consent to the personal jurisdiction of such courts. However, NVIDIA shall |
|||
have the right to seek injunctive relief in any court of competent jurisdiction. |
|||
|
|||
SECTION 11: TRADE AND COMPLIANCE |
|||
|
|||
You shall comply with all applicable import, export, trade, and economic sanctions |
|||
laws, including without limitation the Export Administration Regulations and |
|||
economic sanctions laws implemented by the Office of Foreign Assets Control, that |
|||
restrict or govern the destination, end-user and end-use of NVIDIA products, |
|||
technology, software, and services. |
|||
|
|||
--- |
|||
|
|||
Version Release Date: October 24, 2025 |
|||
@ -1,140 +1,246 @@ |
|||
# gr00t_wbc |
|||
<div align="center"> |
|||
|
|||
Software stack for loco-manipulation experiments across multiple humanoid platforms, with primary support for the Unitree G1. This repository provides whole-body control policies, a teleoperation stack, and a data exporter. |
|||
<img src="media/groot_wbc.png" width="800" alt="GEAR SONIC Header"> |
|||
|
|||
<!-- --- --> |
|||
|
|||
|
|||
</div> |
|||
|
|||
<div align="center"> |
|||
|
|||
[](LICENSE) |
|||
[](https://github.com/isaac-sim/IsaacLab/releases/tag/v2.3.0) |
|||
[](https://nvlabs.github.io/GR00T-WholeBodyControl/) |
|||
|
|||
</div> |
|||
|
|||
--- |
|||
|
|||
## System Installation |
|||
|
|||
### Prerequisites |
|||
- Ubuntu 22.04 |
|||
- NVIDIA GPU with a recent driver |
|||
- Docker and NVIDIA Container Toolkit (required for GPU access inside the container) |
|||
|
|||
### Repository Setup |
|||
Install Git and Git LFS: |
|||
```bash |
|||
sudo apt update |
|||
sudo apt install git git-lfs |
|||
git lfs install |
|||
``` |
|||
|
|||
Clone the repository: |
|||
# GR00T-WholeBodyControl |
|||
|
|||
This is the codebase for the **GR00T Whole-Body Control (WBC)** projects. It hosts model checkpoints and scripts for training, evaluating, and deploying advanced whole-body controllers for humanoid robots. We currently support: |
|||
|
|||
- **Decoupled WBC**: the decoupled controller (RL for lower body, and IK for upper body) used in NVIDIA GR00T [N1.5](https://research.nvidia.com/labs/gear/gr00t-n1_5/) and [N1.6](https://research.nvidia.com/labs/gear/gr00t-n1_6/) models; |
|||
- **GEAR-SONIC Series**: our latest iteration of generalist humanoid whole-body controllers (see our [whitepaper](https://nvlabs.github.io/GEAR-SONIC/)). |
|||
|
|||
## Table of Contents |
|||
|
|||
- [GEAR-SONIC](#gear-sonic) |
|||
- [VR Whole-Body Teleoperation](#vr-whole-body-teleoperation) |
|||
- [Kinematic Planner](#kinematic-planner) |
|||
- [TODOs](#todos) |
|||
- [What's Included](#whats-included) |
|||
- [Setup](#setup) |
|||
- [Documentation](#documentation) |
|||
- [Citation](#citation) |
|||
- [License](#license) |
|||
- [Support](#support) |
|||
- [Decoupled WBC](#decoupled-wbc) |
|||
|
|||
|
|||
## GEAR-SONIC |
|||
|
|||
<p style="font-size: 1.2em;"> |
|||
<a href="https://nvlabs.github.io/GEAR-SONIC/"><strong>Website</strong></a> | |
|||
<a href="https://huggingface.co/"><strong>Model</strong></a> | |
|||
<a href="https://arxiv.org/abs/2511.07820"><strong>Paper</strong></a> |
|||
</p> |
|||
|
|||
<div align="center"> |
|||
<img src="docs/source/_static/sonic-preview-gif-480P.gif" width="800" > |
|||
|
|||
</div> |
|||
|
|||
SONIC is a humanoid behavior foundation model that gives robots a core set of motor skills learned from large-scale human motion data. Rather than building separate controllers for predefined motions, SONIC uses motion tracking as a scalable training task, enabling a single unified policy to produce natural, whole-body movement and support a wide range of behaviors — from walking and crawling to teleoperation and multi-modal control. It is designed to generalize beyond the motions it has seen during training and to serve as a foundation for higher-level planning and interaction. |
|||
|
|||
In this repo, we will release SONIC's training code, deployment framework, model checkpoints, and teleoperation stack for data collection. |
|||
|
|||
|
|||
## VR Whole-Body Teleoperation |
|||
|
|||
SONIC supports real-time whole-body teleoperation via PICO VR headset, enabling natural human-to-robot motion transfer for data collection and interactive control. |
|||
|
|||
<div align="center"> |
|||
<table> |
|||
<tr> |
|||
<td align="center"><b>Walking</b></td> |
|||
<td align="center"><b>Running</b></td> |
|||
</tr> |
|||
<tr> |
|||
<td align="center"><img src="media/teleop_walking.gif" width="400"></td> |
|||
<td align="center"><img src="media/teleop_running.gif" width="400"></td> |
|||
</tr> |
|||
<tr> |
|||
<td align="center"><b>Sideways Movement</b></td> |
|||
<td align="center"><b>Kneeling</b></td> |
|||
</tr> |
|||
<tr> |
|||
<td align="center"><img src="media/teleop_sideways.gif" width="400"></td> |
|||
<td align="center"><img src="media/teleop_kneeling.gif" width="400"></td> |
|||
</tr> |
|||
<tr> |
|||
<td align="center"><b>Getting Up</b></td> |
|||
<td align="center"><b>Jumping</b></td> |
|||
</tr> |
|||
<tr> |
|||
<td align="center"><img src="media/teleop_getup.gif" width="400"></td> |
|||
<td align="center"><img src="media/teleop_jumping.gif" width="400"></td> |
|||
</tr> |
|||
<tr> |
|||
<td align="center"><b>Bimanual Manipulation</b></td> |
|||
<td align="center"><b>Object Hand-off</b></td> |
|||
</tr> |
|||
<tr> |
|||
<td align="center"><img src="media/teleop_bimanual.gif" width="400"></td> |
|||
<td align="center"><img src="media/teleop_switch_hands.gif" width="400"></td> |
|||
</tr> |
|||
</table> |
|||
</div> |
|||
|
|||
## Kinematic Planner |
|||
|
|||
SONIC includes a kinematic planner for real-time locomotion generation — choose a movement style, steer with keyboard/gamepad, and adjust speed and height on the fly. |
|||
|
|||
<div align="center"> |
|||
<table> |
|||
<tr> |
|||
<td align="center" colspan="2"><b>In-the-Wild Navigation</b></td> |
|||
</tr> |
|||
<tr> |
|||
<td align="center" colspan="2"><img src="media/planner/planner_in_the_wild_navigation.gif" width="800"></td> |
|||
</tr> |
|||
<tr> |
|||
<td align="center"><b>Run</b></td> |
|||
<td align="center"><b>Happy</b></td> |
|||
</tr> |
|||
<tr> |
|||
<td align="center"><img src="media/planner/planner_run.gif" width="400"></td> |
|||
<td align="center"><img src="media/planner/planner_happy.gif" width="400"></td> |
|||
</tr> |
|||
<tr> |
|||
<td align="center"><b>Stealth</b></td> |
|||
<td align="center"><b>Injured</b></td> |
|||
</tr> |
|||
<tr> |
|||
<td align="center"><img src="media/planner/planner_stealth.gif" width="400"></td> |
|||
<td align="center"><img src="media/planner/planner_injured.gif" width="400"></td> |
|||
</tr> |
|||
<tr> |
|||
<td align="center"><b>Kneeling</b></td> |
|||
<td align="center"><b>Hand Crawling</b></td> |
|||
</tr> |
|||
<tr> |
|||
<td align="center"><img src="media/planner/planner_kneeling.gif" width="400"></td> |
|||
<td align="center"><img src="media/planner/planner_hand_crawling.gif" width="400"></td> |
|||
</tr> |
|||
<tr> |
|||
<td align="center"><b>Elbow Crawling</b></td> |
|||
<td align="center"><b>Boxing</b></td> |
|||
</tr> |
|||
<tr> |
|||
<td align="center"><img src="media/planner/planner_elbow_crawling.gif" width="400"></td> |
|||
<td align="center"><img src="media/planner/planner_boxing.gif" width="400"></td> |
|||
</tr> |
|||
</table> |
|||
</div> |
|||
|
|||
## TODOs |
|||
|
|||
- [x] Release pretrained SONIC policy checkpoints |
|||
- [x] Open source C++ inference stack |
|||
- [x] Setup documentation |
|||
- [x] Open source teleoperation stack and demonstration scripts |
|||
- [ ] Release training scripts and recipes for motion imitation and fine-tuning |
|||
- [ ] Open source large-scale data collection workflows and fine-tuning VLA scripts. |
|||
- [ ] Publish additional preprocessed large-scale human motion datasets |
|||
|
|||
|
|||
|
|||
## What's Included |
|||
|
|||
This release includes: |
|||
|
|||
- **`gear_sonic_deploy`**: C++ inference stack for deploying SONIC policies on real hardware |
|||
- **`gear_sonic`**: Teleoperation stack for collecting demonstration data (no training code, YET.) |
|||
|
|||
### Setup |
|||
|
|||
Clone the repository with Git LFS: |
|||
```bash |
|||
mkdir -p ~/Projects |
|||
cd ~/Projects |
|||
git clone https://github.com/NVlabs/gr00t_wbc.git |
|||
cd gr00t_wbc |
|||
git clone https://github.com/NVlabs/GR00T-WholeBodyControl.git |
|||
cd GR00T-WholeBodyControl |
|||
git lfs pull |
|||
``` |
|||
|
|||
### Docker Environment |
|||
We provide a Docker image with all dependencies pre-installed. |
|||
## Documentation |
|||
|
|||
Install a fresh image and start a container: |
|||
```bash |
|||
./docker/run_docker.sh --install --root |
|||
``` |
|||
This pulls the latest `gr00t_wbc` image from `docker.io/nvgear`. |
|||
📚 **[Full Documentation](https://nvlabs.github.io/GR00T-WholeBodyControl/)** |
|||
|
|||
Start or re-enter a container: |
|||
```bash |
|||
./docker/run_docker.sh --root |
|||
``` |
|||
### Getting Started |
|||
- [Installation Guide](https://nvlabs.github.io/GR00T-WholeBodyControl/getting_started/installation_deploy.html) |
|||
- [Quick Start](https://nvlabs.github.io/GR00T-WholeBodyControl/getting_started/quickstart.html) |
|||
- [VR Teleoperation Setup](https://nvlabs.github.io/GR00T-WholeBodyControl/getting_started/vr_teleop_setup.html) |
|||
|
|||
Use `--root` to run as the `root` user. To run as a normal user, build the image locally: |
|||
```bash |
|||
./docker/run_docker.sh --build |
|||
``` |
|||
--- |
|||
### Tutorials |
|||
- [Keyboard Control](https://nvlabs.github.io/GR00T-WholeBodyControl/tutorials/keyboard.html) |
|||
- [Gamepad Control](https://nvlabs.github.io/GR00T-WholeBodyControl/tutorials/gamepad.html) |
|||
- [ZMQ Communication](https://nvlabs.github.io/GR00T-WholeBodyControl/tutorials/zmq.html) |
|||
- [ZMQ Manager / PICO VR](https://nvlabs.github.io/GR00T-WholeBodyControl/tutorials/vr_wholebody_teleop.html) |
|||
|
|||
## Running the Control Stack |
|||
|
|||
Once inside the container, the control policies can be launched directly. |
|||
|
|||
- Simulation: |
|||
```bash |
|||
python gr00t_wbc/control/main/teleop/run_g1_control_loop.py |
|||
``` |
|||
- Real robot: Ensure the host machine network is configured per the [G1 SDK Development Guide](https://support.unitree.com/home/en/G1_developer) and set a static IP at `192.168.123.222`, subnet mask `255.255.255.0`: |
|||
```bash |
|||
python gr00t_wbc/control/main/teleop/run_g1_control_loop.py --interface real |
|||
``` |
|||
|
|||
Keyboard shortcuts (terminal window): |
|||
- `]`: Activate policy |
|||
- `o`: Deactivate policy |
|||
- `9`: Release / Hold the robot |
|||
- `w` / `s`: Move forward / backward |
|||
- `a` / `d`: Strafe left / right |
|||
- `q` / `e`: Rotate left / right |
|||
- `z`: Zero navigation commands |
|||
- `1` / `2`: Raise / lower the base height |
|||
- `backspace` (viewer): Reset the robot in the visualizer |
|||
### Best Practices |
|||
- [Teleoperation](https://nvlabs.github.io/GR00T-WholeBodyControl/user_guide/teleoperation.html) |
|||
|
|||
--- |
|||
|
|||
## Running the Teleoperation Stack |
|||
|
|||
The teleoperation policy primarily uses Pico controllers for coordinated hand and body control. It also supports other teleoperation devices, including LeapMotion and HTC Vive with Nintendo Switch Joy-Con controllers. |
|||
|
|||
Keep `run_g1_control_loop.py` running, and in another terminal run: |
|||
|
|||
```bash |
|||
python gr00t_wbc/control/main/teleop/run_teleop_policy_loop.py --hand_control_device=pico --body_control_device=pico |
|||
|
|||
--- |
|||
|
|||
## Citation |
|||
|
|||
If you use GEAR-SONIC in your research, please cite: |
|||
|
|||
```bibtex |
|||
@article{luo2025sonic, |
|||
title={SONIC: Supersizing Motion Tracking for Natural Humanoid Whole-Body Control}, |
|||
author={Luo, Zhengyi and Yuan, Ye and Wang, Tingwu and Li, Chenran and Chen, Sirui and Casta\~neda, Fernando and Cao, Zi-Ang and Li, Jiefeng and Minor, David and Ben, Qingwei and Da, Xingye and Ding, Runyu and Hogg, Cyrus and Song, Lina and Lim, Edy and Jeong, Eugene and He, Tairan and Xue, Haoru and Xiao, Wenli and Wang, Zi and Yuen, Simon and Kautz, Jan and Chang, Yan and Iqbal, Umar and Fan, Linxi and Zhu, Yuke}, |
|||
journal={arXiv preprint arXiv:2511.07820}, |
|||
year={2025} |
|||
} |
|||
``` |
|||
|
|||
### Pico Setup and Controls |
|||
Configure the teleop app on your Pico headset by following the [XR Robotics guidelines](https://github.com/XR-Robotics). |
|||
--- |
|||
|
|||
The necessary PC software is pre-installed in the Docker container. Only the [XRoboToolkit-PC-Service](https://github.com/XR-Robotics/XRoboToolkit-PC-Service) component is needed. |
|||
## License |
|||
|
|||
Prerequisites: Connect the Pico to the same network as the host computer. |
|||
This project uses dual licensing: |
|||
|
|||
Controller bindings: |
|||
- `menu + left trigger`: Toggle lower-body policy |
|||
- `menu + right trigger`: Toggle upper-body policy |
|||
- `Left stick`: X/Y translation |
|||
- `Right stick`: Yaw rotation |
|||
- `L/R triggers`: Control hand grippers |
|||
- **Source Code**: Licensed under Apache License 2.0 - applies to all code, scripts, and software components in this repository |
|||
- **Model Weights**: Licensed under NVIDIA Open Model License - applies to all trained model checkpoints and weights |
|||
|
|||
Pico unit test: |
|||
```bash |
|||
python gr00t_wbc/control/teleop/streamers/pico_streamer.py |
|||
``` |
|||
See [LICENSE](LICENSE) for the complete dual-license text. |
|||
|
|||
Please review both licenses before using this project. The NVIDIA Open Model License permits commercial use with attribution and requires compliance with NVIDIA's Trustworthy AI terms. |
|||
|
|||
All required legal documents, including the Apache 2.0 license, 3rd-party attributions, and DCO language, are consolidated in the /legal folder of this repository. |
|||
|
|||
--- |
|||
|
|||
## Running the Data Collection Stack |
|||
## Support |
|||
|
|||
Run the full stack (control loop, teleop policy, and camera forwarder) via the deployment helper: |
|||
```bash |
|||
python scripts/deploy_g1.py \ |
|||
--interface sim \ |
|||
--camera_host localhost \ |
|||
--sim_in_single_process \ |
|||
--simulator robocasa \ |
|||
--image-publish \ |
|||
--enable-offscreen \ |
|||
--env_name PnPBottle \ |
|||
--hand_control_device=pico \ |
|||
--body_control_device=pico |
|||
``` |
|||
For questions and issues, please contact the GEAR WBC team at [gear-wbc@nvidia.com](gear-wbc@nvidia.com) to provide feedback! |
|||
|
|||
The `tmux` session `g1_deployment` is created with panes for: |
|||
- `control_data_teleop`: Main control loop, data collection, and teleoperation policy |
|||
- `camera`: Camera forwarder |
|||
- `camera_viewer`: Optional live camera feed |
|||
## Decoupled WBC |
|||
|
|||
Operations in the `controller` window (`control_data_teleop` pane, left): |
|||
- `]`: Activate policy |
|||
- `o`: Deactivate policy |
|||
- `k`: Reset the simulation and policies |
|||
- `` ` ``: Terminate the tmux session |
|||
- `ctrl + d`: Exit the shell in the pane |
|||
For the Decoupled WBC used in GR00T N1.5 and N1.6 models, please refer to the [Decoupled WBC documentation](docs/source/references/decoupled_wbc.md). |
|||
|
|||
Operations in the `data exporter` window (`control_data_teleop` pane, right top): |
|||
- Enter the task prompt |
|||
|
|||
Operations on Pico controllers: |
|||
- `A`: Start/Stop recording |
|||
- `B`: Discard trajectory |
|||
## Acknowledgments |
|||
We would like to acknowledge the following projects from which parts of the code in this repo are derived from: |
|||
- [Beyond Mimic](https://github.com/HybridRobotics/whole_body_tracking) |
|||
- [Isaac Lab](https://github.com/isaac-sim/IsaacLab) |
|||
@ -0,0 +1,60 @@ |
|||
from abc import abstractmethod |
|||
|
|||
from decoupled_wbc.control.base.env import Env |
|||
from decoupled_wbc.control.base.sensor import Sensor |
|||
from decoupled_wbc.control.robot_model.robot_model import RobotModel |
|||
|
|||
|
|||
class Hands: |
|||
"""Container class for left and right hand environments. |
|||
|
|||
Attributes: |
|||
left: Environment for the left hand |
|||
right: Environment for the right hand |
|||
""" |
|||
|
|||
left: Env |
|||
right: Env |
|||
|
|||
|
|||
class HumanoidEnv(Env): |
|||
"""Base class for humanoid robot environments. |
|||
|
|||
This class provides the interface for accessing the robot's body, hands, and sensors. |
|||
""" |
|||
|
|||
def body(self) -> Env: |
|||
"""Get the robot's body environment. |
|||
|
|||
Returns: |
|||
Env: The body environment |
|||
""" |
|||
pass |
|||
|
|||
def hands(self) -> Hands: |
|||
"""Get the robot's hands. |
|||
|
|||
Returns: |
|||
Hands: Container with left and right hand environments |
|||
""" |
|||
pass |
|||
|
|||
def sensors(self) -> dict[str, Sensor]: |
|||
"""Get the sensors of this environment |
|||
|
|||
Returns: |
|||
dict: A dictionary of sensors |
|||
""" |
|||
pass |
|||
|
|||
@abstractmethod |
|||
def robot_model(self) -> RobotModel: |
|||
"""Get the robot model of this environment |
|||
This robot model is used to dispatch whole body actions to body |
|||
and hand actuators and to reconstruct proprioceptive |
|||
observations from body and hands. |
|||
|
|||
Returns: |
|||
RobotModel: The robot model |
|||
""" |
|||
pass |
|||
@ -0,0 +1,67 @@ |
|||
from typing import Any, Dict |
|||
|
|||
import gymnasium as gym |
|||
import numpy as np |
|||
|
|||
from decoupled_wbc.control.base.env import Env |
|||
from decoupled_wbc.control.envs.g1.utils.command_sender import BodyCommandSender |
|||
from decoupled_wbc.control.envs.g1.utils.state_processor import BodyStateProcessor |
|||
|
|||
|
|||
class G1Body(Env): |
|||
def __init__(self, config: Dict[str, Any]): |
|||
super().__init__() |
|||
self.body_state_processor = BodyStateProcessor(config=config) |
|||
self.body_command_sender = BodyCommandSender(config=config) |
|||
|
|||
def observe(self) -> dict[str, any]: |
|||
body_state = self.body_state_processor._prepare_low_state() # (1, 148) |
|||
assert body_state.shape == (1, 148) |
|||
body_q = body_state[ |
|||
0, 7 : 7 + 12 + 3 + 7 + 7 |
|||
] # leg (12) + waist (3) + left arm (7) + right arm (7) |
|||
body_dq = body_state[0, 42 : 42 + 12 + 3 + 7 + 7] |
|||
body_ddq = body_state[0, 112 : 112 + 12 + 3 + 7 + 7] |
|||
body_tau_est = body_state[0, 77 : 77 + 12 + 3 + 7 + 7] |
|||
floating_base_pose = body_state[0, 0:7] |
|||
floating_base_vel = body_state[0, 36:42] |
|||
floating_base_acc = body_state[0, 106:112] |
|||
torso_quat = body_state[0, 141:145] |
|||
torso_ang_vel = body_state[0, 145:148] |
|||
|
|||
return { |
|||
"body_q": body_q, |
|||
"body_dq": body_dq, |
|||
"body_ddq": body_ddq, |
|||
"body_tau_est": body_tau_est, |
|||
"floating_base_pose": floating_base_pose, |
|||
"floating_base_vel": floating_base_vel, |
|||
"floating_base_acc": floating_base_acc, |
|||
"torso_quat": torso_quat, |
|||
"torso_ang_vel": torso_ang_vel, |
|||
} |
|||
|
|||
def queue_action(self, action: dict[str, any]): |
|||
# action should contain body_q, body_dq, body_tau |
|||
self.body_command_sender.send_command( |
|||
action["body_q"], action["body_dq"], action["body_tau"] |
|||
) |
|||
|
|||
def observation_space(self) -> gym.Space: |
|||
return gym.spaces.Dict( |
|||
{ |
|||
"body_q": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(29,)), |
|||
"body_dq": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(29,)), |
|||
"floating_base_pose": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(7,)), |
|||
"floating_base_vel": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(6,)), |
|||
} |
|||
) |
|||
|
|||
def action_space(self) -> gym.Space: |
|||
return gym.spaces.Dict( |
|||
{ |
|||
"body_q": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(29,)), |
|||
"body_dq": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(29,)), |
|||
"body_tau": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(29,)), |
|||
} |
|||
) |
|||
@ -0,0 +1,324 @@ |
|||
from copy import deepcopy |
|||
from typing import Dict |
|||
|
|||
import gymnasium as gym |
|||
import numpy as np |
|||
from scipy.spatial.transform import Rotation as R |
|||
|
|||
from decoupled_wbc.control.base.humanoid_env import Hands, HumanoidEnv |
|||
from decoupled_wbc.control.envs.g1.g1_body import G1Body |
|||
from decoupled_wbc.control.envs.g1.g1_hand import G1ThreeFingerHand |
|||
from decoupled_wbc.control.envs.g1.sim.simulator_factory import SimulatorFactory, init_channel |
|||
from decoupled_wbc.control.envs.g1.utils.joint_safety import JointSafetyMonitor |
|||
from decoupled_wbc.control.robot_model.instantiation.g1 import instantiate_g1_robot_model |
|||
from decoupled_wbc.control.robot_model.robot_model import RobotModel |
|||
from decoupled_wbc.control.utils.ros_utils import ROSManager |
|||
|
|||
|
|||
class G1Env(HumanoidEnv): |
|||
def __init__( |
|||
self, |
|||
env_name: str = "default", |
|||
robot_model: RobotModel = None, |
|||
wbc_version: str = "v2", |
|||
config: Dict[str, any] = None, |
|||
**kwargs, |
|||
): |
|||
super().__init__() |
|||
self.robot_model = deepcopy(robot_model) # need to cache FK results |
|||
self.config = config |
|||
|
|||
# Initialize safety monitor (visualization disabled) |
|||
self.safety_monitor = JointSafetyMonitor( |
|||
robot_model, enable_viz=False, env_type=self.config.get("ENV_TYPE", "real") |
|||
) |
|||
self.last_obs = None |
|||
self.last_safety_ok = True # Track last safety status from queue_action |
|||
|
|||
init_channel(config=self.config) |
|||
|
|||
# Initialize body and hands |
|||
self._body = G1Body(config=self.config) |
|||
|
|||
self.with_hands = config.get("with_hands", False) |
|||
|
|||
# Gravity compensation settings |
|||
self.enable_gravity_compensation = config.get("enable_gravity_compensation", False) |
|||
self.gravity_compensation_joints = config.get("gravity_compensation_joints", ["arms"]) |
|||
|
|||
if self.enable_gravity_compensation: |
|||
print( |
|||
f"Gravity compensation enabled for joint groups: {self.gravity_compensation_joints}" |
|||
) |
|||
if self.with_hands: |
|||
self._hands = Hands() |
|||
self._hands.left = G1ThreeFingerHand(is_left=True) |
|||
self._hands.right = G1ThreeFingerHand(is_left=False) |
|||
|
|||
# Initialize simulator if in simulation mode |
|||
self.use_sim = self.config.get("ENV_TYPE") == "sim" |
|||
|
|||
if self.use_sim: |
|||
# Create simulator using factory |
|||
|
|||
kwargs.update( |
|||
{ |
|||
"onscreen": self.config.get("ENABLE_ONSCREEN", True), |
|||
"offscreen": self.config.get("ENABLE_OFFSCREEN", False), |
|||
} |
|||
) |
|||
self.sim = SimulatorFactory.create_simulator( |
|||
config=self.config, |
|||
env_name=env_name, |
|||
wbc_version=wbc_version, |
|||
body_ik_solver_settings_type=kwargs.get("body_ik_solver_settings_type", "default"), |
|||
**kwargs, |
|||
) |
|||
else: |
|||
self.sim = None |
|||
|
|||
# using the real robot |
|||
self.calibrate_hands() |
|||
|
|||
# Initialize ROS 2 node |
|||
self.ros_manager = ROSManager(node_name="g1_env") |
|||
self.ros_node = self.ros_manager.node |
|||
|
|||
self.delay_list = [] |
|||
self.visualize_delay = False |
|||
self.print_delay_interval = 100 |
|||
self.cnt = 0 |
|||
|
|||
def start_simulator(self): |
|||
# imag epublish disabled since the sim is running in a sub-thread |
|||
SimulatorFactory.start_simulator(self.sim, as_thread=True, enable_image_publish=False) |
|||
|
|||
def step_simulator(self): |
|||
sim_num_steps = int(self.config["REWARD_DT"] / self.config["SIMULATE_DT"]) |
|||
for _ in range(sim_num_steps): |
|||
self.sim.sim_env.sim_step() |
|||
self.sim.sim_env.update_viewer() |
|||
|
|||
def body(self) -> G1Body: |
|||
return self._body |
|||
|
|||
def hands(self) -> Hands: |
|||
if not self.with_hands: |
|||
raise RuntimeError( |
|||
"Hands not initialized. Use --with_hands True to enable hand functionality." |
|||
) |
|||
return self._hands |
|||
|
|||
def observe(self) -> Dict[str, any]: |
|||
# Get observations from body and hands |
|||
body_obs = self.body().observe() |
|||
|
|||
body_q = body_obs["body_q"] |
|||
body_dq = body_obs["body_dq"] |
|||
body_ddq = body_obs["body_ddq"] |
|||
body_tau_est = body_obs["body_tau_est"] |
|||
|
|||
if self.with_hands: |
|||
left_hand_obs = self.hands().left.observe() |
|||
right_hand_obs = self.hands().right.observe() |
|||
left_hand_q = left_hand_obs["hand_q"] |
|||
right_hand_q = right_hand_obs["hand_q"] |
|||
left_hand_dq = left_hand_obs["hand_dq"] |
|||
right_hand_dq = right_hand_obs["hand_dq"] |
|||
left_hand_ddq = left_hand_obs["hand_ddq"] |
|||
right_hand_ddq = right_hand_obs["hand_ddq"] |
|||
left_hand_tau_est = left_hand_obs["hand_tau_est"] |
|||
right_hand_tau_est = right_hand_obs["hand_tau_est"] |
|||
|
|||
# Body and hand joint measurements come in actuator order, so we need to convert them to joint order |
|||
whole_q = self.robot_model.get_configuration_from_actuated_joints( |
|||
body_actuated_joint_values=body_q, |
|||
left_hand_actuated_joint_values=left_hand_q, |
|||
right_hand_actuated_joint_values=right_hand_q, |
|||
) |
|||
whole_dq = self.robot_model.get_configuration_from_actuated_joints( |
|||
body_actuated_joint_values=body_dq, |
|||
left_hand_actuated_joint_values=left_hand_dq, |
|||
right_hand_actuated_joint_values=right_hand_dq, |
|||
) |
|||
whole_ddq = self.robot_model.get_configuration_from_actuated_joints( |
|||
body_actuated_joint_values=body_ddq, |
|||
left_hand_actuated_joint_values=left_hand_ddq, |
|||
right_hand_actuated_joint_values=right_hand_ddq, |
|||
) |
|||
whole_tau_est = self.robot_model.get_configuration_from_actuated_joints( |
|||
body_actuated_joint_values=body_tau_est, |
|||
left_hand_actuated_joint_values=left_hand_tau_est, |
|||
right_hand_actuated_joint_values=right_hand_tau_est, |
|||
) |
|||
else: |
|||
# Body and hand joint measurements come in actuator order, so we need to convert them to joint order |
|||
whole_q = self.robot_model.get_configuration_from_actuated_joints( |
|||
body_actuated_joint_values=body_q, |
|||
) |
|||
whole_dq = self.robot_model.get_configuration_from_actuated_joints( |
|||
body_actuated_joint_values=body_dq, |
|||
) |
|||
whole_ddq = self.robot_model.get_configuration_from_actuated_joints( |
|||
body_actuated_joint_values=body_ddq, |
|||
) |
|||
whole_tau_est = self.robot_model.get_configuration_from_actuated_joints( |
|||
body_actuated_joint_values=body_tau_est, |
|||
) |
|||
|
|||
eef_obs = self.get_eef_obs(whole_q) |
|||
|
|||
obs = { |
|||
"q": whole_q, |
|||
"dq": whole_dq, |
|||
"ddq": whole_ddq, |
|||
"tau_est": whole_tau_est, |
|||
"floating_base_pose": body_obs["floating_base_pose"], |
|||
"floating_base_vel": body_obs["floating_base_vel"], |
|||
"floating_base_acc": body_obs["floating_base_acc"], |
|||
"wrist_pose": np.concatenate([eef_obs["left_wrist_pose"], eef_obs["right_wrist_pose"]]), |
|||
"torso_quat": body_obs["torso_quat"], |
|||
"torso_ang_vel": body_obs["torso_ang_vel"], |
|||
} |
|||
|
|||
if self.use_sim and self.sim: |
|||
obs.update(self.sim.get_privileged_obs()) |
|||
|
|||
# Store last observation for safety checking |
|||
self.last_obs = obs |
|||
|
|||
return obs |
|||
|
|||
@property |
|||
def observation_space(self) -> gym.Space: |
|||
# @todo: check if the low and high bounds are correct for body_obs. |
|||
q_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,)) |
|||
dq_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,)) |
|||
ddq_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,)) |
|||
tau_est_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,)) |
|||
floating_base_pose_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(7,)) |
|||
floating_base_vel_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(6,)) |
|||
floating_base_acc_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(6,)) |
|||
wrist_pose_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(7 + 7,)) |
|||
return gym.spaces.Dict( |
|||
{ |
|||
"floating_base_pose": floating_base_pose_space, |
|||
"floating_base_vel": floating_base_vel_space, |
|||
"floating_base_acc": floating_base_acc_space, |
|||
"q": q_space, |
|||
"dq": dq_space, |
|||
"ddq": ddq_space, |
|||
"tau_est": tau_est_space, |
|||
"wrist_pose": wrist_pose_space, |
|||
} |
|||
) |
|||
|
|||
def queue_action(self, action: Dict[str, any]): |
|||
# Safety check |
|||
if self.last_obs is not None: |
|||
safety_result = self.safety_monitor.handle_violations(self.last_obs, action) |
|||
action = safety_result["action"] |
|||
|
|||
# Map action from joint order to actuator order |
|||
body_actuator_q = self.robot_model.get_body_actuated_joints(action["q"]) |
|||
|
|||
self.body().queue_action( |
|||
{ |
|||
"body_q": body_actuator_q, |
|||
"body_dq": np.zeros_like(body_actuator_q), |
|||
"body_tau": np.zeros_like(body_actuator_q), |
|||
} |
|||
) |
|||
|
|||
if self.with_hands: |
|||
left_hand_actuator_q = self.robot_model.get_hand_actuated_joints( |
|||
action["q"], side="left" |
|||
) |
|||
right_hand_actuator_q = self.robot_model.get_hand_actuated_joints( |
|||
action["q"], side="right" |
|||
) |
|||
self.hands().left.queue_action({"hand_q": left_hand_actuator_q}) |
|||
self.hands().right.queue_action({"hand_q": right_hand_actuator_q}) |
|||
|
|||
def action_space(self) -> gym.Space: |
|||
return gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,)) |
|||
|
|||
def calibrate_hands(self): |
|||
"""Calibrate the hand joint qpos if real robot""" |
|||
if self.with_hands: |
|||
print("calibrating left hand") |
|||
self.hands().left.calibrate_hand() |
|||
print("calibrating right hand") |
|||
self.hands().right.calibrate_hand() |
|||
else: |
|||
print("Skipping hand calibration - hands disabled") |
|||
|
|||
def set_ik_indicator(self, teleop_cmd): |
|||
"""Set the IK indicators for the simulator""" |
|||
if self.config["SIMULATOR"] == "robocasa": |
|||
if "left_wrist" in teleop_cmd and "right_wrist" in teleop_cmd: |
|||
left_wrist_input_pose = teleop_cmd["left_wrist"] |
|||
right_wrist_input_pose = teleop_cmd["right_wrist"] |
|||
ik_wrapper = self.sim.env.env.unwrapped.env |
|||
ik_wrapper.set_target_poses_outside_env( |
|||
[left_wrist_input_pose, right_wrist_input_pose] |
|||
) |
|||
else: |
|||
raise NotImplementedError("IK indicators are only implemented for robocasa simulator") |
|||
|
|||
def set_sync_mode(self, sync_mode: bool, steps_per_action: int = 4): |
|||
"""When set to True, the simulator will wait for the action to be sent to it""" |
|||
if self.config["SIMULATOR"] == "robocasa": |
|||
self.sim.set_sync_mode(sync_mode, steps_per_action) |
|||
|
|||
def reset(self): |
|||
if self.sim: |
|||
self.sim.reset() |
|||
|
|||
def close(self): |
|||
if self.sim: |
|||
self.sim.close() |
|||
|
|||
def robot_model(self) -> RobotModel: |
|||
return self.robot_model |
|||
|
|||
def get_reward(self): |
|||
if self.sim: |
|||
return self.sim.get_reward() |
|||
|
|||
def reset_obj_pos(self): |
|||
if hasattr(self.sim, "base_env") and hasattr(self.sim.base_env, "reset_obj_pos"): |
|||
self.sim.base_env.reset_obj_pos() |
|||
|
|||
def get_eef_obs(self, q: np.ndarray) -> Dict[str, np.ndarray]: |
|||
self.robot_model.cache_forward_kinematics(q) |
|||
eef_obs = {} |
|||
for side in ["left", "right"]: |
|||
wrist_placement = self.robot_model.frame_placement( |
|||
self.robot_model.supplemental_info.hand_frame_names[side] |
|||
) |
|||
wrist_pos, wrist_quat = wrist_placement.translation[:3], R.from_matrix( |
|||
wrist_placement.rotation |
|||
).as_quat(scalar_first=True) |
|||
eef_obs[f"{side}_wrist_pose"] = np.concatenate([wrist_pos, wrist_quat]) |
|||
|
|||
return eef_obs |
|||
|
|||
def get_joint_safety_status(self) -> bool: |
|||
"""Get current joint safety status from the last queue_action safety check. |
|||
|
|||
Returns: |
|||
bool: True if joints are safe (no shutdown required), False if unsafe |
|||
""" |
|||
return self.last_safety_ok |
|||
|
|||
def handle_keyboard_button(self, key): |
|||
# Only handles keyboard buttons for the mujoco simulator for now. |
|||
if self.use_sim and self.config.get("SIMULATOR", "mujoco") == "mujoco": |
|||
self.sim.handle_keyboard_button(key) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
env = G1Env(robot_model=instantiate_g1_robot_model(), wbc_version="gear_wbc") |
|||
while True: |
|||
print(env.observe()) |
|||
@ -0,0 +1,89 @@ |
|||
import time |
|||
|
|||
import gymnasium as gym |
|||
import numpy as np |
|||
|
|||
from decoupled_wbc.control.base.env import Env |
|||
from decoupled_wbc.control.envs.g1.utils.command_sender import HandCommandSender |
|||
from decoupled_wbc.control.envs.g1.utils.state_processor import HandStateProcessor |
|||
|
|||
|
|||
class G1ThreeFingerHand(Env): |
|||
def __init__(self, is_left: bool = True): |
|||
super().__init__() |
|||
self.is_left = is_left |
|||
self.hand_state_processor = HandStateProcessor(is_left=self.is_left) |
|||
self.hand_command_sender = HandCommandSender(is_left=self.is_left) |
|||
self.hand_q_offset = np.zeros(7) |
|||
|
|||
def observe(self) -> dict[str, any]: |
|||
hand_state = self.hand_state_processor._prepare_low_state() # (1, 28) |
|||
assert hand_state.shape == (1, 28) |
|||
|
|||
# Apply offset to the hand state |
|||
hand_state[0, :7] = hand_state[0, :7] + self.hand_q_offset |
|||
|
|||
hand_q = hand_state[0, :7] |
|||
hand_dq = hand_state[0, 7:14] |
|||
hand_ddq = hand_state[0, 21:28] |
|||
hand_tau_est = hand_state[0, 14:21] |
|||
|
|||
# Return the state for this specific hand (left or right) |
|||
return { |
|||
"hand_q": hand_q, |
|||
"hand_dq": hand_dq, |
|||
"hand_ddq": hand_ddq, |
|||
"hand_tau_est": hand_tau_est, |
|||
} |
|||
|
|||
def queue_action(self, action: dict[str, any]): |
|||
# Apply offset to the hand target |
|||
action["hand_q"] = action["hand_q"] - self.hand_q_offset |
|||
|
|||
# action should contain hand_q |
|||
self.hand_command_sender.send_command(action["hand_q"]) |
|||
|
|||
def observation_space(self) -> gym.Space: |
|||
return gym.spaces.Dict( |
|||
{ |
|||
"hand_q": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(7,)), |
|||
"hand_dq": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(7,)), |
|||
"hand_ddq": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(7,)), |
|||
"hand_tau_est": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(7,)), |
|||
} |
|||
) |
|||
|
|||
def action_space(self) -> gym.Space: |
|||
return gym.spaces.Dict({"hand_q": gym.spaces.Box(low=-np.inf, high=np.inf, shape=(7,))}) |
|||
|
|||
def calibrate_hand(self): |
|||
hand_obs = self.observe() |
|||
hand_q = hand_obs["hand_q"] |
|||
|
|||
hand_q_target = np.zeros_like(hand_q) |
|||
hand_q_target[0] = hand_q[0] |
|||
|
|||
# joint limit |
|||
hand_q0_upper_limit = np.deg2rad(60) # lower limit is -60 |
|||
|
|||
# move the figure counterclockwise until the limit |
|||
while True: |
|||
|
|||
if hand_q_target[0] - hand_q[0] < np.deg2rad(60): |
|||
hand_q_target[0] += np.deg2rad(10) |
|||
else: |
|||
self.hand_q_offset[0] = hand_q0_upper_limit - hand_q[0] |
|||
break |
|||
|
|||
self.queue_action({"hand_q": hand_q_target}) |
|||
|
|||
hand_obs = self.observe() |
|||
hand_q = hand_obs["hand_q"] |
|||
|
|||
time.sleep(0.1) |
|||
|
|||
print("done calibration, q0 offset (deg):", np.rad2deg(self.hand_q_offset[0])) |
|||
|
|||
# done calibrating, set target to zero |
|||
self.hand_q_target = np.zeros_like(hand_q) |
|||
self.queue_action({"hand_q": self.hand_q_target}) |
|||
@ -0,0 +1,772 @@ |
|||
import argparse |
|||
import pathlib |
|||
from pathlib import Path |
|||
import threading |
|||
from threading import Lock, Thread |
|||
from typing import Dict |
|||
|
|||
import mujoco |
|||
import mujoco.viewer |
|||
import numpy as np |
|||
import rclpy |
|||
from unitree_sdk2py.core.channel import ChannelFactoryInitialize |
|||
import yaml |
|||
|
|||
from decoupled_wbc.control.envs.g1.sim.image_publish_utils import ImagePublishProcess |
|||
from decoupled_wbc.control.envs.g1.sim.metric_utils import check_contact, check_height |
|||
from decoupled_wbc.control.envs.g1.sim.sim_utilts import get_subtree_body_names |
|||
from decoupled_wbc.control.envs.g1.sim.unitree_sdk2py_bridge import ElasticBand, UnitreeSdk2Bridge |
|||
|
|||
DECOUPLED_WBC_ROOT = Path(__file__).resolve().parent.parent.parent.parent.parent.parent |
|||
|
|||
|
|||
class DefaultEnv: |
|||
"""Base environment class that handles simulation environment setup and step""" |
|||
|
|||
def __init__( |
|||
self, |
|||
config: Dict[str, any], |
|||
env_name: str = "default", |
|||
camera_configs: Dict[str, any] = {}, |
|||
onscreen: bool = False, |
|||
offscreen: bool = False, |
|||
enable_image_publish: bool = False, |
|||
): |
|||
# global_view is only set up for this specifc scene for now. |
|||
if config["ROBOT_SCENE"] == "decoupled_wbc/control/robot_model/model_data/g1/scene_29dof.xml": |
|||
camera_configs["global_view"] = { |
|||
"height": 400, |
|||
"width": 400, |
|||
} |
|||
self.config = config |
|||
self.env_name = env_name |
|||
self.num_body_dof = self.config["NUM_JOINTS"] |
|||
self.num_hand_dof = self.config["NUM_HAND_JOINTS"] |
|||
self.sim_dt = self.config["SIMULATE_DT"] |
|||
self.obs = None |
|||
self.torques = np.zeros(self.num_body_dof + self.num_hand_dof * 2) |
|||
self.torque_limit = np.array(self.config["motor_effort_limit_list"]) |
|||
self.camera_configs = camera_configs |
|||
|
|||
# Thread safety lock |
|||
self.reward_lock = Lock() |
|||
|
|||
# Unitree bridge will be initialized by the simulator |
|||
self.unitree_bridge = None |
|||
|
|||
# Store display mode |
|||
self.onscreen = onscreen |
|||
|
|||
# Initialize scene (defined in subclasses) |
|||
self.init_scene() |
|||
self.last_reward = 0 |
|||
|
|||
# Setup offscreen rendering if needed |
|||
self.offscreen = offscreen |
|||
if self.offscreen: |
|||
self.init_renderers() |
|||
self.image_dt = self.config.get("IMAGE_DT", 0.033333) |
|||
self.image_publish_process = None |
|||
|
|||
def start_image_publish_subprocess(self, start_method: str = "spawn", camera_port: int = 5555): |
|||
# Use spawn method for better GIL isolation, or configured method |
|||
if len(self.camera_configs) == 0: |
|||
print( |
|||
"Warning: No camera configs provided, image publishing subprocess will not be started" |
|||
) |
|||
return |
|||
start_method = self.config.get("MP_START_METHOD", "spawn") |
|||
self.image_publish_process = ImagePublishProcess( |
|||
camera_configs=self.camera_configs, |
|||
image_dt=self.image_dt, |
|||
zmq_port=camera_port, |
|||
start_method=start_method, |
|||
verbose=self.config.get("verbose", False), |
|||
) |
|||
self.image_publish_process.start_process() |
|||
|
|||
def init_scene(self): |
|||
"""Initialize the default robot scene""" |
|||
self.mj_model = mujoco.MjModel.from_xml_path( |
|||
str(pathlib.Path(DECOUPLED_WBC_ROOT) / self.config["ROBOT_SCENE"]) |
|||
) |
|||
self.mj_data = mujoco.MjData(self.mj_model) |
|||
self.mj_model.opt.timestep = self.sim_dt |
|||
self.torso_index = mujoco.mj_name2id(self.mj_model, mujoco.mjtObj.mjOBJ_BODY, "torso_link") |
|||
self.root_body = "pelvis" |
|||
# Enable the elastic band |
|||
if self.config["ENABLE_ELASTIC_BAND"]: |
|||
self.elastic_band = ElasticBand() |
|||
if "g1" in self.config["ROBOT_TYPE"]: |
|||
if self.config["enable_waist"]: |
|||
self.band_attached_link = self.mj_model.body("pelvis").id |
|||
else: |
|||
self.band_attached_link = self.mj_model.body("torso_link").id |
|||
elif "h1" in self.config["ROBOT_TYPE"]: |
|||
self.band_attached_link = self.mj_model.body("torso_link").id |
|||
else: |
|||
self.band_attached_link = self.mj_model.body("base_link").id |
|||
|
|||
if self.onscreen: |
|||
self.viewer = mujoco.viewer.launch_passive( |
|||
self.mj_model, |
|||
self.mj_data, |
|||
key_callback=self.elastic_band.MujuocoKeyCallback, |
|||
show_left_ui=False, |
|||
show_right_ui=False, |
|||
) |
|||
else: |
|||
mujoco.mj_forward(self.mj_model, self.mj_data) |
|||
self.viewer = None |
|||
else: |
|||
if self.onscreen: |
|||
self.viewer = mujoco.viewer.launch_passive( |
|||
self.mj_model, self.mj_data, show_left_ui=False, show_right_ui=False |
|||
) |
|||
else: |
|||
mujoco.mj_forward(self.mj_model, self.mj_data) |
|||
self.viewer = None |
|||
|
|||
if self.viewer: |
|||
# viewer camera |
|||
self.viewer.cam.azimuth = 120 # Horizontal rotation in degrees |
|||
self.viewer.cam.elevation = -30 # Vertical tilt in degrees |
|||
self.viewer.cam.distance = 2.0 # Distance from camera to target |
|||
self.viewer.cam.lookat = np.array([0, 0, 0.5]) # Point the camera is looking at |
|||
|
|||
# Note that the actuator order is the same as the joint order in the mujoco model. |
|||
self.body_joint_index = [] |
|||
self.left_hand_index = [] |
|||
self.right_hand_index = [] |
|||
for i in range(self.mj_model.njnt): |
|||
name = self.mj_model.joint(i).name |
|||
if any( |
|||
[ |
|||
part_name in name |
|||
for part_name in ["hip", "knee", "ankle", "waist", "shoulder", "elbow", "wrist"] |
|||
] |
|||
): |
|||
self.body_joint_index.append(i) |
|||
elif "left_hand" in name: |
|||
self.left_hand_index.append(i) |
|||
elif "right_hand" in name: |
|||
self.right_hand_index.append(i) |
|||
|
|||
assert len(self.body_joint_index) == self.config["NUM_JOINTS"] |
|||
assert len(self.left_hand_index) == self.config["NUM_HAND_JOINTS"] |
|||
assert len(self.right_hand_index) == self.config["NUM_HAND_JOINTS"] |
|||
|
|||
self.body_joint_index = np.array(self.body_joint_index) |
|||
self.left_hand_index = np.array(self.left_hand_index) |
|||
self.right_hand_index = np.array(self.right_hand_index) |
|||
|
|||
def init_renderers(self): |
|||
# Initialize camera renderers |
|||
self.renderers = {} |
|||
for camera_name, camera_config in self.camera_configs.items(): |
|||
renderer = mujoco.Renderer( |
|||
self.mj_model, height=camera_config["height"], width=camera_config["width"] |
|||
) |
|||
self.renderers[camera_name] = renderer |
|||
|
|||
def compute_body_torques(self) -> np.ndarray: |
|||
"""Compute body torques based on the current robot state""" |
|||
body_torques = np.zeros(self.num_body_dof) |
|||
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd: |
|||
for i in range(self.unitree_bridge.num_body_motor): |
|||
if self.unitree_bridge.use_sensor: |
|||
body_torques[i] = ( |
|||
self.unitree_bridge.low_cmd.motor_cmd[i].tau |
|||
+ self.unitree_bridge.low_cmd.motor_cmd[i].kp |
|||
* (self.unitree_bridge.low_cmd.motor_cmd[i].q - self.mj_data.sensordata[i]) |
|||
+ self.unitree_bridge.low_cmd.motor_cmd[i].kd |
|||
* ( |
|||
self.unitree_bridge.low_cmd.motor_cmd[i].dq |
|||
- self.mj_data.sensordata[i + self.unitree_bridge.num_body_motor] |
|||
) |
|||
) |
|||
else: |
|||
body_torques[i] = ( |
|||
self.unitree_bridge.low_cmd.motor_cmd[i].tau |
|||
+ self.unitree_bridge.low_cmd.motor_cmd[i].kp |
|||
* ( |
|||
self.unitree_bridge.low_cmd.motor_cmd[i].q |
|||
- self.mj_data.qpos[self.body_joint_index[i] + 7 - 1] |
|||
) |
|||
+ self.unitree_bridge.low_cmd.motor_cmd[i].kd |
|||
* ( |
|||
self.unitree_bridge.low_cmd.motor_cmd[i].dq |
|||
- self.mj_data.qvel[self.body_joint_index[i] + 6 - 1] |
|||
) |
|||
) |
|||
return body_torques |
|||
|
|||
def compute_hand_torques(self) -> np.ndarray: |
|||
"""Compute hand torques based on the current robot state""" |
|||
left_hand_torques = np.zeros(self.num_hand_dof) |
|||
right_hand_torques = np.zeros(self.num_hand_dof) |
|||
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd: |
|||
for i in range(self.unitree_bridge.num_hand_motor): |
|||
left_hand_torques[i] = ( |
|||
self.unitree_bridge.left_hand_cmd.motor_cmd[i].tau |
|||
+ self.unitree_bridge.left_hand_cmd.motor_cmd[i].kp |
|||
* ( |
|||
self.unitree_bridge.left_hand_cmd.motor_cmd[i].q |
|||
- self.mj_data.qpos[self.left_hand_index[i] + 7 - 1] |
|||
) |
|||
+ self.unitree_bridge.left_hand_cmd.motor_cmd[i].kd |
|||
* ( |
|||
self.unitree_bridge.left_hand_cmd.motor_cmd[i].dq |
|||
- self.mj_data.qvel[self.left_hand_index[i] + 6 - 1] |
|||
) |
|||
) |
|||
right_hand_torques[i] = ( |
|||
self.unitree_bridge.right_hand_cmd.motor_cmd[i].tau |
|||
+ self.unitree_bridge.right_hand_cmd.motor_cmd[i].kp |
|||
* ( |
|||
self.unitree_bridge.right_hand_cmd.motor_cmd[i].q |
|||
- self.mj_data.qpos[self.right_hand_index[i] + 7 - 1] |
|||
) |
|||
+ self.unitree_bridge.right_hand_cmd.motor_cmd[i].kd |
|||
* ( |
|||
self.unitree_bridge.right_hand_cmd.motor_cmd[i].dq |
|||
- self.mj_data.qvel[self.right_hand_index[i] + 6 - 1] |
|||
) |
|||
) |
|||
return np.concatenate((left_hand_torques, right_hand_torques)) |
|||
|
|||
def compute_body_qpos(self) -> np.ndarray: |
|||
"""Compute body joint positions based on the current command""" |
|||
body_qpos = np.zeros(self.num_body_dof) |
|||
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd: |
|||
for i in range(self.unitree_bridge.num_body_motor): |
|||
body_qpos[i] = self.unitree_bridge.low_cmd.motor_cmd[i].q |
|||
return body_qpos |
|||
|
|||
def compute_hand_qpos(self) -> np.ndarray: |
|||
"""Compute hand joint positions based on the current command""" |
|||
hand_qpos = np.zeros(self.num_hand_dof * 2) |
|||
if self.unitree_bridge is not None and self.unitree_bridge.low_cmd: |
|||
for i in range(self.unitree_bridge.num_hand_motor): |
|||
hand_qpos[i] = self.unitree_bridge.left_hand_cmd.motor_cmd[i].q |
|||
hand_qpos[i + self.num_hand_dof] = self.unitree_bridge.right_hand_cmd.motor_cmd[i].q |
|||
return hand_qpos |
|||
|
|||
def prepare_obs(self) -> Dict[str, any]: |
|||
"""Prepare observation dictionary from the current robot state""" |
|||
obs = {} |
|||
obs["floating_base_pose"] = self.mj_data.qpos[:7] |
|||
obs["floating_base_vel"] = self.mj_data.qvel[:6] |
|||
obs["floating_base_acc"] = self.mj_data.qacc[:6] |
|||
obs["secondary_imu_quat"] = self.mj_data.xquat[self.torso_index] |
|||
obs["secondary_imu_vel"] = self.mj_data.cvel[self.torso_index] |
|||
obs["body_q"] = self.mj_data.qpos[self.body_joint_index + 7 - 1] |
|||
obs["body_dq"] = self.mj_data.qvel[self.body_joint_index + 6 - 1] |
|||
obs["body_ddq"] = self.mj_data.qacc[self.body_joint_index + 6 - 1] |
|||
obs["body_tau_est"] = self.mj_data.actuator_force[self.body_joint_index - 1] |
|||
if self.num_hand_dof > 0: |
|||
obs["left_hand_q"] = self.mj_data.qpos[self.left_hand_index + 7 - 1] |
|||
obs["left_hand_dq"] = self.mj_data.qvel[self.left_hand_index + 6 - 1] |
|||
obs["left_hand_ddq"] = self.mj_data.qacc[self.left_hand_index + 6 - 1] |
|||
obs["left_hand_tau_est"] = self.mj_data.actuator_force[self.left_hand_index - 1] |
|||
obs["right_hand_q"] = self.mj_data.qpos[self.right_hand_index + 7 - 1] |
|||
obs["right_hand_dq"] = self.mj_data.qvel[self.right_hand_index + 6 - 1] |
|||
obs["right_hand_ddq"] = self.mj_data.qacc[self.right_hand_index + 6 - 1] |
|||
obs["right_hand_tau_est"] = self.mj_data.actuator_force[self.right_hand_index - 1] |
|||
obs["time"] = self.mj_data.time |
|||
return obs |
|||
|
|||
def sim_step(self): |
|||
self.obs = self.prepare_obs() |
|||
self.unitree_bridge.PublishLowState(self.obs) |
|||
if self.unitree_bridge.joystick: |
|||
self.unitree_bridge.PublishWirelessController() |
|||
if self.config["ENABLE_ELASTIC_BAND"]: |
|||
if self.elastic_band.enable: |
|||
# Get Cartesian pose and velocity of the band_attached_link |
|||
pose = np.concatenate( |
|||
[ |
|||
self.mj_data.xpos[self.band_attached_link], # link position in world |
|||
self.mj_data.xquat[ |
|||
self.band_attached_link |
|||
], # link quaternion in world [w,x,y,z] |
|||
np.zeros(6), # placeholder for velocity |
|||
] |
|||
) |
|||
|
|||
# Get velocity in world frame |
|||
mujoco.mj_objectVelocity( |
|||
self.mj_model, |
|||
self.mj_data, |
|||
mujoco.mjtObj.mjOBJ_BODY, |
|||
self.band_attached_link, |
|||
pose[7:13], |
|||
0, # 0 for world frame |
|||
) |
|||
|
|||
# Reorder velocity from [ang, lin] to [lin, ang] |
|||
pose[7:10], pose[10:13] = pose[10:13], pose[7:10].copy() |
|||
self.mj_data.xfrc_applied[self.band_attached_link] = self.elastic_band.Advance(pose) |
|||
else: |
|||
# explicitly resetting the force when the band is not enabled |
|||
self.mj_data.xfrc_applied[self.band_attached_link] = np.zeros(6) |
|||
body_torques = self.compute_body_torques() |
|||
hand_torques = self.compute_hand_torques() |
|||
self.torques[self.body_joint_index - 1] = body_torques |
|||
if self.num_hand_dof > 0: |
|||
self.torques[self.left_hand_index - 1] = hand_torques[: self.num_hand_dof] |
|||
self.torques[self.right_hand_index - 1] = hand_torques[self.num_hand_dof :] |
|||
|
|||
self.torques = np.clip(self.torques, -self.torque_limit, self.torque_limit) |
|||
|
|||
if self.config["FREE_BASE"]: |
|||
self.mj_data.ctrl = np.concatenate((np.zeros(6), self.torques)) |
|||
else: |
|||
self.mj_data.ctrl = self.torques |
|||
mujoco.mj_step(self.mj_model, self.mj_data) |
|||
# self.check_self_collision() |
|||
|
|||
def kinematics_step(self): |
|||
""" |
|||
Run kinematics only: compute the qpos of the robot and directly set the qpos. |
|||
For debugging purposes. |
|||
""" |
|||
if self.unitree_bridge is not None: |
|||
self.unitree_bridge.PublishLowState(self.prepare_obs()) |
|||
if self.unitree_bridge.joystick: |
|||
self.unitree_bridge.PublishWirelessController() |
|||
|
|||
if self.config["ENABLE_ELASTIC_BAND"]: |
|||
if self.elastic_band.enable: |
|||
# Get Cartesian pose and velocity of the band_attached_link |
|||
pose = np.concatenate( |
|||
[ |
|||
self.mj_data.xpos[self.band_attached_link], # link position in world |
|||
self.mj_data.xquat[ |
|||
self.band_attached_link |
|||
], # link quaternion in world [w,x,y,z] |
|||
np.zeros(6), # placeholder for velocity |
|||
] |
|||
) |
|||
|
|||
# Get velocity in world frame |
|||
mujoco.mj_objectVelocity( |
|||
self.mj_model, |
|||
self.mj_data, |
|||
mujoco.mjtObj.mjOBJ_BODY, |
|||
self.band_attached_link, |
|||
pose[7:13], |
|||
0, # 0 for world frame |
|||
) |
|||
|
|||
# Reorder velocity from [ang, lin] to [lin, ang] |
|||
pose[7:10], pose[10:13] = pose[10:13], pose[7:10].copy() |
|||
|
|||
self.mj_data.xfrc_applied[self.band_attached_link] = self.elastic_band.Advance(pose) |
|||
else: |
|||
# explicitly resetting the force when the band is not enabled |
|||
self.mj_data.xfrc_applied[self.band_attached_link] = np.zeros(6) |
|||
|
|||
body_qpos = self.compute_body_qpos() # (num_body_dof,) |
|||
hand_qpos = self.compute_hand_qpos() # (num_hand_dof * 2,) |
|||
|
|||
self.mj_data.qpos[self.body_joint_index + 7 - 1] = body_qpos |
|||
self.mj_data.qpos[self.left_hand_index + 7 - 1] = hand_qpos[: self.num_hand_dof] |
|||
self.mj_data.qpos[self.right_hand_index + 7 - 1] = hand_qpos[self.num_hand_dof :] |
|||
|
|||
mujoco.mj_kinematics(self.mj_model, self.mj_data) |
|||
mujoco.mj_comPos(self.mj_model, self.mj_data) |
|||
|
|||
def apply_perturbation(self, key): |
|||
"""Apply perturbation to the robot""" |
|||
# Add velocity perturbations in body frame |
|||
perturbation_x_body = 0.0 # forward/backward in body frame |
|||
perturbation_y_body = 0.0 # left/right in body frame |
|||
if key == "up": |
|||
perturbation_x_body = 1.0 # forward |
|||
elif key == "down": |
|||
perturbation_x_body = -1.0 # backward |
|||
elif key == "left": |
|||
perturbation_y_body = 1.0 # left |
|||
elif key == "right": |
|||
perturbation_y_body = -1.0 # right |
|||
|
|||
# Transform body frame velocity to world frame using MuJoCo's rotation |
|||
vel_body = np.array([perturbation_x_body, perturbation_y_body, 0.0]) |
|||
vel_world = np.zeros(3) |
|||
base_quat = self.mj_data.qpos[3:7] # [w, x, y, z] quaternion |
|||
|
|||
# Use MuJoCo's robust quaternion rotation (handles invalid quaternions automatically) |
|||
mujoco.mju_rotVecQuat(vel_world, vel_body, base_quat) |
|||
|
|||
# Apply to base linear velocity in world frame |
|||
self.mj_data.qvel[0] += vel_world[0] # world X velocity |
|||
self.mj_data.qvel[1] += vel_world[1] # world Y velocity |
|||
|
|||
# Update dynamics after velocity change |
|||
mujoco.mj_forward(self.mj_model, self.mj_data) |
|||
|
|||
def update_viewer(self): |
|||
if self.viewer is not None: |
|||
self.viewer.sync() |
|||
|
|||
def update_viewer_camera(self): |
|||
if self.viewer is not None: |
|||
if self.viewer.cam.type == mujoco.mjtCamera.mjCAMERA_TRACKING: |
|||
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FREE |
|||
else: |
|||
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_TRACKING |
|||
|
|||
def update_reward(self): |
|||
"""Calculate reward. Should be implemented by subclasses.""" |
|||
with self.reward_lock: |
|||
self.last_reward = 0 |
|||
|
|||
def get_reward(self): |
|||
"""Thread-safe way to get the last calculated reward.""" |
|||
with self.reward_lock: |
|||
return self.last_reward |
|||
|
|||
def set_unitree_bridge(self, unitree_bridge): |
|||
"""Set the unitree bridge from the simulator""" |
|||
self.unitree_bridge = unitree_bridge |
|||
|
|||
def get_privileged_obs(self): |
|||
"""Get privileged observation. Should be implemented by subclasses.""" |
|||
return {} |
|||
|
|||
def update_render_caches(self): |
|||
"""Update render cache and shared memory for subprocess.""" |
|||
render_caches = {} |
|||
for camera_name, camera_config in self.camera_configs.items(): |
|||
renderer = self.renderers[camera_name] |
|||
if "params" in camera_config: |
|||
renderer.update_scene(self.mj_data, camera=camera_config["params"]) |
|||
else: |
|||
renderer.update_scene(self.mj_data, camera=camera_name) |
|||
render_caches[camera_name + "_image"] = renderer.render() |
|||
|
|||
# Update shared memory if image publishing process is available |
|||
if self.image_publish_process is not None: |
|||
self.image_publish_process.update_shared_memory(render_caches) |
|||
|
|||
return render_caches |
|||
|
|||
def handle_keyboard_button(self, key): |
|||
if self.elastic_band is not None: |
|||
self.elastic_band.handle_keyboard_button(key) |
|||
|
|||
if key == "backspace": |
|||
self.reset() |
|||
if key == "v": |
|||
self.update_viewer_camera() |
|||
if key in ["up", "down", "left", "right"]: |
|||
self.apply_perturbation(key) |
|||
|
|||
def check_fall(self): |
|||
"""Check if the robot has fallen""" |
|||
self.fall = False |
|||
if self.mj_data.qpos[2] < 0.2: |
|||
self.fall = True |
|||
print(f"Warning: Robot has fallen, height: {self.mj_data.qpos[2]:.3f} m") |
|||
|
|||
if self.fall: |
|||
self.reset() |
|||
|
|||
def check_self_collision(self): |
|||
"""Check for self-collision of the robot""" |
|||
robot_bodies = get_subtree_body_names(self.mj_model, self.mj_model.body(self.root_body).id) |
|||
self_collision, contact_bodies = check_contact( |
|||
self.mj_model, self.mj_data, robot_bodies, robot_bodies, return_all_contact_bodies=True |
|||
) |
|||
if self_collision: |
|||
print(f"Warning: Self-collision detected: {contact_bodies}") |
|||
return self_collision |
|||
|
|||
def reset(self): |
|||
mujoco.mj_resetData(self.mj_model, self.mj_data) |
|||
|
|||
|
|||
class CubeEnv(DefaultEnv): |
|||
"""Environment with a cube object for pick and place tasks""" |
|||
|
|||
def __init__( |
|||
self, |
|||
config: Dict[str, any], |
|||
onscreen: bool = False, |
|||
offscreen: bool = False, |
|||
enable_image_publish: bool = False, |
|||
): |
|||
# Override the robot scene |
|||
config = config.copy() # Create a copy to avoid modifying the original |
|||
config["ROBOT_SCENE"] = "decoupled_wbc/control/robot_model/model_data/g1/pnp_cube_43dof.xml" |
|||
super().__init__(config, "cube", {}, onscreen, offscreen, enable_image_publish) |
|||
|
|||
def update_reward(self): |
|||
"""Calculate reward based on gripper contact with cube and cube height""" |
|||
right_hand_body = [ |
|||
"right_hand_thumb_2_link", |
|||
"right_hand_middle_1_link", |
|||
"right_hand_index_1_link", |
|||
] |
|||
gripper_cube_contact = check_contact( |
|||
self.mj_model, self.mj_data, right_hand_body, "cube_body" |
|||
) |
|||
cube_lifted = check_height(self.mj_model, self.mj_data, "cube", 0.85, 2.0) |
|||
|
|||
with self.reward_lock: |
|||
self.last_reward = gripper_cube_contact & cube_lifted |
|||
|
|||
|
|||
class BoxEnv(DefaultEnv): |
|||
"""Environment with a box object for manipulation tasks""" |
|||
|
|||
def __init__( |
|||
self, |
|||
config: Dict[str, any], |
|||
onscreen: bool = False, |
|||
offscreen: bool = False, |
|||
enable_image_publish: bool = False, |
|||
): |
|||
# Override the robot scene |
|||
config = config.copy() # Create a copy to avoid modifying the original |
|||
config["ROBOT_SCENE"] = "decoupled_wbc/control/robot_model/model_data/g1/lift_box_43dof.xml" |
|||
super().__init__(config, "box", {}, onscreen, offscreen, enable_image_publish) |
|||
|
|||
def reward(self): |
|||
"""Calculate reward based on gripper contact with cube and cube height""" |
|||
left_hand_body = [ |
|||
"left_hand_thumb_2_link", |
|||
"left_hand_middle_1_link", |
|||
"left_hand_index_1_link", |
|||
] |
|||
right_hand_body = [ |
|||
"right_hand_thumb_2_link", |
|||
"right_hand_middle_1_link", |
|||
"right_hand_index_1_link", |
|||
] |
|||
gripper_box_contact = check_contact(self.mj_model, self.mj_data, left_hand_body, "box_body") |
|||
gripper_box_contact &= check_contact( |
|||
self.mj_model, self.mj_data, right_hand_body, "box_body" |
|||
) |
|||
box_lifted = check_height(self.mj_model, self.mj_data, "box", 0.92, 2.0) |
|||
|
|||
print("gripper_box_contact: ", gripper_box_contact, "box_lifted: ", box_lifted) |
|||
|
|||
with self.reward_lock: |
|||
self.last_reward = gripper_box_contact & box_lifted |
|||
return self.last_reward |
|||
|
|||
|
|||
class BottleEnv(DefaultEnv): |
|||
"""Environment with a cylinder object for manipulation tasks""" |
|||
|
|||
def __init__( |
|||
self, |
|||
config: Dict[str, any], |
|||
onscreen: bool = False, |
|||
offscreen: bool = False, |
|||
enable_image_publish: bool = False, |
|||
): |
|||
# Override the robot scene |
|||
config = config.copy() # Create a copy to avoid modifying the original |
|||
config["ROBOT_SCENE"] = "decoupled_wbc/control/robot_model/model_data/g1/pnp_bottle_43dof.xml" |
|||
camera_configs = { |
|||
"egoview": { |
|||
"height": 400, |
|||
"width": 400, |
|||
}, |
|||
} |
|||
super().__init__( |
|||
config, "cylinder", camera_configs, onscreen, offscreen, enable_image_publish |
|||
) |
|||
|
|||
self.bottle_body = self.mj_model.body("bottle_body") |
|||
self.bottle_geom = self.mj_model.geom("bottle") |
|||
|
|||
if self.viewer is not None: |
|||
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED |
|||
self.viewer.cam.fixedcamid = self.mj_model.camera("egoview").id |
|||
|
|||
def update_reward(self): |
|||
"""Calculate reward based on gripper contact with cylinder and cylinder height""" |
|||
pass |
|||
|
|||
def get_privileged_obs(self): |
|||
obs_pos = self.mj_data.xpos[self.bottle_body.id] |
|||
obs_quat = self.mj_data.xquat[self.bottle_body.id] |
|||
return {"bottle_pos": obs_pos, "bottle_quat": obs_quat} |
|||
|
|||
|
|||
class BaseSimulator: |
|||
"""Base simulator class that handles initialization and running of simulations""" |
|||
|
|||
def __init__(self, config: Dict[str, any], env_name: str = "default", **kwargs): |
|||
self.config = config |
|||
self.env_name = env_name |
|||
|
|||
# Initialize ROS 2 node |
|||
if not rclpy.ok(): |
|||
rclpy.init() |
|||
self.node = rclpy.create_node("sim_mujoco") |
|||
self.thread = threading.Thread(target=rclpy.spin, args=(self.node,), daemon=True) |
|||
self.thread.start() |
|||
else: |
|||
self.thread = None |
|||
executor = rclpy.get_global_executor() |
|||
self.node = executor.get_nodes()[0] # will only take the first node |
|||
|
|||
# Create rate objects for different update frequencies |
|||
self.sim_dt = self.config["SIMULATE_DT"] |
|||
self.reward_dt = self.config.get("REWARD_DT", 0.02) |
|||
self.image_dt = self.config.get("IMAGE_DT", 0.033333) |
|||
self.viewer_dt = self.config.get("VIEWER_DT", 0.02) |
|||
self.rate = self.node.create_rate(1 / self.sim_dt) |
|||
|
|||
# Create the appropriate environment based on name |
|||
if env_name == "default": |
|||
self.sim_env = DefaultEnv(config, env_name, **kwargs) |
|||
elif env_name == "pnp_cube": |
|||
self.sim_env = CubeEnv(config, **kwargs) |
|||
elif env_name == "lift_box": |
|||
self.sim_env = BoxEnv(config, **kwargs) |
|||
elif env_name == "pnp_bottle": |
|||
self.sim_env = BottleEnv(config, **kwargs) |
|||
else: |
|||
raise ValueError(f"Invalid environment name: {env_name}") |
|||
|
|||
# Initialize the DDS communication layer - should be safe to call multiple times |
|||
|
|||
try: |
|||
if self.config.get("INTERFACE", None): |
|||
ChannelFactoryInitialize(self.config["DOMAIN_ID"], self.config["INTERFACE"]) |
|||
else: |
|||
ChannelFactoryInitialize(self.config["DOMAIN_ID"]) |
|||
except Exception as e: |
|||
# If it fails because it's already initialized, that's okay |
|||
print(f"Note: Channel factory initialization attempt: {e}") |
|||
|
|||
# Initialize the unitree bridge and pass it to the environment |
|||
self.init_unitree_bridge() |
|||
self.sim_env.set_unitree_bridge(self.unitree_bridge) |
|||
|
|||
# Initialize additional components |
|||
self.init_subscriber() |
|||
self.init_publisher() |
|||
|
|||
self.sim_thread = None |
|||
|
|||
def start_as_thread(self): |
|||
# Create simulation thread |
|||
self.sim_thread = Thread(target=self.start) |
|||
self.sim_thread.start() |
|||
|
|||
def start_image_publish_subprocess(self, start_method: str = "spawn", camera_port: int = 5555): |
|||
"""Start the image publish subprocess""" |
|||
self.sim_env.start_image_publish_subprocess(start_method, camera_port) |
|||
|
|||
def init_subscriber(self): |
|||
"""Initialize subscribers. Can be overridden by subclasses.""" |
|||
pass |
|||
|
|||
def init_publisher(self): |
|||
"""Initialize publishers. Can be overridden by subclasses.""" |
|||
pass |
|||
|
|||
def init_unitree_bridge(self): |
|||
"""Initialize the unitree SDK bridge""" |
|||
self.unitree_bridge = UnitreeSdk2Bridge(self.config) |
|||
if self.config["USE_JOYSTICK"]: |
|||
self.unitree_bridge.SetupJoystick( |
|||
device_id=self.config["JOYSTICK_DEVICE"], js_type=self.config["JOYSTICK_TYPE"] |
|||
) |
|||
|
|||
def start(self): |
|||
"""Main simulation loop""" |
|||
sim_cnt = 0 |
|||
|
|||
try: |
|||
while ( |
|||
self.sim_env.viewer and self.sim_env.viewer.is_running() |
|||
) or self.sim_env.viewer is None: |
|||
# Run simulation step |
|||
self.sim_env.sim_step() |
|||
|
|||
# Update viewer at viewer rate |
|||
if sim_cnt % int(self.viewer_dt / self.sim_dt) == 0: |
|||
self.sim_env.update_viewer() |
|||
|
|||
# Calculate reward at reward rate |
|||
if sim_cnt % int(self.reward_dt / self.sim_dt) == 0: |
|||
self.sim_env.update_reward() |
|||
|
|||
# Update render caches at image rate |
|||
if sim_cnt % int(self.image_dt / self.sim_dt) == 0: |
|||
self.sim_env.update_render_caches() |
|||
|
|||
# Sleep to maintain correct rate |
|||
self.rate.sleep() |
|||
|
|||
sim_cnt += 1 |
|||
except rclpy.exceptions.ROSInterruptException: |
|||
# This is expected when ROS shuts down - exit cleanly |
|||
pass |
|||
except Exception: |
|||
self.close() |
|||
|
|||
def __del__(self): |
|||
"""Clean up resources when simulator is deleted""" |
|||
self.close() |
|||
|
|||
def reset(self): |
|||
"""Reset the simulation. Can be overridden by subclasses.""" |
|||
self.sim_env.reset() |
|||
|
|||
def close(self): |
|||
"""Close the simulation. Can be overridden by subclasses.""" |
|||
try: |
|||
# Stop image publishing subprocess |
|||
if self.sim_env.image_publish_process is not None: |
|||
self.sim_env.image_publish_process.stop() |
|||
|
|||
# Close viewer |
|||
if hasattr(self.sim_env, "viewer") and self.sim_env.viewer is not None: |
|||
self.sim_env.viewer.close() |
|||
|
|||
# Shutdown ROS |
|||
if rclpy.ok(): |
|||
rclpy.shutdown() |
|||
except Exception as e: |
|||
print(f"Warning during close: {e}") |
|||
|
|||
def get_privileged_obs(self): |
|||
obs = self.sim_env.get_privileged_obs() |
|||
# TODO: add ros2 topic to get privileged obs |
|||
return obs |
|||
|
|||
def handle_keyboard_button(self, key): |
|||
# Only handles keyboard buttons for default env. |
|||
if self.env_name == "default": |
|||
self.sim_env.handle_keyboard_button(key) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
parser = argparse.ArgumentParser(description="Robot") |
|||
parser.add_argument( |
|||
"--config", |
|||
type=str, |
|||
default="./decoupled_wbc/control/main/teleop/configs/g1_29dof_gear_wbc.yaml", |
|||
help="config file", |
|||
) |
|||
args = parser.parse_args() |
|||
|
|||
with open(args.config, "r") as file: |
|||
config = yaml.load(file, Loader=yaml.FullLoader) |
|||
|
|||
if config.get("INTERFACE", None): |
|||
ChannelFactoryInitialize(config["DOMAIN_ID"], config["INTERFACE"]) |
|||
else: |
|||
ChannelFactoryInitialize(config["DOMAIN_ID"]) |
|||
|
|||
simulation = BaseSimulator(config) |
|||
simulation.start_as_thread() |
|||
@ -0,0 +1,256 @@ |
|||
import multiprocessing as mp |
|||
from multiprocessing import shared_memory |
|||
import time |
|||
from typing import Any, Dict |
|||
|
|||
import numpy as np |
|||
|
|||
from decoupled_wbc.control.sensor.sensor_server import ImageMessageSchema, SensorServer |
|||
|
|||
|
|||
def get_multiprocessing_info(verbose: bool = True): |
|||
"""Get information about multiprocessing start methods""" |
|||
|
|||
if verbose: |
|||
print(f"Available start methods: {mp.get_all_start_methods()}") |
|||
return mp.get_start_method() |
|||
|
|||
|
|||
class ImagePublishProcess: |
|||
"""Subprocess for publishing images using shared memory and ZMQ""" |
|||
|
|||
def __init__( |
|||
self, |
|||
camera_configs: Dict[str, Any], |
|||
image_dt: float, |
|||
zmq_port: int = 5555, |
|||
start_method: str = "spawn", |
|||
verbose: bool = False, |
|||
): |
|||
self.camera_configs = camera_configs |
|||
self.image_dt = image_dt |
|||
self.zmq_port = zmq_port |
|||
self.verbose = verbose |
|||
self.shared_memory_blocks = {} |
|||
self.shared_memory_info = {} |
|||
self.process = None |
|||
|
|||
# Use specific context to avoid global state pollution |
|||
self.mp_context = mp.get_context(start_method) |
|||
if self.verbose: |
|||
print(f"Using multiprocessing context: {start_method}") |
|||
|
|||
self.stop_event = self.mp_context.Event() |
|||
self.data_ready_event = self.mp_context.Event() |
|||
|
|||
# Ensure events start in correct state |
|||
self.stop_event.clear() |
|||
self.data_ready_event.clear() |
|||
|
|||
if self.verbose: |
|||
print(f"Initial stop_event state: {self.stop_event.is_set()}") |
|||
print(f"Initial data_ready_event state: {self.data_ready_event.is_set()}") |
|||
|
|||
# Calculate shared memory requirements for each camera |
|||
for camera_name, camera_config in camera_configs.items(): |
|||
height = camera_config["height"] |
|||
width = camera_config["width"] |
|||
# RGB image: height * width * 3 (uint8) |
|||
size = height * width * 3 |
|||
|
|||
# Create shared memory block |
|||
shm = shared_memory.SharedMemory(create=True, size=size) |
|||
self.shared_memory_blocks[camera_name] = shm |
|||
self.shared_memory_info[camera_name] = { |
|||
"name": shm.name, |
|||
"size": size, |
|||
"shape": (height, width, 3), |
|||
"dtype": np.uint8, |
|||
} |
|||
|
|||
def start_process(self): |
|||
"""Start the image publishing subprocess""" |
|||
if self.verbose: |
|||
print(f"Starting subprocess with stop_event state: {self.stop_event.is_set()}") |
|||
self.process = self.mp_context.Process( |
|||
target=self._image_publish_worker, |
|||
args=( |
|||
self.shared_memory_info, |
|||
self.image_dt, |
|||
self.zmq_port, |
|||
self.stop_event, |
|||
self.data_ready_event, |
|||
self.verbose, |
|||
), |
|||
) |
|||
self.process.start() |
|||
if self.verbose: |
|||
print(f"Subprocess started, PID: {self.process.pid}") |
|||
|
|||
def update_shared_memory(self, render_caches: Dict[str, np.ndarray]): |
|||
"""Update shared memory with new rendered images""" |
|||
images_updated = 0 |
|||
for camera_name in self.camera_configs.keys(): |
|||
image_key = f"{camera_name}_image" |
|||
if image_key in render_caches: |
|||
image = render_caches[image_key] |
|||
|
|||
# Ensure image is uint8 and has correct shape |
|||
if image.dtype != np.uint8: |
|||
image = (image * 255).astype(np.uint8) |
|||
|
|||
# Get shared memory array |
|||
shm = self.shared_memory_blocks[camera_name] |
|||
shared_array = np.ndarray( |
|||
self.shared_memory_info[camera_name]["shape"], |
|||
dtype=self.shared_memory_info[camera_name]["dtype"], |
|||
buffer=shm.buf, |
|||
) |
|||
|
|||
# Copy image data to shared memory atomically |
|||
np.copyto(shared_array, image) |
|||
images_updated += 1 |
|||
|
|||
# Signal that new data is ready only after all images are written |
|||
if images_updated > 0: |
|||
if self.verbose: |
|||
print(f"Main process: Updated {images_updated} images, setting data_ready_event") |
|||
self.data_ready_event.set() |
|||
elif self.verbose: |
|||
print( |
|||
"Main process: No images to update. " |
|||
"please check if camera configs are provided and the renderer is properly initialized" |
|||
) |
|||
|
|||
def stop(self): |
|||
"""Stop the image publishing subprocess""" |
|||
if self.verbose: |
|||
print("Stopping image publishing subprocess...") |
|||
self.stop_event.set() |
|||
|
|||
if self.process and self.process.is_alive(): |
|||
# Give the process time to clean up gracefully |
|||
self.process.join(timeout=5) |
|||
if self.process.is_alive(): |
|||
if self.verbose: |
|||
print("Subprocess didn't stop gracefully, terminating...") |
|||
self.process.terminate() |
|||
self.process.join(timeout=2) |
|||
if self.process.is_alive(): |
|||
if self.verbose: |
|||
print("Force killing subprocess...") |
|||
self.process.kill() |
|||
self.process.join() |
|||
|
|||
# Clean up shared memory |
|||
for camera_name, shm in self.shared_memory_blocks.items(): |
|||
try: |
|||
shm.close() |
|||
shm.unlink() |
|||
if self.verbose: |
|||
print(f"Cleaned up shared memory for {camera_name}") |
|||
except Exception as e: |
|||
if self.verbose: |
|||
print(f"Warning: Failed to cleanup shared memory for {camera_name}: {e}") |
|||
|
|||
self.shared_memory_blocks.clear() |
|||
if self.verbose: |
|||
print("Image publishing subprocess stopped and cleaned up") |
|||
|
|||
@staticmethod |
|||
def _image_publish_worker( |
|||
shared_memory_info, image_dt, zmq_port, stop_event, data_ready_event, verbose |
|||
): |
|||
"""Worker function that runs in the subprocess""" |
|||
if verbose: |
|||
print(f"Worker started! PID: {__import__('os').getpid()}") |
|||
print(f"Worker stop_event state at start: {stop_event.is_set()}") |
|||
print(f"Worker data_ready_event state at start: {data_ready_event.is_set()}") |
|||
|
|||
try: |
|||
# Initialize ZMQ sensor server |
|||
sensor_server = SensorServer() |
|||
sensor_server.start_server(port=zmq_port) |
|||
|
|||
# Connect to shared memory blocks |
|||
shared_arrays = {} |
|||
shm_blocks = {} |
|||
for camera_name, info in shared_memory_info.items(): |
|||
shm = shared_memory.SharedMemory(name=info["name"]) |
|||
shm_blocks[camera_name] = shm |
|||
shared_arrays[camera_name] = np.ndarray( |
|||
info["shape"], dtype=info["dtype"], buffer=shm.buf |
|||
) |
|||
|
|||
print( |
|||
f"Image publishing subprocess started with {len(shared_arrays)} cameras on ZMQ port {zmq_port}" |
|||
) |
|||
|
|||
loop_count = 0 |
|||
last_data_time = time.time() |
|||
|
|||
while not stop_event.is_set(): |
|||
loop_count += 1 |
|||
|
|||
# Wait for new data with shorter timeout for better responsiveness |
|||
timeout = min(image_dt, 0.1) # Max 100ms timeout |
|||
data_available = data_ready_event.wait(timeout=timeout) |
|||
|
|||
current_time = time.time() |
|||
|
|||
if data_available: |
|||
data_ready_event.clear() |
|||
if loop_count % 50 == 0: |
|||
print("Image publish frequency: ", 1 / (current_time - last_data_time)) |
|||
last_data_time = current_time |
|||
|
|||
# Collect all camera images and serialize them |
|||
try: |
|||
from decoupled_wbc.control.sensor.sensor_server import ImageUtils |
|||
|
|||
# Copy all images atomically at once |
|||
image_copies = {name: arr.copy() for name, arr in shared_arrays.items()} |
|||
|
|||
# Create message with all camera images |
|||
message_dict = { |
|||
"images": image_copies, |
|||
"timestamps": {name: current_time for name in image_copies.keys()}, |
|||
} |
|||
|
|||
# Create ImageMessageSchema and serialize |
|||
image_msg = ImageMessageSchema( |
|||
timestamps=message_dict.get("timestamps"), |
|||
images=message_dict.get("images", None), |
|||
) |
|||
|
|||
# Serialize and send via ZMQ |
|||
serialized_data = image_msg.serialize() |
|||
|
|||
# Add individual camera images to the message |
|||
for camera_name, image_copy in image_copies.items(): |
|||
serialized_data[f"{camera_name}"] = ImageUtils.encode_image(image_copy) |
|||
|
|||
sensor_server.send_message(serialized_data) |
|||
|
|||
except Exception as e: |
|||
print(f"Error publishing images: {e}") |
|||
|
|||
elif verbose and loop_count % 10 == 0: |
|||
print(f"Subprocess: Still waiting for data... (iteration {loop_count})") |
|||
|
|||
# Small sleep to prevent busy waiting when no data |
|||
if not data_available: |
|||
time.sleep(0.001) |
|||
|
|||
except KeyboardInterrupt: |
|||
print("Image publisher interrupted by user") |
|||
finally: |
|||
# Clean up |
|||
try: |
|||
for shm in shm_blocks.values(): |
|||
shm.close() |
|||
sensor_server.stop_server() |
|||
except Exception as e: |
|||
print(f"Error during subprocess cleanup: {e}") |
|||
if verbose: |
|||
print("Image publish subprocess stopped") |
|||
@ -0,0 +1,71 @@ |
|||
from typing import List, Tuple |
|||
|
|||
import mujoco |
|||
|
|||
from decoupled_wbc.control.envs.g1.sim.sim_utilts import get_body_geom_ids |
|||
|
|||
|
|||
def check_contact( |
|||
mj_model: mujoco.MjModel, |
|||
mj_data: mujoco.MjData, |
|||
bodies_1: List[str] | str, |
|||
bodies_2: List[str] | str, |
|||
return_all_contact_bodies: bool = False, |
|||
) -> Tuple[bool, List[Tuple[str, str]]] | bool: |
|||
""" |
|||
Finds contact between two body groups. Any geom in the body is considered to be in contact. |
|||
Args: |
|||
mj_model (MujocoModel): Current simulation object |
|||
mj_data (MjData): Current simulation data |
|||
bodies_1 (str or list of int): an individual body name or list of body names. |
|||
bodies_2 (str or list of int): another individual body name or list of body names. |
|||
Returns: |
|||
bool: True if any body in @bodies_1 is in contact with any body in @bodies_2. |
|||
""" |
|||
if isinstance(bodies_1, str): |
|||
bodies_1 = [bodies_1] |
|||
if isinstance(bodies_2, str): |
|||
bodies_2 = [bodies_2] |
|||
|
|||
geoms_1 = [get_body_geom_ids(mj_model, mj_model.body(g).id) for g in bodies_1] |
|||
geoms_1 = [g for geom_list in geoms_1 for g in geom_list] |
|||
geoms_2 = [get_body_geom_ids(mj_model, mj_model.body(g).id) for g in bodies_2] |
|||
geoms_2 = [g for geom_list in geoms_2 for g in geom_list] |
|||
contact_bodies = [] |
|||
for i in range(mj_data.ncon): |
|||
contact = mj_data.contact[i] |
|||
# check contact geom in geoms |
|||
c1_in_g1 = contact.geom1 in geoms_1 |
|||
c2_in_g2 = contact.geom2 in geoms_2 if geoms_2 is not None else True |
|||
# check contact geom in geoms (flipped) |
|||
c2_in_g1 = contact.geom2 in geoms_1 |
|||
c1_in_g2 = contact.geom1 in geoms_2 if geoms_2 is not None else True |
|||
if (c1_in_g1 and c2_in_g2) or (c1_in_g2 and c2_in_g1): |
|||
contact_bodies.append( |
|||
( |
|||
mj_model.body(mj_model.geom(contact.geom1).bodyid).name, |
|||
mj_model.body(mj_model.geom(contact.geom2).bodyid).name, |
|||
) |
|||
) |
|||
if not return_all_contact_bodies: |
|||
break |
|||
if return_all_contact_bodies: |
|||
return len(contact_bodies) > 0, set(contact_bodies) |
|||
else: |
|||
return len(contact_bodies) > 0 |
|||
|
|||
|
|||
def check_height( |
|||
mj_model: mujoco.MjModel, |
|||
mj_data: mujoco.MjData, |
|||
geom_name: str, |
|||
lower_bound: float = -float("inf"), |
|||
upper_bound: float = float("inf"), |
|||
): |
|||
""" |
|||
Checks if the height of a geom is greater than a given height. |
|||
""" |
|||
geom_id = mj_model.geom(geom_name).id |
|||
return ( |
|||
mj_data.geom_xpos[geom_id][2] < upper_bound and mj_data.geom_xpos[geom_id][2] > lower_bound |
|||
) |
|||
@ -0,0 +1,63 @@ |
|||
from typing import Any, Dict, Tuple |
|||
|
|||
from unitree_sdk2py.core.channel import ChannelFactoryInitialize |
|||
|
|||
from decoupled_wbc.control.envs.g1.sim.unitree_sdk2py_bridge import UnitreeSdk2Bridge |
|||
from decoupled_wbc.control.envs.robocasa.async_env_server import RoboCasaEnvServer |
|||
from decoupled_wbc.control.robot_model.instantiation import get_robot_type_and_model |
|||
|
|||
|
|||
class RoboCasaG1EnvServer(RoboCasaEnvServer): |
|||
def __init__( |
|||
self, env_name: str, wbc_config: Dict[str, Any], env_kwargs: Dict[str, Any], **kwargs |
|||
): |
|||
if UnitreeSdk2Bridge is None: |
|||
raise ImportError("UnitreeSdk2Bridge is required for RoboCasaG1EnvServer") |
|||
self.wbc_config = wbc_config |
|||
_, robot_model = get_robot_type_and_model( |
|||
"G1", |
|||
enable_waist_ik=wbc_config["enable_waist"], |
|||
) |
|||
if env_kwargs.get("camera_names", None) is None: |
|||
env_kwargs["camera_names"] = [ |
|||
"robot0_oak_egoview", |
|||
"robot0_oak_left_monoview", |
|||
"robot0_oak_right_monoview", |
|||
"robot0_rs_tppview", |
|||
] |
|||
if env_kwargs.get("render_camera", None) is None: |
|||
if env_kwargs.get("renderer", "mjviewer") == "mjviewer": |
|||
env_kwargs["render_camera"] = "robot0_oak_egoview" |
|||
else: |
|||
env_kwargs["render_camera"] = [ |
|||
"robot0_oak_egoview", |
|||
"robot0_rs_tppview", |
|||
] |
|||
|
|||
super().__init__(env_name, "G1", robot_model, env_kwargs=env_kwargs, **kwargs) |
|||
|
|||
def init_channel(self): |
|||
|
|||
try: |
|||
if self.wbc_config.get("INTERFACE", None): |
|||
ChannelFactoryInitialize(self.wbc_config["DOMAIN_ID"], self.wbc_config["INTERFACE"]) |
|||
else: |
|||
ChannelFactoryInitialize(self.wbc_config["DOMAIN_ID"]) |
|||
except Exception: |
|||
# If it fails because it's already initialized, that's okay |
|||
pass |
|||
|
|||
self.channel_bridge = UnitreeSdk2Bridge(config=self.wbc_config) |
|||
|
|||
def publish_obs(self): |
|||
# with self.cache_lock: |
|||
obs = self.caches["obs"] |
|||
self.channel_bridge.PublishLowState(obs) |
|||
|
|||
def get_action(self) -> Tuple[Dict[str, Any], bool, bool]: |
|||
q, ready, is_new_action = self.channel_bridge.GetAction() |
|||
return {"q": q}, ready, is_new_action |
|||
|
|||
def reset(self): |
|||
super().reset() |
|||
self.channel_bridge.reset() |
|||
@ -0,0 +1,144 @@ |
|||
import time |
|||
from typing import Any, Dict |
|||
|
|||
from unitree_sdk2py.core.channel import ChannelFactoryInitialize |
|||
|
|||
from decoupled_wbc.control.envs.g1.sim.base_sim import BaseSimulator |
|||
|
|||
|
|||
def init_channel(config: Dict[str, Any]) -> None: |
|||
""" |
|||
Initialize the communication channel for simulator/robot communication. |
|||
|
|||
Args: |
|||
config: Configuration dictionary containing DOMAIN_ID and optionally INTERFACE |
|||
""" |
|||
if config.get("INTERFACE", None): |
|||
ChannelFactoryInitialize(config["DOMAIN_ID"], config["INTERFACE"]) |
|||
else: |
|||
ChannelFactoryInitialize(config["DOMAIN_ID"]) |
|||
|
|||
|
|||
class SimulatorFactory: |
|||
"""Factory class for creating different types of simulators.""" |
|||
|
|||
@staticmethod |
|||
def create_simulator(config: Dict[str, Any], env_name: str = "default", **kwargs): |
|||
""" |
|||
Create a simulator based on the configuration. |
|||
|
|||
Args: |
|||
config: Configuration dictionary containing SIMULATOR type |
|||
env_name: Environment name |
|||
**kwargs: Additional keyword arguments for specific simulators |
|||
""" |
|||
simulator_type = config.get("SIMULATOR", "mujoco") |
|||
if simulator_type == "mujoco": |
|||
return SimulatorFactory._create_mujoco_simulator(config, env_name, **kwargs) |
|||
elif simulator_type == "robocasa": |
|||
return SimulatorFactory._create_robocasa_simulator(config, env_name, **kwargs) |
|||
else: |
|||
print( |
|||
f"Warning: Invalid simulator type: {simulator_type}. " |
|||
"If you are using run_sim_loop, please ignore this warning." |
|||
) |
|||
return None |
|||
|
|||
@staticmethod |
|||
def _create_mujoco_simulator(config: Dict[str, Any], env_name: str = "default", **kwargs): |
|||
"""Create a MuJoCo simulator instance.""" |
|||
env_kwargs = dict( |
|||
onscreen=kwargs.pop("onscreen", True), |
|||
offscreen=kwargs.pop("offscreen", False), |
|||
enable_image_publish=kwargs.get("enable_image_publish", False), |
|||
) |
|||
return BaseSimulator(config=config, env_name=env_name, **env_kwargs) |
|||
|
|||
@staticmethod |
|||
def _create_robocasa_simulator(config: Dict[str, Any], env_name: str = "default", **kwargs): |
|||
"""Create a RoboCasa simulator instance.""" |
|||
from decoupled_wbc.control.envs.g1.sim.robocasa_sim import RoboCasaG1EnvServer |
|||
from decoupled_wbc.control.envs.robocasa.utils.controller_utils import ( |
|||
update_robosuite_controller_configs, |
|||
) |
|||
from decoupled_wbc.control.envs.robocasa.utils.sim_utils import change_simulation_timestep |
|||
|
|||
change_simulation_timestep(config["SIMULATE_DT"]) |
|||
|
|||
# Use default environment if not specified |
|||
if env_name == "default": |
|||
env_name = "GroundOnly" |
|||
|
|||
# Get or create controller configurations |
|||
controller_configs = kwargs.get("controller_configs") |
|||
if controller_configs is None: |
|||
wbc_version = kwargs.get("wbc_version", "gear_wbc") |
|||
controller_configs = update_robosuite_controller_configs("G1", wbc_version) |
|||
|
|||
# Build environment kwargs |
|||
env_kwargs = dict( |
|||
onscreen=kwargs.pop("onscreen", True), |
|||
offscreen=kwargs.pop("offscreen", False), |
|||
camera_names=kwargs.pop("camera_names", None), |
|||
camera_heights=kwargs.pop("camera_heights", None), |
|||
camera_widths=kwargs.pop("camera_widths", None), |
|||
control_freq=kwargs.pop("control_freq", 50), |
|||
controller_configs=controller_configs, |
|||
ik_indicator=kwargs.pop("ik_indicator", False), |
|||
randomize_cameras=kwargs.pop("randomize_cameras", True), |
|||
) |
|||
|
|||
kwargs.update( |
|||
{ |
|||
"verbose": config.pop("verbose", False), |
|||
"sim_freq": 1 / config.pop("SIMULATE_DT"), |
|||
} |
|||
) |
|||
|
|||
return RoboCasaG1EnvServer( |
|||
env_name=env_name, |
|||
wbc_config=config, |
|||
env_kwargs=env_kwargs, |
|||
**kwargs, |
|||
) |
|||
|
|||
@staticmethod |
|||
def start_simulator( |
|||
simulator, |
|||
as_thread: bool = True, |
|||
enable_image_publish: bool = False, |
|||
mp_start_method: str = "spawn", |
|||
camera_port: int = 5555, |
|||
): |
|||
""" |
|||
Start the simulator either as a thread or as a separate process. |
|||
|
|||
Args: |
|||
simulator: The simulator instance to start |
|||
config: Configuration dictionary |
|||
as_thread: If True, start as thread; if False, start as subprocess |
|||
enable_offscreen: If True and not as_thread, start image publishing |
|||
""" |
|||
|
|||
if as_thread: |
|||
simulator.start_as_thread() |
|||
else: |
|||
# Wrap in try-except to make sure simulator is properly closed upon exit. |
|||
try: |
|||
if enable_image_publish: |
|||
simulator.start_image_publish_subprocess( |
|||
start_method=mp_start_method, |
|||
camera_port=camera_port, |
|||
) |
|||
time.sleep(1) |
|||
simulator.start() |
|||
except KeyboardInterrupt: |
|||
print("+++++Simulator interrupted by user.") |
|||
except Exception as e: |
|||
print(f"++++error in simulator: {e} ++++") |
|||
finally: |
|||
print("++++closing simulator ++++") |
|||
simulator.close() |
|||
|
|||
# Allow simulator to initialize |
|||
time.sleep(1) |
|||
@ -0,0 +1,534 @@ |
|||
"""Joint safety monitor for G1 robot. |
|||
|
|||
This module implements safety monitoring for arm and finger joint velocities using |
|||
joint groups defined in the robot model's supplemental info. Leg joints are not monitored. |
|||
""" |
|||
|
|||
from datetime import datetime |
|||
import sys |
|||
import time |
|||
from typing import Dict, List, Optional, Tuple |
|||
|
|||
import numpy as np |
|||
|
|||
from decoupled_wbc.data.viz.rerun_viz import RerunViz |
|||
|
|||
|
|||
class JointSafetyMonitor: |
|||
"""Monitor joint velocities for G1 robot arms and hands.""" |
|||
|
|||
# Velocity limits in rad/s |
|||
ARM_VELOCITY_LIMIT = 6.0 # rad/s for arm joints |
|||
HAND_VELOCITY_LIMIT = 50.0 # rad/s for finger joints |
|||
|
|||
def __init__(self, robot_model, enable_viz: bool = False, env_type: str = "real"): |
|||
"""Initialize joint safety monitor. |
|||
|
|||
Args: |
|||
robot_model: The robot model containing joint information |
|||
enable_viz: If True, enable rerun visualization (default False) |
|||
env_type: Environment type - "sim" or "real" (default "real") |
|||
""" |
|||
self.robot_model = robot_model |
|||
self.safety_margin = 1.0 # Hardcoded safety margin |
|||
self.enable_viz = enable_viz |
|||
self.env_type = env_type |
|||
|
|||
# Startup ramping parameters |
|||
self.control_frequency = 50 # Hz, hardcoded from run_g1_control_loop.py |
|||
self.ramp_duration_steps = int(2.0 * self.control_frequency) # 2 seconds * 50Hz = 100 steps |
|||
self.startup_counter = 0 |
|||
self.initial_positions = None |
|||
self.startup_complete = False |
|||
|
|||
# Initialize velocity and position limits for monitored joints |
|||
self.velocity_limits = {} |
|||
self.position_limits = {} |
|||
self._initialize_limits() |
|||
|
|||
# Track violations for reporting |
|||
self.violations = [] |
|||
|
|||
# Initialize visualization |
|||
self.right_arm_indices = None |
|||
self.right_arm_joint_names = [] |
|||
self.left_arm_indices = None |
|||
self.left_arm_joint_names = [] |
|||
self.right_hand_indices = None |
|||
self.right_hand_joint_names = [] |
|||
self.left_hand_indices = None |
|||
self.left_hand_joint_names = [] |
|||
try: |
|||
arm_indices = self.robot_model.get_joint_group_indices("arms") |
|||
all_joint_names = [self.robot_model.joint_names[i] for i in arm_indices] |
|||
# Filter for right and left arm joints |
|||
self.right_arm_joint_names = [ |
|||
name for name in all_joint_names if name.startswith("right_") |
|||
] |
|||
self.right_arm_indices = [ |
|||
self.robot_model.joint_to_dof_index[name] for name in self.right_arm_joint_names |
|||
] |
|||
self.left_arm_joint_names = [ |
|||
name for name in all_joint_names if name.startswith("left_") |
|||
] |
|||
self.left_arm_indices = [ |
|||
self.robot_model.joint_to_dof_index[name] for name in self.left_arm_joint_names |
|||
] |
|||
# Hand joints |
|||
hand_indices = self.robot_model.get_joint_group_indices("hands") |
|||
all_hand_names = [self.robot_model.joint_names[i] for i in hand_indices] |
|||
self.right_hand_joint_names = [ |
|||
name for name in all_hand_names if name.startswith("right_") |
|||
] |
|||
self.right_hand_indices = [ |
|||
self.robot_model.joint_to_dof_index[name] for name in self.right_hand_joint_names |
|||
] |
|||
self.left_hand_joint_names = [ |
|||
name for name in all_hand_names if name.startswith("left_") |
|||
] |
|||
self.left_hand_indices = [ |
|||
self.robot_model.joint_to_dof_index[name] for name in self.left_hand_joint_names |
|||
] |
|||
except ValueError as e: |
|||
print(f"[JointSafetyMonitor] Warning: Could not initialize arm/hand visualization: {e}") |
|||
except Exception: |
|||
pass |
|||
|
|||
# Use single tensor_key for each plot |
|||
self.right_arm_pos_key = "right_arm_qpos" |
|||
self.left_arm_pos_key = "left_arm_qpos" |
|||
self.right_arm_vel_key = "right_arm_dq" |
|||
self.left_arm_vel_key = "left_arm_dq" |
|||
self.right_hand_pos_key = "right_hand_qpos" |
|||
self.left_hand_pos_key = "left_hand_qpos" |
|||
self.right_hand_vel_key = "right_hand_dq" |
|||
self.left_hand_vel_key = "left_hand_dq" |
|||
|
|||
# Define a consistent color palette for up to 8 joints (tab10 + extra) |
|||
self.joint_colors = [ |
|||
[31, 119, 180], # blue |
|||
[255, 127, 14], # orange |
|||
[44, 160, 44], # green |
|||
[214, 39, 40], # red |
|||
[148, 103, 189], # purple |
|||
[140, 86, 75], # brown |
|||
[227, 119, 194], # pink |
|||
[127, 127, 127], # gray (for 8th joint if needed) |
|||
] |
|||
|
|||
# Initialize Rerun visualization only if enabled |
|||
self.viz = None |
|||
if self.enable_viz: |
|||
try: |
|||
self.viz = RerunViz( |
|||
image_keys=[], |
|||
tensor_keys=[ |
|||
self.right_arm_pos_key, |
|||
self.left_arm_pos_key, |
|||
self.right_arm_vel_key, |
|||
self.left_arm_vel_key, |
|||
self.right_hand_pos_key, |
|||
self.left_hand_pos_key, |
|||
self.right_hand_vel_key, |
|||
self.left_hand_vel_key, |
|||
], |
|||
window_size=10.0, |
|||
app_name="joint_safety_monitor", |
|||
) |
|||
except Exception: |
|||
self.viz = None |
|||
|
|||
def _initialize_limits(self): |
|||
"""Initialize velocity and position limits for arm and hand joints using robot model joint groups.""" |
|||
if self.robot_model.supplemental_info is None: |
|||
raise ValueError("Robot model must have supplemental_info to use joint groups") |
|||
|
|||
# Get arm joint indices from robot model joint groups |
|||
try: |
|||
arm_indices = self.robot_model.get_joint_group_indices("arms") |
|||
arm_joint_names = [self.robot_model.joint_names[i] for i in arm_indices] |
|||
|
|||
for joint_name in arm_joint_names: |
|||
# Set velocity limits |
|||
vel_limit = self.ARM_VELOCITY_LIMIT * self.safety_margin |
|||
self.velocity_limits[joint_name] = {"min": -vel_limit, "max": vel_limit} |
|||
|
|||
# Set position limits from robot model |
|||
if joint_name in self.robot_model.joint_to_dof_index: |
|||
joint_idx = self.robot_model.joint_to_dof_index[joint_name] |
|||
# Adjust index for floating base if present |
|||
limit_idx = joint_idx - (7 if self.robot_model.is_floating_base_model else 0) |
|||
|
|||
if 0 <= limit_idx < len(self.robot_model.lower_joint_limits): |
|||
pos_min = self.robot_model.lower_joint_limits[limit_idx] |
|||
pos_max = self.robot_model.upper_joint_limits[limit_idx] |
|||
|
|||
# Apply safety margin to position limits |
|||
pos_range = pos_max - pos_min |
|||
margin = pos_range * (1.0 - self.safety_margin) / 2.0 |
|||
|
|||
self.position_limits[joint_name] = { |
|||
"min": pos_min + margin, |
|||
"max": pos_max - margin, |
|||
} |
|||
except ValueError as e: |
|||
print(f"[JointSafetyMonitor] Warning: Could not find 'arms' joint group: {e}") |
|||
|
|||
# Get hand joint indices from robot model joint groups |
|||
try: |
|||
hand_indices = self.robot_model.get_joint_group_indices("hands") |
|||
hand_joint_names = [self.robot_model.joint_names[i] for i in hand_indices] |
|||
|
|||
for joint_name in hand_joint_names: |
|||
# Set velocity limits only for hands (no position limits for now) |
|||
vel_limit = self.HAND_VELOCITY_LIMIT * self.safety_margin |
|||
self.velocity_limits[joint_name] = {"min": -vel_limit, "max": vel_limit} |
|||
except ValueError as e: |
|||
print(f"[JointSafetyMonitor] Warning: Could not find 'hands' joint group: {e}") |
|||
|
|||
def check_safety(self, obs: Dict, action: Dict) -> Tuple[bool, List[Dict]]: |
|||
"""Check if current velocities and positions are within safe bounds. |
|||
|
|||
Args: |
|||
obs: Observation dictionary containing joint positions and velocities |
|||
action: Action dictionary containing target positions |
|||
|
|||
Returns: |
|||
(is_safe, violations): Tuple of safety status and list of violations |
|||
Note: is_safe=False only for velocity violations (triggers shutdown) |
|||
Position violations are warnings only (don't affect is_safe) |
|||
""" |
|||
self.violations = [] |
|||
is_safe = True |
|||
joint_names = self.robot_model.joint_names |
|||
|
|||
# Check current joint velocities (critical - triggers shutdown) |
|||
if "dq" in obs: |
|||
joint_velocities = obs["dq"] |
|||
|
|||
for i, joint_name in enumerate(joint_names): |
|||
# Only check monitored joints |
|||
if joint_name not in self.velocity_limits: |
|||
continue |
|||
|
|||
if i < len(joint_velocities): |
|||
velocity = joint_velocities[i] |
|||
limits = self.velocity_limits[joint_name] |
|||
|
|||
if velocity < limits["min"] or velocity > limits["max"]: |
|||
violation = { |
|||
"joint": joint_name, |
|||
"type": "velocity", |
|||
"value": velocity, |
|||
"limit_min": limits["min"], |
|||
"limit_max": limits["max"], |
|||
"exceeded_by": self._calculate_exceeded_percentage( |
|||
velocity, limits["min"], limits["max"] |
|||
), |
|||
"critical": True, # Velocity violations are critical |
|||
} |
|||
self.violations.append(violation) |
|||
is_safe = False |
|||
|
|||
# Check current joint positions (warning only - no shutdown) |
|||
if "q" in obs: |
|||
joint_positions = obs["q"] |
|||
|
|||
for i, joint_name in enumerate(joint_names): |
|||
# Only check joints with position limits (arms) |
|||
if joint_name not in self.position_limits: |
|||
continue |
|||
|
|||
if i < len(joint_positions): |
|||
position = joint_positions[i] |
|||
limits = self.position_limits[joint_name] |
|||
|
|||
if position < limits["min"] or position > limits["max"]: |
|||
violation = { |
|||
"joint": joint_name, |
|||
"type": "position", |
|||
"value": position, |
|||
"limit_min": limits["min"], |
|||
"limit_max": limits["max"], |
|||
"exceeded_by": self._calculate_exceeded_percentage( |
|||
position, limits["min"], limits["max"] |
|||
), |
|||
"critical": False, # Position violations are warnings only |
|||
} |
|||
self.violations.append(violation) |
|||
# Don't set is_safe = False for position violations |
|||
|
|||
return is_safe, self.violations |
|||
|
|||
def _calculate_exceeded_percentage( |
|||
self, value: float, limit_min: float, limit_max: float |
|||
) -> float: |
|||
"""Calculate by how much percentage a value exceeds the limits.""" |
|||
if value < limit_min: |
|||
return abs((value - limit_min) / limit_min) * 100 |
|||
elif value > limit_max: |
|||
return abs((value - limit_max) / limit_max) * 100 |
|||
return 0.0 |
|||
|
|||
def get_safe_action(self, obs: Dict, original_action: Dict) -> Dict: |
|||
"""Generate a safe action with startup ramping for smooth initialization. |
|||
|
|||
Args: |
|||
obs: Observation dictionary containing current joint positions |
|||
original_action: The original action that may cause violations |
|||
|
|||
Returns: |
|||
Safe action with startup ramping applied if within ramp duration |
|||
""" |
|||
safe_action = original_action.copy() |
|||
|
|||
# Handle startup ramping for arm joints |
|||
if not self.startup_complete: |
|||
if self.initial_positions is None and "q" in obs: |
|||
# Store initial positions from first observation |
|||
self.initial_positions = obs["q"].copy() |
|||
|
|||
if ( |
|||
self.startup_counter < self.ramp_duration_steps |
|||
and self.initial_positions is not None |
|||
and "q" in safe_action |
|||
): |
|||
# Ramp factor: 0.0 at start → 1.0 at end |
|||
ramp_factor = self.startup_counter / self.ramp_duration_steps |
|||
|
|||
# Apply ramping only to monitored arm joints |
|||
for joint_name in self.velocity_limits: # Only monitored arm joints |
|||
if joint_name in self.robot_model.joint_to_dof_index: |
|||
joint_idx = self.robot_model.joint_to_dof_index[joint_name] |
|||
if joint_idx < len(safe_action["q"]) and joint_idx < len( |
|||
self.initial_positions |
|||
): |
|||
initial_pos = self.initial_positions[joint_idx] |
|||
target_pos = original_action["q"][joint_idx] |
|||
# Linear interpolation: initial + ramp_factor * (target - initial) |
|||
safe_action["q"][joint_idx] = initial_pos + ramp_factor * ( |
|||
target_pos - initial_pos |
|||
) |
|||
|
|||
# Increment counter for next iteration |
|||
self.startup_counter += 1 |
|||
else: |
|||
# Ramping complete - use original actions |
|||
self.startup_complete = True |
|||
|
|||
return safe_action |
|||
|
|||
def get_violation_report(self, violations: Optional[List[Dict]] = None) -> str: |
|||
"""Generate a formatted error report for violations. |
|||
|
|||
Args: |
|||
violations: List of violations to report (uses self.violations if None) |
|||
|
|||
Returns: |
|||
Formatted error message string |
|||
""" |
|||
if violations is None: |
|||
violations = self.violations |
|||
|
|||
if not violations: |
|||
return "No violations detected." |
|||
|
|||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3] |
|||
|
|||
# Check if these are critical violations or warnings |
|||
critical_violations = [v for v in violations if v.get("critical", True)] |
|||
warning_violations = [v for v in violations if not v.get("critical", True)] |
|||
|
|||
if critical_violations and warning_violations: |
|||
report = f"Joint safety bounds exceeded!\nTimestamp: {timestamp}\nViolations:\n" |
|||
elif critical_violations: |
|||
report = f"Joint safety bounds exceeded!\nTimestamp: {timestamp}\nViolations:\n" |
|||
else: |
|||
report = f"Joint position warnings!\nTimestamp: {timestamp}\nWarnings:\n" |
|||
|
|||
for violation in violations: |
|||
joint = violation["joint"] |
|||
vtype = violation["type"] |
|||
value = violation["value"] |
|||
exceeded = violation["exceeded_by"] |
|||
limit_min = violation["limit_min"] |
|||
limit_max = violation["limit_max"] |
|||
|
|||
if vtype == "velocity": |
|||
report += f" - {joint}: {vtype}={value:.3f} rad/s " |
|||
report += f"(limit: ±{limit_max:.3f} rad/s) - " |
|||
report += f"EXCEEDED by {exceeded:.1f}%\n" |
|||
elif vtype == "position": |
|||
report += f" - {joint}: {vtype}={value:.3f} rad " |
|||
report += f"(limits: [{limit_min:.3f}, {limit_max:.3f}] rad) - " |
|||
report += f"EXCEEDED by {exceeded:.1f}%\n" |
|||
|
|||
# Add appropriate action message |
|||
if critical_violations: |
|||
report += "Action: Safe mode engaged (kp=0, tau=0). System shutdown initiated.\n" |
|||
report += "Please restart Docker container to resume operation." |
|||
else: |
|||
report += "Action: Position warning only. Robot continues operation." |
|||
|
|||
return report |
|||
|
|||
def handle_violations(self, obs: Dict, action: Dict) -> Dict: |
|||
"""Check safety and handle violations appropriately. |
|||
|
|||
Args: |
|||
obs: Observation dictionary |
|||
action: Action dictionary |
|||
|
|||
Returns: |
|||
Dict with keys: |
|||
- 'safe_to_continue': bool - whether robot should continue operation |
|||
- 'action': Dict - potentially modified safe action |
|||
- 'shutdown_required': bool - whether system shutdown is needed |
|||
""" |
|||
is_safe, violations = self.check_safety(obs, action) |
|||
|
|||
# Apply startup ramping (always, regardless of violations) |
|||
safe_action = self.get_safe_action(obs, action) |
|||
|
|||
# Visualize arm and hand joint positions and velocities if enabled |
|||
if self.enable_viz: |
|||
if ( |
|||
self.right_arm_indices is not None |
|||
and self.left_arm_indices is not None |
|||
and self.right_hand_indices is not None |
|||
and self.left_hand_indices is not None |
|||
and "q" in obs |
|||
and "dq" in obs |
|||
and self.viz is not None |
|||
): |
|||
try: |
|||
right_arm_positions = obs["q"][self.right_arm_indices] |
|||
left_arm_positions = obs["q"][self.left_arm_indices] |
|||
right_arm_velocities = obs["dq"][self.right_arm_indices] |
|||
left_arm_velocities = obs["dq"][self.left_arm_indices] |
|||
right_hand_positions = obs["q"][self.right_hand_indices] |
|||
left_hand_positions = obs["q"][self.left_hand_indices] |
|||
right_hand_velocities = obs["dq"][self.right_hand_indices] |
|||
left_hand_velocities = obs["dq"][self.left_hand_indices] |
|||
tensor_dict = { |
|||
self.right_arm_pos_key: right_arm_positions, |
|||
self.left_arm_pos_key: left_arm_positions, |
|||
self.right_arm_vel_key: right_arm_velocities, |
|||
self.left_arm_vel_key: left_arm_velocities, |
|||
self.right_hand_pos_key: right_hand_positions, |
|||
self.left_hand_pos_key: left_hand_positions, |
|||
self.right_hand_vel_key: right_hand_velocities, |
|||
self.left_hand_vel_key: left_hand_velocities, |
|||
} |
|||
self.viz.plot_tensors(tensor_dict, time.time()) |
|||
except Exception: |
|||
pass |
|||
|
|||
if not violations: |
|||
return {"safe_to_continue": True, "action": safe_action, "shutdown_required": False} |
|||
|
|||
# Separate critical (velocity) and warning (position) violations |
|||
critical_violations = [v for v in violations if v.get("critical", True)] |
|||
# warning_violations = [v for v in violations if not v.get('critical', True)] |
|||
|
|||
# Print warnings for position violations |
|||
# if warning_violations: |
|||
# warning_msg = self.get_violation_report(warning_violations) |
|||
# print(f"[SAFETY WARNING] {warning_msg}") |
|||
|
|||
# Handle critical violations (velocity) - trigger shutdown |
|||
if not is_safe and critical_violations: |
|||
error_msg = self.get_violation_report(critical_violations) |
|||
if self.env_type == "real": |
|||
print(f"[SAFETY VIOLATION] {error_msg}") |
|||
self.trigger_system_shutdown() |
|||
|
|||
return {"safe_to_continue": False, "action": safe_action, "shutdown_required": True} |
|||
|
|||
# Only position violations - continue with safe action |
|||
return {"safe_to_continue": True, "action": safe_action, "shutdown_required": False} |
|||
|
|||
def trigger_system_shutdown(self): |
|||
"""Trigger system shutdown after safety violation.""" |
|||
print("\n[SAFETY] Initiating system shutdown due to safety violation...") |
|||
sys.exit(1) |
|||
|
|||
|
|||
def main(): |
|||
"""Test the joint safety monitor with joint groups.""" |
|||
print("Testing joint safety monitor with joint groups...") |
|||
|
|||
try: |
|||
from decoupled_wbc.control.robot_model.instantiation.g1 import instantiate_g1_robot_model |
|||
|
|||
# Instantiate robot model |
|||
robot_model = instantiate_g1_robot_model() |
|||
print(f"Robot model created with {len(robot_model.joint_names)} joints") |
|||
|
|||
# Create safety monitor |
|||
safety_monitor = JointSafetyMonitor(robot_model) |
|||
print("Safety monitor created successfully!") |
|||
print(f"Monitoring {len(safety_monitor.velocity_limits)} joints") |
|||
|
|||
# Print monitored joints |
|||
print("\nVelocity limits:") |
|||
for joint_name, limits in safety_monitor.velocity_limits.items(): |
|||
print(f" - {joint_name}: ±{limits['max']:.2f} rad/s") |
|||
|
|||
print(f"\nPosition limits (arms only): {len(safety_monitor.position_limits)} joints") |
|||
for joint_name, limits in safety_monitor.position_limits.items(): |
|||
print(f" - {joint_name}: [{limits['min']:.3f}, {limits['max']:.3f}] rad") |
|||
|
|||
# Test safety checking with safe values |
|||
print("\n--- Testing Safety Checking ---") |
|||
|
|||
# Create mock observation with safe values |
|||
safe_obs = { |
|||
"q": np.zeros(robot_model.num_dofs), # All joints at zero position |
|||
"dq": np.zeros(robot_model.num_dofs), # All joints at zero velocity |
|||
} |
|||
safe_action = {"q": np.zeros(robot_model.num_dofs)} |
|||
|
|||
# Test handle_violations method |
|||
result = safety_monitor.handle_violations(safe_obs, safe_action) |
|||
print( |
|||
f"Safe values test: safe_to_continue={result['safe_to_continue']}, " |
|||
f"shutdown_required={result['shutdown_required']}" |
|||
) |
|||
|
|||
# Test with unsafe velocity |
|||
unsafe_obs = safe_obs.copy() |
|||
unsafe_obs["dq"] = np.zeros(robot_model.num_dofs) |
|||
# Set left shoulder pitch velocity to exceed limit |
|||
left_shoulder_idx = robot_model.dof_index("left_shoulder_pitch_joint") |
|||
unsafe_obs["dq"][left_shoulder_idx] = 6.0 # Exceeds 5.0 rad/s limit |
|||
|
|||
print("\nUnsafe velocity test:") |
|||
result = safety_monitor.handle_violations(unsafe_obs, safe_action) |
|||
print( |
|||
f" safe_to_continue={result['safe_to_continue']}, shutdown_required={result['shutdown_required']}" |
|||
) |
|||
|
|||
# Test with unsafe position only |
|||
unsafe_pos_obs = safe_obs.copy() |
|||
unsafe_pos_obs["q"] = np.zeros(robot_model.num_dofs) |
|||
# Set left shoulder pitch position to exceed limit |
|||
unsafe_pos_obs["q"][left_shoulder_idx] = -4.0 # Exceeds lower limit of -3.089 |
|||
|
|||
print("\nUnsafe position test:") |
|||
result = safety_monitor.handle_violations(unsafe_pos_obs, safe_action) |
|||
print( |
|||
f" safe_to_continue={result['safe_to_continue']}, shutdown_required={result['shutdown_required']}" |
|||
) |
|||
|
|||
print("\nAll tests completed successfully!") |
|||
|
|||
except Exception as e: |
|||
print(f"Test failed with error: {e}") |
|||
import traceback |
|||
|
|||
traceback.print_exc() |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
main() |
|||
@ -0,0 +1,305 @@ |
|||
from abc import abstractmethod |
|||
import threading |
|||
import time |
|||
from typing import Any, Dict, Tuple |
|||
|
|||
import mujoco |
|||
import numpy as np |
|||
import rclpy |
|||
|
|||
from decoupled_wbc.control.envs.g1.sim.image_publish_utils import ImagePublishProcess |
|||
from decoupled_wbc.control.envs.robocasa.utils.robocasa_env import ( |
|||
Gr00tLocomanipRoboCasaEnv, |
|||
) # noqa: F401 |
|||
from decoupled_wbc.control.robot_model.robot_model import RobotModel |
|||
from decoupled_wbc.control.utils.keyboard_dispatcher import KeyboardListenerSubscriber |
|||
|
|||
|
|||
class RoboCasaEnvServer: |
|||
""" |
|||
This class is responsible for running the simulation environment loop in a separate thread. |
|||
It communicates with the main thread via the `publish_obs` and `get_action` methods through `channel_bridge`. |
|||
It will also handle the viewer sync when `onscreen` is True. |
|||
""" |
|||
|
|||
def __init__( |
|||
self, |
|||
env_name: str, |
|||
robot_name: str, |
|||
robot_model: RobotModel, |
|||
env_kwargs: Dict[str, Any], |
|||
**kwargs, |
|||
): |
|||
# initialize environment |
|||
if env_kwargs.get("onscreen", False): |
|||
env_kwargs["onscreen"] = False |
|||
self.onscreen = True # onscreen render in the main thread |
|||
self.render_camera = env_kwargs.get("render_camera", None) |
|||
else: |
|||
self.onscreen = False |
|||
self.env_name = env_name |
|||
self.env = Gr00tLocomanipRoboCasaEnv(env_name, robot_name, robot_model, **env_kwargs) |
|||
self.init_caches() |
|||
self.cache_lock = threading.Lock() |
|||
|
|||
# initialize channel |
|||
self.init_channel() |
|||
|
|||
# initialize ROS2 node |
|||
if not rclpy.ok(): |
|||
rclpy.init() |
|||
self.node = rclpy.create_node("sim_robocasa") |
|||
self.thread = threading.Thread(target=rclpy.spin, args=(self.node,), daemon=True) |
|||
self.thread.start() |
|||
else: |
|||
self.thread = None |
|||
executor = rclpy.get_global_executor() |
|||
self.node = executor.get_nodes()[0] # will only take the first node |
|||
|
|||
self.control_freq = env_kwargs.get("control_freq", 1 / 0.02) |
|||
self.sim_freq = kwargs.get("sim_freq", 1 / 0.005) |
|||
self.control_rate = self.node.create_rate(self.control_freq) |
|||
|
|||
self.running = False |
|||
self.sim_thread = None |
|||
self.sync_lock = threading.Lock() |
|||
|
|||
self.sync_mode = kwargs.get("sync_mode", False) |
|||
self.steps_per_action = kwargs.get("steps_per_action", 1) |
|||
|
|||
self.image_dt = kwargs.get("image_dt", 0.04) |
|||
self.image_publish_process = None |
|||
self.viewer_freq = kwargs.get("viewer_freq", 1 / 0.02) |
|||
self.viewer = None |
|||
|
|||
self.verbose = kwargs.get("verbose", True) |
|||
|
|||
# Initialize keyboard listener for env reset |
|||
self.keyboard_listener = KeyboardListenerSubscriber() |
|||
|
|||
self.reset() |
|||
|
|||
@property |
|||
def base_env(self): |
|||
return self.env.env |
|||
|
|||
def start_image_publish_subprocess(self, start_method: str = "spawn", camera_port: int = 5555): |
|||
"""Initialize image publishing subprocess if cameras are configured""" |
|||
if len(self.env.camera_names) == 0: |
|||
print( |
|||
"Warning: No camera configs provided, image publishing subprocess will not be started" |
|||
) |
|||
return |
|||
|
|||
# Build camera configs from env camera settings |
|||
camera_configs = {} |
|||
for env_cam_name in self.env.camera_names: |
|||
camera_config = self.env.camera_key_mapper.get_camera_config(env_cam_name) |
|||
mapped_cam_name, cam_width, cam_height = camera_config |
|||
camera_configs[mapped_cam_name] = {"height": cam_height, "width": cam_width} |
|||
|
|||
self.image_publish_process = ImagePublishProcess( |
|||
camera_configs=camera_configs, |
|||
image_dt=self.image_dt, |
|||
zmq_port=camera_port, |
|||
start_method=start_method, |
|||
verbose=self.verbose, |
|||
) |
|||
|
|||
self.image_publish_process.start_process() |
|||
|
|||
def update_render_caches(self, obs: Dict[str, Any]): |
|||
"""Update render cache and shared memory for subprocess""" |
|||
if self.image_publish_process is None: |
|||
return |
|||
|
|||
# Extract image observations from obs dict |
|||
render_caches = { |
|||
k: v for k, v in obs.items() if k.endswith("_image") and isinstance(v, np.ndarray) |
|||
} |
|||
|
|||
# Update shared memory if image publishing process is available |
|||
if render_caches: |
|||
self.image_publish_process.update_shared_memory(render_caches) |
|||
|
|||
def init_caches(self): |
|||
self.caches = { |
|||
"obs": None, |
|||
"reward": None, |
|||
"terminated": None, |
|||
"truncated": None, |
|||
"info": None, |
|||
} |
|||
|
|||
def reset(self, **kwargs): |
|||
if self.viewer is not None: |
|||
self.viewer.close() |
|||
|
|||
obs, info = self.env.reset(**kwargs) |
|||
self.caches["obs"] = obs |
|||
self.caches["reward"] = 0 |
|||
self.caches["terminated"] = False |
|||
self.caches["truncated"] = False |
|||
self.caches["info"] = info |
|||
|
|||
# initialize viewer |
|||
if self.onscreen: |
|||
self.viewer = mujoco.viewer.launch_passive( |
|||
self.base_env.sim.model._model, |
|||
self.base_env.sim.data._data, |
|||
show_left_ui=False, |
|||
show_right_ui=False, |
|||
) |
|||
self.viewer.opt.geomgroup[0] = 0 # disable collision visualization |
|||
if self.render_camera is not None: |
|||
self.viewer.cam.type = mujoco.mjtCamera.mjCAMERA_FIXED |
|||
self.viewer.cam.fixedcamid = self.base_env.sim.model._model.cam( |
|||
self.render_camera |
|||
).id |
|||
|
|||
# self.episode_state.reset_state() |
|||
return obs, info |
|||
|
|||
@abstractmethod |
|||
def init_channel(self): |
|||
raise NotImplementedError("init_channel must be implemented by the subclass") |
|||
|
|||
@abstractmethod |
|||
def publish_obs(self): |
|||
raise NotImplementedError("publish_obs must be implemented by the subclass") |
|||
|
|||
@abstractmethod |
|||
def get_action(self) -> Tuple[Dict[str, Any], bool, bool]: |
|||
raise NotImplementedError("get_action must be implemented by the subclass") |
|||
|
|||
def start_as_thread(self): |
|||
"""Start the simulation thread""" |
|||
if self.sim_thread is not None and self.sim_thread.is_alive(): |
|||
return |
|||
|
|||
self.sim_thread = threading.Thread(target=self.start) |
|||
self.sim_thread.daemon = True |
|||
self.sim_thread.start() |
|||
|
|||
def set_sync_mode(self, sync_mode: bool, steps_per_action: int = 4): |
|||
"""Set the sync mode of the environment server""" |
|||
with self.sync_lock: |
|||
self.sync_mode = sync_mode |
|||
self.steps_per_action = steps_per_action |
|||
|
|||
def _check_keyboard_input(self): |
|||
"""Check for keyboard input and handle state transitions""" |
|||
key = self.keyboard_listener.read_msg() |
|||
if key == "k": |
|||
print("\033[1;32m[Sim env]\033[0m Resetting sim environment") |
|||
self.reset() |
|||
|
|||
def start(self): |
|||
"""Function executed by the simulation thread""" |
|||
iter_idx = 0 |
|||
steps_per_cur_action = 0 |
|||
t_start = time.monotonic() |
|||
|
|||
self.running = True |
|||
|
|||
while self.running: |
|||
# Check keyboard input for state transitions |
|||
self._check_keyboard_input() |
|||
|
|||
# Publish observations and get new action |
|||
self.publish_obs() |
|||
action, ready, is_new_action = self.get_action() |
|||
# ready is True if the action is received from the control loop |
|||
# is_new_action is True if the action is new (not the same as the previous action) |
|||
with self.sync_lock: |
|||
sync_mode = self.sync_mode |
|||
max_steps_per_action = self.steps_per_action |
|||
|
|||
# Process action if ready and within step limits |
|||
action_should_apply = ready and ( |
|||
(not sync_mode) or steps_per_cur_action < max_steps_per_action |
|||
) |
|||
if action_should_apply: |
|||
obs, reward, terminated, truncated, info = self.env.step(action) |
|||
with self.cache_lock: |
|||
self.caches["obs"] = obs |
|||
self.caches["reward"] = reward |
|||
self.caches["terminated"] = terminated |
|||
self.caches["truncated"] = truncated |
|||
self.caches["info"] = info |
|||
|
|||
if reward == 1.0 and iter_idx % 50 == 0: |
|||
print("\033[92mTask successful. Can save data now.\033[0m") |
|||
|
|||
iter_idx += 1 |
|||
steps_per_cur_action += 1 |
|||
if self.verbose and sync_mode: |
|||
print("steps_per_cur_action: ", steps_per_cur_action) |
|||
|
|||
# Update render caches at image publishing rate |
|||
if action_should_apply and iter_idx % int(self.image_dt * self.control_freq) == 0: |
|||
with self.cache_lock: |
|||
obs_copy = self.caches["obs"].copy() |
|||
self.update_render_caches(obs_copy) |
|||
|
|||
# Reset step counter for new actions |
|||
if is_new_action: |
|||
steps_per_cur_action = 0 |
|||
|
|||
# Update viewer at specified frequency |
|||
if self.onscreen and iter_idx % (self.control_freq / self.viewer_freq) == 0: |
|||
self.viewer.sync() |
|||
|
|||
# Check if we're meeting the desired control frequency |
|||
if iter_idx % 100 == 0: |
|||
end_time = time.monotonic() |
|||
if self.verbose: |
|||
print( |
|||
f"sim FPS: {100.0 / (end_time - t_start) * (self.sim_freq / self.control_freq)}" |
|||
) |
|||
if (end_time - t_start) > ((110.0 / self.control_freq)): # for tolerance |
|||
print( |
|||
f"Warning: Sim runs at " |
|||
"{100.0/(end_time - t_start) * (self.sim_freq / self.control_freq):.1f}Hz, " |
|||
f"but should run at {self.sim_freq:.1f}Hz" |
|||
) |
|||
t_start = end_time |
|||
|
|||
# reset obj pos every 200 steps |
|||
if iter_idx % 200 == 0: |
|||
if hasattr(self.base_env, "reset_obj_pos"): |
|||
self.base_env.reset_obj_pos() |
|||
|
|||
self.control_rate.sleep() |
|||
|
|||
def get_privileged_obs(self): |
|||
"""Get privileged observation. Should be implemented by subclasses.""" |
|||
obs = {} |
|||
with self.cache_lock: |
|||
if hasattr(self.base_env, "get_privileged_obs_keys"): |
|||
for key in self.base_env.get_privileged_obs_keys(): |
|||
obs[key] = self.caches["obs"][key] |
|||
|
|||
for key in self.caches["obs"].keys(): |
|||
if key.endswith("_image"): |
|||
obs[key] = self.caches["obs"][key] |
|||
|
|||
return obs |
|||
|
|||
def stop(self): |
|||
"""Stop the simulation thread""" |
|||
self.running = False |
|||
if self.sim_thread is not None: |
|||
self.sim_thread.join(timeout=1.0) # Wait for thread to finish with timeout |
|||
self.sim_thread = None |
|||
|
|||
def close(self): |
|||
self.stop() |
|||
if self.image_publish_process is not None: |
|||
self.image_publish_process.stop() |
|||
if self.onscreen: |
|||
self.viewer.close() |
|||
self.env.close() |
|||
|
|||
def get_reward(self): |
|||
return self.base_env.reward() |
|||
@ -0,0 +1,586 @@ |
|||
import sys |
|||
from typing import Any, Dict, Tuple |
|||
|
|||
import gymnasium as gym |
|||
from gymnasium.envs.registration import register |
|||
import numpy as np |
|||
from robocasa.environments.locomanipulation import REGISTERED_LOCOMANIPULATION_ENVS |
|||
from robocasa.models.robots import GR00T_LOCOMANIP_ENVS_ROBOTS |
|||
from robosuite.environments.robot_env import RobotEnv |
|||
from scipy.spatial.transform import Rotation as R |
|||
|
|||
from decoupled_wbc.control.envs.g1.utils.joint_safety import JointSafetyMonitor |
|||
from decoupled_wbc.control.envs.robocasa.utils.controller_utils import ( |
|||
update_robosuite_controller_configs, |
|||
) |
|||
from decoupled_wbc.control.envs.robocasa.utils.robocasa_env import ( # noqa: F401 |
|||
ALLOWED_LANGUAGE_CHARSET, |
|||
Gr00tLocomanipRoboCasaEnv, |
|||
) |
|||
from decoupled_wbc.control.robot_model.instantiation import get_robot_type_and_model |
|||
from decoupled_wbc.control.utils.n1_utils import ( |
|||
prepare_gym_space_for_eval, |
|||
prepare_observation_for_eval, |
|||
) |
|||
from decoupled_wbc.data.constants import RS_VIEW_CAMERA_HEIGHT, RS_VIEW_CAMERA_WIDTH |
|||
|
|||
|
|||
class SyncEnv(gym.Env): |
|||
MAX_MUJOCO_STATE_LEN = 800 |
|||
|
|||
def __init__(self, env_name, robot_name, **kwargs): |
|||
self.env_name = env_name |
|||
self.robot_name = robot_name |
|||
self.onscreen = kwargs.get("onscreen", True) |
|||
self.enable_gravity_compensation = kwargs.pop("enable_gravity_compensation", False) |
|||
self.gravity_compensation_joints = kwargs.pop("gravity_compensation_joints", ["arms"]) |
|||
_, self.robot_model = get_robot_type_and_model( |
|||
robot_name, enable_waist_ik=kwargs.pop("enable_waist", False) |
|||
) |
|||
|
|||
env_kwargs = { |
|||
"onscreen": kwargs.get("onscreen", True), |
|||
"offscreen": kwargs.get("offscreen", False), |
|||
"renderer": kwargs.get("renderer", "mjviewer"), |
|||
"render_camera": kwargs.get("render_camera", "frontview"), |
|||
"camera_names": kwargs.get("camera_names", ["frontview"]), |
|||
"camera_heights": kwargs.get("camera_heights", None), |
|||
"camera_widths": kwargs.get("camera_widths", None), |
|||
"controller_configs": kwargs["controller_configs"], |
|||
"control_freq": kwargs.get("control_freq", 50), |
|||
"translucent_robot": kwargs.get("translucent_robot", True), |
|||
"ik_indicator": kwargs.get("ik_indicator", False), |
|||
"randomize_cameras": kwargs.get("randomize_cameras", True), |
|||
} |
|||
self.env = Gr00tLocomanipRoboCasaEnv( |
|||
env_name, robot_name, robot_model=self.robot_model, **env_kwargs |
|||
) |
|||
self.init_cache() |
|||
|
|||
self.reset() |
|||
|
|||
@property |
|||
def base_env(self) -> RobotEnv: |
|||
return self.env.env |
|||
|
|||
def overwrite_floating_base_action(self, navigate_cmd): |
|||
if self.base_env.robots[0].robot_model.default_base == "FloatingLeggedBase": |
|||
self.env.unwrapped.overridden_floating_base_action = navigate_cmd |
|||
|
|||
def get_mujoco_state_info(self): |
|||
mujoco_state = self.base_env.sim.get_state().flatten() |
|||
assert len(mujoco_state) < SyncEnv.MAX_MUJOCO_STATE_LEN |
|||
padding_width = SyncEnv.MAX_MUJOCO_STATE_LEN - len(mujoco_state) |
|||
padded_mujoco_state = np.pad( |
|||
mujoco_state, (0, padding_width), mode="constant", constant_values=0 |
|||
) |
|||
max_mujoco_state_len = SyncEnv.MAX_MUJOCO_STATE_LEN |
|||
mujoco_state_len = len(mujoco_state) |
|||
mujoco_state = padded_mujoco_state.copy() |
|||
return max_mujoco_state_len, mujoco_state_len, mujoco_state |
|||
|
|||
def reset_to(self, state: Dict[str, Any]) -> Dict[str, Any] | None: |
|||
if hasattr(self.base_env, "reset_to"): |
|||
self.base_env.reset_to(state) |
|||
else: |
|||
# todo: maybe update robosuite to have reset_to() |
|||
env = self.base_env |
|||
if "model_file" in state: |
|||
xml = env.edit_model_xml(state["model_file"]) |
|||
env.reset_from_xml_string(xml) |
|||
env.sim.reset() |
|||
if "states" in state: |
|||
env.sim.set_state_from_flattened(state["states"]) |
|||
env.sim.forward() |
|||
|
|||
obs = self.env.force_update_observation(timestep=0) |
|||
self.cache["obs"] = obs |
|||
return |
|||
|
|||
def get_state(self) -> Dict[str, Any]: |
|||
return self.base_env.get_state() |
|||
|
|||
def is_success(self): |
|||
""" |
|||
Check if the task condition(s) is reached. Should return a dictionary |
|||
{ str: bool } with at least a "task" key for the overall task success, |
|||
and additional optional keys corresponding to other task criteria. |
|||
""" |
|||
# First, try to use the base environment's is_success method if it exists |
|||
if hasattr(self.base_env, "is_success"): |
|||
return self.base_env.is_success() |
|||
|
|||
# Fall back to using _check_success if available |
|||
elif hasattr(self.base_env, "_check_success"): |
|||
succ = self.base_env._check_success() |
|||
if isinstance(succ, dict): |
|||
assert "task" in succ |
|||
return succ |
|||
return {"task": succ} |
|||
|
|||
# If neither method exists, return failure |
|||
else: |
|||
return {"task": False} |
|||
|
|||
def init_cache(self): |
|||
self.cache = { |
|||
"obs": None, |
|||
"reward": None, |
|||
"terminated": None, |
|||
"truncated": None, |
|||
"info": None, |
|||
} |
|||
|
|||
def reset(self, seed=None, options=None) -> Tuple[Dict[str, any], Dict[str, any]]: |
|||
self.init_cache() |
|||
obs, info = self.env.reset(seed=seed, options=options) |
|||
self.cache["obs"] = obs |
|||
self.cache["reward"] = 0 |
|||
self.cache["terminated"] = False |
|||
self.cache["truncated"] = False |
|||
self.cache["info"] = info |
|||
return self.observe(), info |
|||
|
|||
def observe(self) -> Dict[str, any]: |
|||
# Get observations from body and hands |
|||
assert ( |
|||
self.cache["obs"] is not None |
|||
), "Observation cache is not initialized, please reset the environment first" |
|||
raw_obs = self.cache["obs"] |
|||
|
|||
# Body and hand joint measurements come in actuator order, so we need to convert them to joint order |
|||
whole_q = self.robot_model.get_configuration_from_actuated_joints( |
|||
body_actuated_joint_values=raw_obs["body_q"], |
|||
left_hand_actuated_joint_values=raw_obs["left_hand_q"], |
|||
right_hand_actuated_joint_values=raw_obs["right_hand_q"], |
|||
) |
|||
whole_dq = self.robot_model.get_configuration_from_actuated_joints( |
|||
body_actuated_joint_values=raw_obs["body_dq"], |
|||
left_hand_actuated_joint_values=raw_obs["left_hand_dq"], |
|||
right_hand_actuated_joint_values=raw_obs["right_hand_dq"], |
|||
) |
|||
whole_ddq = self.robot_model.get_configuration_from_actuated_joints( |
|||
body_actuated_joint_values=raw_obs["body_ddq"], |
|||
left_hand_actuated_joint_values=raw_obs["left_hand_ddq"], |
|||
right_hand_actuated_joint_values=raw_obs["right_hand_ddq"], |
|||
) |
|||
whole_tau_est = self.robot_model.get_configuration_from_actuated_joints( |
|||
body_actuated_joint_values=raw_obs["body_tau_est"], |
|||
left_hand_actuated_joint_values=raw_obs["left_hand_tau_est"], |
|||
right_hand_actuated_joint_values=raw_obs["right_hand_tau_est"], |
|||
) |
|||
eef_obs = self.get_eef_obs(whole_q) |
|||
|
|||
obs = { |
|||
"q": whole_q, |
|||
"dq": whole_dq, |
|||
"ddq": whole_ddq, |
|||
"tau_est": whole_tau_est, |
|||
"floating_base_pose": raw_obs["floating_base_pose"], |
|||
"floating_base_vel": raw_obs["floating_base_vel"], |
|||
"floating_base_acc": raw_obs["floating_base_acc"], |
|||
"wrist_pose": np.concatenate([eef_obs["left_wrist_pose"], eef_obs["right_wrist_pose"]]), |
|||
} |
|||
|
|||
# Add state keys for model input |
|||
obs = prepare_observation_for_eval(self.robot_model, obs) |
|||
|
|||
obs["annotation.human.task_description"] = raw_obs["language.language_instruction"] |
|||
|
|||
if hasattr(self.base_env, "get_privileged_obs_keys"): |
|||
for key in self.base_env.get_privileged_obs_keys(): |
|||
obs[key] = raw_obs[key] |
|||
|
|||
for key in raw_obs.keys(): |
|||
if key.endswith("_image"): |
|||
obs[key] = raw_obs[key] |
|||
# TODO: add video.key without _image suffix for evaluation, convert to uint8, remove later |
|||
obs[f"video.{key.replace('_image', '')}"] = raw_obs[key] |
|||
return obs |
|||
|
|||
def step( |
|||
self, action: Dict[str, any] |
|||
) -> Tuple[Dict[str, any], float, bool, bool, Dict[str, any]]: |
|||
self.queue_action(action) |
|||
return self.get_step_info() |
|||
|
|||
def get_observation(self): |
|||
return self.base_env._get_observations() # assumes base env is robosuite |
|||
|
|||
def get_step_info(self) -> Tuple[Dict[str, any], float, bool, bool, Dict[str, any]]: |
|||
return ( |
|||
self.observe(), |
|||
self.cache["reward"], |
|||
self.cache["terminated"], |
|||
self.cache["truncated"], |
|||
self.cache["info"], |
|||
) |
|||
|
|||
def convert_q_to_actuated_joint_order(self, q: np.ndarray) -> np.ndarray: |
|||
body_q = self.robot_model.get_body_actuated_joints(q) |
|||
left_hand_q = self.robot_model.get_hand_actuated_joints(q, side="left") |
|||
right_hand_q = self.robot_model.get_hand_actuated_joints(q, side="right") |
|||
whole_q = np.zeros_like(q) |
|||
whole_q[self.robot_model.get_joint_group_indices("body")] = body_q |
|||
whole_q[self.robot_model.get_joint_group_indices("left_hand")] = left_hand_q |
|||
whole_q[self.robot_model.get_joint_group_indices("right_hand")] = right_hand_q |
|||
|
|||
return whole_q |
|||
|
|||
def set_ik_indicator(self, teleop_cmd): |
|||
"""Set the IK indicators for the simulator""" |
|||
if "left_wrist" in teleop_cmd and "right_wrist" in teleop_cmd: |
|||
left_wrist_input_pose = teleop_cmd["left_wrist"] |
|||
right_wrist_input_pose = teleop_cmd["right_wrist"] |
|||
ik_wrapper = self.base_env |
|||
ik_wrapper.set_target_poses_outside_env([left_wrist_input_pose, right_wrist_input_pose]) |
|||
|
|||
def render(self): |
|||
if self.base_env.viewer is not None: |
|||
self.base_env.viewer.update() |
|||
if self.onscreen: |
|||
self.base_env.render() |
|||
|
|||
def queue_action(self, action: Dict[str, any]): |
|||
# action is in pinocchio joint order, we need to convert it to actuator order |
|||
action_q = self.convert_q_to_actuated_joint_order(action["q"]) |
|||
|
|||
# Compute gravity compensation torques if enabled |
|||
tau_q = np.zeros_like(action_q) |
|||
if self.enable_gravity_compensation and self.robot_model is not None: |
|||
try: |
|||
# Get current robot configuration from cache (more efficient than observe()) |
|||
raw_obs = self.cache["obs"] |
|||
|
|||
# Convert from actuated joint order to joint order for Pinocchio |
|||
current_q_joint_order = self.robot_model.get_configuration_from_actuated_joints( |
|||
body_actuated_joint_values=raw_obs["body_q"], |
|||
left_hand_actuated_joint_values=raw_obs["left_hand_q"], |
|||
right_hand_actuated_joint_values=raw_obs["right_hand_q"], |
|||
) |
|||
|
|||
# Compute gravity compensation in joint order using current robot configuration |
|||
gravity_torques_joint_order = self.robot_model.compute_gravity_compensation_torques( |
|||
current_q_joint_order, joint_groups=self.gravity_compensation_joints |
|||
) |
|||
|
|||
# Convert gravity torques to actuated joint order |
|||
gravity_torques_actuated = self.convert_q_to_actuated_joint_order( |
|||
gravity_torques_joint_order |
|||
) |
|||
|
|||
# Add gravity compensation to torques |
|||
tau_q += gravity_torques_actuated |
|||
|
|||
except Exception as e: |
|||
print(f"Error applying gravity compensation in sync_env: {e}") |
|||
|
|||
obs, reward, terminated, truncated, info = self.env.step({"q": action_q, "tau": tau_q}) |
|||
self.cache["obs"] = obs |
|||
self.cache["reward"] = reward |
|||
self.cache["terminated"] = terminated |
|||
self.cache["truncated"] = truncated |
|||
self.cache["info"] = info |
|||
|
|||
def queue_state(self, state: Dict[str, any]): |
|||
# This function is for debugging or cross-playback between sim and real only. |
|||
state_q = self.convert_q_to_actuated_joint_order(state["q"]) |
|||
obs, reward, terminated, truncated, info = self.env.unwrapped.step_only_kinematics( |
|||
{"q": state_q} |
|||
) |
|||
self.cache["obs"] = obs |
|||
self.cache["reward"] = reward |
|||
self.cache["terminated"] = terminated |
|||
self.cache["truncated"] = truncated |
|||
self.cache["info"] = info |
|||
|
|||
@property |
|||
def observation_space(self) -> gym.Space: |
|||
# @todo: check if the low and high bounds are correct for body_obs. |
|||
q_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,)) |
|||
dq_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,)) |
|||
ddq_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,)) |
|||
tau_est_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(self.robot_model.num_dofs,)) |
|||
floating_base_pose_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(7,)) |
|||
floating_base_vel_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(6,)) |
|||
floating_base_acc_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(6,)) |
|||
wrist_pose_space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(7 + 7,)) |
|||
|
|||
obs_space = gym.spaces.Dict( |
|||
{ |
|||
"floating_base_pose": floating_base_pose_space, |
|||
"floating_base_vel": floating_base_vel_space, |
|||
"floating_base_acc": floating_base_acc_space, |
|||
"q": q_space, |
|||
"dq": dq_space, |
|||
"ddq": ddq_space, |
|||
"tau_est": tau_est_space, |
|||
"wrist_pose": wrist_pose_space, |
|||
} |
|||
) |
|||
|
|||
obs_space = prepare_gym_space_for_eval(self.robot_model, obs_space) |
|||
|
|||
obs_space["annotation.human.task_description"] = gym.spaces.Text( |
|||
max_length=256, charset=ALLOWED_LANGUAGE_CHARSET |
|||
) |
|||
|
|||
if hasattr(self.base_env, "get_privileged_obs_keys"): |
|||
for key, shape in self.base_env.get_privileged_obs_keys().items(): |
|||
space = gym.spaces.Box(low=-np.inf, high=np.inf, shape=shape) |
|||
obs_space[key] = space |
|||
|
|||
robocasa_obs_space = self.env.observation_space |
|||
for key in robocasa_obs_space.keys(): |
|||
if key.endswith("_image"): |
|||
space = gym.spaces.Box( |
|||
low=-np.inf, high=np.inf, shape=robocasa_obs_space[key].shape |
|||
) |
|||
obs_space[key] = space |
|||
# TODO: add video.key without _image suffix for evaluation, remove later |
|||
space_uint = gym.spaces.Box(low=0, high=255, shape=space.shape, dtype=np.uint8) |
|||
obs_space[f"video.{key.replace('_image', '')}"] = space_uint |
|||
|
|||
return obs_space |
|||
|
|||
def reset_obj_pos(self): |
|||
# For Tairan's goal-reaching task, a hacky way to reset the object position is needed. |
|||
if hasattr(self.base_env, "reset_obj_pos"): |
|||
self.base_env.reset_obj_pos() |
|||
|
|||
@property |
|||
def action_space(self) -> gym.Space: |
|||
return self.env.action_space |
|||
|
|||
def close(self): |
|||
self.env.close() |
|||
|
|||
def __repr__(self): |
|||
return ( |
|||
f"SyncEnv(env_name={self.env_name}, \n" |
|||
f" observation_space={self.observation_space}, \n" |
|||
f" action_space={self.action_space})" |
|||
) |
|||
|
|||
def get_joint_gains(self): |
|||
controller = self.base_env.robots[0].composite_controller |
|||
|
|||
gains = {} |
|||
key_mapping = { |
|||
"left": "left_arm", |
|||
"right": "right_arm", |
|||
"legs": "legs", |
|||
"torso": "waist", |
|||
"head": "neck", |
|||
} |
|||
for k in controller.part_controllers.keys(): |
|||
if hasattr(controller.part_controllers[k], "kp"): |
|||
if k in key_mapping: |
|||
gains[key_mapping[k]] = controller.part_controllers[k].kp |
|||
else: |
|||
gains[k] = controller.part_controllers[k].kp |
|||
gains.update( |
|||
{ |
|||
"left_hand": self.base_env.sim.model.actuator_gainprm[ |
|||
self.base_env.robots[0]._ref_actuators_indexes_dict["left_gripper"], 0 |
|||
], |
|||
"right_hand": self.base_env.sim.model.actuator_gainprm[ |
|||
self.base_env.robots[0]._ref_actuators_indexes_dict["right_gripper"], 0 |
|||
], |
|||
} |
|||
) |
|||
joint_gains = np.zeros(self.robot_model.num_dofs) |
|||
for k in gains.keys(): |
|||
joint_gains[self.robot_model.get_joint_group_indices(k)] = gains[k] |
|||
return joint_gains |
|||
|
|||
def get_joint_damping(self): |
|||
controller = self.base_env.robots[0].composite_controller |
|||
damping = {} |
|||
key_mapping = { |
|||
"left": "left_arm", |
|||
"right": "right_arm", |
|||
"legs": "legs", |
|||
"torso": "waist", |
|||
"head": "neck", |
|||
} |
|||
for k in controller.part_controllers.keys(): |
|||
if hasattr(controller.part_controllers[k], "kd"): |
|||
if k in key_mapping: |
|||
damping[key_mapping[k]] = controller.part_controllers[k].kd |
|||
else: |
|||
damping[k] = controller.part_controllers[k].kd |
|||
damping.update( |
|||
{ |
|||
"left_hand": -self.base_env.sim.model.actuator_biasprm[ |
|||
self.base_env.robots[0]._ref_actuators_indexes_dict["left_gripper"], 2 |
|||
], |
|||
"right_hand": -self.base_env.sim.model.actuator_biasprm[ |
|||
self.base_env.robots[0]._ref_actuators_indexes_dict["right_gripper"], 2 |
|||
], |
|||
} |
|||
) |
|||
joint_damping = np.zeros(self.robot_model.num_dofs) |
|||
for k in damping.keys(): |
|||
joint_damping[self.robot_model.get_joint_group_indices(k)] = damping[k] |
|||
return joint_damping |
|||
|
|||
def get_eef_obs(self, q: np.ndarray) -> Dict[str, np.ndarray]: |
|||
self.robot_model.cache_forward_kinematics(q) |
|||
eef_obs = {} |
|||
for side in ["left", "right"]: |
|||
wrist_placement = self.robot_model.frame_placement( |
|||
self.robot_model.supplemental_info.hand_frame_names[side] |
|||
) |
|||
wrist_pos, wrist_quat = wrist_placement.translation[:3], R.from_matrix( |
|||
wrist_placement.rotation |
|||
).as_quat(scalar_first=True) |
|||
eef_obs[f"{side}_wrist_pose"] = np.concatenate([wrist_pos, wrist_quat]) |
|||
|
|||
return eef_obs |
|||
|
|||
|
|||
class G1SyncEnv(SyncEnv): |
|||
def __init__( |
|||
self, |
|||
env_name, |
|||
robot_name, |
|||
**kwargs, |
|||
): |
|||
renderer = kwargs.get("renderer", "mjviewer") |
|||
if renderer == "mjviewer": |
|||
default_render_camera = ["robot0_oak_egoview"] |
|||
elif renderer in ["mujoco", "rerun"]: |
|||
default_render_camera = [ |
|||
"robot0_oak_egoview", |
|||
"robot0_oak_left_monoview", |
|||
"robot0_oak_right_monoview", |
|||
] |
|||
else: |
|||
raise NotImplementedError |
|||
default_camera_names = [ |
|||
"robot0_oak_egoview", |
|||
"robot0_oak_left_monoview", |
|||
"robot0_oak_right_monoview", |
|||
] |
|||
default_camera_heights = [ |
|||
RS_VIEW_CAMERA_HEIGHT, |
|||
RS_VIEW_CAMERA_HEIGHT, |
|||
RS_VIEW_CAMERA_HEIGHT, |
|||
] |
|||
default_camera_widths = [ |
|||
RS_VIEW_CAMERA_WIDTH, |
|||
RS_VIEW_CAMERA_WIDTH, |
|||
RS_VIEW_CAMERA_WIDTH, |
|||
] |
|||
|
|||
kwargs.update( |
|||
{ |
|||
"onscreen": kwargs.get("onscreen", True), |
|||
"offscreen": kwargs.get("offscreen", False), |
|||
"render_camera": kwargs.get("render_camera", default_render_camera), |
|||
"camera_names": kwargs.get("camera_names", default_camera_names), |
|||
"camera_heights": kwargs.get("camera_heights", default_camera_heights), |
|||
"camera_widths": kwargs.get("camera_widths", default_camera_widths), |
|||
"translucent_robot": kwargs.get("translucent_robot", False), |
|||
} |
|||
) |
|||
super().__init__(env_name=env_name, robot_name=robot_name, **kwargs) |
|||
|
|||
# Initialize safety monitor (visualization disabled) - G1 specific |
|||
self.safety_monitor = JointSafetyMonitor( |
|||
self.robot_model, |
|||
enable_viz=False, |
|||
env_type="sim", # G1SyncEnv is always simulation |
|||
) |
|||
self.safety_monitor.ramp_duration_steps = 0 |
|||
self.safety_monitor.startup_complete = True |
|||
self.safety_monitor.LOWER_BODY_VELOCITY_LIMIT = ( |
|||
1e10 # disable lower body velocity limits since it impacts WBC |
|||
) |
|||
self.last_safety_ok = True # Track last safety status from queue_action |
|||
|
|||
@property |
|||
def observation_space(self): |
|||
obs_space = super().observation_space |
|||
obs_space["torso_quat"] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(4,)) |
|||
obs_space["torso_ang_vel"] = gym.spaces.Box(low=-np.inf, high=np.inf, shape=(3,)) |
|||
return obs_space |
|||
|
|||
def observe(self): |
|||
obs = super().observe() |
|||
obs["torso_quat"] = self.cache["obs"]["secondary_imu_quat"] |
|||
obs["torso_ang_vel"] = self.cache["obs"]["secondary_imu_vel"][3:6] |
|||
return obs |
|||
|
|||
def queue_action(self, action: Dict[str, any]): |
|||
# Safety check before queuing action |
|||
obs = self.observe() |
|||
safety_result = self.safety_monitor.handle_violations(obs, action) |
|||
action = safety_result["action"] |
|||
# Save safety status for efficient access |
|||
self.last_safety_ok = not safety_result.get("shutdown_required", False) |
|||
# Check if shutdown is required |
|||
if safety_result["shutdown_required"]: |
|||
self.safety_monitor.trigger_system_shutdown() |
|||
|
|||
# Call parent queue_action with potentially modified action |
|||
super().queue_action(action) |
|||
|
|||
def get_joint_safety_status(self) -> bool: |
|||
"""Get current joint safety status from the last queue_action safety check. |
|||
|
|||
Returns: |
|||
bool: True if joints are safe (no shutdown required), False if unsafe |
|||
""" |
|||
return self.last_safety_ok |
|||
|
|||
|
|||
def create_gym_sync_env_class(env, robot, robot_alias, wbc_version): |
|||
class_name = f"{env}_{robot}_{wbc_version}" |
|||
id_name = f"gr00tlocomanip_{robot_alias}/{class_name}" |
|||
|
|||
if robot_alias.startswith("g1"): |
|||
env_class_type = G1SyncEnv |
|||
elif robot_alias.startswith("gr1"): |
|||
env_class_type = globals().get("GR1SyncEnv", SyncEnv) |
|||
else: |
|||
env_class_type = SyncEnv |
|||
|
|||
controller_configs = update_robosuite_controller_configs( |
|||
robot=robot, |
|||
wbc_version=wbc_version, |
|||
) |
|||
|
|||
env_class_type = type( |
|||
class_name, |
|||
(env_class_type,), |
|||
{ |
|||
"__init__": lambda self, **kwargs: super(self.__class__, self).__init__( |
|||
env_name=env, |
|||
robot_name=robot, |
|||
controller_configs=controller_configs, |
|||
**kwargs, |
|||
) |
|||
}, |
|||
) |
|||
|
|||
current_module = sys.modules["decoupled_wbc.control.envs.robocasa.sync_env"] |
|||
setattr(current_module, class_name, env_class_type) |
|||
register( |
|||
id=id_name, # Unique ID for the environment |
|||
entry_point=f"decoupled_wbc.control.envs.robocasa.sync_env:{class_name}", |
|||
) |
|||
|
|||
|
|||
WBC_VERSION = "gear_wbc" |
|||
|
|||
for ENV in REGISTERED_LOCOMANIPULATION_ENVS: |
|||
for ROBOT, ROBOT_ALIAS in GR00T_LOCOMANIP_ENVS_ROBOTS.items(): |
|||
create_gym_sync_env_class(ENV, ROBOT, ROBOT_ALIAS, WBC_VERSION) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
|
|||
env = gym.make("gr00tlocomanip_g1_sim/PnPBottle_g1_gear_wbc") |
|||
print(env.observation_space) |
|||
@ -0,0 +1,73 @@ |
|||
from dataclasses import dataclass |
|||
from typing import Dict, Optional, Tuple |
|||
|
|||
from decoupled_wbc.data.constants import RS_VIEW_CAMERA_HEIGHT, RS_VIEW_CAMERA_WIDTH |
|||
|
|||
|
|||
@dataclass |
|||
class CameraConfig: |
|||
width: int |
|||
height: int |
|||
mapped_key: str |
|||
|
|||
|
|||
class CameraKeyMapper: |
|||
def __init__(self): |
|||
# Default camera dimensions |
|||
self.default_width = RS_VIEW_CAMERA_WIDTH |
|||
self.default_height = RS_VIEW_CAMERA_HEIGHT |
|||
|
|||
# Camera key mapping with custom dimensions |
|||
self.camera_configs: Dict[str, CameraConfig] = { |
|||
# GR1 |
|||
"egoview": CameraConfig(self.default_width, self.default_height, "ego_view"), |
|||
"frontview": CameraConfig(self.default_width, self.default_height, "front_view"), |
|||
# G1 |
|||
"robot0_rs_egoview": CameraConfig(self.default_width, self.default_height, "ego_view"), |
|||
"robot0_rs_tppview": CameraConfig(self.default_width, self.default_height, "tpp_view"), |
|||
"robot0_oak_egoview": CameraConfig(self.default_width, self.default_height, "ego_view"), |
|||
"robot0_oak_left_monoview": CameraConfig( |
|||
self.default_width, self.default_height, "ego_view_left_mono" |
|||
), |
|||
"robot0_oak_right_monoview": CameraConfig( |
|||
self.default_width, self.default_height, "ego_view_right_mono" |
|||
), |
|||
} |
|||
|
|||
def get_camera_config(self, key: str) -> Optional[Tuple[str, int, int]]: |
|||
""" |
|||
Get the mapped camera key and dimensions for a given camera key. |
|||
|
|||
Args: |
|||
key: The input camera key |
|||
|
|||
Returns: |
|||
Tuple of (mapped_key, width, height) if key exists, None otherwise |
|||
""" |
|||
config = self.camera_configs.get(key.lower()) |
|||
if config is None: |
|||
return None |
|||
return config.mapped_key, config.width, config.height |
|||
|
|||
def add_camera_config( |
|||
self, key: str, mapped_key: str, width: int = 256, height: int = 256 |
|||
) -> None: |
|||
""" |
|||
Add a new camera configuration or update an existing one. |
|||
|
|||
Args: |
|||
key: The camera key to add/update |
|||
mapped_key: The actual camera key to map to |
|||
width: Camera width in pixels |
|||
height: Camera height in pixels |
|||
""" |
|||
self.camera_configs[key.lower()] = CameraConfig(width, height, mapped_key) |
|||
|
|||
def get_all_camera_keys(self) -> list: |
|||
""" |
|||
Get all available camera keys. |
|||
|
|||
Returns: |
|||
List of all camera keys |
|||
""" |
|||
return list(self.camera_configs.keys()) |
|||
@ -0,0 +1,443 @@ |
|||
import os |
|||
from typing import Any, Dict, List, Tuple |
|||
|
|||
from gymnasium import spaces |
|||
import mujoco |
|||
import numpy as np |
|||
import robocasa |
|||
from robocasa.utils.gym_utils.gymnasium_basic import ( |
|||
RoboCasaEnv, |
|||
create_env_robosuite, |
|||
) |
|||
from robocasa.wrappers.ik_wrapper import IKWrapper |
|||
from robosuite.controllers import load_composite_controller_config |
|||
from robosuite.utils.log_utils import ROBOSUITE_DEFAULT_LOGGER |
|||
|
|||
from decoupled_wbc.control.envs.robocasa.utils.cam_key_converter import CameraKeyMapper |
|||
from decoupled_wbc.control.envs.robocasa.utils.robot_key_converter import Gr00tObsActionConverter |
|||
from decoupled_wbc.control.robot_model.robot_model import RobotModel |
|||
|
|||
ALLOWED_LANGUAGE_CHARSET = ( |
|||
"abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789 ,.\n\t[]{}()!?'_:" |
|||
) |
|||
|
|||
|
|||
class Gr00tLocomanipRoboCasaEnv(RoboCasaEnv): |
|||
def __init__( |
|||
self, |
|||
env_name: str, |
|||
robots_name: str, |
|||
robot_model: RobotModel, # gr00t robot model |
|||
input_space: str = "JOINT_SPACE", # either "JOINT_SPACE" or "EEF_SPACE" |
|||
camera_names: List[str] = ["egoview"], |
|||
camera_heights: List[int] | None = None, |
|||
camera_widths: List[int] | None = None, |
|||
onscreen: bool = False, |
|||
offscreen: bool = False, |
|||
dump_rollout_dataset_dir: str | None = None, |
|||
rollout_hdf5: str | None = None, |
|||
rollout_trainset: int | None = None, |
|||
controller_configs: str | None = None, |
|||
ik_indicator: bool = False, |
|||
**kwargs, |
|||
): |
|||
# ========= Create env ========= |
|||
if controller_configs is None: |
|||
if "G1" in robots_name: |
|||
controller_configs = ( |
|||
"robocasa/examples/third_party_controller/default_mink_ik_g1_wbc.json" |
|||
) |
|||
elif "GR1" in robots_name: |
|||
controller_configs = ( |
|||
"robocasa/examples/third_party_controller/default_mink_ik_gr1_smallkd.json" |
|||
) |
|||
else: |
|||
assert False, f"Unsupported robot name: {robots_name}" |
|||
controller_configs = os.path.join( |
|||
os.path.dirname(robocasa.__file__), |
|||
"../", |
|||
controller_configs, |
|||
) |
|||
controller_configs = load_composite_controller_config( |
|||
controller=controller_configs, |
|||
robot=robots_name.split("_")[0], |
|||
) |
|||
if input_space == "JOINT_SPACE": |
|||
controller_configs["type"] = "BASIC" |
|||
controller_configs["composite_controller_specific_configs"] = {} |
|||
controller_configs["control_delta"] = False |
|||
|
|||
self.camera_key_mapper = CameraKeyMapper() |
|||
self.camera_names = camera_names |
|||
|
|||
if camera_widths is None: |
|||
self.camera_widths = [ |
|||
self.camera_key_mapper.get_camera_config(name)[1] for name in camera_names |
|||
] |
|||
else: |
|||
self.camera_widths = camera_widths |
|||
if camera_heights is None: |
|||
self.camera_heights = [ |
|||
self.camera_key_mapper.get_camera_config(name)[2] for name in camera_names |
|||
] |
|||
else: |
|||
self.camera_heights = camera_heights |
|||
|
|||
self.env, self.env_kwargs = create_env_robosuite( |
|||
env_name=env_name, |
|||
robots=robots_name.split("_"), |
|||
controller_configs=controller_configs, |
|||
camera_names=camera_names, |
|||
camera_widths=self.camera_widths, |
|||
camera_heights=self.camera_heights, |
|||
enable_render=offscreen, |
|||
onscreen=onscreen, |
|||
**kwargs, # Forward kwargs to create_env_robosuite |
|||
) |
|||
|
|||
if ik_indicator: |
|||
self.env = IKWrapper(self.env, ik_indicator=True) |
|||
|
|||
# ========= create converters first to get total DOFs ========= |
|||
# For now, assume single robot (multi-robot support can be added later) |
|||
self.obs_action_converter: List[Gr00tObsActionConverter] = [ |
|||
Gr00tObsActionConverter( |
|||
robot_model=robot_model, |
|||
robosuite_robot_model=self.env.robots[i], |
|||
) |
|||
for i in range(len(self.env.robots)) |
|||
] |
|||
|
|||
self.body_dofs = sum(converter.body_dof for converter in self.obs_action_converter) |
|||
self.gripper_dofs = sum(converter.gripper_dof for converter in self.obs_action_converter) |
|||
self.total_dofs = self.body_dofs + self.gripper_dofs |
|||
self.body_nu = sum(converter.body_nu for converter in self.obs_action_converter) |
|||
self.gripper_nu = sum(converter.gripper_nu for converter in self.obs_action_converter) |
|||
self.total_nu = self.body_nu + self.gripper_nu |
|||
|
|||
# ========= create spaces to match total DOFs ========= |
|||
self.get_observation_space() |
|||
self.get_action_space() |
|||
|
|||
self.enable_render = offscreen |
|||
self.render_obs_key = f"{camera_names[0]}_image" |
|||
self.render_cache = None |
|||
|
|||
self.dump_rollout_dataset_dir = dump_rollout_dataset_dir |
|||
self.gr00t_exporter = None |
|||
self.np_exporter = None |
|||
|
|||
self.rollout_hdf5 = rollout_hdf5 |
|||
self.rollout_trainset = rollout_trainset |
|||
self.rollout_initial_state = {} |
|||
|
|||
self.verbose = False |
|||
for k, v in self.observation_space.items(): |
|||
self.verbose and print("{OBS}", k, v) |
|||
for k, v in self.action_space.items(): |
|||
self.verbose and print("{ACTION}", k, v) |
|||
|
|||
self.overridden_floating_base_action = None |
|||
|
|||
def get_observation_space(self): |
|||
self.observation_space = spaces.Dict({}) |
|||
|
|||
# Add all the observation spaces |
|||
self.observation_space["time"] = spaces.Box( |
|||
low=-np.inf, high=np.inf, shape=(1,), dtype=np.float32 |
|||
) |
|||
self.observation_space["floating_base_pose"] = spaces.Box( |
|||
low=-np.inf, high=np.inf, shape=(7,), dtype=np.float32 |
|||
) |
|||
self.observation_space["floating_base_vel"] = spaces.Box( |
|||
low=-np.inf, high=np.inf, shape=(6,), dtype=np.float32 |
|||
) |
|||
self.observation_space["floating_base_acc"] = spaces.Box( |
|||
low=-np.inf, high=np.inf, shape=(6,), dtype=np.float32 |
|||
) |
|||
self.observation_space["body_q"] = spaces.Box( |
|||
low=-np.inf, high=np.inf, shape=(self.body_dofs,), dtype=np.float32 |
|||
) |
|||
self.observation_space["body_dq"] = spaces.Box( |
|||
low=-np.inf, high=np.inf, shape=(self.body_dofs,), dtype=np.float32 |
|||
) |
|||
self.observation_space["body_ddq"] = spaces.Box( |
|||
low=-np.inf, high=np.inf, shape=(self.body_dofs,), dtype=np.float32 |
|||
) |
|||
self.observation_space["body_tau_est"] = spaces.Box( |
|||
low=-np.inf, high=np.inf, shape=(self.body_nu,), dtype=np.float32 |
|||
) |
|||
self.observation_space["left_hand_q"] = spaces.Box( |
|||
low=-np.inf, high=np.inf, shape=(self.gripper_dofs // 2,), dtype=np.float32 |
|||
) |
|||
self.observation_space["left_hand_dq"] = spaces.Box( |
|||
low=-np.inf, high=np.inf, shape=(self.gripper_dofs // 2,), dtype=np.float32 |
|||
) |
|||
self.observation_space["left_hand_ddq"] = spaces.Box( |
|||
low=-np.inf, high=np.inf, shape=(self.gripper_dofs // 2,), dtype=np.float32 |
|||
) |
|||
self.observation_space["left_hand_tau_est"] = spaces.Box( |
|||
low=-np.inf, high=np.inf, shape=(self.gripper_nu // 2,), dtype=np.float32 |
|||
) |
|||
self.observation_space["right_hand_q"] = spaces.Box( |
|||
low=-np.inf, high=np.inf, shape=(self.gripper_dofs // 2,), dtype=np.float32 |
|||
) |
|||
self.observation_space["right_hand_dq"] = spaces.Box( |
|||
low=-np.inf, high=np.inf, shape=(self.gripper_dofs // 2,), dtype=np.float32 |
|||
) |
|||
self.observation_space["right_hand_ddq"] = spaces.Box( |
|||
low=-np.inf, high=np.inf, shape=(self.gripper_dofs // 2,), dtype=np.float32 |
|||
) |
|||
self.observation_space["right_hand_tau_est"] = spaces.Box( |
|||
low=-np.inf, high=np.inf, shape=(self.gripper_nu // 2,), dtype=np.float32 |
|||
) |
|||
|
|||
self.observation_space["language.language_instruction"] = spaces.Text( |
|||
max_length=256, charset=ALLOWED_LANGUAGE_CHARSET |
|||
) |
|||
|
|||
# Add camera observation spaces |
|||
for camera_name, w, h in zip(self.camera_names, self.camera_widths, self.camera_heights): |
|||
k = self.camera_key_mapper.get_camera_config(camera_name)[0] |
|||
self.observation_space[f"{k}_image"] = spaces.Box( |
|||
low=0, high=255, shape=(h, w, 3), dtype=np.uint8 |
|||
) |
|||
|
|||
# Add extra privileged observation spaces |
|||
if hasattr(self.env, "get_privileged_obs_keys"): |
|||
for key, shape in self.env.get_privileged_obs_keys().items(): |
|||
self.observation_space[key] = spaces.Box( |
|||
low=-np.inf, high=np.inf, shape=shape, dtype=np.float32 |
|||
) |
|||
|
|||
# Add robot-specific observation spaces |
|||
if hasattr(self.env.robots[0].robot_model, "torso_body"): |
|||
self.observation_space["secondary_imu_quat"] = spaces.Box( |
|||
low=-np.inf, high=np.inf, shape=(4,), dtype=np.float32 |
|||
) |
|||
self.observation_space["secondary_imu_vel"] = spaces.Box( |
|||
low=-np.inf, high=np.inf, shape=(6,), dtype=np.float32 |
|||
) |
|||
|
|||
def get_action_space(self): |
|||
self.action_space = spaces.Dict( |
|||
{"q": spaces.Box(low=-np.inf, high=np.inf, shape=(self.total_dofs,), dtype=np.float32)} |
|||
) |
|||
|
|||
def reset(self, seed=None, options=None): |
|||
raw_obs, info = super().reset(seed=seed, options=options) |
|||
obs = self.get_gr00t_observation(raw_obs) |
|||
|
|||
lang = self.env.get_ep_meta().get("lang", "") |
|||
ROBOSUITE_DEFAULT_LOGGER.info(f"Instruction: {lang}") |
|||
|
|||
return obs, info |
|||
|
|||
def step( |
|||
self, action: Dict[str, Any] |
|||
) -> Tuple[Dict[str, Any], float, bool, bool, Dict[str, Any]]: |
|||
# action={"q": xxx, "tau": xxx} |
|||
for k, v in action.items(): |
|||
self.verbose and print("<ACTION>", k, v) |
|||
|
|||
joint_actoin_vec = action["q"] |
|||
action_dict = {} |
|||
for ii, robot in enumerate(self.env.robots): |
|||
pf = robot.robot_model.naming_prefix |
|||
_action_dict = self.obs_action_converter[ii].gr00t_to_robocasa_action_dict( |
|||
joint_actoin_vec |
|||
) |
|||
action_dict.update({f"{pf}{k}": v for k, v in _action_dict.items()}) |
|||
if action.get("tau", None) is not None: |
|||
_torque_dict = self.obs_action_converter[ii].gr00t_to_robocasa_action_dict( |
|||
action["tau"] |
|||
) |
|||
action_dict.update({f"{pf}{k}_tau": v for k, v in _torque_dict.items()}) |
|||
if self.overridden_floating_base_action is not None: |
|||
action_dict["robot0_base"] = self.overridden_floating_base_action |
|||
raw_obs, reward, terminated, truncated, info = super().step(action_dict) |
|||
obs = self.get_gr00t_observation(raw_obs) |
|||
|
|||
for k, v in obs.items(): |
|||
self.verbose and print("<OBS>", k, v.shape if k.startswith("video.") else v) |
|||
self.verbose = False |
|||
|
|||
return obs, reward, terminated, truncated, info |
|||
|
|||
def step_only_kinematics( |
|||
self, action: Dict[str, Any] |
|||
) -> Tuple[Dict[str, Any], float, bool, bool, Dict[str, Any]]: |
|||
joint_actoin_vec = action["q"] |
|||
for ii, robot in enumerate(self.env.robots): |
|||
joint_names = np.array(self.env.sim.model.joint_names)[robot._ref_joint_indexes] |
|||
body_q = self.obs_action_converter[ii].gr00t_to_robocasa_joint_order( |
|||
joint_names, joint_actoin_vec |
|||
) |
|||
self.env.sim.data.qpos[robot._ref_joint_pos_indexes] = body_q |
|||
|
|||
for side in ["left", "right"]: |
|||
joint_names = np.array(self.env.sim.model.joint_names)[ |
|||
robot._ref_joints_indexes_dict[side + "_gripper"] |
|||
] |
|||
gripper_q = self.obs_action_converter[ii].gr00t_to_robocasa_joint_order( |
|||
joint_names, joint_actoin_vec |
|||
) |
|||
self.env.sim.data.qpos[robot._ref_gripper_joint_pos_indexes[side]] = gripper_q |
|||
|
|||
mujoco.mj_forward(self.env.sim.model._model, self.env.sim.data._data) |
|||
|
|||
obs = self.force_update_observation() |
|||
return obs, 0, False, False, {"success": False} |
|||
|
|||
def force_update_observation(self, timestep=0): |
|||
raw_obs = self.env._get_observations(force_update=True, timestep=timestep) |
|||
obs = self.get_basic_observation(raw_obs) |
|||
obs = self.get_gr00t_observation(obs) |
|||
return obs |
|||
|
|||
def get_basic_observation(self, raw_obs): |
|||
# this function takes a lot of time, so we disable it for now |
|||
# raw_obs.update(gather_robot_observations(self.env, format_gripper_space=False)) |
|||
|
|||
# Image are in (H, W, C), flip it upside down |
|||
def process_img(img): |
|||
return np.copy(img[::-1, :, :]) |
|||
|
|||
for obs_name, obs_value in raw_obs.items(): |
|||
if obs_name.endswith("_image"): |
|||
# image observations |
|||
raw_obs[obs_name] = process_img(obs_value) |
|||
else: |
|||
# non-image observations |
|||
raw_obs[obs_name] = obs_value.astype(np.float32) |
|||
|
|||
# Return black image if rendering is disabled |
|||
if not self.enable_render: |
|||
for ii, name in enumerate(self.camera_names): |
|||
raw_obs[f"{name}_image"] = np.zeros( |
|||
(self.camera_heights[ii], self.camera_widths[ii], 3), dtype=np.uint8 |
|||
) |
|||
|
|||
self.render_cache = raw_obs[self.render_obs_key] |
|||
raw_obs["language"] = self.env.get_ep_meta().get("lang", "") |
|||
|
|||
return raw_obs |
|||
|
|||
def convert_body_q(self, q: np.ndarray) -> np.ndarray: |
|||
# q is in the order of the joints |
|||
robot = self.env.robots[0] |
|||
joint_names = np.array(self.env.sim.model.joint_names)[robot._ref_joint_indexes] |
|||
# this joint names are in the order of the obs_vec |
|||
actuated_q = self.obs_action_converter[0].robocasa_to_gr00t_actuated_order( |
|||
joint_names, q, "body" |
|||
) |
|||
return actuated_q |
|||
|
|||
def convert_gripper_q(self, q: np.ndarray, side: str = "left") -> np.ndarray: |
|||
# q is in the order of the joints |
|||
robot = self.env.robots[0] |
|||
joint_names = np.array(self.env.sim.model.joint_names)[ |
|||
robot._ref_joints_indexes_dict[side + "_gripper"] |
|||
] |
|||
actuated_q = self.obs_action_converter[0].robocasa_to_gr00t_actuated_order( |
|||
joint_names, q, side + "_gripper" |
|||
) |
|||
return actuated_q |
|||
|
|||
def convert_gripper_tau(self, tau: np.ndarray, side: str = "left") -> np.ndarray: |
|||
# tau is in the order of the actuators |
|||
robot = self.env.robots[0] |
|||
actuator_idx = robot._ref_actuators_indexes_dict[side + "_gripper"] |
|||
actuated_joint_names = [ |
|||
self.env.sim.model.joint_id2name(self.env.sim.model.actuator_trnid[i][0]) |
|||
for i in actuator_idx |
|||
] |
|||
actuated_tau = self.obs_action_converter[0].robocasa_to_gr00t_actuated_order( |
|||
actuated_joint_names, tau, side + "_gripper" |
|||
) |
|||
return actuated_tau |
|||
|
|||
def get_gr00t_observation(self, raw_obs: Dict[str, Any]) -> Dict[str, Any]: |
|||
obs = {} |
|||
|
|||
if self.env.sim.model.jnt_type[0] == mujoco.mjtJoint.mjJNT_FREE: |
|||
# If the first joint is a free joint, use this way to get the floating base data |
|||
obs["floating_base_pose"] = self.env.sim.data.qpos[:7] |
|||
obs["floating_base_vel"] = self.env.sim.data.qvel[:6] |
|||
obs["floating_base_acc"] = self.env.sim.data.qacc[:6] |
|||
else: |
|||
# Otherwise, use self.env.sim.model to fetch the floating base pose |
|||
root_body_id = self.env.sim.model.body_name2id("robot0_base") |
|||
|
|||
# Get position and orientation from body state |
|||
root_pos = self.env.sim.data.body_xpos[root_body_id] |
|||
root_quat = self.env.sim.data.body_xquat[root_body_id] # quaternion in wxyz format |
|||
|
|||
# Combine position and quaternion to form 7-DOF pose |
|||
obs["floating_base_pose"] = np.concatenate([root_pos, root_quat]) |
|||
# set vel and acc to 0 |
|||
obs["floating_base_vel"] = np.zeros(6) |
|||
obs["floating_base_acc"] = np.zeros(6) |
|||
|
|||
obs["body_q"] = self.convert_body_q(raw_obs["robot0_joint_pos"]) |
|||
obs["body_dq"] = self.convert_body_q(raw_obs["robot0_joint_vel"]) |
|||
obs["body_ddq"] = self.convert_body_q(raw_obs["robot0_joint_acc"]) |
|||
|
|||
obs["left_hand_q"] = self.convert_gripper_q(raw_obs["robot0_left_gripper_qpos"], "left") |
|||
obs["left_hand_dq"] = self.convert_gripper_q(raw_obs["robot0_left_gripper_qvel"], "left") |
|||
obs["left_hand_ddq"] = self.convert_gripper_q(raw_obs["robot0_left_gripper_qacc"], "left") |
|||
obs["right_hand_q"] = self.convert_gripper_q(raw_obs["robot0_right_gripper_qpos"], "right") |
|||
obs["right_hand_dq"] = self.convert_gripper_q(raw_obs["robot0_right_gripper_qvel"], "right") |
|||
obs["right_hand_ddq"] = self.convert_gripper_q( |
|||
raw_obs["robot0_right_gripper_qacc"], "right" |
|||
) |
|||
|
|||
robot = self.env.robots[0] |
|||
body_tau_idx_list = [] |
|||
left_gripper_tau_idx_list = [] |
|||
right_gripper_tau_idx_list = [] |
|||
for part_name, actuator_idx in robot._ref_actuators_indexes_dict.items(): |
|||
if "left_gripper" in part_name: |
|||
left_gripper_tau_idx_list.extend(actuator_idx) |
|||
elif "right_gripper" in part_name: |
|||
right_gripper_tau_idx_list.extend(actuator_idx) |
|||
elif "base" in part_name: |
|||
assert ( |
|||
len(actuator_idx) == 0 or robot.robot_model.default_base == "FloatingLeggedBase" |
|||
) |
|||
else: |
|||
body_tau_idx_list.extend(actuator_idx) |
|||
|
|||
body_tau_idx_list = sorted(body_tau_idx_list) |
|||
left_gripper_tau_idx_list = sorted(left_gripper_tau_idx_list) |
|||
right_gripper_tau_idx_list = sorted(right_gripper_tau_idx_list) |
|||
obs["body_tau_est"] = self.convert_body_q( |
|||
self.env.sim.data.actuator_force[body_tau_idx_list] |
|||
) |
|||
obs["right_hand_tau_est"] = self.convert_gripper_tau( |
|||
self.env.sim.data.actuator_force[right_gripper_tau_idx_list], "right" |
|||
) |
|||
obs["left_hand_tau_est"] = self.convert_gripper_tau( |
|||
self.env.sim.data.actuator_force[left_gripper_tau_idx_list], "left" |
|||
) |
|||
|
|||
obs["time"] = self.env.sim.data.time |
|||
|
|||
# Add camera images |
|||
for ii, camera_name in enumerate(self.camera_names): |
|||
mapped_camera_name = self.camera_key_mapper.get_camera_config(camera_name)[0] |
|||
obs[f"{mapped_camera_name}_image"] = raw_obs[f"{camera_name}_image"] |
|||
|
|||
# Add privileged observations |
|||
if hasattr(self.env, "get_privileged_obs_keys"): |
|||
for key in self.env.get_privileged_obs_keys(): |
|||
obs[key] = raw_obs[key] |
|||
|
|||
# Add robot-specific observations |
|||
if hasattr(self.env.robots[0].robot_model, "torso_body"): |
|||
obs["secondary_imu_quat"] = raw_obs["robot0_torso_link_imu_quat"] |
|||
obs["secondary_imu_vel"] = raw_obs["robot0_torso_link_imu_vel"] |
|||
|
|||
obs["language.language_instruction"] = raw_obs["language"] |
|||
|
|||
return obs |
|||
@ -0,0 +1,301 @@ |
|||
from dataclasses import dataclass |
|||
from typing import Any, Dict, List, Tuple |
|||
|
|||
import numpy as np |
|||
from robocasa.models.robots import remove_mimic_joints |
|||
from robosuite.models.robots import RobotModel as RobosuiteRobotModel |
|||
|
|||
from decoupled_wbc.control.robot_model import RobotModel |
|||
|
|||
|
|||
class Gr00tJointInfo: |
|||
""" |
|||
Mapping from decoupled_wbc actuated joint names to robocasa joint names. |
|||
""" |
|||
|
|||
def __init__(self, robot_model: RobosuiteRobotModel): |
|||
self.robocasa_body_prefix = "robot0_" |
|||
self.robocasa_gripper_prefix = "gripper0_" |
|||
|
|||
self.robot_model: RobotModel = robot_model |
|||
self.body_actuated_joint_names: List[str] = ( |
|||
self.robot_model.supplemental_info.body_actuated_joints |
|||
) |
|||
self.left_hand_actuated_joint_names: List[str] = ( |
|||
self.robot_model.supplemental_info.left_hand_actuated_joints |
|||
) |
|||
self.right_hand_actuated_joint_names: List[str] = ( |
|||
self.robot_model.supplemental_info.right_hand_actuated_joints |
|||
) |
|||
|
|||
self.actuated_joint_names: List[str] = self._get_gr00t_actuated_joint_names() |
|||
self.body_actuated_joint_to_index: Dict[str, int] = ( |
|||
self._get_gr00t_body_actuated_joint_name_to_index() |
|||
) |
|||
self.gripper_actuated_joint_to_index: Tuple[Dict[str, int], Dict[str, int]] = ( |
|||
self._get_gr00t_gripper_actuated_joint_name_to_index() |
|||
) |
|||
self.actuated_joint_name_to_index: Dict[str, int] = ( |
|||
self._get_gr00t_actuated_joint_name_to_index() |
|||
) |
|||
|
|||
def _get_gr00t_actuated_joint_names(self) -> List[str]: |
|||
"""Get list of gr00t actuated joint names ordered by their indices.""" |
|||
if self.robot_model.supplemental_info is None: |
|||
raise ValueError("Robot model must have supplemental_info") |
|||
|
|||
# Get joint names and indices |
|||
body_names = self.robot_model.supplemental_info.body_actuated_joints |
|||
left_hand_names = self.robot_model.supplemental_info.left_hand_actuated_joints |
|||
right_hand_names = self.robot_model.supplemental_info.right_hand_actuated_joints |
|||
|
|||
body_indices = self.robot_model.get_joint_group_indices("body") |
|||
left_hand_indices = self.robot_model.get_joint_group_indices("left_hand") |
|||
right_hand_indices = self.robot_model.get_joint_group_indices("right_hand") |
|||
|
|||
# Create a dictionary mapping index to name |
|||
index_to_name = {} |
|||
for name, idx in zip(body_names, body_indices): |
|||
index_to_name[idx] = self.robocasa_body_prefix + name |
|||
for name, idx in zip(left_hand_names, left_hand_indices): |
|||
index_to_name[idx] = self.robocasa_gripper_prefix + "left_" + name |
|||
for name, idx in zip(right_hand_names, right_hand_indices): |
|||
index_to_name[idx] = self.robocasa_gripper_prefix + "right_" + name |
|||
sorted_indices = sorted(index_to_name.keys()) |
|||
all_actuated_joint_names = [index_to_name[idx] for idx in sorted_indices] |
|||
return all_actuated_joint_names |
|||
|
|||
def _get_gr00t_body_actuated_joint_name_to_index(self) -> Dict[str, int]: |
|||
"""Get dictionary mapping gr00t actuated joint names to indices.""" |
|||
if self.robot_model.supplemental_info is None: |
|||
raise ValueError("Robot model must have supplemental_info") |
|||
body_names = self.robot_model.supplemental_info.body_actuated_joints |
|||
body_indices = self.robot_model.get_joint_group_indices("body") |
|||
sorted_indices = np.argsort(body_indices) |
|||
sorted_names = [body_names[i] for i in sorted_indices] |
|||
return {self.robocasa_body_prefix + name: ii for ii, name in enumerate(sorted_names)} |
|||
|
|||
def _get_gr00t_gripper_actuated_joint_name_to_index( |
|||
self, |
|||
) -> Tuple[Dict[str, int], Dict[str, int]]: |
|||
"""Get dictionary mapping gr00t actuated joint names to indices.""" |
|||
if self.robot_model.supplemental_info is None: |
|||
raise ValueError("Robot model must have supplemental_info") |
|||
left_hand_names = self.robot_model.supplemental_info.left_hand_actuated_joints |
|||
right_hand_names = self.robot_model.supplemental_info.right_hand_actuated_joints |
|||
left_hand_indices = self.robot_model.get_joint_group_indices("left_hand") |
|||
right_hand_indices = self.robot_model.get_joint_group_indices("right_hand") |
|||
sorted_left_hand_indices = np.argsort(left_hand_indices) |
|||
sorted_right_hand_indices = np.argsort(right_hand_indices) |
|||
sorted_left_hand_names = [left_hand_names[i] for i in sorted_left_hand_indices] |
|||
sorted_right_hand_names = [right_hand_names[i] for i in sorted_right_hand_indices] |
|||
return ( |
|||
{ |
|||
self.robocasa_gripper_prefix + "left_" + name: ii |
|||
for ii, name in enumerate(sorted_left_hand_names) |
|||
}, |
|||
{ |
|||
self.robocasa_gripper_prefix + "right_" + name: ii |
|||
for ii, name in enumerate(sorted_right_hand_names) |
|||
}, |
|||
) |
|||
|
|||
def _get_gr00t_actuated_joint_name_to_index(self) -> Dict[str, int]: |
|||
"""Get dictionary mapping gr00t actuated joint names to indices.""" |
|||
return {name: ii for ii, name in enumerate(self.actuated_joint_names)} |
|||
|
|||
|
|||
@dataclass |
|||
class Gr00tObsActionConverter: |
|||
""" |
|||
Converter to align simulation environment joint action space with real environment joint action space. |
|||
Handles joint order and range conversion. |
|||
""" |
|||
|
|||
robot_model: RobotModel |
|||
robosuite_robot_model: RobosuiteRobotModel |
|||
robocasa_body_prefix: str = "robot0_" |
|||
robocasa_gripper_prefix: str = "gripper0_" |
|||
|
|||
def __post_init__(self): |
|||
"""Initialize converter with robot configuration.""" |
|||
|
|||
self.robot_key = self.robot_model.supplemental_info.name |
|||
self.gr00t_joint_info = Gr00tJointInfo(self.robot_model) |
|||
self.robocasa_joint_names_for_each_part: Dict[str, List[str]] = ( |
|||
self._get_robocasa_joint_names_for_each_part() |
|||
) |
|||
self.robocasa_actuator_names_for_each_part: Dict[str, List[str]] = ( |
|||
self._get_robotcasa_actuator_names_for_each_part() |
|||
) |
|||
|
|||
# Store mappings directly as class attributes |
|||
self.gr00t_joint_name_to_index = self.gr00t_joint_info.actuated_joint_name_to_index |
|||
self.gr00t_body_joint_name_to_index = self.gr00t_joint_info.body_actuated_joint_to_index |
|||
self.gr00t_gripper_joint_name_to_index = { |
|||
"left": self.gr00t_joint_info.gripper_actuated_joint_to_index[0], |
|||
"right": self.gr00t_joint_info.gripper_actuated_joint_to_index[1], |
|||
} |
|||
self.gr00t_to_robocasa_actuator_indices = self._get_actuator_mapping() |
|||
|
|||
if self.robot_key == "GR1_Fourier": |
|||
self.joint_multiplier = ( |
|||
lambda x: np.array([-1, 1, 1, -1, -1, -1, -1, -1, -1, -1, -1]) * x |
|||
) |
|||
self.actuator_multiplier = ( |
|||
lambda x: np.array([-1, -1, -1, -1, -1, -1, -1, -1, 1, 1, -1]) * x |
|||
) |
|||
else: |
|||
self.joint_multiplier = lambda x: x |
|||
self.actuator_multiplier = lambda x: x |
|||
|
|||
# Store DOF counts directly |
|||
self.body_dof = len(self.gr00t_joint_info.body_actuated_joint_names) |
|||
self.gripper_dof = len(self.gr00t_joint_info.left_hand_actuated_joint_names) + len( |
|||
self.gr00t_joint_info.right_hand_actuated_joint_names |
|||
) |
|||
self.whole_dof = self.body_dof + self.gripper_dof |
|||
self.body_nu = len(self.gr00t_joint_info.body_actuated_joint_names) |
|||
self.gripper_nu = len(self.gr00t_joint_info.left_hand_actuated_joint_names) + len( |
|||
self.gr00t_joint_info.right_hand_actuated_joint_names |
|||
) |
|||
self.whole_nu = self.body_nu + self.gripper_nu |
|||
|
|||
def _get_robocasa_joint_names_for_each_part(self) -> Dict[str, List[str]]: |
|||
part_names = self.robosuite_robot_model._ref_joints_indexes_dict.keys() |
|||
robocasa_joint_names_for_each_part = {} |
|||
for part_name in part_names: |
|||
joint_indices = self.robosuite_robot_model._ref_joints_indexes_dict[part_name] |
|||
joint_names = [ |
|||
self.robosuite_robot_model.sim.model.joint_id2name(j) for j in joint_indices |
|||
] |
|||
robocasa_joint_names_for_each_part[part_name] = joint_names |
|||
return robocasa_joint_names_for_each_part |
|||
|
|||
def _get_robotcasa_actuator_names_for_each_part(self) -> Dict[str, List[str]]: |
|||
part_names = self.robosuite_robot_model._ref_actuators_indexes_dict.keys() |
|||
robocasa_actuator_names_for_each_part = {} |
|||
for part_name in part_names: |
|||
if part_name == "base": |
|||
continue |
|||
actuator_indices = self.robosuite_robot_model._ref_actuators_indexes_dict[part_name] |
|||
actuator_names = [ |
|||
self.robosuite_robot_model.sim.model.actuator_id2name(j) for j in actuator_indices |
|||
] |
|||
robocasa_actuator_names_for_each_part[part_name] = actuator_names |
|||
return robocasa_actuator_names_for_each_part |
|||
|
|||
def _get_actuator_mapping(self) -> Dict[str, List[int]]: |
|||
"""Get mapping from decoupled_wbc actuatored joint order to robocasa actuatored joint order for whole body.""" |
|||
return { |
|||
part_name: [ |
|||
self.gr00t_joint_info.actuated_joint_name_to_index[j] |
|||
for j in self.robocasa_actuator_names_for_each_part[part_name] |
|||
] |
|||
for part_name in self.robocasa_actuator_names_for_each_part.keys() |
|||
} |
|||
|
|||
def check_action_dim_match(self, vec_dim: int) -> bool: |
|||
""" |
|||
Check if input vector dimension matches expected dimension. |
|||
|
|||
Args: |
|||
vec_dim: Dimension of input vector |
|||
|
|||
Returns: |
|||
bool: True if dimensions match |
|||
""" |
|||
return vec_dim == self.whole_dof |
|||
|
|||
def gr00t_to_robocasa_action_dict(self, action_vec: np.ndarray) -> Dict[str, Any]: |
|||
""" |
|||
Convert gr00t flat action vector to robocasa dictionary mapping part names to actions. |
|||
|
|||
Args: |
|||
robot: Robocasa robot model instance |
|||
action_vec: Full action vector array in gr00t actuated joint order |
|||
|
|||
Returns: |
|||
dict: Mapping from part names to action vectors for robocasa |
|||
""" |
|||
if not self.check_action_dim_match(len(action_vec)): |
|||
raise ValueError( |
|||
f"Action vector dimension mismatch: {len(action_vec)} != {self.whole_dof}" |
|||
) |
|||
|
|||
action_dict = {} |
|||
cc = self.robosuite_robot_model.composite_controller |
|||
|
|||
for part_name, controller in cc.part_controllers.items(): |
|||
if "gripper" in part_name: |
|||
robocasa_action = action_vec[self.gr00t_to_robocasa_actuator_indices[part_name]] |
|||
if self.actuator_multiplier is not None: |
|||
robocasa_action = self.actuator_multiplier(robocasa_action) |
|||
action_dict[part_name] = remove_mimic_joints( |
|||
cc.grippers[part_name], robocasa_action |
|||
) |
|||
elif "base" in part_name: |
|||
assert ( |
|||
len(self.gr00t_to_robocasa_actuator_indices.get(part_name, [])) == 0 |
|||
or self.robosuite_robot_model.default_base == "FloatingLeggedBase" |
|||
) |
|||
else: |
|||
action_dict[part_name] = action_vec[ |
|||
self.gr00t_to_robocasa_actuator_indices[part_name] |
|||
] |
|||
|
|||
return action_dict |
|||
|
|||
def robocasa_to_gr00t_actuated_order( |
|||
self, joint_names: List[str], q: np.ndarray, obs_type: str = "body" |
|||
) -> np.ndarray: |
|||
""" |
|||
Convert observation from robocasa joint order to gr00t actuated joint order. |
|||
|
|||
Args: |
|||
joint_names: List of joint names in robocasa order (with prefixes) |
|||
q: Joint positions corresponding to joint_names |
|||
obs_type: Type of observation ("body", "left_gripper", "right_gripper", or "whole") |
|||
|
|||
Returns: |
|||
Joint positions in gr00t actuated joint order |
|||
""" |
|||
assert len(joint_names) == len(q), "Joint names and q must have the same length" |
|||
|
|||
if obs_type == "body": |
|||
actuated_q = np.zeros(self.body_dof) |
|||
for i, jn in enumerate(joint_names): |
|||
actuated_q[self.gr00t_body_joint_name_to_index[jn]] = q[i] |
|||
elif obs_type == "left_gripper": |
|||
actuated_q = np.zeros(self.gripper_dof // 2) |
|||
for i, jn in enumerate(joint_names): |
|||
actuated_q[self.gr00t_gripper_joint_name_to_index["left"][jn]] = q[i] |
|||
elif obs_type == "right_gripper": |
|||
actuated_q = np.zeros(self.gripper_dof // 2) |
|||
for i, jn in enumerate(joint_names): |
|||
actuated_q[self.gr00t_gripper_joint_name_to_index["right"][jn]] = q[i] |
|||
elif obs_type == "whole": |
|||
actuated_q = np.zeros(self.whole_dof) |
|||
for i, jn in enumerate(joint_names): |
|||
actuated_q[self.gr00t_joint_name_to_index[jn]] = q[i] |
|||
else: |
|||
raise ValueError(f"Unknown observation type: {obs_type}") |
|||
return actuated_q |
|||
|
|||
def gr00t_to_robocasa_joint_order( |
|||
self, joint_names: List[str], q_in_actuated_order: np.ndarray |
|||
) -> np.ndarray: |
|||
""" |
|||
Convert gr00t actuated joint order to robocasa joint order. |
|||
|
|||
Args: |
|||
joint_names: List of joint names in robocasa order (with prefixes) |
|||
q_in_actuated_order: Joint positions corresponding to joint_names in gr00t actuated joint order |
|||
|
|||
Returns: |
|||
Joint positions in robocasa joint order |
|||
""" |
|||
q = np.zeros(len(joint_names)) |
|||
for i, jn in enumerate(joint_names): |
|||
q[i] = q_in_actuated_order[self.gr00t_joint_name_to_index[jn]] |
|||
return q |
|||
@ -0,0 +1,483 @@ |
|||
from dataclasses import dataclass |
|||
import os |
|||
from pathlib import Path |
|||
from typing import Literal, Optional |
|||
|
|||
import yaml |
|||
|
|||
import decoupled_wbc |
|||
from decoupled_wbc.control.main.config_template import ArgsConfig as ArgsConfigTemplate |
|||
from decoupled_wbc.control.policy.wbc_policy_factory import WBC_VERSIONS |
|||
from decoupled_wbc.control.utils.network_utils import resolve_interface |
|||
|
|||
|
|||
def override_wbc_config( |
|||
wbc_config: dict, config: "BaseConfig", missed_keys_only: bool = False |
|||
) -> dict: |
|||
"""Override WBC YAML values with dataclass values. |
|||
|
|||
Args: |
|||
wbc_config: The loaded WBC YAML configuration dictionary |
|||
config: The BaseConfig dataclass instance with override values |
|||
missed_keys_only: If True, only add keys that don't exist in wbc_config. |
|||
If False, validate all keys exist and override all. |
|||
|
|||
Returns: |
|||
Updated wbc_config dictionary with overridden values |
|||
|
|||
Raises: |
|||
KeyError: If any required keys are missing from the WBC YAML configuration |
|||
(only when missed_keys_only=False) |
|||
""" |
|||
# Override yaml values with dataclass values |
|||
key_to_value = { |
|||
"INTERFACE": config.interface, |
|||
"ENV_TYPE": config.env_type, |
|||
"VERSION": config.wbc_version, |
|||
"SIMULATOR": config.simulator, |
|||
"SIMULATE_DT": 1 / float(config.sim_frequency), |
|||
"ENABLE_OFFSCREEN": config.enable_offscreen, |
|||
"ENABLE_ONSCREEN": config.enable_onscreen, |
|||
"model_path": config.wbc_model_path, |
|||
"enable_waist": config.enable_waist, |
|||
"with_hands": config.with_hands, |
|||
"verbose": config.verbose, |
|||
"verbose_timing": config.verbose_timing, |
|||
"upper_body_max_joint_speed": config.upper_body_joint_speed, |
|||
"keyboard_dispatcher_type": config.keyboard_dispatcher_type, |
|||
"enable_gravity_compensation": config.enable_gravity_compensation, |
|||
"gravity_compensation_joints": config.gravity_compensation_joints, |
|||
"high_elbow_pose": config.high_elbow_pose, |
|||
} |
|||
|
|||
if missed_keys_only: |
|||
# Only add keys that don't exist in wbc_config |
|||
for key in key_to_value: |
|||
if key not in wbc_config: |
|||
wbc_config[key] = key_to_value[key] |
|||
else: |
|||
# Set all keys (overwrite existing) |
|||
for key in key_to_value: |
|||
wbc_config[key] = key_to_value[key] |
|||
|
|||
# g1 kp, kd, sim2real gap |
|||
if config.env_type == "real": |
|||
# update waist pitch damping, index 14 |
|||
wbc_config["MOTOR_KD"][14] = wbc_config["MOTOR_KD"][14] - 10 |
|||
|
|||
return wbc_config |
|||
|
|||
|
|||
@dataclass |
|||
class BaseConfig(ArgsConfigTemplate): |
|||
"""Base config inherited by all G1 control loops""" |
|||
|
|||
# WBC Configuration |
|||
wbc_version: Literal[tuple(WBC_VERSIONS)] = "gear_wbc" |
|||
"""Version of the whole body controller.""" |
|||
|
|||
wbc_model_path: str = ( |
|||
"policy/GR00T-WholeBodyControl-Balance.onnx," "policy/GR00T-WholeBodyControl-Walk.onnx" |
|||
) |
|||
"""Path to WBC model file (relative to decoupled_wbc/sim2mujoco/resources/robots/g1)""" |
|||
"""gear_wbc model path: policy/GR00T-WholeBodyControl-Balance.onnx,policy/GR00T-WholeBodyControl-Walk.onnx""" |
|||
|
|||
wbc_policy_class: str = "G1DecoupledWholeBodyPolicy" |
|||
"""Whole body policy class.""" |
|||
|
|||
# System Configuration |
|||
interface: str = "sim" |
|||
"""Interface to use for the control loop. [sim, real, lo, enxe8ea6a9c4e09]""" |
|||
|
|||
simulator: str = "mujoco" |
|||
"""Simulator to use.""" |
|||
|
|||
sim_sync_mode: bool = False |
|||
"""Whether to run the control loop in sync mode.""" |
|||
|
|||
control_frequency: int = 50 |
|||
"""Frequency of the control loop.""" |
|||
|
|||
sim_frequency: int = 200 |
|||
"""Frequency of the simulation loop.""" |
|||
|
|||
# Robot Configuration |
|||
enable_waist: bool = False |
|||
"""Whether to include waist joints in IK.""" |
|||
|
|||
with_hands: bool = True |
|||
"""Enable hand functionality. When False, robot operates without hands.""" |
|||
|
|||
high_elbow_pose: bool = False |
|||
"""Enable high elbow pose configuration for default joint positions.""" |
|||
|
|||
verbose: bool = True |
|||
"""Whether to print verbose output.""" |
|||
|
|||
# Additional common fields |
|||
enable_offscreen: bool = False |
|||
"""Whether to enable offscreen rendering.""" |
|||
|
|||
enable_onscreen: bool = True |
|||
"""Whether to enable onscreen rendering.""" |
|||
|
|||
upper_body_joint_speed: float = 1000 |
|||
"""Upper body joint speed.""" |
|||
|
|||
env_name: str = "default" |
|||
"""Environment name.""" |
|||
|
|||
ik_indicator: bool = False |
|||
"""Whether to draw IK indicators.""" |
|||
|
|||
verbose_timing: bool = False |
|||
"""Enable verbose timing output every iteration.""" |
|||
|
|||
keyboard_dispatcher_type: str = "raw" |
|||
"""Keyboard dispatcher to use. [raw, ros]""" |
|||
|
|||
# Gravity Compensation Configuration |
|||
enable_gravity_compensation: bool = False |
|||
"""Enable gravity compensation using pinocchio dynamics.""" |
|||
|
|||
gravity_compensation_joints: Optional[list[str]] = None |
|||
"""Joint groups to apply gravity compensation to (e.g., ['arms', 'left_arm', 'right_arm']).""" |
|||
# Teleop/Device Configuration |
|||
body_control_device: str = "dummy" |
|||
"""Device to use for body control. Options: dummy, vive, iphone, leapmotion, joycon.""" |
|||
|
|||
hand_control_device: Optional[str] = "dummy" |
|||
"""Device to use for hand control. Options: None, manus, joycon, iphone.""" |
|||
|
|||
body_streamer_ip: str = "10.112.210.229" |
|||
"""IP address for body streamer (vive only).""" |
|||
|
|||
body_streamer_keyword: str = "knee" |
|||
"""Body streamer keyword (vive only).""" |
|||
|
|||
enable_visualization: bool = False |
|||
"""Whether to enable visualization.""" |
|||
|
|||
enable_real_device: bool = True |
|||
"""Whether to enable real device.""" |
|||
|
|||
teleop_frequency: int = 20 |
|||
"""Teleoperation frequency (Hz).""" |
|||
|
|||
teleop_replay_path: Optional[str] = None |
|||
"""Path to teleop replay data.""" |
|||
|
|||
# Deployment/Camera Configuration |
|||
robot_ip: str = "192.168.123.164" |
|||
"""Robot IP address""" |
|||
# Data collection settings |
|||
data_collection: bool = True |
|||
"""Enable data collection""" |
|||
|
|||
data_collection_frequency: int = 20 |
|||
"""Data collection frequency (Hz)""" |
|||
|
|||
root_output_dir: str = "outputs" |
|||
"""Root output directory""" |
|||
|
|||
# Policy settings |
|||
enable_upper_body_operation: bool = True |
|||
"""Enable upper body operation""" |
|||
|
|||
upper_body_operation_mode: Literal["teleop", "inference"] = "teleop" |
|||
"""Upper body operation mode""" |
|||
|
|||
def __post_init__(self): |
|||
# Resolve interface (handles sim/real shortcuts, platform differences, and error handling) |
|||
self.interface, self.env_type = resolve_interface(self.interface) |
|||
|
|||
def load_wbc_yaml(self) -> dict: |
|||
"""Load and merge wbc yaml with dataclass overrides""" |
|||
# Get the base path to decoupled_wbc and convert to Path object |
|||
package_path = Path(os.path.dirname(decoupled_wbc.__file__)) |
|||
|
|||
if self.wbc_version == "gear_wbc": |
|||
config_path = str(package_path / "control/main/teleop/configs/g1_29dof_gear_wbc.yaml") |
|||
else: |
|||
raise ValueError( |
|||
f"Invalid wbc_version: {self.wbc_version}, please use one of: " f"gear_wbc" |
|||
) |
|||
|
|||
with open(config_path) as file: |
|||
wbc_config = yaml.load(file, Loader=yaml.FullLoader) |
|||
|
|||
# Override yaml values with dataclass values |
|||
wbc_config = override_wbc_config(wbc_config, self) |
|||
|
|||
return wbc_config |
|||
|
|||
|
|||
@dataclass |
|||
class ControlLoopConfig(BaseConfig): |
|||
"""Config for running the G1 control loop.""" |
|||
|
|||
pass |
|||
|
|||
|
|||
@dataclass |
|||
class TeleopConfig(BaseConfig): |
|||
"""Config for running the G1 teleop policy loop.""" |
|||
|
|||
robot: Literal["g1"] = "g1" |
|||
"""Name of the robot to use, e.g., 'g1'.""" |
|||
|
|||
lerobot_replay_path: Optional[str] = None |
|||
"""Path to lerobot replay data.""" |
|||
|
|||
# Override defaults for teleop-specific values |
|||
body_streamer_ip: str = "10.110.67.24" |
|||
"""IP address for body streamer (vive only).""" |
|||
|
|||
body_streamer_keyword: str = "foot" |
|||
"""Keyword for body streamer (vive only).""" |
|||
|
|||
teleop_frequency: float = 20 # Override to be float instead of int |
|||
"""Frequency of the teleop loop.""" |
|||
|
|||
binary_hand_ik: bool = True |
|||
"""Whether to use binary IK.""" |
|||
|
|||
|
|||
@dataclass |
|||
class ComposedCameraClientConfig: |
|||
"""Config for running the composed camera client.""" |
|||
|
|||
camera_port: int = 5555 |
|||
"""Port number""" |
|||
|
|||
camera_host: str = "localhost" |
|||
"""Host IP address""" |
|||
|
|||
fps: float = 20.0 |
|||
"""FPS of the camera viewer""" |
|||
|
|||
|
|||
@dataclass |
|||
class DataExporterConfig(BaseConfig, ComposedCameraClientConfig): |
|||
"""Config for running the G1 data exporter.""" |
|||
|
|||
dataset_name: Optional[str] = None |
|||
"""Name of the dataset to save the data to. If the dataset already exists, |
|||
the new episodes will be appended to existing dataset. If the dataset does not exist, |
|||
episodes will be saved under root_output_dir/dataset_name. |
|||
""" |
|||
|
|||
task_prompt: str = "demo" |
|||
"""Language Task prompt for the dataset.""" |
|||
|
|||
state_dim: int = 43 |
|||
"""Size of the state.""" |
|||
|
|||
action_dim: int = 43 |
|||
"""Size of the action.""" |
|||
|
|||
teleoperator_username: Optional[str] = None |
|||
"""Teleoperator username.""" |
|||
|
|||
support_operator_username: Optional[str] = None |
|||
"""Support operator username.""" |
|||
|
|||
robot_id: Optional[str] = None |
|||
"""Robot ID.""" |
|||
|
|||
lower_body_policy: Optional[str] = None |
|||
"""Lower body policy.""" |
|||
|
|||
img_stream_viewer: bool = False |
|||
"""Whether to open a matplot lib window to view the camera images.""" |
|||
|
|||
text_to_speech: bool = True |
|||
"""Whether to use text-to-speech for voice feedback.""" |
|||
|
|||
add_stereo_camera: bool = True |
|||
"""Whether to add stereo camera for data collection. If False, only use a signle ego view camera.""" |
|||
|
|||
|
|||
@dataclass |
|||
class SyncSimDataCollectionConfig(ControlLoopConfig, TeleopConfig): |
|||
"""Args Config for running the data collection loop.""" |
|||
|
|||
robot: str = "G1" |
|||
"""Name of the robot to collect data for (e.g., G1 variants).""" |
|||
|
|||
task_name: str = "GroundOnly" |
|||
"""Name of the task to collect data for. [PnPBottle, GroundOnly, ...]""" |
|||
|
|||
body_control_device: str = "dummy" |
|||
"""Device to use for body control. Options: dummy, vive, iphone, leapmotion, joycon.""" |
|||
|
|||
hand_control_device: Optional[str] = "dummy" |
|||
"""Device to use for hand control. Options: None, manus, joycon, iphone.""" |
|||
|
|||
remove_existing_dir: bool = False |
|||
"""Whether to remove existing output directory if it exists.""" |
|||
|
|||
hardcode_teleop_cmd: bool = False |
|||
"""Whether to hardcode the teleop command for testing purposes.""" |
|||
|
|||
ik_indicator: bool = False |
|||
"""Whether to draw IK indicators.""" |
|||
|
|||
enable_onscreen: bool = True |
|||
"""Whether to show the onscreen rendering.""" |
|||
|
|||
save_img_obs: bool = False |
|||
"""Whether to save image observations.""" |
|||
|
|||
success_hold_steps: int = 50 |
|||
"""Number of steps to collect after task completion before saving.""" |
|||
|
|||
renderer: Literal["mjviewer", "mujoco", "rerun"] = "mjviewer" |
|||
"""Renderer to use for the environment. """ |
|||
|
|||
replay_data_path: str | None = None |
|||
"""Path to the data (.pkl) to replay. If None, will not replay. Used for CI/CD.""" |
|||
|
|||
replay_speed: float = 2.5 |
|||
"""Speed multiplier for replay data. Higher values make replay slower (e.g., 2.5 for sync sim tests).""" |
|||
|
|||
ci_test: bool = False |
|||
"""Whether to run the CI test.""" |
|||
|
|||
ci_test_mode: Literal["unit", "pre_merge"] = "pre_merge" |
|||
"""'unit' for fast 50-step tests, 'pre_merge' for 500-step test with tracking checks.""" |
|||
|
|||
manual_control: bool = False |
|||
"""Enable manual control of data collection start/save. When True, use toggle_data_collection |
|||
to manually control episode states (idle -> recording -> need_to_save -> idle). |
|||
When False (default), automatically starts and stops data collection based on task completion.""" |
|||
|
|||
|
|||
@dataclass |
|||
class SyncSimPlaybackConfig(SyncSimDataCollectionConfig): |
|||
"""Configuration class for playback script arguments.""" |
|||
|
|||
enable_real_device: bool = False |
|||
"""Whether to enable real device""" |
|||
|
|||
dataset: str | None = None |
|||
"""Path to the demonstration dataset, either an HDF5 file or a LeRobot folder path.""" |
|||
|
|||
use_actions: bool = False |
|||
"""Whether to use actions for playback""" |
|||
|
|||
use_wbc_goals: bool = False |
|||
"""Whether to use WBC goals for control""" |
|||
|
|||
use_teleop_cmd: bool = False |
|||
"""Whether to use teleop IK for action generation""" |
|||
|
|||
# Video recording arguments. |
|||
# Warning: enabling this key will leads to divergence between playback and recording. |
|||
save_video: bool = False |
|||
"""Whether to save video of the playback""" |
|||
|
|||
# Saving to LeRobot dataset. |
|||
# Warning: enabling this key will leads to divergence between playback and recording. |
|||
save_lerobot: bool = False |
|||
"""Whether to save the playback as a new LeRobot dataset""" |
|||
|
|||
video_path: str | None = None |
|||
"""Path to save the output video. If not specified, |
|||
will use the nearest folder to dataset and save as playback_video.mp4""" |
|||
|
|||
num_episodes: int = 1 |
|||
"""Number of episodes to load and playback/record (loads only the first N episodes from dataset)""" |
|||
|
|||
intervention: bool = False |
|||
"""Whether to denote intervention timesteps with colored borders in video frames""" |
|||
|
|||
ci_test: bool = False |
|||
"""Whether this is a CI test run, which limits the number of steps for testing purposes""" |
|||
|
|||
def validate_args(self): |
|||
# Validate argument combinations |
|||
if self.use_teleop_cmd and not self.use_actions: |
|||
raise ValueError("--use-teleop-cmd requires --use-actions to be set") |
|||
|
|||
# Note: using teleop cmd has playback divergence unlike using wbc goals, as TeleopPolicy has a warmup loop |
|||
if self.use_teleop_cmd and self.use_wbc_goals: |
|||
raise ValueError("--use-teleop-cmd and --use-wbc-goals are mutually exclusive") |
|||
|
|||
if (self.use_teleop_cmd or self.use_wbc_goals) and not self.use_actions: |
|||
raise ValueError( |
|||
"You are using --use-teleop-cmd or --use-wbc-goals but not --use-actions. " |
|||
"This will not play back actions whether via teleop or wbc goals. " |
|||
"Instead, it'll play back states only." |
|||
) |
|||
|
|||
if self.save_img_obs and not self.save_lerobot: |
|||
raise ValueError("--save-img-obs is only supported with --save-lerobot") |
|||
|
|||
if self.intervention and not self.save_video: |
|||
raise ValueError("--intervention requires --save-video to be enabled for visualization") |
|||
|
|||
|
|||
@dataclass |
|||
class WebcamRecorderConfig(BaseConfig): |
|||
"""Config for running the webcam recorder.""" |
|||
|
|||
output_dir: str = "logs_experiment" |
|||
"""Output directory for webcam recordings""" |
|||
|
|||
device_id: int = 0 |
|||
"""Camera device ID""" |
|||
|
|||
fps: int = 30 |
|||
"""Recording frame rate""" |
|||
|
|||
duration: Optional[int] = None |
|||
"""Recording duration in seconds (None for continuous)""" |
|||
|
|||
|
|||
@dataclass |
|||
class SimLoopConfig(BaseConfig): |
|||
"""Config for running the simulation loop.""" |
|||
|
|||
mp_start_method: str = "spawn" |
|||
"""Multiprocessing start method""" |
|||
|
|||
enable_image_publish: bool = False |
|||
"""Enable image publishing in simulation""" |
|||
|
|||
camera_port: int = 5555 |
|||
"""Camera port for image publishing""" |
|||
|
|||
verbose: bool = False |
|||
"""Verbose output, override the base config verbose""" |
|||
|
|||
|
|||
@dataclass |
|||
class DeploymentConfig(BaseConfig, ComposedCameraClientConfig): |
|||
"""G1 Robot Deployment Configuration |
|||
|
|||
Simplified deployment config that inherits all common fields from G1BaseConfig. |
|||
All deployment settings are now available in the base config. |
|||
""" |
|||
|
|||
camera_publish_rate: float = 30.0 |
|||
"""Camera publish rate (Hz)""" |
|||
|
|||
view_camera: bool = True |
|||
"""Enable camera viewer""" |
|||
# Webcam recording settings |
|||
enable_webcam_recording: bool = True |
|||
"""Enable webcam recording for real robot deployment monitoring""" |
|||
|
|||
webcam_output_dir: str = "logs_experiment" |
|||
"""Output directory for webcam recordings""" |
|||
|
|||
skip_img_transform: bool = False |
|||
"""Skip image transformation in the model (for faster internet)""" |
|||
|
|||
sim_in_single_process: bool = False |
|||
"""Run simulator in a separate process. When True, sets simulator to None in main control loop |
|||
and launches run_sim_loop.py separately.""" |
|||
|
|||
image_publish: bool = False |
|||
"""Enable image publishing in simulation loop (passed to run_sim_loop.py)""" |
|||
@ -0,0 +1,421 @@ |
|||
GEAR_WBC_CONFIG: "decoupled_wbc/sim2mujoco/resources/robots/g1/g1_gear_wbc.yaml" |
|||
|
|||
# copy from g1_43dof_hist.yaml |
|||
ROBOT_TYPE: 'g1_29dof' # Robot name, "go2", "b2", "b2w", "h1", "go2w", "g1" |
|||
ROBOT_SCENE: "decoupled_wbc/control/robot_model/model_data/g1/scene_43dof.xml" # Robot scene, for Sim2Sim |
|||
# ROBOT_SCENE: "decoupled_wbc/control/robot_model/model_data/g1/scene_29dof_activated3dex.xml" |
|||
|
|||
DOMAIN_ID: 0 # Domain id |
|||
# Network Interface, "lo" for simulation and the one with "192.168.123.222" for real robot |
|||
# INTERFACE: "enxe8ea6a9c4e09" |
|||
# INTERFACE: "enxc8a3623c9cb7" |
|||
INTERFACE: "lo" |
|||
SIMULATOR: "mujoco" # "robocasa" |
|||
|
|||
USE_JOYSTICK: 0 # Simulate Unitree WirelessController using a gamepad (0: disable, 1: enable) |
|||
JOYSTICK_TYPE: "xbox" # support "xbox" and "switch" gamepad layout |
|||
JOYSTICK_DEVICE: 0 # Joystick number |
|||
|
|||
FREE_BASE: False |
|||
|
|||
PRINT_SCENE_INFORMATION: True # Print link, joint and sensors information of robot |
|||
ENABLE_ELASTIC_BAND: True # Virtual spring band, used for lifting h1 |
|||
|
|||
SIMULATE_DT: 0.005 # Need to be larger than the runtime of viewer.sync() |
|||
VIEWER_DT: 0.02 # Viewer update time |
|||
REWARD_DT: 0.02 |
|||
USE_SENSOR: False |
|||
USE_HISTORY: True |
|||
USE_HISTORY_LOCO: True |
|||
USE_HISTORY_MIMIC: True |
|||
|
|||
GAIT_PERIOD: 0.9 # 1.25 |
|||
|
|||
MOTOR2JOINT: [0, 1, 2, 3, 4, 5, |
|||
6, 7, 8, 9, 10, 11, |
|||
12, 13, 14, |
|||
15, 16, 17, 18, 19, 20, 21, |
|||
22, 23, 24, 25, 26, 27, 28] |
|||
|
|||
JOINT2MOTOR: [0, 1, 2, 3, 4, 5, |
|||
6, 7, 8, 9, 10, 11, |
|||
12, 13, 14, |
|||
15, 16, 17, 18, 19, 20, 21, |
|||
22, 23, 24, 25, 26, 27, 28] |
|||
|
|||
|
|||
UNITREE_LEGGED_CONST: |
|||
HIGHLEVEL: 0xEE |
|||
LOWLEVEL: 0xFF |
|||
TRIGERLEVEL: 0xF0 |
|||
PosStopF: 2146000000.0 |
|||
VelStopF: 16000.0 |
|||
MODE_MACHINE: 5 |
|||
MODE_PR: 0 |
|||
|
|||
JOINT_KP: [ |
|||
100, 100, 100, 200, 20, 20, |
|||
100, 100, 100, 200, 20, 20, |
|||
400, 400, 400, |
|||
90, 60, 20, 60, 4, 4, 4, |
|||
90, 60, 20, 60, 4, 4, 4 |
|||
] |
|||
|
|||
|
|||
JOINT_KD: [ |
|||
2.5, 2.5, 2.5, 5, 0.2, 0.1, |
|||
2.5, 2.5, 2.5, 5, 0.2, 0.1, |
|||
5.0, 5.0, 5.0, |
|||
2.0, 1.0, 0.4, 1.0, 0.2, 0.2, 0.2, |
|||
2.0, 1.0, 0.4, 1.0, 0.2, 0.2, 0.2 |
|||
] |
|||
|
|||
# arm kp |
|||
# soft kp, safe, test it first |
|||
# 50, 50, 20, 20, 10, 10, 10 |
|||
# hard kp, use only if policy is safe |
|||
# 200, 200, 80, 80, 50, 50, 50, |
|||
|
|||
# MOTOR_KP: [ |
|||
# 100, 100, 100, 200, 20, 20, |
|||
# 100, 100, 100, 200, 20, 20, |
|||
# 400, 400, 400, |
|||
# 50, 50, 20, 20, 10, 10, 10, |
|||
# 50, 50, 20, 20, 10, 10, 10 |
|||
# ] |
|||
|
|||
MOTOR_KP: [ |
|||
150, 150, 150, 200, 40, 40, |
|||
150, 150, 150, 200, 40, 40, |
|||
250, 250, 250, |
|||
100, 100, 40, 40, 20, 20, 20, |
|||
100, 100, 40, 40, 20, 20, 20 |
|||
] |
|||
|
|||
MOTOR_KD: [ |
|||
2, 2, 2, 4, 2, 2, |
|||
2, 2, 2, 4, 2, 2, |
|||
5, 5, 5, |
|||
5, 5, 2, 2, 2, 2, 2, |
|||
5, 5, 2, 2, 2, 2, 2 |
|||
] |
|||
|
|||
# MOTOR_KP: [ |
|||
# 100, 100, 100, 200, 20, 20, |
|||
# 100, 100, 100, 200, 20, 20, |
|||
# 400, 400, 400, |
|||
# 90, 60, 20, 60, 4, 4, 4, |
|||
# 90, 60, 20, 60, 4, 4, 4 |
|||
# ] |
|||
|
|||
|
|||
# MOTOR_KD: [ |
|||
# 2.5, 2.5, 2.5, 5, 0.2, 0.1, |
|||
# 2.5, 2.5, 2.5, 5, 0.2, 0.1, |
|||
# 5.0, 5.0, 5.0, |
|||
# 2.0, 1.0, 0.4, 1.0, 0.2, 0.2, 0.2, |
|||
# 2.0, 1.0, 0.4, 1.0, 0.2, 0.2, 0.2 |
|||
# ] |
|||
|
|||
|
|||
WeakMotorJointIndex: |
|||
left_hip_yaw_joint: 0 |
|||
left_hip_roll_joint: 1 |
|||
left_hip_pitch_joint: 2 |
|||
left_knee_joint: 3 |
|||
left_ankle_pitch_joint: 4 |
|||
left_ankle_roll_joint: 5 |
|||
right_hip_yaw_joint: 6 |
|||
right_hip_roll_joint: 7 |
|||
right_hip_pitch_joint: 8 |
|||
right_knee_joint: 9 |
|||
right_ankle_pitch_joint: 10 |
|||
right_ankle_roll_joint: 11 |
|||
waist_yaw_joint : 12 |
|||
waist_roll_joint : 13 |
|||
waist_pitch_joint : 14 |
|||
left_shoulder_pitch_joint: 15 |
|||
left_shoulder_roll_joint: 16 |
|||
left_shoulder_yaw_joint: 17 |
|||
left_elbow_joint: 18 |
|||
left_wrist_roll_joint: 19 |
|||
left_wrist_pitch_joint: 20 |
|||
left_wrist_yaw_joint: 21 |
|||
right_shoulder_pitch_joint: 22 |
|||
right_shoulder_roll_joint: 23 |
|||
right_shoulder_yaw_joint: 24 |
|||
right_elbow_joint: 25 |
|||
right_wrist_roll_joint: 26 |
|||
right_wrist_pitch_joint: 27 |
|||
right_wrist_yaw_joint: 28 |
|||
|
|||
NUM_MOTORS: 29 |
|||
NUM_JOINTS: 29 |
|||
NUM_HAND_MOTORS: 7 |
|||
NUM_HAND_JOINTS: 7 |
|||
NUM_UPPER_BODY_JOINTS: 17 |
|||
|
|||
DEFAULT_DOF_ANGLES: [ |
|||
-0.1, # left_hip_pitch_joint |
|||
0.0, # left_hip_roll_joint |
|||
0.0, # left_hip_yaw_joint |
|||
0.3, # left_knee_joint |
|||
-0.2, # left_ankle_pitch_joint |
|||
0.0, # left_ankle_roll_joint |
|||
-0.1, # right_hip_pitch_joint |
|||
0.0, # right_hip_roll_joint |
|||
0.0, # right_hip_yaw_joint |
|||
0.3, # right_knee_joint |
|||
-0.2, # right_ankle_pitch_joint |
|||
0.0, # right_ankle_roll_joint |
|||
0.0, # waist_yaw_joint |
|||
0.0, # waist_roll_joint |
|||
0.0, # waist_pitch_joint |
|||
0.0, # left_shoulder_pitch_joint |
|||
0.0, # left_shoulder_roll_joint |
|||
0.0, # left_shoulder_yaw_joint |
|||
0.0, # left_elbow_joint |
|||
0.0, # left_wrist_roll_joint |
|||
0.0, # left_wrist_pitch_joint |
|||
0.0, # left_wrist_yaw_joint |
|||
0.0, # right_shoulder_pitch_joint |
|||
0.0, # right_shoulder_roll_joint |
|||
0.0, # right_shoulder_yaw_joint |
|||
0.0, # right_elbow_joint |
|||
0.0, # right_wrist_roll_joint |
|||
0.0, # right_wrist_pitch_joint |
|||
0.0 # right_wrist_yaw_joint |
|||
] |
|||
|
|||
DEFAULT_MOTOR_ANGLES: [ |
|||
-0.1, # left_hip_pitch_joint |
|||
0.0, # left_hip_roll_joint |
|||
0.0, # left_hip_yaw_joint |
|||
0.3, # left_knee_joint |
|||
-0.2, # left_ankle_pitch_joint |
|||
0.0, # left_ankle_roll_joint |
|||
-0.1, # right_hip_pitch_joint |
|||
0.0, # right_hip_roll_joint |
|||
0.0, # right_hip_yaw_joint |
|||
0.3, # right_knee_joint |
|||
-0.2, # right_ankle_pitch_joint |
|||
0.0, # right_ankle_roll_joint |
|||
0.0, # waist_yaw_joint |
|||
0.0, # waist_roll_joint |
|||
0.0, # waist_pitch_joint |
|||
0.0, # left_shoulder_pitch_joint |
|||
0.0, # left_shoulder_roll_joint |
|||
0.0, # left_shoulder_yaw_joint |
|||
0.0, # left_elbow_joint |
|||
0.0, # left_wrist_roll_joint |
|||
0.0, # left_wrist_pitch_joint |
|||
0.0, # left_wrist_yaw_joint |
|||
0.0, # right_shoulder_pitch_joint |
|||
0.0, # right_shoulder_roll_joint |
|||
0.0, # right_shoulder_yaw_joint |
|||
0.0, # right_elbow_joint |
|||
0.0, # right_wrist_roll_joint |
|||
0.0, # right_wrist_pitch_joint |
|||
0.0 # right_wrist_yaw_joint |
|||
] |
|||
|
|||
motor_pos_lower_limit_list: [-2.5307, -0.5236, -2.7576, -0.087267, -0.87267, -0.2618, |
|||
-2.5307, -2.9671, -2.7576, -0.087267, -0.87267, -0.2618, |
|||
-2.618, -0.52, -0.52, |
|||
-3.0892, -1.5882, -2.618, -1.0472, |
|||
-1.972222054, -1.61443, -1.61443, |
|||
-3.0892, -2.2515, -2.618, -1.0472, |
|||
-1.972222054, -1.61443, -1.61443] |
|||
motor_pos_upper_limit_list: [2.8798, 2.9671, 2.7576, 2.8798, 0.5236, 0.2618, |
|||
2.8798, 0.5236, 2.7576, 2.8798, 0.5236, 0.2618, |
|||
2.618, 0.52, 0.52, |
|||
2.6704, 2.2515, 2.618, 2.0944, |
|||
1.972222054, 1.61443, 1.61443, |
|||
2.6704, 1.5882, 2.618, 2.0944, |
|||
1.972222054, 1.61443, 1.61443] |
|||
motor_vel_limit_list: [32.0, 32.0, 32.0, 20.0, 37.0, 37.0, |
|||
32.0, 32.0, 32.0, 20.0, 37.0, 37.0, |
|||
32.0, 37.0, 37.0, |
|||
37.0, 37.0, 37.0, 37.0, |
|||
37.0, 22.0, 22.0, |
|||
37.0, 37.0, 37.0, 37.0, |
|||
37.0, 22.0, 22.0] |
|||
motor_effort_limit_list: [88.0, 88.0, 88.0, 139.0, 50.0, 50.0, |
|||
88.0, 88.0, 88.0, 139.0, 50.0, 50.0, |
|||
88.0, 50.0, 50.0, |
|||
25.0, 25.0, 25.0, 25.0, |
|||
25.0, 5.0, 5.0, |
|||
2.45, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7, |
|||
25.0, 25.0, 25.0, 25.0, |
|||
25.0, 5.0, 5.0, |
|||
2.45, 0.7, 0.7, 0.7, 0.7, 0.7, 0.7] |
|||
history_config: { |
|||
base_ang_vel: 4, |
|||
projected_gravity: 4, |
|||
command_lin_vel: 4, |
|||
command_ang_vel: 4, |
|||
command_base_height: 4, |
|||
command_stand: 4, |
|||
ref_upper_dof_pos: 4, |
|||
dof_pos: 4, |
|||
dof_vel: 4, |
|||
actions: 4, |
|||
# phase_time: 4, |
|||
ref_motion_phase: 4, |
|||
sin_phase: 4, |
|||
cos_phase: 4 |
|||
} |
|||
history_loco_config: { |
|||
base_ang_vel: 4, |
|||
projected_gravity: 4, |
|||
command_lin_vel: 4, |
|||
command_ang_vel: 4, |
|||
# command_base_height: 4, |
|||
command_stand: 4, |
|||
ref_upper_dof_pos: 4, |
|||
dof_pos: 4, |
|||
dof_vel: 4, |
|||
actions: 4, |
|||
# phase_time: 4, |
|||
sin_phase: 4, |
|||
cos_phase: 4 |
|||
} |
|||
history_loco_height_config: { |
|||
base_ang_vel: 4, |
|||
projected_gravity: 4, |
|||
command_lin_vel: 4, |
|||
command_ang_vel: 4, |
|||
command_base_height: 4, |
|||
command_stand: 4, |
|||
ref_upper_dof_pos: 4, |
|||
dof_pos: 4, |
|||
dof_vel: 4, |
|||
actions: 4, |
|||
# phase_time: 4, |
|||
sin_phase: 4, |
|||
cos_phase: 4 |
|||
} |
|||
history_mimic_config: { |
|||
base_ang_vel: 4, |
|||
projected_gravity: 4, |
|||
dof_pos: 4, |
|||
dof_vel: 4, |
|||
actions: 4, |
|||
ref_motion_phase: 4, |
|||
} |
|||
obs_dims: { |
|||
base_lin_vel: 3, |
|||
base_ang_vel: 3, |
|||
projected_gravity: 3, |
|||
command_lin_vel: 2, |
|||
command_ang_vel: 1, |
|||
command_stand: 1, |
|||
command_base_height: 1, |
|||
ref_upper_dof_pos: 17, # upper body actions |
|||
dof_pos: 29, |
|||
dof_vel: 29, |
|||
# actions: 12, # lower body actions |
|||
actions: 29, # full body actions |
|||
phase_time: 1, |
|||
ref_motion_phase: 1, # mimic motion phase |
|||
sin_phase: 1, |
|||
cos_phase: 1, |
|||
} |
|||
obs_loco_dims: { |
|||
base_lin_vel: 3, |
|||
base_ang_vel: 3, |
|||
projected_gravity: 3, |
|||
command_lin_vel: 2, |
|||
command_ang_vel: 1, |
|||
command_stand: 1, |
|||
command_base_height: 1, |
|||
ref_upper_dof_pos: 17, # upper body actions |
|||
dof_pos: 29, |
|||
dof_vel: 29, |
|||
actions: 12, # lower body actions |
|||
phase_time: 1, |
|||
sin_phase: 1, |
|||
cos_phase: 1, |
|||
} |
|||
obs_mimic_dims: { |
|||
base_lin_vel: 3, |
|||
base_ang_vel: 3, |
|||
projected_gravity: 3, |
|||
dof_pos: 29, |
|||
dof_vel: 29, |
|||
actions: 29, # full body actions |
|||
ref_motion_phase: 1, # mimic motion phase |
|||
} |
|||
obs_scales: { |
|||
base_lin_vel: 2.0, |
|||
base_ang_vel: 0.25, |
|||
projected_gravity: 1.0, |
|||
command_lin_vel: 1, |
|||
command_ang_vel: 1, |
|||
command_stand: 1, |
|||
command_base_height: 2, # Yuanhang: it's 2, not 1! |
|||
ref_upper_dof_pos: 1.0, |
|||
dof_pos: 1.0, |
|||
dof_vel: 0.05, |
|||
history: 1.0, |
|||
history_loco: 1.0, |
|||
history_mimic: 1.0, |
|||
actions: 1.0, |
|||
phase_time: 1.0, |
|||
ref_motion_phase: 1.0, |
|||
sin_phase: 1.0, |
|||
cos_phase: 1.0 |
|||
} |
|||
|
|||
loco_upper_body_dof_pos: [ |
|||
0.0, 0.0, 0.0, # waist |
|||
0.0, 0.3, 0.0, 1.0, # left shoulder and elbow |
|||
0.0, 0.0, 0.0, # left wrist |
|||
0.0, -0.3, 0.0, 1.0, # right shoulder and elbow |
|||
0.0, 0.0, 0.0 # right wrist |
|||
] |
|||
|
|||
robot_dofs: { |
|||
"g1_29dof": [1, 1, 1, 1, 1, 1, |
|||
1, 1, 1, 1, 1, 1, |
|||
1, 1, 1, |
|||
1, 1, 1, 1, 1, 1, 1, |
|||
1, 1, 1, 1, 1, 1, 1], |
|||
"g1_29dof_anneal_23dof": [1, 1, 1, 1, 1, 1, |
|||
1, 1, 1, 1, 1, 1, |
|||
1, 1, 1, |
|||
1, 1, 1, 1, 0, 0, 0, |
|||
1, 1, 1, 1, 0, 0, 0], |
|||
} |
|||
|
|||
mimic_robot_types: { |
|||
|
|||
"APT_level1": "g1_29dof_anneal_23dof", |
|||
} |
|||
|
|||
|
|||
|
|||
|
|||
# 01281657 |
|||
mimic_models: { |
|||
"APT_level1": "20250116_225127-TairanTestbed_G129dofANNEAL23dof_dm_APT_video_APT_level1_MinimalFriction-0.3_RfiTrue_Far0.325_RESUME_LARGENOISE-motion_tracking-g1_29dof_anneal_23dof/exported/model_176500.onnx", |
|||
|
|||
} |
|||
|
|||
|
|||
|
|||
start_upper_body_dof_pos: { |
|||
|
|||
"APT_level1": |
|||
[0.19964170455932617, 0.07710712403059006, -0.2882401943206787, |
|||
0.21672365069389343, 0.15629297494888306, -0.5167576670646667, 0.5782126784324646, |
|||
0.0, 0.0, 0.0, |
|||
0.25740593671798706, -0.2504104673862457, 0.22500675916671753, 0.5127624273300171, |
|||
0.0, 0.0, 0.0], |
|||
|
|||
} |
|||
|
|||
motion_length_s: { |
|||
"APT_level1": 7.66, |
|||
|
|||
} |
|||
@ -0,0 +1,627 @@ |
|||
""" |
|||
A convenience script to playback random demonstrations using the decoupled_wbc controller from |
|||
a set of demonstrations stored in a hdf5 file. |
|||
|
|||
Arguments: |
|||
--dataset (str): Path to demonstrations |
|||
--use-actions (optional): If this flag is provided, the actions are played back |
|||
through the MuJoCo simulator, instead of loading the simulator states |
|||
one by one. |
|||
--use-wbc-goals (optional): If set, will use the stored WBC goals to control the robot, |
|||
otherwise will use the actions directly. Only relevant if --use-actions is set. |
|||
--use-teleop-cmd (optional): If set, will use teleop IK directly with WBC timing |
|||
for action generation. Only relevant if --use-actions is set. |
|||
--visualize-gripper (optional): If set, will visualize the gripper site |
|||
--save-video (optional): If set, will save video of the playback using offscreen rendering |
|||
--video-path (optional): Path to save the output video. If not specified, will use the nearest |
|||
folder to dataset and save as playback_video.mp4 |
|||
--num-episodes (optional): Number of episodes to playback/record (if None, plays random episodes) |
|||
|
|||
Example: |
|||
$ python decoupled_wbc/control/main/teleop/playback_sync_sim_data.py --dataset output/robocasa_datasets/ |
|||
--use-actions --use-wbc-goals |
|||
|
|||
$ python decoupled_wbc/control/main/teleop/playback_sync_sim_data.py --dataset output/robocasa_datasets/ |
|||
--use-actions --use-teleop-cmd |
|||
|
|||
# Record video of the first 5 episodes using WBC goals |
|||
$ python decoupled_wbc/control/main/teleop/playback_sync_sim_data.py --dataset output/robocasa_datasets/ |
|||
--use-actions --use-wbc-goals --save-video --num-episodes 5 |
|||
""" |
|||
|
|||
import json |
|||
import os |
|||
from pathlib import Path |
|||
import time |
|||
from typing import Optional |
|||
|
|||
import cv2 |
|||
import numpy as np |
|||
import rclpy |
|||
from robosuite.environments.robot_env import RobotEnv |
|||
from tqdm import tqdm |
|||
import tyro |
|||
|
|||
from decoupled_wbc.control.main.teleop.configs.configs import SyncSimPlaybackConfig |
|||
from decoupled_wbc.control.robot_model.instantiation import get_robot_type_and_model |
|||
from decoupled_wbc.control.utils.sync_sim_utils import ( |
|||
generate_frame, |
|||
get_data_exporter, |
|||
get_env, |
|||
get_policies, |
|||
) |
|||
from decoupled_wbc.data.constants import RS_VIEW_CAMERA_HEIGHT, RS_VIEW_CAMERA_WIDTH |
|||
from decoupled_wbc.data.exporter import TypedLeRobotDataset |
|||
|
|||
CONTROL_NODE_NAME = "playback_node" |
|||
GREEN_BOLD = "\033[1;32m" |
|||
RED_BOLD = "\033[1;31m" |
|||
RESET = "\033[0m" |
|||
|
|||
|
|||
def load_lerobot_dataset(root_path, max_episodes=None): |
|||
task_name = None |
|||
episodes = [] |
|||
start_index = 0 |
|||
with open(Path(root_path) / "meta/episodes.jsonl", "r") as f: |
|||
for line in f: |
|||
episode = json.loads(line) |
|||
episode["start_index"] = start_index |
|||
start_index += episode["length"] |
|||
assert ( |
|||
task_name is None or task_name == episode["tasks"][0] |
|||
), "All episodes should have the same task name" |
|||
task_name = episode["tasks"][0] |
|||
episodes.append(episode) |
|||
|
|||
dataset = TypedLeRobotDataset( |
|||
repo_id="tmp/test", |
|||
root=root_path, |
|||
load_video=False, |
|||
) |
|||
|
|||
script_config = dataset.meta.info["script_config"] |
|||
|
|||
assert len(dataset) == start_index, "Dataset length does not match expected length" |
|||
|
|||
# Limit episodes if specified |
|||
if max_episodes is not None: |
|||
episodes = episodes[:max_episodes] |
|||
print( |
|||
f"Loading only first {len(episodes)} episodes (limited by max_episodes={max_episodes})" |
|||
) |
|||
|
|||
f = {} |
|||
seeds = [] |
|||
for ep in tqdm(range(len(episodes))): |
|||
seed = None |
|||
f[f"data/demo_{ep + 1}/states"] = [] |
|||
f[f"data/demo_{ep + 1}/actions"] = [] |
|||
f[f"data/demo_{ep + 1}/teleop_cmd"] = [] |
|||
f[f"data/demo_{ep + 1}/wbc_goal"] = [] |
|||
start_index = episodes[ep]["start_index"] |
|||
end_index = start_index + episodes[ep]["length"] |
|||
for i in tqdm(range(start_index, end_index)): |
|||
frame = dataset[i] |
|||
# load the seed |
|||
assert ( |
|||
seed is None or seed == np.array(frame["observation.sim.seed"]).item() |
|||
), "All observations in an episode should have the same seed" |
|||
seed = np.array(frame["observation.sim.seed"]).item() |
|||
# load the state |
|||
mujoco_state_len = frame["observation.sim.mujoco_state_len"] |
|||
mujoco_state = frame["observation.sim.mujoco_state"] |
|||
f[f"data/demo_{ep + 1}/states"].append(np.array(mujoco_state[:mujoco_state_len])) |
|||
# load the action |
|||
action = frame["action"] |
|||
f[f"data/demo_{ep + 1}/actions"].append(np.array(action)) |
|||
|
|||
# load the teleop command |
|||
teleop_cmd = { |
|||
"left_wrist": np.array(frame["observation.sim.left_wrist"].reshape(4, 4)), |
|||
"right_wrist": np.array(frame["observation.sim.right_wrist"].reshape(4, 4)), |
|||
"left_fingers": { |
|||
"position": np.array(frame["observation.sim.left_fingers"].reshape(25, 4, 4)), |
|||
}, |
|||
"right_fingers": { |
|||
"position": np.array(frame["observation.sim.right_fingers"].reshape(25, 4, 4)), |
|||
}, |
|||
"target_upper_body_pose": np.array(frame["observation.sim.target_upper_body_pose"]), |
|||
"base_height_command": np.array(frame["teleop.base_height_command"]), |
|||
"navigate_cmd": np.array(frame["teleop.navigate_command"]), |
|||
} |
|||
f[f"data/demo_{ep + 1}/teleop_cmd"].append(teleop_cmd) |
|||
# load the WBC goal |
|||
wbc_goal = { |
|||
"wrist_pose": np.array(frame["action.eef"]), |
|||
"target_upper_body_pose": np.array(frame["observation.sim.target_upper_body_pose"]), |
|||
"navigate_cmd": np.array(frame["teleop.navigate_command"]), |
|||
"base_height_command": np.array(frame["teleop.base_height_command"]), |
|||
} |
|||
f[f"data/demo_{ep + 1}/wbc_goal"].append(wbc_goal) |
|||
|
|||
seeds.append(seed) |
|||
|
|||
return seeds, f, script_config |
|||
|
|||
|
|||
def validate_state(recorded_state, playback_state, ep, step, tolerance=1e-5): |
|||
"""Validate that playback state matches recorded state within tolerance.""" |
|||
if not np.allclose(recorded_state, playback_state, atol=tolerance): |
|||
err = np.linalg.norm(recorded_state - playback_state) |
|||
print(f"[warning] state diverged by {err:.12f} for ep {ep} at step {step}") |
|||
return False |
|||
return True |
|||
|
|||
|
|||
def generate_and_save_frame( |
|||
config, sync_env, obs, wbc_action, seed, teleop_cmd, wbc_goal, gr00t_exporter |
|||
): |
|||
"""Generate and save a frame to LeRobot dataset if enabled.""" |
|||
if config.save_lerobot: |
|||
max_mujoco_state_len, mujoco_state_len, mujoco_state = sync_env.get_mujoco_state_info() |
|||
frame = generate_frame( |
|||
obs, |
|||
wbc_action, |
|||
seed, |
|||
mujoco_state, |
|||
mujoco_state_len, |
|||
max_mujoco_state_len, |
|||
teleop_cmd, |
|||
wbc_goal, |
|||
config.save_img_obs, |
|||
) |
|||
gr00t_exporter.add_frame(frame) |
|||
|
|||
|
|||
def playback_wbc_goals( |
|||
sync_env, |
|||
wbc_policy, |
|||
wbc_goals, |
|||
teleop_cmds, |
|||
states, |
|||
env, |
|||
onscreen, |
|||
config, |
|||
video_writer, |
|||
ep, |
|||
seed, |
|||
gr00t_exporter, |
|||
end_steps, |
|||
): |
|||
"""Playback using WBC goals to control the robot.""" |
|||
ret = True |
|||
num_wbc_goals = len(wbc_goals) if end_steps == -1 else min(end_steps, len(wbc_goals)) |
|||
|
|||
for jj in range(num_wbc_goals): |
|||
wbc_goal = wbc_goals[jj] |
|||
obs = sync_env.observe() |
|||
wbc_policy.set_observation(obs) |
|||
wbc_policy.set_goal(wbc_goal) |
|||
wbc_action = wbc_policy.get_action() |
|||
sync_env.queue_action(wbc_action) |
|||
|
|||
# Save frame if needed |
|||
if config.save_lerobot: |
|||
teleop_cmd = teleop_cmds[jj] |
|||
generate_and_save_frame( |
|||
config, sync_env, obs, wbc_action, seed, teleop_cmd, wbc_goal, gr00t_exporter |
|||
) |
|||
|
|||
capture_or_render_frame(env, onscreen, config, video_writer) |
|||
|
|||
if jj < len(states) - 1: |
|||
state_playback = env.sim.get_state().flatten() |
|||
if not validate_state(states[jj + 1], state_playback, ep, jj): |
|||
ret = False |
|||
|
|||
return ret |
|||
|
|||
|
|||
def playback_teleop_cmd( |
|||
sync_env, |
|||
wbc_policy, |
|||
teleop_policy, |
|||
wbc_goals, |
|||
teleop_cmds, |
|||
states, |
|||
env, |
|||
onscreen, |
|||
config, |
|||
video_writer, |
|||
ep, |
|||
seed, |
|||
gr00t_exporter, |
|||
end_steps, |
|||
): |
|||
"""Playback using teleop commands to control the robot.""" |
|||
ret = True |
|||
num_steps = len(wbc_goals) if end_steps == -1 else min(end_steps, len(wbc_goals)) |
|||
|
|||
for jj in range(num_steps): |
|||
wbc_goal = wbc_goals[jj] |
|||
teleop_cmd = teleop_cmds[jj] |
|||
|
|||
# Set IK goal from teleop command |
|||
ik_data = { |
|||
"body_data": { |
|||
teleop_policy.retargeting_ik.body.supplemental_info.hand_frame_names[ |
|||
"left" |
|||
]: teleop_cmd["left_wrist"], |
|||
teleop_policy.retargeting_ik.body.supplemental_info.hand_frame_names[ |
|||
"right" |
|||
]: teleop_cmd["right_wrist"], |
|||
}, |
|||
"left_hand_data": teleop_cmd["left_fingers"], |
|||
"right_hand_data": teleop_cmd["right_fingers"], |
|||
} |
|||
teleop_policy.retargeting_ik.set_goal(ik_data) |
|||
|
|||
# Store original and get new upper body pose |
|||
target_upper_body_pose = wbc_goal["target_upper_body_pose"].copy() |
|||
wbc_goal["target_upper_body_pose"] = teleop_policy.retargeting_ik.get_action() |
|||
|
|||
# Execute WBC policy |
|||
obs = sync_env.observe() |
|||
wbc_policy.set_observation(obs) |
|||
wbc_policy.set_goal(wbc_goal) |
|||
wbc_action = wbc_policy.get_action() |
|||
sync_env.queue_action(wbc_action) |
|||
|
|||
# Save frame if needed |
|||
generate_and_save_frame( |
|||
config, sync_env, obs, wbc_action, seed, teleop_cmd, wbc_goal, gr00t_exporter |
|||
) |
|||
|
|||
# Render or capture frame |
|||
capture_or_render_frame(env, onscreen, config, video_writer) |
|||
|
|||
# Validate states |
|||
if jj < len(states) - 1: |
|||
if not np.allclose( |
|||
target_upper_body_pose, wbc_goal["target_upper_body_pose"], atol=1e-5 |
|||
): |
|||
err = np.linalg.norm(target_upper_body_pose - wbc_goal["target_upper_body_pose"]) |
|||
print( |
|||
f"[warning] target_upper_body_pose diverged by {err:.12f} for ep {ep} at step {jj}" |
|||
) |
|||
ret = False |
|||
|
|||
state_playback = env.sim.get_state().flatten() |
|||
if not validate_state(states[jj + 1], state_playback, ep, jj): |
|||
ret = False |
|||
|
|||
return ret |
|||
|
|||
|
|||
def playback_actions( |
|||
sync_env, |
|||
actions, |
|||
teleop_cmds, |
|||
wbc_goals, |
|||
states, |
|||
env, |
|||
onscreen, |
|||
config, |
|||
video_writer, |
|||
ep, |
|||
seed, |
|||
gr00t_exporter, |
|||
end_steps, |
|||
): |
|||
"""Playback using actions directly.""" |
|||
ret = True |
|||
num_actions = len(actions) if end_steps == -1 else min(end_steps, len(actions)) |
|||
|
|||
for j in range(num_actions): |
|||
sync_env.queue_action({"q": actions[j]}) |
|||
|
|||
# Save frame if needed |
|||
if config.save_lerobot: |
|||
obs = sync_env.observe() |
|||
teleop_cmd = teleop_cmds[j] |
|||
wbc_goal = wbc_goals[j] |
|||
wbc_action = {"q": actions[j]} |
|||
generate_and_save_frame( |
|||
config, sync_env, obs, wbc_action, seed, teleop_cmd, wbc_goal, gr00t_exporter |
|||
) |
|||
|
|||
capture_or_render_frame(env, onscreen, config, video_writer) |
|||
|
|||
if j < len(states) - 1: |
|||
state_playback = env.sim.get_state().flatten() |
|||
if not validate_state(states[j + 1], state_playback, ep, j): |
|||
ret = False |
|||
|
|||
return ret |
|||
|
|||
|
|||
def playback_states( |
|||
sync_env, |
|||
states, |
|||
actions, |
|||
teleop_cmds, |
|||
wbc_goals, |
|||
env, |
|||
onscreen, |
|||
config, |
|||
video_writer, |
|||
seed, |
|||
gr00t_exporter, |
|||
end_steps, |
|||
ep, |
|||
): |
|||
"""Playback by forcing mujoco states directly.""" |
|||
ret = True |
|||
num_states = len(states) if end_steps == -1 else min(end_steps, len(states)) |
|||
|
|||
for i in range(num_states): |
|||
sync_env.reset_to({"states": states[i]}) |
|||
sync_env.render() |
|||
|
|||
# Validate that the state was set correctly |
|||
if i < len(states): |
|||
state_playback = env.sim.get_state().flatten() |
|||
if not validate_state(states[i], state_playback, ep, i): |
|||
ret = False |
|||
|
|||
# Save frame if needed |
|||
if config.save_lerobot: |
|||
obs = sync_env.observe() |
|||
teleop_cmd = teleop_cmds[i] |
|||
wbc_goal = wbc_goals[i] |
|||
wbc_action = {"q": actions[i]} |
|||
generate_and_save_frame( |
|||
config, sync_env, obs, wbc_action, seed, teleop_cmd, wbc_goal, gr00t_exporter |
|||
) |
|||
|
|||
capture_or_render_frame(env, onscreen, config, video_writer) |
|||
|
|||
return ret |
|||
|
|||
|
|||
def main(config: SyncSimPlaybackConfig): |
|||
ret = True |
|||
start_time = time.time() |
|||
|
|||
np.set_printoptions(precision=5, suppress=True, linewidth=120) |
|||
|
|||
assert config.dataset is not None, "Folder must be specified for playback" |
|||
|
|||
seeds, f, script_config = load_lerobot_dataset(config.dataset) |
|||
|
|||
config.update( |
|||
script_config, |
|||
allowed_keys=[ |
|||
"wbc_version", |
|||
"wbc_model_path", |
|||
"wbc_policy_class", |
|||
"control_frequency", |
|||
"enable_waist", |
|||
"with_hands", |
|||
"env_name", |
|||
"robot", |
|||
"task_name", |
|||
"teleop_frequency", |
|||
"data_collection_frequency", |
|||
"enable_gravity_compensation", |
|||
"gravity_compensation_joints", |
|||
], |
|||
) |
|||
config.validate_args() |
|||
|
|||
robot_type, robot_model = get_robot_type_and_model(config.robot, config.enable_waist) |
|||
|
|||
# Setup rendering |
|||
if config.save_video or config.save_img_obs: |
|||
onscreen = False |
|||
offscreen = True |
|||
else: |
|||
onscreen = True |
|||
offscreen = False |
|||
|
|||
# Set default video path if not specified |
|||
if config.save_video and config.video_path is None: |
|||
if os.path.isfile(config.dataset): |
|||
video_folder = Path(config.dataset).parent |
|||
else: |
|||
video_folder = Path(config.dataset) |
|||
video_folder.mkdir(parents=True, exist_ok=True) |
|||
config.video_path = str(video_folder / "playback_video.mp4") |
|||
print(f"Video recording enabled. Output: {config.video_path}") |
|||
|
|||
sync_env = get_env(config, onscreen=onscreen, offscreen=offscreen) |
|||
|
|||
gr00t_exporter = None |
|||
if config.save_lerobot: |
|||
obs = sync_env.observe() |
|||
gr00t_exporter = get_data_exporter(config, obs, robot_model) |
|||
|
|||
# Initialize policies |
|||
wbc_policy, teleop_policy = get_policies( |
|||
config, robot_type, robot_model, activate_keyboard_listener=False |
|||
) |
|||
|
|||
# List of all demonstrations episodes |
|||
demos = [f"demo_{i + 1}" for i in range(len(seeds))] |
|||
print(f"Loaded and will playback {len(demos)} episodes") |
|||
env = sync_env.base_env |
|||
|
|||
# Setup video writer |
|||
video_writer = None |
|||
fourcc = None |
|||
if config.save_video: |
|||
fourcc = cv2.VideoWriter_fourcc(*"mp4v") |
|||
video_writer = cv2.VideoWriter( |
|||
config.video_path, fourcc, 20, (RS_VIEW_CAMERA_WIDTH, RS_VIEW_CAMERA_HEIGHT) |
|||
) |
|||
|
|||
print("Loaded {} episodes from {}".format(len(demos), config.dataset)) |
|||
print("seeds:", seeds) |
|||
print("demos:", demos, "\n\n") |
|||
|
|||
# Handle episode selection - either limited number or infinite random |
|||
max_episodes = len(demos) |
|||
episode_count = 0 |
|||
while True: |
|||
if episode_count >= max_episodes: |
|||
break |
|||
ep = demos[episode_count] |
|||
print(f"Playing back episode: {ep}") |
|||
episode_count += 1 |
|||
|
|||
# read the model xml, using the metadata stored in the attribute for this episode |
|||
seed = seeds[int(ep.split("_")[-1]) - 1] |
|||
sync_env.reset(seed=seed) |
|||
|
|||
# load the actions and states |
|||
states = f["data/{}/states".format(ep)] |
|||
actions = f["data/{}/actions".format(ep)] |
|||
teleop_cmds = f["data/{}/teleop_cmd".format(ep)] |
|||
wbc_goals = f["data/{}/wbc_goal".format(ep)] |
|||
|
|||
# reset the policies |
|||
wbc_policy, teleop_policy, _ = get_policies( |
|||
config, robot_type, robot_model, activate_keyboard_listener=False |
|||
) |
|||
end_steps = 20 if config.ci_test else -1 |
|||
|
|||
if config.use_actions: |
|||
# load the initial state |
|||
sync_env.reset_to({"states": states[0]}) |
|||
# load the actions and play them back open-loop |
|||
if config.use_wbc_goals: |
|||
# use the wbc_goals to control the robot |
|||
episode_ret = playback_wbc_goals( |
|||
sync_env, |
|||
wbc_policy, |
|||
wbc_goals, |
|||
teleop_cmds, |
|||
states, |
|||
env, |
|||
onscreen, |
|||
config, |
|||
video_writer, |
|||
ep, |
|||
seed, |
|||
gr00t_exporter, |
|||
end_steps, |
|||
) |
|||
ret = ret and episode_ret |
|||
elif config.use_teleop_cmd: |
|||
# use the teleop commands to control the robot |
|||
episode_ret = playback_teleop_cmd( |
|||
sync_env, |
|||
wbc_policy, |
|||
teleop_policy, |
|||
wbc_goals, |
|||
teleop_cmds, |
|||
states, |
|||
env, |
|||
onscreen, |
|||
config, |
|||
video_writer, |
|||
ep, |
|||
seed, |
|||
gr00t_exporter, |
|||
end_steps, |
|||
) |
|||
ret = ret and episode_ret |
|||
else: |
|||
episode_ret = playback_actions( |
|||
sync_env, |
|||
actions, |
|||
teleop_cmds, |
|||
wbc_goals, |
|||
states, |
|||
env, |
|||
onscreen, |
|||
config, |
|||
video_writer, |
|||
ep, |
|||
seed, |
|||
gr00t_exporter, |
|||
end_steps, |
|||
) |
|||
ret = ret and episode_ret |
|||
else: |
|||
# force the sequence of internal mujoco states one by one |
|||
episode_ret = playback_states( |
|||
sync_env, |
|||
states, |
|||
actions, |
|||
teleop_cmds, |
|||
wbc_goals, |
|||
env, |
|||
onscreen, |
|||
config, |
|||
video_writer, |
|||
seed, |
|||
gr00t_exporter, |
|||
end_steps, |
|||
ep, |
|||
) |
|||
ret = ret and episode_ret |
|||
|
|||
if config.save_lerobot: |
|||
gr00t_exporter.save_episode() |
|||
|
|||
print(f"Episode {ep} playback finished.\n\n") |
|||
|
|||
# close the env |
|||
sync_env.close() |
|||
|
|||
# Cleanup |
|||
if video_writer is not None: |
|||
video_writer.release() |
|||
print(f"Video saved to: {config.video_path}") |
|||
|
|||
end_time = time.time() |
|||
elapsed_time = end_time - start_time |
|||
|
|||
if config.save_lerobot: |
|||
print(f"LeRobot dataset saved to: {gr00t_exporter.root}") |
|||
|
|||
print( |
|||
f"{GREEN_BOLD}Playback with WBC version: {config.wbc_version}, {config.wbc_model_path}, " |
|||
f"{config.wbc_policy_class}, use_actions: {config.use_actions}, use_wbc_goals: {config.use_wbc_goals}, " |
|||
f"use_teleop_cmd: {config.use_teleop_cmd}{RESET}" |
|||
) |
|||
if ret: |
|||
print(f"{GREEN_BOLD}Playback completed successfully in {elapsed_time:.2f} seconds!{RESET}") |
|||
else: |
|||
print(f"{RED_BOLD}Playback encountered an error in {elapsed_time:.2f} seconds!{RESET}") |
|||
|
|||
return ret |
|||
|
|||
|
|||
def capture_or_render_frame( |
|||
env: RobotEnv, |
|||
onscreen: bool, |
|||
config: SyncSimPlaybackConfig, |
|||
video_writer: Optional[cv2.VideoWriter], |
|||
): |
|||
"""Capture frame for video recording if enabled, or render the environment.""" |
|||
if config.save_video: |
|||
if hasattr(env, "sim") and hasattr(env.sim, "render"): |
|||
img = env.sim.render( |
|||
width=RS_VIEW_CAMERA_WIDTH, |
|||
height=RS_VIEW_CAMERA_HEIGHT, |
|||
camera_name=env.render_camera[0], |
|||
) |
|||
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|||
img_bgr = np.flipud(img_bgr) |
|||
video_writer.write(img_bgr) |
|||
elif onscreen: |
|||
env.render() |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
config = tyro.cli(SyncSimPlaybackConfig) |
|||
|
|||
rclpy.init(args=None) |
|||
node = rclpy.create_node("playback_decoupled_wbc_control") |
|||
|
|||
main(config) |
|||
|
|||
rclpy.shutdown() |
|||
@ -0,0 +1,249 @@ |
|||
""" |
|||
Camera viewer with manual recording support. |
|||
|
|||
This script provides a camera viewer that can display multiple camera streams |
|||
and record them to video files with manual start/stop controls. |
|||
|
|||
Features: |
|||
- Onscreen mode: Display camera feeds with optional recording |
|||
- Offscreen mode: No display, recording only when triggered |
|||
- Manual recording control with keyboard (R key to start/stop) |
|||
|
|||
Usage Examples: |
|||
|
|||
1. Basic onscreen viewing (with recording capability): |
|||
python run_camera_viewer.py --camera-host localhost --camera-port 5555 |
|||
|
|||
2. Offscreen mode (no display, recording only): |
|||
python run_camera_viewer.py --offscreen --camera-host localhost --camera-port 5555 |
|||
|
|||
3. Custom output directory: |
|||
python run_camera_viewer.py --output-path ./my_recordings --camera-host localhost |
|||
|
|||
Controls: |
|||
- R key: Start/Stop recording |
|||
- Q key: Quit application |
|||
|
|||
Output Structure: |
|||
camera_output_20241211_143052/ |
|||
├── rec_143205/ |
|||
│ ├── ego_view_color_image.mp4 |
|||
│ ├── head_left_color_image.mp4 |
|||
│ └── head_right_color_image.mp4 |
|||
└── rec_143410/ |
|||
├── ego_view_color_image.mp4 |
|||
└── head_left_color_image.mp4 |
|||
""" |
|||
|
|||
from dataclasses import dataclass |
|||
from pathlib import Path |
|||
import threading |
|||
import time |
|||
from typing import Any, Optional |
|||
|
|||
import cv2 |
|||
import rclpy |
|||
from sshkeyboard import listen_keyboard, stop_listening |
|||
import tyro |
|||
|
|||
from decoupled_wbc.control.main.teleop.configs.configs import ComposedCameraClientConfig |
|||
from decoupled_wbc.control.sensor.composed_camera import ComposedCameraClientSensor |
|||
from decoupled_wbc.control.utils.img_viewer import ImageViewer |
|||
|
|||
|
|||
@dataclass |
|||
class CameraViewerConfig(ComposedCameraClientConfig): |
|||
"""Config for running the camera viewer with recording support.""" |
|||
|
|||
offscreen: bool = False |
|||
"""Run in offscreen mode (no display, manual recording with R key).""" |
|||
|
|||
output_path: Optional[str] = None |
|||
"""Output path for saving videos. If None, auto-generates path.""" |
|||
|
|||
codec: str = "mp4v" |
|||
"""Video codec to use for saving (e.g., 'mp4v', 'XVID').""" |
|||
|
|||
|
|||
ArgsConfig = CameraViewerConfig |
|||
|
|||
|
|||
def _get_camera_titles(image_data: dict[str, Any]) -> list[str]: |
|||
""" |
|||
Detect all the individual camera streams from the image data. |
|||
|
|||
schema format: |
|||
{ |
|||
"timestamps": {"ego_view": 123.45, "ego_view_left_mono": 123.46}, |
|||
"images": {"ego_view": np.ndarray, "ego_view_left_mono": np.ndarray} |
|||
} |
|||
|
|||
Returns list of camera keys (e.g., ["ego_view", "ego_view_left_mono", "ego_view_right_mono"]) |
|||
""" |
|||
# Extract all camera keys from the images dictionary |
|||
camera_titles = list(image_data.get("images", {}).keys()) |
|||
return camera_titles |
|||
|
|||
|
|||
def main(config: ArgsConfig): |
|||
"""Main function to run the camera viewer.""" |
|||
# Initialize ROS |
|||
rclpy.init(args=None) |
|||
node = rclpy.create_node("camera_viewer") |
|||
|
|||
# Start ROS spin in a separate thread |
|||
thread = threading.Thread(target=rclpy.spin, args=(node,), daemon=True) |
|||
thread.start() |
|||
|
|||
image_sub = ComposedCameraClientSensor(server_ip=config.camera_host, port=config.camera_port) |
|||
|
|||
# pre-fetch a sample image to get the number of camera angles |
|||
retry_count = 0 |
|||
while True: |
|||
_sample_image = image_sub.read() |
|||
if _sample_image: |
|||
break |
|||
retry_count += 1 |
|||
time.sleep(0.1) |
|||
if retry_count > 10: |
|||
raise Exception("Failed to get sample image") |
|||
|
|||
camera_titles = _get_camera_titles(_sample_image) |
|||
|
|||
# Setup output directory |
|||
if config.output_path is None: |
|||
output_dir = Path("camera_recordings") |
|||
else: |
|||
output_dir = Path(config.output_path) |
|||
|
|||
# Recording state |
|||
is_recording = False |
|||
video_writers = {} |
|||
frame_count = 0 |
|||
recording_start_time = None |
|||
should_quit = False |
|||
|
|||
def on_press(key): |
|||
nonlocal is_recording, video_writers, frame_count, recording_start_time, should_quit |
|||
|
|||
if key == "r": |
|||
if not is_recording: |
|||
# Start recording |
|||
recording_dir = output_dir / f"rec_{time.strftime('%Y%m%d_%H%M%S')}" |
|||
recording_dir.mkdir(parents=True, exist_ok=True) |
|||
|
|||
# Create video writers |
|||
fourcc = cv2.VideoWriter_fourcc(*config.codec) |
|||
video_writers = {} |
|||
|
|||
for title in camera_titles: |
|||
img = _sample_image["images"].get(title) |
|||
if img is not None: |
|||
height, width = img.shape[:2] |
|||
video_path = recording_dir / f"{title}.mp4" |
|||
writer = cv2.VideoWriter( |
|||
str(video_path), fourcc, config.fps, (width, height) |
|||
) |
|||
video_writers[title] = writer |
|||
|
|||
is_recording = True |
|||
recording_start_time = time.time() |
|||
frame_count = 0 |
|||
print(f"🔴 Recording started: {recording_dir}") |
|||
else: |
|||
# Stop recording |
|||
is_recording = False |
|||
for title, writer in video_writers.items(): |
|||
writer.release() |
|||
video_writers = {} |
|||
|
|||
duration = time.time() - recording_start_time if recording_start_time else 0 |
|||
print(f"⏹️ Recording stopped - {duration:.1f}s, {frame_count} frames") |
|||
elif key == "q": |
|||
should_quit = True |
|||
stop_listening() |
|||
|
|||
# Setup keyboard listener in a separate thread |
|||
keyboard_thread = threading.Thread( |
|||
target=lambda: listen_keyboard(on_press=on_press), daemon=True |
|||
) |
|||
keyboard_thread.start() |
|||
|
|||
# Setup viewer for onscreen mode |
|||
viewer = None |
|||
if not config.offscreen: |
|||
viewer = ImageViewer( |
|||
title="Camera Viewer", |
|||
figsize=(10, 8), |
|||
num_images=len(camera_titles), |
|||
image_titles=camera_titles, |
|||
) |
|||
|
|||
# Print instructions |
|||
mode = "Offscreen" if config.offscreen else "Onscreen" |
|||
print(f"{mode} mode - Target FPS: {config.fps}") |
|||
print(f"Videos will be saved to: {output_dir}") |
|||
print("Controls: R key to start/stop recording, Q key to quit, Ctrl+C to exit") |
|||
|
|||
# Create ROS rate controller |
|||
rate = node.create_rate(config.fps) |
|||
|
|||
try: |
|||
while rclpy.ok() and not should_quit: |
|||
# Get images from all subscribers |
|||
images = [] |
|||
image_data = image_sub.read() |
|||
if image_data: |
|||
for title in camera_titles: |
|||
img = image_data["images"].get(title) |
|||
images.append(img) |
|||
|
|||
# Save frame if recording |
|||
if is_recording and img is not None and title in video_writers: |
|||
# Convert from RGB to BGR for OpenCV |
|||
if len(img.shape) == 3 and img.shape[2] == 3: |
|||
img_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) |
|||
else: |
|||
img_bgr = img |
|||
video_writers[title].write(img_bgr) |
|||
|
|||
# Display images if not offscreen |
|||
if not config.offscreen and viewer and any(img is not None for img in images): |
|||
status = "🔴 REC" if is_recording else "⏸️ Ready" |
|||
viewer._fig.suptitle(f"Camera Viewer - {status}") |
|||
viewer.show_multiple(images) |
|||
|
|||
# Progress feedback |
|||
if is_recording: |
|||
frame_count += 1 |
|||
if frame_count % 100 == 0: |
|||
duration = time.time() - recording_start_time |
|||
print(f"Recording: {frame_count} frames ({duration:.1f}s)") |
|||
|
|||
rate.sleep() |
|||
|
|||
except KeyboardInterrupt: |
|||
print("\nExiting...") |
|||
finally: |
|||
# Cleanup |
|||
try: |
|||
stop_listening() |
|||
except Exception: |
|||
pass |
|||
|
|||
if video_writers: |
|||
for title, writer in video_writers.items(): |
|||
writer.release() |
|||
if is_recording: |
|||
duration = time.time() - recording_start_time |
|||
print(f"Final: {duration:.1f}s, {frame_count} frames") |
|||
|
|||
if viewer: |
|||
viewer.close() |
|||
|
|||
rclpy.shutdown() |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
config = tyro.cli(ArgsConfig) |
|||
main(config) |
|||
@ -0,0 +1,236 @@ |
|||
from copy import deepcopy |
|||
import time |
|||
|
|||
import tyro |
|||
|
|||
from decoupled_wbc.control.envs.g1.g1_env import G1Env |
|||
from decoupled_wbc.control.main.constants import ( |
|||
CONTROL_GOAL_TOPIC, |
|||
DEFAULT_BASE_HEIGHT, |
|||
DEFAULT_NAV_CMD, |
|||
DEFAULT_WRIST_POSE, |
|||
JOINT_SAFETY_STATUS_TOPIC, |
|||
LOWER_BODY_POLICY_STATUS_TOPIC, |
|||
ROBOT_CONFIG_TOPIC, |
|||
STATE_TOPIC_NAME, |
|||
) |
|||
from decoupled_wbc.control.main.teleop.configs.configs import ControlLoopConfig |
|||
from decoupled_wbc.control.policy.wbc_policy_factory import get_wbc_policy |
|||
from decoupled_wbc.control.robot_model.instantiation.g1 import ( |
|||
instantiate_g1_robot_model, |
|||
) |
|||
from decoupled_wbc.control.utils.keyboard_dispatcher import ( |
|||
KeyboardDispatcher, |
|||
KeyboardEStop, |
|||
KeyboardListenerPublisher, |
|||
ROSKeyboardDispatcher, |
|||
) |
|||
from decoupled_wbc.control.utils.ros_utils import ( |
|||
ROSManager, |
|||
ROSMsgPublisher, |
|||
ROSMsgSubscriber, |
|||
ROSServiceServer, |
|||
) |
|||
from decoupled_wbc.control.utils.telemetry import Telemetry |
|||
|
|||
CONTROL_NODE_NAME = "ControlPolicy" |
|||
|
|||
|
|||
def main(config: ControlLoopConfig): |
|||
ros_manager = ROSManager(node_name=CONTROL_NODE_NAME) |
|||
node = ros_manager.node |
|||
|
|||
# start the robot config server |
|||
ROSServiceServer(ROBOT_CONFIG_TOPIC, config.to_dict()) |
|||
|
|||
wbc_config = config.load_wbc_yaml() |
|||
|
|||
data_exp_pub = ROSMsgPublisher(STATE_TOPIC_NAME) |
|||
lower_body_policy_status_pub = ROSMsgPublisher(LOWER_BODY_POLICY_STATUS_TOPIC) |
|||
joint_safety_status_pub = ROSMsgPublisher(JOINT_SAFETY_STATUS_TOPIC) |
|||
|
|||
# Initialize telemetry |
|||
telemetry = Telemetry(window_size=100) |
|||
|
|||
waist_location = "lower_and_upper_body" if config.enable_waist else "lower_body" |
|||
robot_model = instantiate_g1_robot_model( |
|||
waist_location=waist_location, high_elbow_pose=config.high_elbow_pose |
|||
) |
|||
|
|||
env = G1Env( |
|||
env_name=config.env_name, |
|||
robot_model=robot_model, |
|||
config=wbc_config, |
|||
wbc_version=config.wbc_version, |
|||
) |
|||
if env.sim and not config.sim_sync_mode: |
|||
env.start_simulator() |
|||
|
|||
wbc_policy = get_wbc_policy("g1", robot_model, wbc_config, config.upper_body_joint_speed) |
|||
|
|||
keyboard_listener_pub = KeyboardListenerPublisher() |
|||
keyboard_estop = KeyboardEStop() |
|||
if config.keyboard_dispatcher_type == "raw": |
|||
dispatcher = KeyboardDispatcher() |
|||
elif config.keyboard_dispatcher_type == "ros": |
|||
dispatcher = ROSKeyboardDispatcher() |
|||
else: |
|||
raise ValueError( |
|||
f"Invalid keyboard dispatcher: {config.keyboard_dispatcher_type}, please use 'raw' or 'ros'" |
|||
) |
|||
dispatcher.register(env) |
|||
dispatcher.register(wbc_policy) |
|||
dispatcher.register(keyboard_listener_pub) |
|||
dispatcher.register(keyboard_estop) |
|||
dispatcher.start() |
|||
|
|||
rate = node.create_rate(config.control_frequency) |
|||
|
|||
upper_body_policy_subscriber = ROSMsgSubscriber(CONTROL_GOAL_TOPIC) |
|||
|
|||
last_teleop_cmd = None |
|||
try: |
|||
while ros_manager.ok(): |
|||
t_start = time.monotonic() |
|||
with telemetry.timer("total_loop"): |
|||
# Step simulator if in sync mode |
|||
with telemetry.timer("step_simulator"): |
|||
if env.sim and config.sim_sync_mode: |
|||
env.step_simulator() |
|||
|
|||
# Measure observation time |
|||
with telemetry.timer("observe"): |
|||
obs = env.observe() |
|||
wbc_policy.set_observation(obs) |
|||
|
|||
# Measure policy setup time |
|||
with telemetry.timer("policy_setup"): |
|||
upper_body_cmd = upper_body_policy_subscriber.get_msg() |
|||
|
|||
t_now = time.monotonic() |
|||
|
|||
wbc_goal = {} |
|||
if upper_body_cmd: |
|||
wbc_goal = upper_body_cmd.copy() |
|||
last_teleop_cmd = upper_body_cmd.copy() |
|||
if config.ik_indicator: |
|||
env.set_ik_indicator(upper_body_cmd) |
|||
# Send goal to policy |
|||
if wbc_goal: |
|||
wbc_goal["interpolation_garbage_collection_time"] = t_now - 2 * ( |
|||
1 / config.control_frequency |
|||
) |
|||
wbc_policy.set_goal(wbc_goal) |
|||
|
|||
# Measure policy action calculation time |
|||
with telemetry.timer("policy_action"): |
|||
wbc_action = wbc_policy.get_action(time=t_now) |
|||
|
|||
# Measure action queue time |
|||
with telemetry.timer("queue_action"): |
|||
env.queue_action(wbc_action) |
|||
|
|||
# Publish status information for InteractiveModeController |
|||
with telemetry.timer("publish_status"): |
|||
# Get policy status - check if the lower body policy has use_policy_action enabled |
|||
policy_use_action = False |
|||
try: |
|||
# Access the lower body policy through the decoupled whole body policy |
|||
if hasattr(wbc_policy, "lower_body_policy"): |
|||
policy_use_action = getattr( |
|||
wbc_policy.lower_body_policy, "use_policy_action", False |
|||
) |
|||
except (AttributeError, TypeError): |
|||
policy_use_action = False |
|||
|
|||
policy_status_msg = {"use_policy_action": policy_use_action, "timestamp": t_now} |
|||
lower_body_policy_status_pub.publish(policy_status_msg) |
|||
|
|||
# Get joint safety status from G1Env (which already runs the safety monitor) |
|||
joint_safety_ok = env.get_joint_safety_status() |
|||
|
|||
joint_safety_status_msg = { |
|||
"joint_safety_ok": joint_safety_ok, |
|||
"timestamp": t_now, |
|||
} |
|||
joint_safety_status_pub.publish(joint_safety_status_msg) |
|||
|
|||
# Start or Stop data collection |
|||
if wbc_goal.get("toggle_data_collection", False): |
|||
dispatcher.handle_key("c") |
|||
|
|||
# Abort the current episode |
|||
if wbc_goal.get("toggle_data_abort", False): |
|||
dispatcher.handle_key("x") |
|||
|
|||
if env.use_sim and wbc_goal.get("reset_env_and_policy", False): |
|||
print("Resetting sim environment and policy") |
|||
# Reset teleop policy & sim env |
|||
dispatcher.handle_key("k") |
|||
|
|||
# Clear upper body commands |
|||
upper_body_policy_subscriber._msg = None |
|||
upper_body_cmd = { |
|||
"target_upper_body_pose": obs["q"][ |
|||
robot_model.get_joint_group_indices("upper_body") |
|||
], |
|||
"wrist_pose": DEFAULT_WRIST_POSE, |
|||
"base_height_command": DEFAULT_BASE_HEIGHT, |
|||
"navigate_cmd": DEFAULT_NAV_CMD, |
|||
} |
|||
last_teleop_cmd = upper_body_cmd.copy() |
|||
|
|||
time.sleep(0.5) |
|||
|
|||
msg = deepcopy(obs) |
|||
for key in obs.keys(): |
|||
if key.endswith("_image"): |
|||
del msg[key] |
|||
|
|||
# exporting data |
|||
if last_teleop_cmd: |
|||
msg.update( |
|||
{ |
|||
"action": wbc_action["q"], |
|||
"action.eef": last_teleop_cmd.get("wrist_pose", DEFAULT_WRIST_POSE), |
|||
"base_height_command": last_teleop_cmd.get( |
|||
"base_height_command", DEFAULT_BASE_HEIGHT |
|||
), |
|||
"navigate_command": last_teleop_cmd.get( |
|||
"navigate_cmd", DEFAULT_NAV_CMD |
|||
), |
|||
"timestamps": { |
|||
"main_loop": time.time(), |
|||
"proprio": time.time(), |
|||
}, |
|||
} |
|||
) |
|||
data_exp_pub.publish(msg) |
|||
end_time = time.monotonic() |
|||
|
|||
if env.sim and (not env.sim.sim_thread or not env.sim.sim_thread.is_alive()): |
|||
raise RuntimeError("Simulator thread is not alive") |
|||
|
|||
rate.sleep() |
|||
|
|||
# Log timing information every 100 iterations (roughly every 2 seconds at 50Hz) |
|||
if config.verbose_timing: |
|||
# When verbose timing is enabled, always show timing |
|||
telemetry.log_timing_info(context="G1 Control Loop", threshold=0.0) |
|||
elif (end_time - t_start) > (1 / config.control_frequency) and not config.sim_sync_mode: |
|||
# Only show timing when loop is slow and verbose_timing is disabled |
|||
telemetry.log_timing_info(context="G1 Control Loop Missed", threshold=0.001) |
|||
|
|||
except ros_manager.exceptions() as e: |
|||
print(f"ROSManager interrupted by user: {e}") |
|||
finally: |
|||
print("Cleaning up...") |
|||
# the order of the following is important |
|||
dispatcher.stop() |
|||
ros_manager.shutdown() |
|||
env.close() |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
config = tyro.cli(ControlLoopConfig) |
|||
main(config) |
|||
@ -0,0 +1,364 @@ |
|||
from collections import deque |
|||
from datetime import datetime |
|||
import threading |
|||
import time |
|||
|
|||
import numpy as np |
|||
import rclpy |
|||
import tyro |
|||
|
|||
from decoupled_wbc.control.main.constants import ROBOT_CONFIG_TOPIC, STATE_TOPIC_NAME |
|||
from decoupled_wbc.control.main.teleop.configs.configs import DataExporterConfig |
|||
from decoupled_wbc.control.robot_model.instantiation import g1 |
|||
from decoupled_wbc.control.sensor.composed_camera import ComposedCameraClientSensor |
|||
from decoupled_wbc.control.utils.episode_state import EpisodeState |
|||
from decoupled_wbc.control.utils.keyboard_dispatcher import KeyboardListenerSubscriber |
|||
from decoupled_wbc.control.utils.ros_utils import ROSMsgSubscriber, ROSServiceClient |
|||
from decoupled_wbc.control.utils.telemetry import Telemetry |
|||
from decoupled_wbc.control.utils.text_to_speech import TextToSpeech |
|||
from decoupled_wbc.data.constants import BUCKET_BASE_PATH |
|||
from decoupled_wbc.data.exporter import DataCollectionInfo, Gr00tDataExporter |
|||
from decoupled_wbc.data.utils import get_dataset_features, get_modality_config |
|||
|
|||
|
|||
class TimeDeltaException(Exception): |
|||
def __init__(self, failure_count: int, reset_timeout_sec: float): |
|||
""" |
|||
Exception raised when the time delta between two messages exceeds |
|||
a threshold for a consecutive number of times |
|||
""" |
|||
self.failure_count = failure_count |
|||
self.reset_timeout_sec = reset_timeout_sec |
|||
self.message = f"{self.failure_count} failures in {self.reset_timeout_sec} seconds" |
|||
super().__init__(self.message) |
|||
|
|||
|
|||
class TimingThresholdMonitor: |
|||
def __init__(self, max_failures=3, reset_timeout_sec=5, time_delta=0.2, raise_exception=False): |
|||
""" |
|||
Monitor the time diff (between two messages) and optionally raise an exception |
|||
if there is a consistent violations |
|||
""" |
|||
self.max_failures = max_failures |
|||
self.reset_timeout_sec = reset_timeout_sec |
|||
self.failure_count = 0 |
|||
self.last_failure_time = 0 |
|||
self.time_delta = time_delta |
|||
self.raise_exception = raise_exception |
|||
|
|||
def reset(self): |
|||
self.failure_count = 0 |
|||
self.last_failure_time = 0 |
|||
|
|||
def log_time_delta(self, time_delta_sec: float): |
|||
time_delta = abs(time_delta_sec) |
|||
if time_delta > self.time_delta: |
|||
self.failure_count += 1 |
|||
self.last_failure_time = time.monotonic() |
|||
|
|||
if self.is_threshold_exceeded(): |
|||
print( |
|||
f"Time delta exception: {self.failure_count} failures in {self.reset_timeout_sec} seconds" |
|||
f", time delta: {time_delta}" |
|||
) |
|||
if self.raise_exception: |
|||
raise TimeDeltaException(self.failure_count, self.reset_timeout_sec) |
|||
|
|||
def is_threshold_exceeded(self): |
|||
if self.failure_count >= self.max_failures: |
|||
return True |
|||
if time.monotonic() - self.last_failure_time > self.reset_timeout_sec: |
|||
self.reset() |
|||
return False |
|||
|
|||
|
|||
class Gr00tDataCollector: |
|||
def __init__( |
|||
self, |
|||
node, |
|||
camera_host: str, |
|||
camera_port: int, |
|||
state_topic_name: str, |
|||
data_exporter: Gr00tDataExporter, |
|||
text_to_speech=None, |
|||
frequency=20, |
|||
state_act_msg_frequency=50, |
|||
): |
|||
|
|||
self.text_to_speech = text_to_speech |
|||
self.frequency = frequency |
|||
self.data_exporter = data_exporter |
|||
|
|||
self.node = node |
|||
|
|||
thread = threading.Thread(target=rclpy.spin, args=(self.node,), daemon=True) |
|||
thread.start() |
|||
time.sleep(0.5) |
|||
|
|||
self._episode_state = EpisodeState() |
|||
self._keyboard_listener = KeyboardListenerSubscriber() |
|||
self._state_subscriber = ROSMsgSubscriber(state_topic_name) |
|||
self._image_subscriber = ComposedCameraClientSensor(server_ip=camera_host, port=camera_port) |
|||
self.rate = self.node.create_rate(self.frequency) |
|||
|
|||
self.obs_act_buffer = deque(maxlen=100) |
|||
self.latest_image_msg = None |
|||
self.latest_proprio_msg = None |
|||
|
|||
self.state_polling_rate = 1 / state_act_msg_frequency |
|||
self.last_state_poll_time = time.monotonic() |
|||
|
|||
self.telemetry = Telemetry(window_size=100) |
|||
self.timing_threshold_monitor = TimingThresholdMonitor() |
|||
|
|||
print(f"Recording to {self.data_exporter.meta.root}") |
|||
|
|||
@property |
|||
def current_episode_index(self): |
|||
return self.data_exporter.episode_buffer["episode_index"] |
|||
|
|||
def _print_and_say(self, message: str, say: bool = True): |
|||
"""Helper to use TextToSpeech print_and_say or fallback to print.""" |
|||
if self.text_to_speech is not None: |
|||
self.text_to_speech.print_and_say(message, say) |
|||
else: |
|||
print(message) |
|||
|
|||
def _check_keyboard_input(self): |
|||
key = self._keyboard_listener.read_msg() |
|||
if key == "c": |
|||
self._episode_state.change_state() |
|||
if self._episode_state.get_state() == self._episode_state.RECORDING: |
|||
self._print_and_say(f"Started recording {self.current_episode_index}") |
|||
elif self._episode_state.get_state() == self._episode_state.NEED_TO_SAVE: |
|||
self._print_and_say("Stopping recording, preparing to save") |
|||
elif self._episode_state.get_state() == self._episode_state.IDLE: |
|||
self._print_and_say("Saved episode and back to idle state") |
|||
elif key == "x": |
|||
if self._episode_state.get_state() == self._episode_state.RECORDING: |
|||
self.data_exporter.save_episode_as_discarded() |
|||
self._episode_state.reset_state() |
|||
self._print_and_say("Discarded episode") |
|||
|
|||
def _add_data_frame(self): |
|||
t_start = time.monotonic() |
|||
|
|||
if self.latest_proprio_msg is None or self.latest_image_msg is None: |
|||
self._print_and_say( |
|||
f"Waiting for message. " |
|||
f"Avail msg: proprio {self.latest_proprio_msg is not None} | " |
|||
f"image {self.latest_image_msg is not None}", |
|||
say=False, |
|||
) |
|||
return False |
|||
|
|||
if self._episode_state.get_state() == self._episode_state.RECORDING: |
|||
|
|||
# Calculate max time delta between images and proprio |
|||
max_time_delta = 0 |
|||
for _, image_time in self.latest_image_msg["timestamps"].items(): |
|||
time_delta = abs(image_time - self.latest_proprio_msg["timestamps"]["proprio"]) |
|||
max_time_delta = max(max_time_delta, time_delta) |
|||
|
|||
self.timing_threshold_monitor.log_time_delta(max_time_delta) |
|||
if (self.timing_threshold_monitor.failure_count + 1) % 100 == 0: |
|||
self._print_and_say("Image state delta too high, please discard data") |
|||
|
|||
frame_data = { |
|||
"observation.state": self.latest_proprio_msg["q"], |
|||
"observation.eef_state": self.latest_proprio_msg["wrist_pose"], |
|||
"action": self.latest_proprio_msg["action"], |
|||
"action.eef": self.latest_proprio_msg["action.eef"], |
|||
"observation.img_state_delta": ( |
|||
np.array( |
|||
[max_time_delta], |
|||
dtype=np.float32, |
|||
) |
|||
), # lerobot only supports adding numpy arrays |
|||
"teleop.navigate_command": np.array( |
|||
self.latest_proprio_msg["navigate_command"], dtype=np.float64 |
|||
), |
|||
"teleop.base_height_command": np.array( |
|||
[self.latest_proprio_msg["base_height_command"]], dtype=np.float64 |
|||
), |
|||
} |
|||
|
|||
# Add images based on dataset features |
|||
images = self.latest_image_msg["images"] |
|||
for feature_name, feature_info in self.data_exporter.features.items(): |
|||
if feature_info.get("dtype") in ["image", "video"]: |
|||
# Extract image key from feature name (e.g., "observation.images.ego_view" -> "ego_view") |
|||
image_key = feature_name.split(".")[-1] |
|||
|
|||
if image_key not in images: |
|||
raise ValueError( |
|||
f"Required image '{image_key}' for feature '{feature_name}' " |
|||
f"not found in image message. Available images: {list(images.keys())}" |
|||
) |
|||
frame_data[feature_name] = images[image_key] |
|||
|
|||
self.data_exporter.add_frame(frame_data) |
|||
|
|||
t_end = time.monotonic() |
|||
if t_end - t_start > (1 / self.frequency): |
|||
print(f"DataExporter Missed: {t_end - t_start} sec") |
|||
|
|||
if self._episode_state.get_state() == self._episode_state.NEED_TO_SAVE: |
|||
self.data_exporter.save_episode() |
|||
self.timing_threshold_monitor.reset() |
|||
self._print_and_say("Finished saving episode") |
|||
self._episode_state.change_state() |
|||
|
|||
return True |
|||
|
|||
def save_and_cleanup(self): |
|||
try: |
|||
self._print_and_say("saving episode done") |
|||
# save on going episode if any |
|||
buffer_size = self.data_exporter.episode_buffer.get("size", 0) |
|||
if buffer_size > 0: |
|||
self.data_exporter.save_episode() |
|||
self._print_and_say(f"Recording complete: {self.data_exporter.meta.root}", say=False) |
|||
except Exception as e: |
|||
self._print_and_say(f"Error saving episode: {e}") |
|||
|
|||
self.node.destroy_node() |
|||
rclpy.shutdown() |
|||
self._print_and_say("Shutting down data exporter...", say=False) |
|||
|
|||
def run(self): |
|||
try: |
|||
while rclpy.ok(): |
|||
t_start = time.monotonic() |
|||
with self.telemetry.timer("total_loop"): |
|||
# 1. poll proprio msg |
|||
with self.telemetry.timer("poll_state"): |
|||
msg = self._state_subscriber.get_msg() |
|||
if msg is not None: |
|||
self.latest_proprio_msg = msg |
|||
|
|||
# 2. poll image msg |
|||
with self.telemetry.timer("poll_image"): |
|||
msg = self._image_subscriber.read() |
|||
if msg is not None: |
|||
self.latest_image_msg = msg |
|||
|
|||
# 3. check keyboard input |
|||
with self.telemetry.timer("check_keyboard"): |
|||
self._check_keyboard_input() |
|||
|
|||
# 4. add frame |
|||
with self.telemetry.timer("add_frame"): |
|||
self._add_data_frame() |
|||
|
|||
end_time = time.monotonic() |
|||
|
|||
self.rate.sleep() |
|||
|
|||
# Log timing information if we missed our target frequency |
|||
if (end_time - t_start) > (1 / self.frequency): |
|||
self.telemetry.log_timing_info( |
|||
context="Data Exporter Loop Missed", threshold=0.001 |
|||
) |
|||
|
|||
except KeyboardInterrupt: |
|||
print("Data exporter terminated by user") |
|||
# The user will trigger a keyboard interrupt if there's something wrong, |
|||
# so we flag the ongoing episode as discarded |
|||
buffer_size = self.data_exporter.episode_buffer.get("size", 0) |
|||
if buffer_size > 0: |
|||
self.data_exporter.save_episode_as_discarded() |
|||
|
|||
finally: |
|||
self.save_and_cleanup() |
|||
|
|||
|
|||
def main(config: DataExporterConfig): |
|||
|
|||
rclpy.init(args=None) |
|||
node = rclpy.create_node("data_exporter") |
|||
|
|||
waist_location = "lower_and_upper_body" if config.enable_waist else "lower_body" |
|||
g1_rm = g1.instantiate_g1_robot_model( |
|||
waist_location=waist_location, high_elbow_pose=config.high_elbow_pose |
|||
) |
|||
|
|||
dataset_features = get_dataset_features(g1_rm, config.add_stereo_camera) |
|||
modality_config = get_modality_config(g1_rm, config.add_stereo_camera) |
|||
|
|||
text_to_speech = TextToSpeech() if config.text_to_speech else None |
|||
|
|||
# Only set DataCollectionInfo if we're creating a new dataset |
|||
# When adding to existing dataset, DataCollectionInfo will be ignored |
|||
if config.robot_id is not None: |
|||
data_collection_info = DataCollectionInfo( |
|||
teleoperator_username=config.teleoperator_username, |
|||
support_operator_username=config.support_operator_username, |
|||
robot_type="g1", |
|||
robot_id=config.robot_id, |
|||
lower_body_policy=config.lower_body_policy, |
|||
wbc_model_path=config.wbc_model_path, |
|||
) |
|||
else: |
|||
# Use default DataCollectionInfo when adding to existing dataset |
|||
# This will be ignored if the dataset already exists |
|||
data_collection_info = DataCollectionInfo() |
|||
|
|||
robot_config_client = ROSServiceClient(ROBOT_CONFIG_TOPIC) |
|||
robot_config = robot_config_client.get_config() |
|||
|
|||
data_exporter = Gr00tDataExporter.create( |
|||
save_root=f"{config.root_output_dir}/{config.dataset_name}", |
|||
fps=config.data_collection_frequency, |
|||
features=dataset_features, |
|||
modality_config=modality_config, |
|||
task=config.task_prompt, |
|||
upload_bucket_path=BUCKET_BASE_PATH, |
|||
data_collection_info=data_collection_info, |
|||
script_config=robot_config, |
|||
) |
|||
|
|||
data_collector = Gr00tDataCollector( |
|||
node=node, |
|||
frequency=config.data_collection_frequency, |
|||
data_exporter=data_exporter, |
|||
state_topic_name=STATE_TOPIC_NAME, |
|||
camera_host=config.camera_host, |
|||
camera_port=config.camera_port, |
|||
text_to_speech=text_to_speech, |
|||
) |
|||
data_collector.run() |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
config = tyro.cli(DataExporterConfig) |
|||
config.task_prompt = input("Enter the task prompt: ").strip().lower() |
|||
add_to_existing_dataset = input("Add to existing dataset? (y/n): ").strip().lower() |
|||
|
|||
if add_to_existing_dataset == "y": |
|||
config.dataset_name = input("Enter the dataset name: ").strip().lower() |
|||
# When adding to existing dataset, we don't need robot_id or operator usernames |
|||
# as they should already be set in the existing dataset |
|||
elif add_to_existing_dataset == "n": |
|||
# robot_id = input("Enter the robot ID: ").strip().lower() |
|||
# if robot_id not in G1_ROBOT_IDS: |
|||
# raise ValueError(f"Invalid robot ID: {robot_id}. Available robot IDs: {G1_ROBOT_IDS}") |
|||
config.robot_id = "sim" |
|||
config.dataset_name = f"{datetime.now().strftime('%Y-%m-%d-%H-%M-%S')}-G1-{config.robot_id}" |
|||
|
|||
# Only ask for operator usernames when creating a new dataset |
|||
# print("Available teleoperator usernames:") |
|||
# for i, username in enumerate(OPERATOR_USERNAMES): |
|||
# print(f"{i}: {username}") |
|||
# teleop_idx = int(input("Select teleoperator username index: ")) |
|||
# config.teleoperator_username = OPERATOR_USERNAMES[teleop_idx] |
|||
config.teleoperator_username = "NEW_USER" |
|||
|
|||
# print("\nAvailable support operator usernames:") |
|||
# for i, username in enumerate(OPERATOR_USERNAMES): |
|||
# print(f"{i}: {username}") |
|||
# support_idx = int(input("Select support operator username index: ")) |
|||
# config.support_operator_username = OPERATOR_USERNAMES[support_idx] |
|||
config.support_operator_username = "NEW_USER" |
|||
|
|||
main(config) |
|||
@ -0,0 +1,68 @@ |
|||
import threading |
|||
import time |
|||
|
|||
import rclpy |
|||
|
|||
from decoupled_wbc.control.main.constants import NAV_CMD_TOPIC |
|||
from decoupled_wbc.control.policy.keyboard_navigation_policy import KeyboardNavigationPolicy |
|||
from decoupled_wbc.control.utils.keyboard_dispatcher import KeyboardListenerSubscriber |
|||
from decoupled_wbc.control.utils.ros_utils import ROSMsgPublisher |
|||
|
|||
FREQUENCY = 10 |
|||
NAV_NODE_NAME = "NavigationPolicy" |
|||
|
|||
|
|||
def main(): |
|||
rclpy.init(args=None) |
|||
node = rclpy.create_node(NAV_NODE_NAME) |
|||
|
|||
# Start ROS spin in a separate thread |
|||
thread = threading.Thread(target=rclpy.spin, args=(node,), daemon=True) |
|||
thread.start() |
|||
time.sleep(0.5) |
|||
|
|||
dict_publisher = ROSMsgPublisher(NAV_CMD_TOPIC) |
|||
keyboard_listener = KeyboardListenerSubscriber() |
|||
|
|||
# Initialize navigation policy |
|||
navigation_policy = KeyboardNavigationPolicy() |
|||
|
|||
# Create rate controller |
|||
rate = node.create_rate(FREQUENCY) |
|||
|
|||
try: |
|||
while rclpy.ok(): |
|||
t_now = time.monotonic() |
|||
# get keyboard input |
|||
|
|||
navigation_policy.handle_keyboard_button(keyboard_listener.read_msg()) |
|||
# Get action from navigation policy |
|||
action = navigation_policy.get_action(time=t_now) |
|||
|
|||
# Add timestamp to the data |
|||
action["timestamp"] = t_now |
|||
|
|||
# Create and publish ByteMultiArray message |
|||
dict_publisher.publish(action) |
|||
|
|||
# Print status periodically (optional) |
|||
if int(t_now * 10) % 10 == 0: |
|||
nav_cmd = action["navigate_cmd"] |
|||
node.get_logger().info( |
|||
f"Nav cmd: linear=({nav_cmd[0]:.2f}, {nav_cmd[1]:.2f}), " |
|||
f"angular={nav_cmd[2]:.2f}" |
|||
) |
|||
|
|||
rate.sleep() |
|||
|
|||
except KeyboardInterrupt: |
|||
print("Navigation control loop terminated by user") |
|||
|
|||
finally: |
|||
# Clean shutdown |
|||
node.destroy_node() |
|||
rclpy.shutdown() |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
main() |
|||
@ -0,0 +1,61 @@ |
|||
from typing import Dict |
|||
|
|||
import tyro |
|||
|
|||
from decoupled_wbc.control.envs.g1.sim.simulator_factory import SimulatorFactory, init_channel |
|||
from decoupled_wbc.control.main.teleop.configs.configs import SimLoopConfig |
|||
from decoupled_wbc.control.robot_model.instantiation.g1 import ( |
|||
instantiate_g1_robot_model, |
|||
) |
|||
from decoupled_wbc.control.robot_model.robot_model import RobotModel |
|||
|
|||
ArgsConfig = SimLoopConfig |
|||
|
|||
|
|||
class SimWrapper: |
|||
def __init__(self, robot_model: RobotModel, env_name: str, config: Dict[str, any], **kwargs): |
|||
self.robot_model = robot_model |
|||
self.config = config |
|||
|
|||
init_channel(config=self.config) |
|||
|
|||
# Create simulator using factory |
|||
self.sim = SimulatorFactory.create_simulator( |
|||
config=self.config, |
|||
env_name=env_name, |
|||
**kwargs, |
|||
) |
|||
|
|||
|
|||
def main(config: ArgsConfig): |
|||
wbc_config = config.load_wbc_yaml() |
|||
# NOTE: we will override the interface to local if it is not specified |
|||
wbc_config["ENV_NAME"] = config.env_name |
|||
|
|||
if config.enable_image_publish: |
|||
assert ( |
|||
config.enable_offscreen |
|||
), "enable_offscreen must be True when enable_image_publish is True" |
|||
|
|||
robot_model = instantiate_g1_robot_model() |
|||
|
|||
sim_wrapper = SimWrapper( |
|||
robot_model=robot_model, |
|||
env_name=config.env_name, |
|||
config=wbc_config, |
|||
onscreen=wbc_config.get("ENABLE_ONSCREEN", True), |
|||
offscreen=wbc_config.get("ENABLE_OFFSCREEN", False), |
|||
) |
|||
# Start simulator as independent process |
|||
SimulatorFactory.start_simulator( |
|||
sim_wrapper.sim, |
|||
as_thread=False, |
|||
enable_image_publish=config.enable_image_publish, |
|||
mp_start_method=config.mp_start_method, |
|||
camera_port=config.camera_port, |
|||
) |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
config = tyro.cli(ArgsConfig) |
|||
main(config) |
|||
@ -0,0 +1,213 @@ |
|||
from pathlib import Path |
|||
import time |
|||
|
|||
import tyro |
|||
|
|||
from decoupled_wbc.control.main.teleop.configs.configs import SyncSimDataCollectionConfig |
|||
from decoupled_wbc.control.robot_model.instantiation import get_robot_type_and_model |
|||
from decoupled_wbc.control.utils.keyboard_dispatcher import ( |
|||
KeyboardDispatcher, |
|||
KeyboardListener, |
|||
) |
|||
from decoupled_wbc.control.utils.ros_utils import ROSManager |
|||
from decoupled_wbc.control.utils.sync_sim_utils import ( |
|||
COLLECTION_KEY, |
|||
SKIP_KEY, |
|||
CITestManager, |
|||
EpisodeManager, |
|||
generate_frame, |
|||
get_data_exporter, |
|||
get_env, |
|||
get_policies, |
|||
) |
|||
from decoupled_wbc.control.utils.telemetry import Telemetry |
|||
|
|||
CONTROL_NODE_NAME = "ControlPolicy" |
|||
|
|||
|
|||
CONTROL_CMD_TOPIC = CONTROL_NODE_NAME + "/q_target" |
|||
ENV_NODE_NAME = "SyncEnv" |
|||
ENV_OBS_TOPIC = ENV_NODE_NAME + "/obs" |
|||
|
|||
|
|||
def display_controls(config: SyncSimDataCollectionConfig): |
|||
""" |
|||
Method to pretty print controls. |
|||
""" |
|||
|
|||
def print_command(char, info): |
|||
char += " " * (30 - len(char)) |
|||
print("{}\t{}".format(char, info)) |
|||
|
|||
print("") |
|||
print_command("Keys", "Command") |
|||
if config.manual_control: |
|||
print_command(COLLECTION_KEY, "start/stop data collection") |
|||
print_command(SKIP_KEY, "skip and collect new episodes") |
|||
print_command("w-s-a-d", "move horizontally in x-y plane (press '=' first to enable)") |
|||
print_command("q", "rotate (counter-clockwise)") |
|||
print_command("e", "rotate (clockwise)") |
|||
print_command("space", "reset all velocity to zero") |
|||
print("") |
|||
|
|||
|
|||
def main(config: SyncSimDataCollectionConfig): |
|||
ros_manager = ROSManager(node_name=CONTROL_NODE_NAME) |
|||
node = ros_manager.node |
|||
|
|||
# Initialize telemetry |
|||
telemetry = Telemetry(window_size=100) |
|||
|
|||
# Initialize robot model |
|||
robot_type, robot_model = get_robot_type_and_model( |
|||
config.robot, |
|||
enable_waist_ik=config.enable_waist, |
|||
) |
|||
|
|||
# Initialize sim env |
|||
env = get_env(config, onscreen=config.enable_onscreen, offscreen=config.save_img_obs) |
|||
seed = int(time.time()) |
|||
env.reset(seed) |
|||
env.render() |
|||
obs = env.observe() |
|||
robot_model.set_initial_body_pose(obs["q"]) |
|||
|
|||
# Initialize data exporter |
|||
exporter = get_data_exporter( |
|||
config, |
|||
obs, |
|||
robot_model, |
|||
save_path=Path("./outputs/ci_test/") if config.ci_test else None, |
|||
) |
|||
|
|||
# Display control signals |
|||
display_controls(config) |
|||
|
|||
# Initialize policies |
|||
wbc_policy, teleop_policy = get_policies(config, robot_type, robot_model) |
|||
|
|||
dispatcher = KeyboardDispatcher() |
|||
keyboard_listener = KeyboardListener() # for data collection keys |
|||
dispatcher.register(keyboard_listener) |
|||
dispatcher.register(wbc_policy) |
|||
dispatcher.register(teleop_policy) |
|||
|
|||
dispatcher.start() |
|||
|
|||
rate = node.create_rate(config.control_frequency) |
|||
|
|||
# Initialize episode manager to handle state transitions and data collection |
|||
episode_manager = EpisodeManager(config) |
|||
|
|||
# Initialize CI test manager |
|||
ci_test_manager = CITestManager(config) if config.ci_test else None |
|||
|
|||
try: |
|||
while ros_manager.ok(): |
|||
|
|||
need_reset = False |
|||
keyboard_input = keyboard_listener.pop_key() |
|||
|
|||
with telemetry.timer("total_loop"): |
|||
max_mujoco_state_len, mujoco_state_len, mujoco_state = env.get_mujoco_state_info() |
|||
|
|||
# Measure observation time |
|||
with telemetry.timer("observe"): |
|||
obs = env.observe() |
|||
wbc_policy.set_observation(obs) |
|||
|
|||
# Measure policy setup time |
|||
with telemetry.timer("policy_setup"): |
|||
teleop_cmd = teleop_policy.get_action() |
|||
|
|||
wbc_goal = {} |
|||
|
|||
# Note that wbc_goal["navigation_cmd'] could be overwritten by teleop_cmd |
|||
if teleop_cmd: |
|||
for key, value in teleop_cmd.items(): |
|||
wbc_goal[key] = value |
|||
# Draw IK indicators |
|||
if config.ik_indicator: |
|||
env.set_ik_indicator(teleop_cmd) |
|||
if wbc_goal: |
|||
wbc_policy.set_goal(wbc_goal) |
|||
|
|||
# Measure policy action calculation time |
|||
with telemetry.timer("policy_action"): |
|||
wbc_action = wbc_policy.get_action() |
|||
|
|||
if config.ci_test: |
|||
ci_test_manager.check_upper_body_motion(robot_model, wbc_action, config) |
|||
|
|||
# Measure action queue time |
|||
with telemetry.timer("step"): |
|||
obs, _, _, _, step_info = env.step(wbc_action) |
|||
env.render() |
|||
episode_manager.increment_step() |
|||
|
|||
if config.ci_test and config.ci_test_mode == "pre_merge": |
|||
ci_test_manager.check_end_effector_tracking( |
|||
teleop_cmd, obs, config, episode_manager.get_step_count() |
|||
) |
|||
|
|||
# Handle data collection trigger |
|||
episode_manager.handle_collection_trigger(wbc_goal, keyboard_input, step_info) |
|||
|
|||
# Collect data frame |
|||
if episode_manager.should_collect_data(): |
|||
frame = generate_frame( |
|||
obs, |
|||
wbc_action, |
|||
seed, |
|||
mujoco_state, |
|||
mujoco_state_len, |
|||
max_mujoco_state_len, |
|||
teleop_cmd, |
|||
wbc_goal, |
|||
config.save_img_obs, |
|||
) |
|||
# exporting data |
|||
exporter.add_frame(frame) |
|||
|
|||
# if done and task_completion_hold_count is 0, save the episode |
|||
need_reset = episode_manager.check_export_and_completion(exporter) |
|||
|
|||
# check data abort |
|||
if episode_manager.handle_skip(wbc_goal, keyboard_input, exporter): |
|||
need_reset = True |
|||
|
|||
if need_reset: |
|||
if config.ci_test: |
|||
print("CI test: Completed...") |
|||
raise KeyboardInterrupt |
|||
|
|||
seed = int(time.time()) |
|||
env.reset(seed) |
|||
env.render() |
|||
|
|||
print("Sleeping for 3 seconds before resetting teleop policy...") |
|||
for j in range(3, 0, -1): |
|||
print(f"Starting in {j}...") |
|||
time.sleep(1) |
|||
|
|||
wbc_policy, teleop_policy = get_policies(config, robot_type, robot_model) |
|||
episode_manager.reset_step_count() |
|||
|
|||
rate.sleep() |
|||
|
|||
except ros_manager.exceptions() as e: |
|||
print(f"ROSManager interrupted by user: {e}") |
|||
finally: |
|||
# Cleanup resources |
|||
teleop_policy.close() |
|||
dispatcher.stop() |
|||
ros_manager.shutdown() |
|||
env.close() |
|||
print("Sync sim data collection loop terminated.") |
|||
|
|||
return True |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
config = tyro.cli(SyncSimDataCollectionConfig) |
|||
main(config) |
|||
@ -0,0 +1,110 @@ |
|||
import time |
|||
|
|||
import rclpy |
|||
import tyro |
|||
|
|||
from decoupled_wbc.control.main.constants import CONTROL_GOAL_TOPIC |
|||
from decoupled_wbc.control.main.teleop.configs.configs import TeleopConfig |
|||
from decoupled_wbc.control.policy.lerobot_replay_policy import LerobotReplayPolicy |
|||
from decoupled_wbc.control.policy.teleop_policy import TeleopPolicy |
|||
from decoupled_wbc.control.robot_model.instantiation.g1 import instantiate_g1_robot_model |
|||
from decoupled_wbc.control.teleop.solver.hand.instantiation.g1_hand_ik_instantiation import ( |
|||
instantiate_g1_hand_ik_solver, |
|||
) |
|||
from decoupled_wbc.control.teleop.teleop_retargeting_ik import TeleopRetargetingIK |
|||
from decoupled_wbc.control.utils.ros_utils import ROSManager, ROSMsgPublisher |
|||
from decoupled_wbc.control.utils.telemetry import Telemetry |
|||
|
|||
TELEOP_NODE_NAME = "TeleopPolicy" |
|||
|
|||
|
|||
def main(config: TeleopConfig): |
|||
ros_manager = ROSManager(node_name=TELEOP_NODE_NAME) |
|||
node = ros_manager.node |
|||
|
|||
if config.robot == "g1": |
|||
waist_location = "lower_and_upper_body" if config.enable_waist else "lower_body" |
|||
robot_model = instantiate_g1_robot_model( |
|||
waist_location=waist_location, high_elbow_pose=config.high_elbow_pose |
|||
) |
|||
left_hand_ik_solver, right_hand_ik_solver = instantiate_g1_hand_ik_solver() |
|||
else: |
|||
raise ValueError(f"Unsupported robot name: {config.robot}") |
|||
|
|||
if config.lerobot_replay_path: |
|||
teleop_policy = LerobotReplayPolicy( |
|||
robot_model=robot_model, parquet_path=config.lerobot_replay_path |
|||
) |
|||
else: |
|||
print("running teleop policy, waiting teleop policy to be initialized...") |
|||
retargeting_ik = TeleopRetargetingIK( |
|||
robot_model=robot_model, |
|||
left_hand_ik_solver=left_hand_ik_solver, |
|||
right_hand_ik_solver=right_hand_ik_solver, |
|||
enable_visualization=config.enable_visualization, |
|||
body_active_joint_groups=["upper_body"], |
|||
) |
|||
teleop_policy = TeleopPolicy( |
|||
robot_model=robot_model, |
|||
retargeting_ik=retargeting_ik, |
|||
body_control_device=config.body_control_device, |
|||
hand_control_device=config.hand_control_device, |
|||
body_streamer_ip=config.body_streamer_ip, # vive tracker, leap motion does not require |
|||
body_streamer_keyword=config.body_streamer_keyword, |
|||
enable_real_device=config.enable_real_device, |
|||
replay_data_path=config.teleop_replay_path, |
|||
) |
|||
|
|||
# Create a publisher for the navigation commands |
|||
control_publisher = ROSMsgPublisher(CONTROL_GOAL_TOPIC) |
|||
|
|||
# Create rate controller |
|||
rate = node.create_rate(config.teleop_frequency) |
|||
iteration = 0 |
|||
time_to_get_to_initial_pose = 2 # seconds |
|||
|
|||
telemetry = Telemetry(window_size=100) |
|||
|
|||
try: |
|||
while rclpy.ok(): |
|||
with telemetry.timer("total_loop"): |
|||
t_start = time.monotonic() |
|||
# Get the current teleop action |
|||
with telemetry.timer("get_action"): |
|||
data = teleop_policy.get_action() |
|||
|
|||
# Add timing information to the message |
|||
t_now = time.monotonic() |
|||
data["timestamp"] = t_now |
|||
|
|||
# Set target completion time - longer for initial pose, then match control frequency |
|||
if iteration == 0: |
|||
data["target_time"] = t_now + time_to_get_to_initial_pose |
|||
else: |
|||
data["target_time"] = t_now + (1 / config.teleop_frequency) |
|||
|
|||
# Publish the teleop command |
|||
with telemetry.timer("publish_teleop_command"): |
|||
control_publisher.publish(data) |
|||
|
|||
# For the initial pose, wait the full duration before continuing |
|||
if iteration == 0: |
|||
print(f"Moving to initial pose for {time_to_get_to_initial_pose} seconds") |
|||
time.sleep(time_to_get_to_initial_pose) |
|||
iteration += 1 |
|||
end_time = time.monotonic() |
|||
if (end_time - t_start) > (1 / config.teleop_frequency): |
|||
telemetry.log_timing_info(context="Teleop Policy Loop Missed", threshold=0.001) |
|||
rate.sleep() |
|||
|
|||
except ros_manager.exceptions() as e: |
|||
print(f"ROSManager interrupted by user: {e}") |
|||
|
|||
finally: |
|||
print("Cleaning up...") |
|||
ros_manager.shutdown() |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
config = tyro.cli(TeleopConfig) |
|||
main(config) |
|||
@ -0,0 +1,157 @@ |
|||
import time as time_module |
|||
from typing import Optional |
|||
|
|||
import numpy as np |
|||
from pinocchio import rpy |
|||
|
|||
from decoupled_wbc.control.base.policy import Policy |
|||
from decoupled_wbc.control.main.constants import DEFAULT_NAV_CMD |
|||
|
|||
|
|||
class G1DecoupledWholeBodyPolicy(Policy): |
|||
""" |
|||
This class implements a whole-body policy for the G1 robot by combining an upper-body |
|||
policy and a lower-body RL-based policy. |
|||
It is designed to work with the G1 robot's specific configuration and control requirements. |
|||
""" |
|||
|
|||
def __init__( |
|||
self, |
|||
robot_model, |
|||
lower_body_policy: Policy, |
|||
upper_body_policy: Policy, |
|||
): |
|||
self.robot_model = robot_model |
|||
self.lower_body_policy = lower_body_policy |
|||
self.upper_body_policy = upper_body_policy |
|||
self.last_goal_time = time_module.monotonic() |
|||
self.is_in_teleop_mode = False # Track if lower body is in teleop mode |
|||
|
|||
def set_observation(self, observation): |
|||
# Upper body policy is open loop (just interpolation), so we don't need to set the observation |
|||
self.lower_body_policy.set_observation(observation) |
|||
|
|||
def set_goal(self, goal): |
|||
""" |
|||
Set the goal for both upper and lower body policies. |
|||
|
|||
Args: |
|||
goal: Command from the planners |
|||
goal["target_upper_body_pose"]: Target pose for the upper body policy |
|||
goal["target_time"]: Target goal time |
|||
goal["interpolation_garbage_collection_time"]: Waypoints earlier than this time are removed |
|||
goal["navigate_cmd"]: Target navigation velocities for the lower body policy |
|||
goal["base_height_command"]: Target base height for both upper and lower body policies |
|||
""" |
|||
# Update goal timestamp for timeout safety |
|||
self.last_goal_time = time_module.monotonic() |
|||
|
|||
upper_body_goal = {} |
|||
lower_body_goal = {} |
|||
|
|||
# Upper body goal keys |
|||
upper_body_keys = [ |
|||
"target_upper_body_pose", |
|||
"base_height_command", |
|||
"target_time", |
|||
"interpolation_garbage_collection_time", |
|||
"navigate_cmd", |
|||
] |
|||
for key in upper_body_keys: |
|||
if key in goal: |
|||
upper_body_goal[key] = goal[key] |
|||
|
|||
# Always ensure navigate_cmd is present to prevent interpolation from old dangerous values |
|||
if "navigate_cmd" not in goal: |
|||
# Safety: Inject safe default navigate_cmd to ensure interpolation goes to stop |
|||
if "target_time" in goal and isinstance(goal["target_time"], list): |
|||
upper_body_goal["navigate_cmd"] = [np.array(DEFAULT_NAV_CMD)] * len( |
|||
goal["target_time"] |
|||
) |
|||
else: |
|||
upper_body_goal["navigate_cmd"] = np.array(DEFAULT_NAV_CMD) |
|||
|
|||
# Set teleop policy command flag |
|||
has_teleop_commands = ("navigate_cmd" in goal) or ("base_height_command" in goal) |
|||
self.is_in_teleop_mode = has_teleop_commands # Track teleop state for timeout safety |
|||
self.lower_body_policy.set_use_teleop_policy_cmd(has_teleop_commands) |
|||
|
|||
# Lower body goal keys |
|||
lower_body_keys = [ |
|||
"toggle_stand_command", |
|||
"toggle_policy_action", |
|||
] |
|||
for key in lower_body_keys: |
|||
if key in goal: |
|||
lower_body_goal[key] = goal[key] |
|||
|
|||
self.upper_body_policy.set_goal(upper_body_goal) |
|||
self.lower_body_policy.set_goal(lower_body_goal) |
|||
|
|||
def get_action(self, time: Optional[float] = None): |
|||
current_time = time if time is not None else time_module.monotonic() |
|||
|
|||
# Safety timeout: Only apply when in teleop mode (communication loss dangerous) |
|||
# When in keyboard mode, no timeout needed (user controls directly) |
|||
if self.is_in_teleop_mode: |
|||
time_since_goal = current_time - self.last_goal_time |
|||
if time_since_goal > 1.0: # 1 second timeout |
|||
print( |
|||
f"SAFETY: Teleop mode timeout after {time_since_goal:.1f}s, injecting safe goal" |
|||
) |
|||
# Inject safe goal to trigger all safety mechanisms (gear_wbc reset + interpolation reset) |
|||
safe_goal = { |
|||
"target_time": current_time + 0.1, |
|||
"interpolation_garbage_collection_time": current_time - 1.0, |
|||
} |
|||
self.set_goal( |
|||
safe_goal |
|||
) # This will reset is_in_teleop_mode to False and trigger all safety |
|||
|
|||
# Get indices for groups |
|||
lower_body_indices = self.robot_model.get_joint_group_indices("lower_body") |
|||
upper_body_indices = self.robot_model.get_joint_group_indices("upper_body") |
|||
|
|||
# Initialize full configuration with zeros |
|||
q = np.zeros(self.robot_model.num_dofs) |
|||
|
|||
upper_body_action = self.upper_body_policy.get_action(time) |
|||
q[upper_body_indices] = upper_body_action["target_upper_body_pose"] |
|||
q_arms = q[self.robot_model.get_joint_group_indices("arms")] |
|||
base_height_command = upper_body_action.get("base_height_command", None) |
|||
interpolated_navigate_cmd = upper_body_action.get("navigate_cmd", None) |
|||
|
|||
# Compute torso orientation relative to waist, to pass to lower body policy |
|||
self.robot_model.cache_forward_kinematics(q, auto_clip=False) |
|||
torso_orientation = self.robot_model.frame_placement("torso_link").rotation |
|||
waist_orientation = self.robot_model.frame_placement("pelvis").rotation |
|||
# Extract yaw from rotation matrix and create a rotation with only yaw |
|||
# The rotation property is a 3x3 numpy array |
|||
waist_yaw = np.arctan2(waist_orientation[1, 0], waist_orientation[0, 0]) |
|||
# Create a rotation matrix with only yaw using Pinocchio's rpy functions |
|||
waist_yaw_only_rotation = rpy.rpyToMatrix(0, 0, waist_yaw) |
|||
yaw_only_waist_from_torso = waist_yaw_only_rotation.T @ torso_orientation |
|||
torso_orientation_rpy = rpy.matrixToRpy(yaw_only_waist_from_torso) |
|||
|
|||
lower_body_action = self.lower_body_policy.get_action( |
|||
time, q_arms, base_height_command, torso_orientation_rpy, interpolated_navigate_cmd |
|||
) |
|||
|
|||
# If pelvis is both in upper and lower body, lower body policy takes preference |
|||
q[lower_body_indices] = lower_body_action["body_action"][0][ |
|||
: len(lower_body_indices) |
|||
] # lower body (legs + waist) |
|||
|
|||
self.last_action = {"q": q} |
|||
|
|||
return {"q": q} |
|||
|
|||
def handle_keyboard_button(self, key): |
|||
try: |
|||
self.lower_body_policy.locomotion_policy.handle_keyboard_button(key) |
|||
except AttributeError: |
|||
# Only catch AttributeError, let other exceptions propagate |
|||
self.lower_body_policy.handle_keyboard_button(key) |
|||
|
|||
def activate_policy(self): |
|||
self.handle_keyboard_button("]") |
|||
@ -0,0 +1,295 @@ |
|||
import collections |
|||
from pathlib import Path |
|||
from typing import Any, Dict, Optional |
|||
|
|||
import numpy as np |
|||
import onnxruntime as ort |
|||
import torch |
|||
|
|||
from decoupled_wbc.control.base.policy import Policy |
|||
from decoupled_wbc.control.utils.gear_wbc_utils import get_gravity_orientation, load_config |
|||
|
|||
|
|||
class G1GearWbcPolicy(Policy): |
|||
"""Simple G1 robot policy using OpenGearWbc trained neural network.""" |
|||
|
|||
def __init__(self, robot_model, config: str, model_path: str): |
|||
"""Initialize G1GearWbcPolicy. |
|||
|
|||
Args: |
|||
config_path: Path to gear_wbc YAML configuration file |
|||
""" |
|||
self.config, self.LEGGED_GYM_ROOT_DIR = load_config(config) |
|||
self.robot_model = robot_model |
|||
self.use_teleop_policy_cmd = False |
|||
|
|||
package_root = Path(__file__).resolve().parents[2] |
|||
self.sim2mujoco_root_dir = str(package_root / "sim2mujoco") |
|||
model_path_1, model_path_2 = model_path.split(",") |
|||
|
|||
self.policy_1 = self.load_onnx_policy( |
|||
self.sim2mujoco_root_dir + "/resources/robots/g1/" + model_path_1 |
|||
) |
|||
self.policy_2 = self.load_onnx_policy( |
|||
self.sim2mujoco_root_dir + "/resources/robots/g1/" + model_path_2 |
|||
) |
|||
|
|||
# Initialize observation history buffer |
|||
self.observation = None |
|||
self.obs_history = collections.deque(maxlen=self.config["obs_history_len"]) |
|||
self.obs_buffer = np.zeros(self.config["num_obs"], dtype=np.float32) |
|||
self.counter = 0 |
|||
|
|||
# Initialize state variables |
|||
self.use_policy_action = False |
|||
self.action = np.zeros(self.config["num_actions"], dtype=np.float32) |
|||
self.target_dof_pos = self.config["default_angles"].copy() |
|||
self.cmd = self.config["cmd_init"].copy() |
|||
self.height_cmd = self.config["height_cmd"] |
|||
self.freq_cmd = self.config["freq_cmd"] |
|||
self.roll_cmd = self.config["rpy_cmd"][0] |
|||
self.pitch_cmd = self.config["rpy_cmd"][1] |
|||
self.yaw_cmd = self.config["rpy_cmd"][2] |
|||
self.gait_indices = torch.zeros((1), dtype=torch.float32) |
|||
|
|||
def load_onnx_policy(self, model_path: str): |
|||
print(f"Loading ONNX policy from {model_path}") |
|||
model = ort.InferenceSession(model_path) |
|||
|
|||
def run_inference(input_tensor): |
|||
ort_inputs = {model.get_inputs()[0].name: input_tensor.cpu().numpy()} |
|||
ort_outs = model.run(None, ort_inputs) |
|||
return torch.tensor(ort_outs[0], device="cpu") |
|||
|
|||
print(f"Successfully loaded ONNX policy from {model_path}") |
|||
|
|||
return run_inference |
|||
|
|||
def compute_observation(self, observation: Dict[str, Any]) -> tuple[np.ndarray, int]: |
|||
"""Compute the observation vector from current state""" |
|||
# Get body joint indices (excluding waist roll and pitch) |
|||
self.gait_indices = torch.remainder(self.gait_indices + 0.02 * self.freq_cmd, 1.0) |
|||
durations = torch.full_like(self.gait_indices, 0.5) |
|||
phases = 0.5 |
|||
foot_indices = [ |
|||
self.gait_indices + phases, # FL |
|||
self.gait_indices, # FR |
|||
] |
|||
self.foot_indices = torch.remainder( |
|||
torch.cat([foot_indices[i].unsqueeze(1) for i in range(2)], dim=1), 1.0 |
|||
) |
|||
for fi in foot_indices: |
|||
stance = fi < durations |
|||
swing = fi >= durations |
|||
fi[stance] = fi[stance] * (0.5 / durations[stance]) |
|||
fi[swing] = 0.5 + (fi[swing] - durations[swing]) * (0.5 / (1 - durations[swing])) |
|||
|
|||
self.clock_inputs = torch.stack([torch.sin(2 * np.pi * fi) for fi in foot_indices], dim=1) |
|||
|
|||
body_indices = self.robot_model.get_joint_group_indices("body") |
|||
body_indices = [idx for idx in body_indices] |
|||
|
|||
n_joints = len(body_indices) |
|||
|
|||
# Extract joint data |
|||
qj = observation["q"][body_indices].copy() |
|||
dqj = observation["dq"][body_indices].copy() |
|||
|
|||
# Extract floating base data |
|||
quat = observation["floating_base_pose"][3:7].copy() # quaternion |
|||
omega = observation["floating_base_vel"][3:6].copy() # angular velocity |
|||
|
|||
# Handle default angles padding |
|||
if len(self.config["default_angles"]) < n_joints: |
|||
padded_defaults = np.zeros(n_joints, dtype=np.float32) |
|||
padded_defaults[: len(self.config["default_angles"])] = self.config["default_angles"] |
|||
else: |
|||
padded_defaults = self.config["default_angles"][:n_joints] |
|||
|
|||
# Scale the values |
|||
qj_scaled = (qj - padded_defaults) * self.config["dof_pos_scale"] |
|||
dqj_scaled = dqj * self.config["dof_vel_scale"] |
|||
gravity_orientation = get_gravity_orientation(quat) |
|||
omega_scaled = omega * self.config["ang_vel_scale"] |
|||
|
|||
# Calculate single observation dimension |
|||
single_obs_dim = 86 # 3 + 1 + 3 + 3 + 3 + n_joints + n_joints + 15, n_joints = 29 |
|||
|
|||
# Create single observation |
|||
single_obs = np.zeros(single_obs_dim, dtype=np.float32) |
|||
single_obs[0:3] = self.cmd[:3] * self.config["cmd_scale"] |
|||
single_obs[3:4] = np.array([self.height_cmd]) |
|||
single_obs[4:7] = np.array([self.roll_cmd, self.pitch_cmd, self.yaw_cmd]) |
|||
single_obs[7:10] = omega_scaled |
|||
single_obs[10:13] = gravity_orientation |
|||
# single_obs[14:17] = omega_scaled_torso |
|||
# single_obs[17:20] = gravity_torso |
|||
single_obs[13 : 13 + n_joints] = qj_scaled |
|||
single_obs[13 + n_joints : 13 + 2 * n_joints] = dqj_scaled |
|||
single_obs[13 + 2 * n_joints : 13 + 2 * n_joints + 15] = self.action |
|||
# single_obs[13 + 2 * n_joints + 15 : 13 + 2 * n_joints + 15 + 2] = ( |
|||
# processed_clock_inputs.detach().cpu().numpy() |
|||
# ) |
|||
return single_obs, single_obs_dim |
|||
|
|||
def set_observation(self, observation: Dict[str, Any]): |
|||
"""Update the policy's current observation of the environment. |
|||
|
|||
Args: |
|||
observation: Dictionary containing single observation from current state |
|||
Should include 'obs' key with current single observation |
|||
""" |
|||
|
|||
# Extract the single observation |
|||
self.observation = observation |
|||
single_obs, single_obs_dim = self.compute_observation(observation) |
|||
|
|||
# Update observation history every control_decimation steps |
|||
# if self.counter % self.config['control_decimation'] == 0: |
|||
# Add current observation to history |
|||
self.obs_history.append(single_obs) |
|||
|
|||
# Fill history with zeros if not enough observations yet |
|||
while len(self.obs_history) < self.config["obs_history_len"]: |
|||
self.obs_history.appendleft(np.zeros_like(single_obs)) |
|||
|
|||
# Construct full observation with history |
|||
single_obs_dim = len(single_obs) |
|||
for i, hist_obs in enumerate(self.obs_history): |
|||
start_idx = i * single_obs_dim |
|||
end_idx = start_idx + single_obs_dim |
|||
self.obs_buffer[start_idx:end_idx] = hist_obs |
|||
|
|||
# Convert to tensor for policy |
|||
self.obs_tensor = torch.from_numpy(self.obs_buffer).unsqueeze(0) |
|||
# self.counter += 1 |
|||
|
|||
assert self.obs_tensor.shape[1] == self.config["num_obs"] |
|||
|
|||
def set_use_teleop_policy_cmd(self, use_teleop_policy_cmd: bool): |
|||
self.use_teleop_policy_cmd = use_teleop_policy_cmd |
|||
# Safety: When teleop is disabled, reset navigation to stop |
|||
if not use_teleop_policy_cmd: |
|||
self.nav_cmd = self.config["cmd_init"].copy() # Reset to safe default |
|||
|
|||
def set_goal(self, goal: Dict[str, Any]): |
|||
"""Set the goal for the policy. |
|||
|
|||
Args: |
|||
goal: Dictionary containing the goal for the policy |
|||
""" |
|||
|
|||
if "toggle_policy_action" in goal: |
|||
if goal["toggle_policy_action"]: |
|||
self.use_policy_action = not self.use_policy_action |
|||
|
|||
def get_action( |
|||
self, |
|||
time: Optional[float] = None, |
|||
arms_target_pose: Optional[np.ndarray] = None, |
|||
base_height_command: Optional[np.ndarray] = None, |
|||
torso_orientation_rpy: Optional[np.ndarray] = None, |
|||
interpolated_navigate_cmd: Optional[np.ndarray] = None, |
|||
) -> Dict[str, Any]: |
|||
"""Compute and return the next action based on current observation. |
|||
|
|||
Args: |
|||
time: Optional "monotonic time" for time-dependent policies (unused) |
|||
|
|||
Returns: |
|||
Dictionary containing the action to be executed |
|||
""" |
|||
if self.obs_tensor is None: |
|||
raise ValueError("No observation set. Call set_observation() first.") |
|||
|
|||
if base_height_command is not None and self.use_teleop_policy_cmd: |
|||
self.height_cmd = ( |
|||
base_height_command[0] |
|||
if isinstance(base_height_command, list) |
|||
else base_height_command |
|||
) |
|||
|
|||
if interpolated_navigate_cmd is not None and self.use_teleop_policy_cmd: |
|||
self.cmd = interpolated_navigate_cmd |
|||
|
|||
if torso_orientation_rpy is not None and self.use_teleop_policy_cmd: |
|||
self.roll_cmd = torso_orientation_rpy[0] |
|||
self.pitch_cmd = torso_orientation_rpy[1] |
|||
self.yaw_cmd = torso_orientation_rpy[2] |
|||
|
|||
# Run policy inference |
|||
with torch.no_grad(): |
|||
# Select appropriate policy based on command magnitude |
|||
if np.linalg.norm(self.cmd) < 0.05: |
|||
# Use standing policy for small commands |
|||
policy = self.policy_1 |
|||
else: |
|||
# Use walking policy for movement commands |
|||
policy = self.policy_2 |
|||
|
|||
self.action = policy(self.obs_tensor).detach().numpy().squeeze() |
|||
|
|||
# Transform action to target_dof_pos |
|||
if self.use_policy_action: |
|||
cmd_q = self.action * self.config["action_scale"] + self.config["default_angles"] |
|||
else: |
|||
cmd_q = self.observation["q"][self.robot_model.get_joint_group_indices("lower_body")] |
|||
|
|||
cmd_dq = np.zeros(self.config["num_actions"]) |
|||
cmd_tau = np.zeros(self.config["num_actions"]) |
|||
|
|||
return {"body_action": (cmd_q, cmd_dq, cmd_tau)} |
|||
|
|||
def handle_keyboard_button(self, key): |
|||
if key == "]": |
|||
self.use_policy_action = True |
|||
elif key == "o": |
|||
self.use_policy_action = False |
|||
elif key == "w": |
|||
self.cmd[0] += 0.2 |
|||
elif key == "s": |
|||
self.cmd[0] -= 0.2 |
|||
elif key == "a": |
|||
self.cmd[1] += 0.2 |
|||
elif key == "d": |
|||
self.cmd[1] -= 0.2 |
|||
elif key == "q": |
|||
self.cmd[2] += 0.2 |
|||
elif key == "e": |
|||
self.cmd[2] -= 0.2 |
|||
elif key == "z": |
|||
self.cmd[0] = 0.0 |
|||
self.cmd[1] = 0.0 |
|||
self.cmd[2] = 0.0 |
|||
elif key == "1": |
|||
self.height_cmd += 0.1 |
|||
elif key == "2": |
|||
self.height_cmd -= 0.1 |
|||
elif key == "n": |
|||
self.freq_cmd -= 0.1 |
|||
self.freq_cmd = max(1.0, self.freq_cmd) |
|||
elif key == "m": |
|||
self.freq_cmd += 0.1 |
|||
self.freq_cmd = min(2.0, self.freq_cmd) |
|||
elif key == "3": |
|||
self.roll_cmd -= np.deg2rad(10) |
|||
elif key == "4": |
|||
self.roll_cmd += np.deg2rad(10) |
|||
elif key == "5": |
|||
self.pitch_cmd -= np.deg2rad(10) |
|||
elif key == "6": |
|||
self.pitch_cmd += np.deg2rad(10) |
|||
elif key == "7": |
|||
self.yaw_cmd -= np.deg2rad(10) |
|||
elif key == "8": |
|||
self.yaw_cmd += np.deg2rad(10) |
|||
|
|||
if key: |
|||
print("--------------------------------") |
|||
print(f"Linear velocity command: {self.cmd}") |
|||
print(f"Base height command: {self.height_cmd}") |
|||
print(f"Use policy action: {self.use_policy_action}") |
|||
print(f"roll deg angle: {np.rad2deg(self.roll_cmd)}") |
|||
print(f"pitch deg angle: {np.rad2deg(self.pitch_cmd)}") |
|||
print(f"yaw deg angle: {np.rad2deg(self.yaw_cmd)}") |
|||
print(f"Gait frequency: {self.freq_cmd}") |
|||
@ -0,0 +1,25 @@ |
|||
from copy import deepcopy |
|||
from typing import Optional |
|||
|
|||
import gymnasium as gym |
|||
|
|||
from decoupled_wbc.control.base.policy import Policy |
|||
|
|||
|
|||
class IdentityPolicy(Policy): |
|||
def __init__(self): |
|||
self.reset() |
|||
|
|||
def get_action(self, time: Optional[float] = None) -> dict[str, any]: |
|||
return self.goal |
|||
|
|||
def set_goal(self, goal: dict[str, any]) -> None: |
|||
self.goal = deepcopy(goal) |
|||
self.goal.pop("interpolation_garbage_collection_time", None) |
|||
self.goal.pop("target_time", None) |
|||
|
|||
def observation_space(self) -> gym.spaces.Dict: |
|||
return gym.spaces.Dict() |
|||
|
|||
def action_space(self) -> gym.spaces.Dict: |
|||
return gym.spaces.Dict() |
|||
@ -0,0 +1,297 @@ |
|||
import numbers |
|||
import time as time_module |
|||
from typing import Any, Dict, Optional, Union |
|||
|
|||
import gymnasium as gym |
|||
import numpy as np |
|||
import scipy.interpolate as si |
|||
|
|||
from decoupled_wbc.control.base.policy import Policy |
|||
|
|||
|
|||
class InterpolationPolicy(Policy): |
|||
def __init__( |
|||
self, |
|||
init_time: float, |
|||
init_values: dict[str, np.ndarray], |
|||
max_change_rate: float, |
|||
): |
|||
""" |
|||
Args: |
|||
init_time: The time of recording the initial values. |
|||
init_values: The initial values of the features. |
|||
The keys are the names of the features, and the values |
|||
are the initial values of the features (1D array). |
|||
max_change_rate: The maximum change rate. |
|||
""" |
|||
super().__init__() |
|||
self.last_action = init_values # Vecs are 1D arrays |
|||
self.concat_order = sorted(init_values.keys()) |
|||
self.concat_dims = [] |
|||
for key in self.concat_order: |
|||
vec = np.array(init_values[key]) |
|||
if vec.ndim == 2 and vec.shape[0] == 1: |
|||
vec = vec[0] |
|||
init_values[key] = vec |
|||
assert vec.ndim == 1, f"The shape of {key} should be (D,). Got {vec.shape}." |
|||
self.concat_dims.append(vec.shape[0]) |
|||
|
|||
self.init_values_concat = self._concat_vecs(init_values, 1) |
|||
self.max_change_rate = max_change_rate |
|||
self.reset(init_time) |
|||
|
|||
def reset(self, init_time: float = time_module.monotonic()): |
|||
self.interp = PoseTrajectoryInterpolator(np.array([init_time]), self.init_values_concat) |
|||
self.last_waypoint_time = init_time |
|||
self.max_change_rate = self.max_change_rate |
|||
|
|||
def _concat_vecs(self, values: dict[str, np.ndarray], length: int) -> np.ndarray: |
|||
""" |
|||
Concatenate the vectors into a 2D array to be used for interpolation. |
|||
Args: |
|||
values: The values to concatenate. |
|||
length: The length of the concatenated vectors (time dimension). |
|||
Returns: |
|||
The concatenated vectors (T, D) arrays. |
|||
""" |
|||
concat_vecs = [] |
|||
for key in self.concat_order: |
|||
if key in values: |
|||
vec = np.array(values[key]) |
|||
if vec.ndim == 1: |
|||
# If the vector is 1D, tile it to the length of the time dimension |
|||
vec = np.tile(vec, (length, 1)) |
|||
assert vec.ndim == 2, f"The shape of {key} should be (T, D). Got {vec.shape}." |
|||
concat_vecs.append(vec) |
|||
else: |
|||
# If the vector is not in the values, use the last action |
|||
# Since the last action is 1D, we need to tile it to the length of the time dimension |
|||
concat_vecs.append(np.tile(self.last_action[key], (length, 1))) |
|||
return np.concatenate(concat_vecs, axis=1) # Vecs are 2D (T, D) arrays |
|||
|
|||
def _unconcat_vecs(self, concat_vec: np.ndarray) -> dict[str, np.ndarray]: |
|||
curr_idx = 0 |
|||
action = {} |
|||
assert ( |
|||
concat_vec.ndim == 1 |
|||
), f"The shape of the concatenated vectors should be (T, D). Got {concat_vec.shape}." |
|||
for key, dim in zip(self.concat_order, self.concat_dims): |
|||
action[key] = concat_vec[curr_idx : curr_idx + dim] |
|||
curr_idx += dim |
|||
return action # Vecs are 1D arrays |
|||
|
|||
def __call__( |
|||
self, observation: Dict[str, Any], goal: Dict[str, Any], time: float |
|||
) -> Dict[str, np.ndarray]: |
|||
raise NotImplementedError( |
|||
"`InterpolationPolicy` accepts goal and provide action in two separate methods." |
|||
) |
|||
|
|||
def set_goal(self, goal: Dict[str, Any]) -> None: |
|||
if "target_time" not in goal: |
|||
return |
|||
assert ( |
|||
"interpolation_garbage_collection_time" in goal |
|||
), "`interpolation_garbage_collection_time` is required." |
|||
target_time = goal.pop("target_time") |
|||
interpolation_garbage_collection_time = goal.pop("interpolation_garbage_collection_time") |
|||
|
|||
if isinstance(target_time, list): |
|||
for key, vec in goal.items(): |
|||
assert isinstance(vec, list) |
|||
assert len(vec) == len(target_time), ( |
|||
f"The length of {key} and `target_time` should be the same. " |
|||
f"Got {len(vec)} and {len(target_time)}." |
|||
) |
|||
else: |
|||
target_time = [target_time] |
|||
for key in goal: |
|||
goal[key] = [goal[key]] |
|||
|
|||
# Concatenate all vectors in goal |
|||
concat_vecs = self._concat_vecs(goal, len(target_time)) |
|||
assert concat_vecs.shape[0] == len(target_time), ( |
|||
f"The length of the concatenated goal and `target_time` should be the same. " |
|||
f"Got {concat_vecs.shape[0]} and {len(target_time)}." |
|||
) |
|||
|
|||
for tt, vec in zip(target_time, concat_vecs): |
|||
if tt < interpolation_garbage_collection_time: |
|||
continue |
|||
self.interp = self.interp.schedule_waypoint( |
|||
pose=vec, |
|||
time=tt, |
|||
max_change_rate=self.max_change_rate, |
|||
interpolation_garbage_collection_time=interpolation_garbage_collection_time, |
|||
last_waypoint_time=self.last_waypoint_time, |
|||
) |
|||
self.last_waypoint_time = tt |
|||
|
|||
def get_action(self, time: Optional[float] = None) -> dict[str, Any]: |
|||
"""Get the next action based on the (current) monotonic time.""" |
|||
if time is None: |
|||
time = time_module.monotonic() |
|||
concat_vec = self.interp(time) |
|||
self.last_action.update(self._unconcat_vecs(concat_vec)) |
|||
return self.last_action |
|||
|
|||
def observation_space(self) -> gym.spaces.Dict: |
|||
"""Return the observation space.""" |
|||
pass |
|||
|
|||
def action_space(self) -> gym.spaces.Dict: |
|||
"""Return the action space.""" |
|||
pass |
|||
|
|||
def close(self) -> None: |
|||
"""Clean up resources.""" |
|||
pass |
|||
|
|||
|
|||
class PoseTrajectoryInterpolator: |
|||
def __init__(self, times: np.ndarray, poses: np.ndarray): |
|||
assert len(times) >= 1 |
|||
assert len(poses) == len(times) |
|||
|
|||
times = np.asarray(times) |
|||
poses = np.asarray(poses) |
|||
|
|||
self.num_joint = len(poses[0]) |
|||
|
|||
if len(times) == 1: |
|||
# special treatment for single step interpolation |
|||
self.single_step = True |
|||
self._times = times |
|||
self._poses = poses |
|||
else: |
|||
self.single_step = False |
|||
assert np.all(times[1:] >= times[:-1]) |
|||
self.pose_interp = si.interp1d(times, poses, axis=0, assume_sorted=True) |
|||
|
|||
@property |
|||
def times(self) -> np.ndarray: |
|||
if self.single_step: |
|||
return self._times |
|||
else: |
|||
return self.pose_interp.x |
|||
|
|||
@property |
|||
def poses(self) -> np.ndarray: |
|||
if self.single_step: |
|||
return self._poses |
|||
else: |
|||
return self.pose_interp.y |
|||
|
|||
def trim(self, start_t: float, end_t: float) -> "PoseTrajectoryInterpolator": |
|||
assert start_t <= end_t |
|||
times = self.times |
|||
should_keep = (start_t < times) & (times < end_t) |
|||
keep_times = times[should_keep] |
|||
all_times = np.concatenate([[start_t], keep_times, [end_t]]) |
|||
# remove duplicates, Slerp requires strictly increasing x |
|||
all_times = np.unique(all_times) |
|||
# interpolate |
|||
all_poses = self(all_times) |
|||
return PoseTrajectoryInterpolator(times=all_times, poses=all_poses) |
|||
|
|||
def schedule_waypoint( |
|||
self, |
|||
pose, |
|||
time, |
|||
max_change_rate=np.inf, |
|||
interpolation_garbage_collection_time=None, |
|||
last_waypoint_time=None, |
|||
) -> "PoseTrajectoryInterpolator": |
|||
if not isinstance(max_change_rate, np.ndarray): |
|||
max_change_rate = np.array([max_change_rate] * self.num_joint) |
|||
|
|||
assert len(max_change_rate) == self.num_joint |
|||
assert np.max(max_change_rate) > 0 |
|||
|
|||
if last_waypoint_time is not None: |
|||
assert interpolation_garbage_collection_time is not None |
|||
|
|||
# trim current interpolator to between interpolation_garbage_collection_time and last_waypoint_time |
|||
start_time = self.times[0] |
|||
end_time = self.times[-1] |
|||
assert start_time <= end_time |
|||
if interpolation_garbage_collection_time is not None: |
|||
if time <= interpolation_garbage_collection_time: |
|||
# if insert time is earlier than current time |
|||
# no effect should be done to the interpolator |
|||
return self |
|||
# now, interpolation_garbage_collection_time < time |
|||
start_time = max(interpolation_garbage_collection_time, start_time) |
|||
|
|||
if last_waypoint_time is not None: |
|||
# if last_waypoint_time is earlier than start_time |
|||
# use start_time |
|||
if time <= last_waypoint_time: |
|||
end_time = interpolation_garbage_collection_time |
|||
else: |
|||
end_time = max(last_waypoint_time, interpolation_garbage_collection_time) |
|||
else: |
|||
end_time = interpolation_garbage_collection_time |
|||
|
|||
end_time = min(end_time, time) |
|||
start_time = min(start_time, end_time) |
|||
# end time should be the latest of all times except time |
|||
# after this we can assume order (proven by zhenjia, due to the 2 min operations) |
|||
|
|||
# Constraints: |
|||
# start_time <= end_time <= time (proven by zhenjia) |
|||
# interpolation_garbage_collection_time <= start_time (proven by zhenjia) |
|||
# interpolation_garbage_collection_time <= time (proven by zhenjia) |
|||
|
|||
# time can't change |
|||
# last_waypoint_time can't change |
|||
# interpolation_garbage_collection_time can't change |
|||
assert start_time <= end_time |
|||
assert end_time <= time |
|||
if last_waypoint_time is not None: |
|||
if time <= last_waypoint_time: |
|||
assert end_time == interpolation_garbage_collection_time |
|||
else: |
|||
assert end_time == max(last_waypoint_time, interpolation_garbage_collection_time) |
|||
|
|||
if interpolation_garbage_collection_time is not None: |
|||
assert interpolation_garbage_collection_time <= start_time |
|||
assert interpolation_garbage_collection_time <= time |
|||
trimmed_interp = self.trim(start_time, end_time) |
|||
# after this, all waypoints in trimmed_interp is within start_time and end_time |
|||
# and is earlier than time |
|||
|
|||
# determine speed |
|||
duration = time - end_time |
|||
end_pose = trimmed_interp(end_time) |
|||
pose_min_duration = np.max(np.abs(end_pose - pose) / max_change_rate) |
|||
duration = max(duration, pose_min_duration) |
|||
assert duration >= 0 |
|||
last_waypoint_time = end_time + duration |
|||
|
|||
# insert new pose |
|||
times = np.append(trimmed_interp.times, [last_waypoint_time], axis=0) |
|||
poses = np.append(trimmed_interp.poses, [pose], axis=0) |
|||
|
|||
# create new interpolator |
|||
final_interp = PoseTrajectoryInterpolator(times, poses) |
|||
return final_interp |
|||
|
|||
def __call__(self, t: Union[numbers.Number, np.ndarray]) -> np.ndarray: |
|||
is_single = False |
|||
if isinstance(t, numbers.Number): |
|||
is_single = True |
|||
t = np.array([t]) |
|||
|
|||
pose = np.zeros((len(t), self.num_joint)) |
|||
if self.single_step: |
|||
pose[:] = self._poses[0] |
|||
else: |
|||
start_time = self.times[0] |
|||
end_time = self.times[-1] |
|||
t = np.clip(t, start_time, end_time) |
|||
pose = self.pose_interp(t) |
|||
|
|||
if is_single: |
|||
pose = pose[0] |
|||
return pose |
|||
@ -0,0 +1,87 @@ |
|||
from typing import Any, Dict, Optional |
|||
|
|||
import numpy as np |
|||
|
|||
from decoupled_wbc.control.base.policy import Policy |
|||
|
|||
|
|||
class KeyboardNavigationPolicy(Policy): |
|||
def __init__( |
|||
self, |
|||
max_linear_velocity: float = 0.5, |
|||
max_angular_velocity: float = 0.5, |
|||
verbose: bool = True, |
|||
**kwargs, |
|||
): |
|||
""" |
|||
Initialize the navigation policy. |
|||
|
|||
Args: |
|||
max_linear_velocity: Maximum linear velocity in m/s (for x and y components) |
|||
max_angular_velocity: Maximum angular velocity in rad/s (for yaw component) |
|||
**kwargs: Additional arguments passed to the base Policy class |
|||
""" |
|||
super().__init__(**kwargs) |
|||
self.max_linear_velocity = max_linear_velocity |
|||
self.max_angular_velocity = max_angular_velocity |
|||
self.verbose = verbose |
|||
|
|||
# Initialize velocity commands |
|||
self.lin_vel_command = np.zeros(2, dtype=np.float32) # [vx, vy] |
|||
self.ang_vel_command = np.zeros(1, dtype=np.float32) # [wz] |
|||
|
|||
def get_action(self, time: Optional[float] = None) -> Dict[str, Any]: |
|||
""" |
|||
Get the action to execute based on current state. |
|||
|
|||
Args: |
|||
time: Current time (optional) |
|||
|
|||
Returns: |
|||
Dict containing the action to execute with: |
|||
- navigate_cmd: np.array([vx, vy, wz]) where: |
|||
- vx: linear velocity in x direction (m/s) |
|||
- vy: linear velocity in y direction (m/s) |
|||
- wz: angular velocity around z axis (rad/s) |
|||
""" |
|||
# Combine linear and angular velocities into a single command |
|||
# Ensure velocities are within limits |
|||
vx = np.clip(self.lin_vel_command[0], -self.max_linear_velocity, self.max_linear_velocity) |
|||
vy = np.clip(self.lin_vel_command[1], -self.max_linear_velocity, self.max_linear_velocity) |
|||
wz = np.clip(self.ang_vel_command[0], -self.max_angular_velocity, self.max_angular_velocity) |
|||
|
|||
navigate_cmd = np.array([vx, vy, wz], dtype=np.float32) |
|||
|
|||
action = {"navigate_cmd": navigate_cmd} |
|||
return action |
|||
|
|||
def handle_keyboard_button(self, keycode: str): |
|||
""" |
|||
Handle keyboard inputs for navigation control. |
|||
|
|||
Args: |
|||
keycode: The key that was pressed |
|||
""" |
|||
if keycode == "w": |
|||
self.lin_vel_command[0] += 0.1 # Increase forward velocity |
|||
elif keycode == "s": |
|||
self.lin_vel_command[0] -= 0.1 # Increase backward velocity |
|||
elif keycode == "a": |
|||
self.lin_vel_command[1] += 0.1 # Increase left velocity |
|||
elif keycode == "d": |
|||
self.lin_vel_command[1] -= 0.1 # Increase right velocity |
|||
elif keycode == "q": |
|||
self.ang_vel_command[0] += 0.1 # Increase counter-clockwise rotation |
|||
elif keycode == "e": |
|||
self.ang_vel_command[0] -= 0.1 # Increase clockwise rotation |
|||
elif keycode == "z": |
|||
# Reset all velocities |
|||
self.lin_vel_command[:] = 0.0 |
|||
self.ang_vel_command[:] = 0.0 |
|||
if self.verbose: |
|||
print("Navigation policy: Reset all velocity commands to zero") |
|||
|
|||
# Print current velocities after any keyboard input |
|||
if self.verbose: |
|||
print(f"Nav lin vel: ({self.lin_vel_command[0]:.2f}, {self.lin_vel_command[1]:.2f})") |
|||
print(f"Nav ang vel: {self.ang_vel_command[0]:.2f}") |
|||
@ -0,0 +1,111 @@ |
|||
import time |
|||
|
|||
import pandas as pd |
|||
|
|||
from decoupled_wbc.control.base.policy import Policy |
|||
from decoupled_wbc.control.main.constants import ( |
|||
DEFAULT_BASE_HEIGHT, |
|||
DEFAULT_NAV_CMD, |
|||
DEFAULT_WRIST_POSE, |
|||
) |
|||
from decoupled_wbc.control.robot_model.robot_model import RobotModel |
|||
from decoupled_wbc.data.viz.rerun_viz import RerunViz |
|||
|
|||
|
|||
class LerobotReplayPolicy(Policy): |
|||
"""Replay policy for Lerobot dataset, so we can replay the dataset |
|||
and just use the action from the dataset. |
|||
|
|||
Args: |
|||
parquet_path: Path to the parquet file containing the dataset. |
|||
""" |
|||
|
|||
is_active = True # by default, the replay policy is active |
|||
|
|||
def __init__(self, robot_model: RobotModel, parquet_path: str, use_viz: bool = False): |
|||
# self.dataset = LerobotDataset(dataset_path) |
|||
self.parquet_path = parquet_path |
|||
self._ctr = 0 |
|||
# read the parquet file |
|||
self.df = pd.read_parquet(self.parquet_path) |
|||
self._max_ctr = len(self.df) |
|||
# get the action from the dataframe |
|||
self.action = self.df.iloc[self._ctr]["action"] |
|||
self.use_viz = use_viz |
|||
if self.use_viz: |
|||
self.viz = RerunViz( |
|||
image_keys=["egoview_image"], |
|||
tensor_keys=[ |
|||
"left_arm_qpos", |
|||
"left_hand_qpos", |
|||
"right_arm_qpos", |
|||
"right_hand_qpos", |
|||
], |
|||
window_size=5.0, |
|||
) |
|||
self.robot_model = robot_model |
|||
self.upper_body_joint_indices = self.robot_model.get_joint_group_indices("upper_body") |
|||
|
|||
def get_action(self) -> dict[str, any]: |
|||
# get the action from the dataframe |
|||
action = self.df.iloc[self._ctr]["action"] |
|||
wrist_pose = self.df.iloc[self._ctr]["action.eef"] |
|||
navigate_cmd = self.df.iloc[self._ctr].get("teleop.navigate_command", DEFAULT_NAV_CMD) |
|||
base_height_cmd = self.df.iloc[self._ctr].get( |
|||
"teleop.base_height_command", DEFAULT_BASE_HEIGHT |
|||
) |
|||
|
|||
self._ctr += 1 |
|||
if self._ctr >= self._max_ctr: |
|||
self._ctr = 0 |
|||
# print(f"Replay {self._ctr} / {self._max_ctr}") |
|||
if self.use_viz: |
|||
self.viz.plot_tensors( |
|||
{ |
|||
"left_arm_qpos": action[self.robot_model.get_joint_group_indices("left_arm")] |
|||
+ 15, |
|||
"left_hand_qpos": action[self.robot_model.get_joint_group_indices("left_hand")] |
|||
+ 15, |
|||
"right_arm_qpos": action[self.robot_model.get_joint_group_indices("right_arm")] |
|||
+ 15, |
|||
"right_hand_qpos": action[ |
|||
self.robot_model.get_joint_group_indices("right_hand") |
|||
] |
|||
+ 15, |
|||
}, |
|||
time.monotonic(), |
|||
) |
|||
|
|||
return { |
|||
"target_upper_body_pose": action[self.upper_body_joint_indices], |
|||
"wrist_pose": wrist_pose, |
|||
"navigate_cmd": navigate_cmd, |
|||
"base_height_cmd": base_height_cmd, |
|||
"timestamp": time.time(), |
|||
} |
|||
|
|||
def action_to_cmd(self, action: dict[str, any]) -> dict[str, any]: |
|||
action["target_upper_body_pose"] = action["q"][ |
|||
self.robot_model.get_joint_group_indices("upper_body") |
|||
] |
|||
del action["q"] |
|||
return action |
|||
|
|||
def set_observation(self, observation: dict[str, any]): |
|||
pass |
|||
|
|||
def get_observation(self) -> dict[str, any]: |
|||
return { |
|||
"wrist_pose": self.df.iloc[self._ctr - 1].get( |
|||
"observation.eef_state", DEFAULT_WRIST_POSE |
|||
), |
|||
"timestamp": time.time(), |
|||
} |
|||
|
|||
|
|||
if __name__ == "__main__": |
|||
policy = LerobotReplayPolicy( |
|||
parquet_path="outputs/g1-open-hands-may7/data/chunk-000/episode_000000.parquet" |
|||
) |
|||
action = policy.get_action() |
|||
print(action) |
|||
@ -0,0 +1,207 @@ |
|||
from contextlib import contextmanager |
|||
import time |
|||
from typing import Optional |
|||
|
|||
import numpy as np |
|||
from scipy.spatial.transform import Rotation as R |
|||
|
|||
from decoupled_wbc.control.base.policy import Policy |
|||
from decoupled_wbc.control.robot_model import RobotModel |
|||
from decoupled_wbc.control.teleop.teleop_retargeting_ik import TeleopRetargetingIK |
|||
from decoupled_wbc.control.teleop.teleop_streamer import TeleopStreamer |
|||
|
|||
|
|||
class TeleopPolicy(Policy): |
|||
""" |
|||
Robot-agnostic teleop policy. |
|||
Clean separation: IK processing vs command passing. |
|||
All robot-specific properties are abstracted through robot_model and hand_ik_solvers. |
|||
""" |
|||
|
|||
def __init__( |
|||
self, |
|||
body_control_device: str, |
|||
hand_control_device: str, |
|||
robot_model: RobotModel, |
|||
retargeting_ik: TeleopRetargetingIK, |
|||
body_streamer_ip: str = "192.168.?.?", |
|||
body_streamer_keyword: str = "shoulder", |
|||
enable_real_device: bool = True, |
|||
replay_data_path: Optional[str] = None, |
|||
replay_speed: float = 1.0, |
|||
wait_for_activation: int = 5, |
|||
activate_keyboard_listener: bool = True, |
|||
): |
|||
if activate_keyboard_listener: |
|||
from decoupled_wbc.control.utils.keyboard_dispatcher import KeyboardListenerSubscriber |
|||
|
|||
self.keyboard_listener = KeyboardListenerSubscriber() |
|||
else: |
|||
self.keyboard_listener = None |
|||
|
|||
self.wait_for_activation = wait_for_activation |
|||
|
|||
self.teleop_streamer = TeleopStreamer( |
|||
robot_model=robot_model, |
|||
body_control_device=body_control_device, |
|||
hand_control_device=hand_control_device, |
|||
enable_real_device=enable_real_device, |
|||
body_streamer_ip=body_streamer_ip, |
|||
body_streamer_keyword=body_streamer_keyword, |
|||
replay_data_path=replay_data_path, |
|||
replay_speed=replay_speed, |
|||
) |
|||
self.robot_model = robot_model |
|||
self.retargeting_ik = retargeting_ik |
|||
self.is_active = False |
|||
|
|||
self.latest_left_wrist_data = np.eye(4) |
|||
self.latest_right_wrist_data = np.eye(4) |
|||
self.latest_left_fingers_data = {"position": np.zeros((25, 4, 4))} |
|||
self.latest_right_fingers_data = {"position": np.zeros((25, 4, 4))} |
|||
|
|||
def set_goal(self, goal: dict[str, any]): |
|||
# The current teleop policy doesn't take higher level commands yet. |
|||
pass |
|||
|
|||
def get_action(self) -> dict[str, any]: |
|||
# Get structured data |
|||
streamer_output = self.teleop_streamer.get_streamer_data() |
|||
|
|||
# Handle activation using teleop_data commands |
|||
self.check_activation( |
|||
streamer_output.teleop_data, wait_for_activation=self.wait_for_activation |
|||
) |
|||
|
|||
action = {} |
|||
|
|||
# Process streamer data if active |
|||
if self.is_active and streamer_output.ik_data: |
|||
body_data = streamer_output.ik_data["body_data"] |
|||
left_hand_data = streamer_output.ik_data["left_hand_data"] |
|||
right_hand_data = streamer_output.ik_data["right_hand_data"] |
|||
|
|||
left_wrist_name = self.robot_model.supplemental_info.hand_frame_names["left"] |
|||
right_wrist_name = self.robot_model.supplemental_info.hand_frame_names["right"] |
|||
self.latest_left_wrist_data = body_data[left_wrist_name] |
|||
self.latest_right_wrist_data = body_data[right_wrist_name] |
|||
self.latest_left_fingers_data = left_hand_data |
|||
self.latest_right_fingers_data = right_hand_data |
|||
|
|||
# TODO: This stores the same data again |
|||
ik_data = { |
|||
"body_data": body_data, |
|||
"left_hand_data": left_hand_data, |
|||
"right_hand_data": right_hand_data, |
|||
} |
|||
action["ik_data"] = ik_data |
|||
|
|||
# Wrist poses (pos and quat) |
|||
# TODO: This stores the same wrist poses in two different formats |
|||
left_wrist_matrix = self.latest_left_wrist_data |
|||
right_wrist_matrix = self.latest_right_wrist_data |
|||
left_wrist_pose = np.concatenate( |
|||
[ |
|||
left_wrist_matrix[:3, 3], |
|||
R.from_matrix(left_wrist_matrix[:3, :3]).as_quat(scalar_first=True), |
|||
] |
|||
) |
|||
right_wrist_pose = np.concatenate( |
|||
[ |
|||
right_wrist_matrix[:3, 3], |
|||
R.from_matrix(right_wrist_matrix[:3, :3]).as_quat(scalar_first=True), |
|||
] |
|||
) |
|||
|
|||
# Combine IK results with control commands (no teleop_data commands) |
|||
action.update( |
|||
{ |
|||
"left_wrist": self.latest_left_wrist_data, |
|||
"right_wrist": self.latest_right_wrist_data, |
|||
"left_fingers": self.latest_left_fingers_data, |
|||
"right_fingers": self.latest_right_fingers_data, |
|||
"wrist_pose": np.concatenate([left_wrist_pose, right_wrist_pose]), |
|||
**streamer_output.control_data, # Only control & data collection commands pass through |
|||
**streamer_output.data_collection_data, |
|||
} |
|||
) |
|||
|
|||
# Run retargeting IK |
|||
if "ik_data" in action: |
|||
self.retargeting_ik.set_goal(action["ik_data"]) |
|||
action["target_upper_body_pose"] = self.retargeting_ik.get_action() |
|||
|
|||
return action |
|||
|
|||
def close(self) -> bool: |
|||
self.teleop_streamer.stop_streaming() |
|||
return True |
|||
|
|||
def check_activation(self, teleop_data: dict, wait_for_activation: int = 5): |
|||
"""Activation logic only looks at teleop data commands""" |
|||
key = self.keyboard_listener.read_msg() if self.keyboard_listener else "" |
|||
toggle_activation_by_keyboard = key == "l" |
|||
reset_teleop_policy_by_keyboard = key == "k" |
|||
toggle_activation_by_teleop = teleop_data.get("toggle_activation", False) |
|||
|
|||
if reset_teleop_policy_by_keyboard: |
|||
print("Resetting teleop policy") |
|||
self.reset() |
|||
|
|||
if toggle_activation_by_keyboard or toggle_activation_by_teleop: |
|||
self.is_active = not self.is_active |
|||
if self.is_active: |
|||
print("Starting teleop policy") |
|||
|
|||
if wait_for_activation > 0 and toggle_activation_by_keyboard: |
|||
print(f"Sleeping for {wait_for_activation} seconds before starting teleop...") |
|||
for i in range(wait_for_activation, 0, -1): |
|||
print(f"Starting in {i}...") |
|||
time.sleep(1) |
|||
|
|||
# dda: calibration logic should use current IK data |
|||
self.teleop_streamer.calibrate() |
|||
print("Teleop policy calibrated") |
|||
else: |
|||
print("Stopping teleop policy") |
|||
|
|||
@contextmanager |
|||
def activate(self): |
|||
try: |
|||
yield self |
|||
finally: |
|||
self.close() |
|||
|
|||
def handle_keyboard_button(self, keycode): |
|||
""" |
|||
Handle keyboard input with proper state toggle. |
|||
""" |
|||
if keycode == "l": |
|||
# Toggle start state |
|||
self.is_active = not self.is_active |
|||
# Reset initialization when stopping |
|||
if not self.is_active: |
|||
self._initialized = False |
|||
if keycode == "k": |
|||
print("Resetting teleop policy") |
|||
self.reset() |
|||
|
|||
def activate_policy(self, wait_for_activation: int = 5): |
|||
"""activate the teleop policy""" |
|||
self.is_active = False |
|||
self.check_activation( |
|||
teleop_data={"toggle_activation": True}, wait_for_activation=wait_for_activation |
|||
) |
|||
|
|||
def reset(self, wait_for_activation: int = 5, auto_activate: bool = False): |
|||
"""Reset the teleop policy to the initial state, and re-activate it.""" |
|||
self.teleop_streamer.reset() |
|||
self.retargeting_ik.reset() |
|||
self.is_active = False |
|||
self.latest_left_wrist_data = np.eye(4) |
|||
self.latest_right_wrist_data = np.eye(4) |
|||
self.latest_left_fingers_data = {"position": np.zeros((25, 4, 4))} |
|||
self.latest_right_fingers_data = {"position": np.zeros((25, 4, 4))} |
|||
|
|||
if auto_activate: |
|||
self.activate_policy(wait_for_activation) |
|||
@ -0,0 +1,65 @@ |
|||
import os |
|||
from pathlib import Path |
|||
import time |
|||
|
|||
import numpy as np |
|||
|
|||
import decoupled_wbc |
|||
from decoupled_wbc.control.main.constants import DEFAULT_BASE_HEIGHT, DEFAULT_NAV_CMD |
|||
from decoupled_wbc.control.policy.g1_gear_wbc_policy import G1GearWbcPolicy |
|||
from decoupled_wbc.control.policy.identity_policy import IdentityPolicy |
|||
from decoupled_wbc.control.policy.interpolation_policy import InterpolationPolicy |
|||
|
|||
from .g1_decoupled_whole_body_policy import G1DecoupledWholeBodyPolicy |
|||
|
|||
WBC_VERSIONS = ["gear_wbc"] |
|||
|
|||
|
|||
def get_wbc_policy( |
|||
robot_type, |
|||
robot_model, |
|||
wbc_config, |
|||
init_time=time.monotonic(), |
|||
): |
|||
current_upper_body_pose = robot_model.get_initial_upper_body_pose() |
|||
|
|||
if robot_type == "g1": |
|||
upper_body_policy_type = wbc_config.get("upper_body_policy_type", "interpolation") |
|||
if upper_body_policy_type == "identity": |
|||
upper_body_policy = IdentityPolicy() |
|||
else: |
|||
upper_body_policy = InterpolationPolicy( |
|||
init_time=init_time, |
|||
init_values={ |
|||
"target_upper_body_pose": current_upper_body_pose, |
|||
"base_height_command": np.array([DEFAULT_BASE_HEIGHT]), |
|||
"navigate_cmd": np.array([DEFAULT_NAV_CMD]), |
|||
}, |
|||
max_change_rate=wbc_config["upper_body_max_joint_speed"], |
|||
) |
|||
|
|||
lower_body_policy_type = wbc_config.get("VERSION", "gear_wbc") |
|||
if lower_body_policy_type not in ["gear_wbc"]: |
|||
raise ValueError( |
|||
f"Invalid lower body policy version: {lower_body_policy_type}. " |
|||
f"Only 'gear_wbc' is supported." |
|||
) |
|||
|
|||
# Get the base path to decoupled_wbc and convert to Path object |
|||
package_path = Path(os.path.dirname(decoupled_wbc.__file__)) |
|||
gear_wbc_config = str(package_path / ".." / wbc_config["GEAR_WBC_CONFIG"]) |
|||
if lower_body_policy_type == "gear_wbc": |
|||
lower_body_policy = G1GearWbcPolicy( |
|||
robot_model=robot_model, |
|||
config=gear_wbc_config, |
|||
model_path=wbc_config["model_path"], |
|||
) |
|||
|
|||
wbc_policy = G1DecoupledWholeBodyPolicy( |
|||
robot_model=robot_model, |
|||
upper_body_policy=upper_body_policy, |
|||
lower_body_policy=lower_body_policy, |
|||
) |
|||
else: |
|||
raise ValueError(f"Invalid robot type: {robot_type}") |
|||
return wbc_policy |
|||
@ -0,0 +1,62 @@ |
|||
import os |
|||
from pathlib import Path |
|||
from typing import Literal |
|||
|
|||
from decoupled_wbc.control.robot_model.robot_model import RobotModel |
|||
from decoupled_wbc.control.robot_model.supplemental_info.g1.g1_supplemental_info import ( |
|||
ElbowPose, |
|||
G1SupplementalInfo, |
|||
WaistLocation, |
|||
) |
|||
|
|||
|
|||
def instantiate_g1_robot_model( |
|||
waist_location: Literal["lower_body", "upper_body", "lower_and_upper_body"] = "lower_body", |
|||
high_elbow_pose: bool = False, |
|||
): |
|||
""" |
|||
Instantiate a G1 robot model with configurable waist location and pose. |
|||
|
|||
Args: |
|||
waist_location: Whether to put waist in "lower_body" (default G1 behavior), |
|||
"upper_body" (waist controlled with arms/manipulation via IK), |
|||
or "lower_and_upper_body" (waist reference from arms/manipulation |
|||
via IK then passed to lower body policy) |
|||
high_elbow_pose: Whether to use high elbow pose configuration for default joint positions |
|||
|
|||
Returns: |
|||
RobotModel: Configured G1 robot model |
|||
""" |
|||
project_root = Path(__file__).resolve().parent.parent.parent.parent.parent |
|||
robot_model_config = { |
|||
"asset_path": os.path.join(project_root, "decoupled_wbc/control/robot_model/model_data/g1"), |
|||
"urdf_path": os.path.join( |
|||
project_root, "decoupled_wbc/control/robot_model/model_data/g1/g1_29dof_with_hand.urdf" |
|||
), |
|||
} |
|||
assert waist_location in [ |
|||
"lower_body", |
|||
"upper_body", |
|||
"lower_and_upper_body", |
|||
], f"Invalid waist_location: {waist_location}. Must be 'lower_body' or 'upper_body' or 'lower_and_upper_body'" |
|||
|
|||
# Map string values to enums |
|||
waist_location_enum = { |
|||
"lower_body": WaistLocation.LOWER_BODY, |
|||
"upper_body": WaistLocation.UPPER_BODY, |
|||
"lower_and_upper_body": WaistLocation.LOWER_AND_UPPER_BODY, |
|||
}[waist_location] |
|||
|
|||
elbow_pose_enum = ElbowPose.HIGH if high_elbow_pose else ElbowPose.LOW |
|||
|
|||
# Create single configurable supplemental info instance |
|||
robot_model_supplemental_info = G1SupplementalInfo( |
|||
waist_location=waist_location_enum, elbow_pose=elbow_pose_enum |
|||
) |
|||
|
|||
robot_model = RobotModel( |
|||
robot_model_config["urdf_path"], |
|||
robot_model_config["asset_path"], |
|||
supplemental_info=robot_model_supplemental_info, |
|||
) |
|||
return robot_model |
|||
Some files were not shown because too many files changed in this diff
Write
Preview
Loading…
Cancel
Save
Reference in new issue