You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
201 lines
5.8 KiB
201 lines
5.8 KiB
import base64
|
|
import signal
|
|
import threading
|
|
from typing import Optional
|
|
|
|
import msgpack
|
|
import msgpack_numpy as mnp
|
|
import rclpy
|
|
from rclpy.executors import SingleThreadedExecutor
|
|
from rclpy.node import Node
|
|
from sensor_msgs.msg import Image
|
|
from std_msgs.msg import ByteMultiArray
|
|
from std_srvs.srv import Trigger
|
|
|
|
_signal_registered = False
|
|
|
|
|
|
def register_keyboard_interrupt_handler():
|
|
"""
|
|
Register a signal handler for SIGINT (Ctrl+C) and SIGTERM that raises KeyboardInterrupt.
|
|
This ensures consistent exception handling across different termination signals.
|
|
"""
|
|
global _signal_registered
|
|
if not _signal_registered:
|
|
|
|
def signal_handler(signum, frame):
|
|
raise KeyboardInterrupt
|
|
|
|
signal.signal(signal.SIGINT, signal_handler)
|
|
signal.signal(signal.SIGTERM, signal_handler)
|
|
_signal_registered = True
|
|
|
|
|
|
class ROSManager:
|
|
"""
|
|
Manages the ROS2 node and executor.
|
|
|
|
Usage example:
|
|
```python
|
|
def main():
|
|
ros_manager = ROSManager()
|
|
node = ros_manager.node
|
|
|
|
try:
|
|
while ros_manager.ok():
|
|
time.sleep(0.1)
|
|
except ros_manager.exceptions() as e:
|
|
print(f"ROSManager interrupted by user: {e}")
|
|
finally:
|
|
ros_manager.shutdown()
|
|
```
|
|
"""
|
|
|
|
def __init__(self, node_name: str = "ros_manager"):
|
|
if not rclpy.ok():
|
|
rclpy.init()
|
|
self.node = rclpy.create_node(node_name)
|
|
self.thread = threading.Thread(target=rclpy.spin, args=(self.node,), daemon=True)
|
|
self.thread.start()
|
|
else:
|
|
executor = rclpy.get_global_executor()
|
|
if len(executor.get_nodes()) > 0:
|
|
self.node = executor.get_nodes()[0]
|
|
else:
|
|
self.node = rclpy.create_node(node_name)
|
|
|
|
register_keyboard_interrupt_handler()
|
|
|
|
@staticmethod
|
|
def ok():
|
|
return rclpy.ok()
|
|
|
|
@staticmethod
|
|
def shutdown():
|
|
if rclpy.ok():
|
|
rclpy.shutdown()
|
|
|
|
@staticmethod
|
|
def exceptions():
|
|
return (rclpy.exceptions.ROSInterruptException, KeyboardInterrupt)
|
|
|
|
|
|
class ROSMsgPublisher:
|
|
"""
|
|
Publishes any serializable dict to a topic.
|
|
"""
|
|
|
|
def __init__(self, topic_name: str):
|
|
ros_manager = ROSManager()
|
|
self.node = ros_manager.node
|
|
self.publisher = self.node.create_publisher(ByteMultiArray, topic_name, 1)
|
|
|
|
def publish(self, msg: dict):
|
|
payload = msgpack.packb(msg, default=mnp.encode)
|
|
payload = tuple(bytes([a]) for a in payload)
|
|
msg = ByteMultiArray()
|
|
msg.data = payload
|
|
self.publisher.publish(msg)
|
|
|
|
|
|
class ROSMsgSubscriber:
|
|
"""
|
|
Subscribes to any topics published by a ROSMsgPublisher.
|
|
"""
|
|
|
|
def __init__(self, topic_name: str):
|
|
ros_manager = ROSManager()
|
|
self.node = ros_manager.node
|
|
self._msg = None
|
|
self.subscription = self.node.create_subscription(
|
|
ByteMultiArray, topic_name, self._callback, 1
|
|
)
|
|
|
|
def _callback(self, msg: ByteMultiArray):
|
|
self._msg = msg
|
|
|
|
def get_msg(self) -> Optional[dict]:
|
|
msg = self._msg
|
|
if msg is None:
|
|
return None
|
|
self._msg = None
|
|
return msgpack.unpackb(bytes([ab for a in msg.data for ab in a]), object_hook=mnp.decode)
|
|
|
|
|
|
class ROSImgMsgSubscriber:
|
|
"""
|
|
Subscribes to an `Image` topic and returns the image as a numpy array and timestamp.
|
|
"""
|
|
|
|
def __init__(self, topic_name: str):
|
|
ros_manager = ROSManager()
|
|
self.node = ros_manager.node
|
|
self._msg = None
|
|
self.subscription = self.node.create_subscription(Image, topic_name, self._callback, 1)
|
|
|
|
from gr00t_wbc.control.utils.cv_bridge import CvBridge
|
|
|
|
self.bridge = CvBridge()
|
|
|
|
def _callback(self, msg: Image):
|
|
self._msg = msg
|
|
|
|
def get_image(self) -> Optional[dict]:
|
|
"""
|
|
Returns the image as a numpy array and the timestamp.
|
|
"""
|
|
|
|
msg = self._msg
|
|
if msg is None:
|
|
return None
|
|
return {
|
|
"image": self.bridge.imgmsg_to_cv2(msg),
|
|
"timestamp": msg.header.stamp.sec + msg.header.stamp.nanosec * 1e-9,
|
|
}
|
|
|
|
|
|
class ROSServiceServer:
|
|
"""
|
|
Generic ROS2 Service server that stores and serves a config dict.
|
|
"""
|
|
|
|
def __init__(self, service_name: str, config: dict):
|
|
ros_manager = ROSManager()
|
|
self.node = ros_manager.node
|
|
packed = msgpack.packb(config, default=mnp.encode)
|
|
self.message = base64.b64encode(packed).decode("ascii")
|
|
self.server = self.node.create_service(Trigger, service_name, self._callback)
|
|
|
|
def _callback(self, request, response):
|
|
try:
|
|
response.success = True
|
|
response.message = self.message
|
|
print("Sending encoded message of length:", len(response.message))
|
|
except Exception as e:
|
|
response.success = False
|
|
response.message = str(e)
|
|
return response
|
|
|
|
|
|
class ROSServiceClient(Node):
|
|
|
|
def __init__(self, service_name: str, node_name: str = "service_client"):
|
|
super().__init__(node_name)
|
|
self.cli = self.create_client(Trigger, service_name)
|
|
while not self.cli.wait_for_service(timeout_sec=1.0):
|
|
self.get_logger().info("service not available, waiting again...")
|
|
self.req = Trigger.Request()
|
|
|
|
def get_config(self):
|
|
future = self.cli.call_async(self.req)
|
|
executor = SingleThreadedExecutor()
|
|
executor.add_node(self)
|
|
executor.spin_until_future_complete(future, timeout_sec=1.0)
|
|
executor.remove_node(self)
|
|
executor.shutdown()
|
|
result = future.result()
|
|
if result.success:
|
|
decoded = base64.b64decode(result.message.encode("ascii"))
|
|
return msgpack.unpackb(decoded, object_hook=mnp.decode)
|
|
else:
|
|
raise RuntimeError(f"Service call failed: {result.message}")
|