diff --git a/demos/pycram_bullet_world_demo/demo.py b/demos/pycram_bullet_world_demo/demo.py index 16aba3828..12679e57d 100644 --- a/demos/pycram_bullet_world_demo/demo.py +++ b/demos/pycram_bullet_world_demo/demo.py @@ -2,7 +2,7 @@ from pycram.designators.action_designator import * from pycram.designators.location_designator import * from pycram.designators.object_designator import * -from pycram.datastructures.enums import ObjectType, WorldMode +from pycram.datastructures.enums import WorldMode from pycram.datastructures.pose import Pose from pycram.process_module import simulated_robot, with_simulated_robot from pycram.object_descriptors.urdf import ObjectDescription @@ -12,6 +12,7 @@ from pycrap import Robot, Apartment, Milk, Cereal, Spoon, Bowl import numpy as np + np.random.seed(420) extension = ObjectDescription.get_file_extension() diff --git a/src/pycram/datastructures/enums.py b/src/pycram/datastructures/enums.py index 3875ae910..0457d7f7e 100644 --- a/src/pycram/datastructures/enums.py +++ b/src/pycram/datastructures/enums.py @@ -282,3 +282,11 @@ def from_pycram_joint_type(cls, joint_type: JointType) -> 'MultiverseJointCMD': return MultiverseJointCMD.PRISMATIC_JOINT_CMD else: raise UnsupportedJointType(joint_type) + + +class FilterConfig(Enum): + """ + Declare existing filter methods. + Currently supported: Butterworth + """ + butterworth = 1 diff --git a/src/pycram/ros/filter.py b/src/pycram/ros/filter.py new file mode 100644 index 000000000..48097b356 --- /dev/null +++ b/src/pycram/ros/filter.py @@ -0,0 +1,42 @@ +from abc import abstractmethod + +from scipy.signal import butter, lfilter + + +class Filter: + """ + Abstract class to ensure that every supported filter needs to implement the filter method. + + :method filter: Abstract method to filter the given data. + """ + + @abstractmethod + def filter(self, data): + pass + + +class Butterworth(Filter): + """ + Implementation for a Butterworth filter. + + :param order: The order of the filter (default is 4). + :param cutoff: The cutoff frequency of the filter (default is 10). + :param fs: The sampling frequency of the data (default is 60). + """ + + def __init__(self, order=4, cutoff=10, fs=60): + self.order = order + self.cutoff = cutoff + self.fs = fs + + self.b, self.a = butter(self.order, cutoff / (0.5 * fs), btype='low') + + def filter(self, data: list): + """ + Filters the given data using a Butterworth filter. + + :param data: The data to be filtered. + + :return: The filtered data. + """ + return lfilter(self.b, self.a, data) diff --git a/src/pycram/ros_utils/force_torque_sensor.py b/src/pycram/ros_utils/force_torque_sensor.py index 3d98e79cd..e3af68b5b 100644 --- a/src/pycram/ros_utils/force_torque_sensor.py +++ b/src/pycram/ros_utils/force_torque_sensor.py @@ -2,19 +2,24 @@ import time import threading +import rospy from geometry_msgs.msg import WrenchStamped from std_msgs.msg import Header + +from ..datastructures.enums import FilterConfig from ..datastructures.world import World +from ..ros.filter import Butterworth from ..ros.data_types import Time from ..ros.publisher import create_publisher -class ForceTorqueSensor: +class ForceTorqueSensorSimulated: """ Simulated force-torque sensor for a joint with a given name. Reads simulated forces and torques at that joint from world and publishes geometry_msgs/Wrench messages to the given topic. """ + def __init__(self, joint_name, fts_topic="/pycram/fts", interval=0.1): """ The given joint_name has to be part of :py:attr:`~pycram.world.World.robot` otherwise a @@ -77,3 +82,147 @@ def _stop_publishing(self) -> None: """ self.kill_event.set() self.thread.join() + + +class ForceTorqueSensor: + """ + Monitor a force-torque sensor of a supported robot and save relevant data. + + Apply a specified filter and save this data as well. + Default filter is the low pass filter 'Butterworth' + + Can also calculate the derivative of (un-)filtered data + + :param robot_name: Name of the robot + :param filter_config: Desired filter (default: Butterworth) + :param filter_order: Order of the filter. Declares the number of elements that delay the sampling + :param custom_topic: Declare a custom topic if the default topics do not fit + """ + + filtered = 'filtered' + unfiltered = 'unfiltered' + + def __init__(self, robot_name, filter_config=FilterConfig.butterworth, filter_order=4, custom_topic=None, + debug=False): + self.robot_name = robot_name + self.filter_config = filter_config + self.filter = self._get_filter(order=filter_order) + self.debug = debug + + self.wrench_topic_name = custom_topic + self.force_torque_subscriber = None + self.init_data = True + + self.whole_data = None + self.prev_values = None + + self.order = filter_order + + self._setup() + + def _setup(self): + self._get_robot_parameters() + self.subscribe() + + def _get_robot_parameters(self): + if self.wrench_topic_name is not None: + return + + if self.robot_name == 'hsrb': + self.wrench_topic_name = '/hsrb/wrist_wrench/compensated' + + elif self.robot_name == 'iai_donbot': + self.wrench_topic_name = '/kms40_driver/wrench' + else: + rospy.logerr(f'{self.robot_name} is not supported') + + def _get_rospy_data(self, data_compensated: WrenchStamped): + if self.init_data: + self.init_data = False + self.prev_values = [data_compensated] * (self.order + 1) + self.whole_data = {self.unfiltered: [data_compensated], + self.filtered: [data_compensated]} + + filtered_data = self._filter_data(data_compensated) + + self.whole_data[self.unfiltered].append(data_compensated) + self.whole_data[self.filtered].append(filtered_data) + + self.prev_values.append(data_compensated) + self.prev_values.pop(0) + + if self.debug: + rospy.logdebug( + f'x: {data_compensated.wrench.force.x}, ' + f'y: {data_compensated.wrench.force.y}, ' + f'z: {data_compensated.wrench.force.z}') + + def _get_filter(self, order=4, cutoff=10, fs=60): + if self.filter_config == FilterConfig.butterworth: + return Butterworth(order=order, cutoff=cutoff, fs=fs) + + def _filter_data(self, current_wrench_data: WrenchStamped) -> WrenchStamped: + filtered_data = WrenchStamped() + for attr in ['x', 'y', 'z']: + force_values = [getattr(val.wrench.force, attr) for val in self.prev_values] + [ + getattr(current_wrench_data.wrench.force, attr)] + torque_values = [getattr(val.wrench.torque, attr) for val in self.prev_values] + [ + getattr(current_wrench_data.wrench.torque, attr)] + + filtered_force = self.filter.filter(force_values)[-1] + filtered_torque = self.filter.filter(torque_values)[-1] + + setattr(filtered_data.wrench.force, attr, filtered_force) + setattr(filtered_data.wrench.torque, attr, filtered_torque) + + return filtered_data + + def subscribe(self): + """ + Subscribe to the specified wrench topic. + + This will automatically be called on setup. + Only use this if you already unsubscribed before. + """ + + self.force_torque_subscriber = rospy.Subscriber(name=self.wrench_topic_name, + data_class=WrenchStamped, + callback=self._get_rospy_data) + + def unsubscribe(self): + """ + Unsubscribe from the specified topic + """ + self.force_torque_subscriber.unregister() + + def get_last_value(self, is_filtered=True) -> WrenchStamped: + """ + Get the most current data values. + + :param is_filtered: Decides about using filtered or raw data + + :return: A list containing the most current values (newest are first) + """ + status = self.filtered if is_filtered else self.unfiltered + return self.whole_data[status][-1] + + def get_derivative(self, is_filtered=True) -> WrenchStamped: + """ + Calculate the derivative of current data. + + :param is_filtered: Decides about using filtered or raw data + """ + status = self.filtered if is_filtered else self.unfiltered + + before: WrenchStamped = self.whole_data[status][-2] + after: WrenchStamped = self.whole_data[status][-1] + derivative = WrenchStamped() + + derivative.wrench.force.x = before.wrench.force.x - after.wrench.force.x + derivative.wrench.force.y = before.wrench.force.y - after.wrench.force.y + derivative.wrench.force.z = before.wrench.force.z - after.wrench.force.z + derivative.wrench.torque.x = before.wrench.torque.x - after.wrench.torque.x + derivative.wrench.torque.y = before.wrench.torque.y - after.wrench.torque.y + derivative.wrench.torque.z = before.wrench.torque.z - after.wrench.torque.z + + return derivative diff --git a/test/test_butterworth_filter.py b/test/test_butterworth_filter.py new file mode 100644 index 000000000..bcbf9f95f --- /dev/null +++ b/test/test_butterworth_filter.py @@ -0,0 +1,44 @@ +import unittest +from pycram.ros.filter import Butterworth + +class TestButterworthFilter(unittest.TestCase): + + def test_initialization_with_default_values(self): + filter = Butterworth() + self.assertEqual(filter.order, 4) + self.assertEqual(filter.cutoff, 10) + self.assertEqual(filter.fs, 60) + + def test_initialization_with_custom_values(self): + filter = Butterworth(order=2, cutoff=5, fs=30) + self.assertEqual(filter.order, 2) + self.assertEqual(filter.cutoff, 5) + self.assertEqual(filter.fs, 30) + + def test_filter_data_with_default_values(self): + filter = Butterworth() + data = [1, 2, 3, 4, 5] + filtered_data = filter.filter(data) + self.assertEqual(len(filtered_data), len(data)) + + def test_filter_data_with_custom_values(self): + filter = Butterworth(order=2, cutoff=5, fs=30) + data = [1, 2, 3, 4, 5] + filtered_data = filter.filter(data) + self.assertEqual(len(filtered_data), len(data)) + + def test_filter_empty_data(self): + filter = Butterworth() + data = [] + filtered_data = filter.filter(data) + self.assertEqual(filtered_data.tolist(), data) + + def test_filter_single_value_data(self): + filter = Butterworth() + data = [1] + filtered_data = filter.filter(data) + expected_filtered_data = [0.026077721701092293] # The expected filtered value + self.assertEqual(filtered_data.tolist(), expected_filtered_data) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file diff --git a/test/test_force_torque_sensor.py b/test/test_force_torque_sensor.py new file mode 100644 index 000000000..90b61eca2 --- /dev/null +++ b/test/test_force_torque_sensor.py @@ -0,0 +1,48 @@ +import unittest +from unittest.mock import patch, MagicMock +from pycram.ros_utils.force_torque_sensor import ForceTorqueSensor, ForceTorqueSensorSimulated +from pycram.datastructures.enums import FilterConfig +from geometry_msgs.msg import WrenchStamped + +class ForceTorqueSensorTestCase(unittest.TestCase): + + @patch('pycram.ros_utils.force_torque_sensor.World') + @patch('pycram.ros_utils.force_torque_sensor.create_publisher') + def test_initialization_simulated_sensor(self, mock_create_publisher, mock_world): + mock_world.current_world.robot.joint_name_to_id = {'joint1': 1} + sensor = ForceTorqueSensorSimulated('joint1') + self.assertEqual('joint1', 'joint1') + self.assertIsNotNone(sensor.fts_pub) + sensor._stop_publishing() + + @patch('pycram.ros_utils.force_torque_sensor.World') + @patch('pycram.ros_utils.force_torque_sensor.create_publisher') + def test_initialization_simulated_sensor_invalid_joint(self, mock_create_publisher, mock_world): + mock_world.current_world.robot.joint_name_to_id = {} + with self.assertRaises(RuntimeError): + ForceTorqueSensorSimulated('invalid_joint') + + @patch('pycram.ros_utils.force_torque_sensor.rospy') + def test_initialization_force_torque_sensor(self, mock_rospy): + sensor = ForceTorqueSensor('hsrb') + self.assertEqual(sensor.robot_name, 'hsrb') + self.assertEqual(sensor.wrench_topic_name, '/hsrb/wrist_wrench/compensated') + + @patch('pycram.ros_utils.force_torque_sensor.rospy') + def test_get_last_value(self, mock_rospy): + sensor = ForceTorqueSensor('hsrb') + mock_data = MagicMock(spec=WrenchStamped) + sensor.whole_data = {sensor.filtered: [mock_data], sensor.unfiltered: [mock_data]} + self.assertEqual(sensor.get_last_value(), mock_data) + + @patch('pycram.ros_utils.force_torque_sensor.rospy') + def test_get_derivative(self, mock_rospy): + sensor = ForceTorqueSensor('hsrb') + mock_data1 = MagicMock(spec=WrenchStamped) + mock_data2 = MagicMock(spec=WrenchStamped) + sensor.whole_data = {sensor.filtered: [mock_data1, mock_data2], sensor.unfiltered: [mock_data1, mock_data2]} + derivative = sensor.get_derivative() + self.assertIsInstance(derivative, WrenchStamped) + +if __name__ == '__main__': + unittest.main() \ No newline at end of file