You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 
 
 
 
 

201 lines
5.9 KiB

import base64
import signal
import threading
from typing import Optional
import msgpack
import msgpack_numpy as mnp
import rclpy
from rclpy.executors import SingleThreadedExecutor
from rclpy.node import Node
from sensor_msgs.msg import Image
from std_msgs.msg import ByteMultiArray
from std_srvs.srv import Trigger
_signal_registered = False
def register_keyboard_interrupt_handler():
"""
Register a signal handler for SIGINT (Ctrl+C) and SIGTERM that raises KeyboardInterrupt.
This ensures consistent exception handling across different termination signals.
"""
global _signal_registered
if not _signal_registered:
def signal_handler(signum, frame):
raise KeyboardInterrupt
signal.signal(signal.SIGINT, signal_handler)
signal.signal(signal.SIGTERM, signal_handler)
_signal_registered = True
class ROSManager:
"""
Manages the ROS2 node and executor.
Usage example:
```python
def main():
ros_manager = ROSManager()
node = ros_manager.node
try:
while ros_manager.ok():
time.sleep(0.1)
except ros_manager.exceptions() as e:
print(f"ROSManager interrupted by user: {e}")
finally:
ros_manager.shutdown()
```
"""
def __init__(self, node_name: str = "ros_manager"):
if not rclpy.ok():
rclpy.init()
self.node = rclpy.create_node(node_name)
self.thread = threading.Thread(target=rclpy.spin, args=(self.node,), daemon=True)
self.thread.start()
else:
executor = rclpy.get_global_executor()
if len(executor.get_nodes()) > 0:
self.node = executor.get_nodes()[0]
else:
self.node = rclpy.create_node(node_name)
register_keyboard_interrupt_handler()
@staticmethod
def ok():
return rclpy.ok()
@staticmethod
def shutdown():
if rclpy.ok():
rclpy.shutdown()
@staticmethod
def exceptions():
return (rclpy.exceptions.ROSInterruptException, KeyboardInterrupt)
class ROSMsgPublisher:
"""
Publishes any serializable dict to a topic.
"""
def __init__(self, topic_name: str):
ros_manager = ROSManager()
self.node = ros_manager.node
self.publisher = self.node.create_publisher(ByteMultiArray, topic_name, 1)
def publish(self, msg: dict):
payload = msgpack.packb(msg, default=mnp.encode)
payload = tuple(bytes([a]) for a in payload)
msg = ByteMultiArray()
msg.data = payload
self.publisher.publish(msg)
class ROSMsgSubscriber:
"""
Subscribes to any topics published by a ROSMsgPublisher.
"""
def __init__(self, topic_name: str):
ros_manager = ROSManager()
self.node = ros_manager.node
self._msg = None
self.subscription = self.node.create_subscription(
ByteMultiArray, topic_name, self._callback, 1
)
def _callback(self, msg: ByteMultiArray):
self._msg = msg
def get_msg(self) -> Optional[dict]:
msg = self._msg
if msg is None:
return None
self._msg = None
return msgpack.unpackb(bytes([ab for a in msg.data for ab in a]), object_hook=mnp.decode)
class ROSImgMsgSubscriber:
"""
Subscribes to an `Image` topic and returns the image as a numpy array and timestamp.
"""
def __init__(self, topic_name: str):
ros_manager = ROSManager()
self.node = ros_manager.node
self._msg = None
self.subscription = self.node.create_subscription(Image, topic_name, self._callback, 1)
from decoupled_wbc.control.utils.cv_bridge import CvBridge
self.bridge = CvBridge()
def _callback(self, msg: Image):
self._msg = msg
def get_image(self) -> Optional[dict]:
"""
Returns the image as a numpy array and the timestamp.
"""
msg = self._msg
if msg is None:
return None
return {
"image": self.bridge.imgmsg_to_cv2(msg),
"timestamp": msg.header.stamp.sec + msg.header.stamp.nanosec * 1e-9,
}
class ROSServiceServer:
"""
Generic ROS2 Service server that stores and serves a config dict.
"""
def __init__(self, service_name: str, config: dict):
ros_manager = ROSManager()
self.node = ros_manager.node
packed = msgpack.packb(config, default=mnp.encode)
self.message = base64.b64encode(packed).decode("ascii")
self.server = self.node.create_service(Trigger, service_name, self._callback)
def _callback(self, request, response):
try:
response.success = True
response.message = self.message
print("Sending encoded message of length:", len(response.message))
except Exception as e:
response.success = False
response.message = str(e)
return response
class ROSServiceClient(Node):
def __init__(self, service_name: str, node_name: str = "service_client"):
super().__init__(node_name)
self.cli = self.create_client(Trigger, service_name)
while not self.cli.wait_for_service(timeout_sec=1.0):
self.get_logger().info("service not available, waiting again...")
self.req = Trigger.Request()
def get_config(self):
future = self.cli.call_async(self.req)
executor = SingleThreadedExecutor()
executor.add_node(self)
executor.spin_until_future_complete(future, timeout_sec=1.0)
executor.remove_node(self)
executor.shutdown()
result = future.result()
if result.success:
decoded = base64.b64decode(result.message.encode("ascii"))
return msgpack.unpackb(decoded, object_hook=mnp.decode)
else:
raise RuntimeError(f"Service call failed: {result.message}")