diff --git a/tracetools_test/CMakeLists.txt b/tracetools_test/CMakeLists.txt index a7d339e..142ad6f 100644 --- a/tracetools_test/CMakeLists.txt +++ b/tracetools_test/CMakeLists.txt @@ -33,6 +33,13 @@ if(BUILD_TESTING) rclcpp std_msgs ) + add_executable(test_intra + src/test_intra.cpp + ) + ament_target_dependencies(test_intra + rclcpp + std_msgs + ) add_executable(test_ping src/test_ping.cpp ) @@ -76,14 +83,15 @@ if(BUILD_TESTING) ) install(TARGETS - test_publisher - test_subscription + test_intra test_ping test_pong - test_timer + test_publisher test_service test_service_ping test_service_pong + test_subscription + test_timer DESTINATION lib/${PROJECT_NAME} ) @@ -94,13 +102,14 @@ if(BUILD_TESTING) # Run each test in its own pytest invocation set(_tracetools_test_pytest_tests + test/test_intra.py test/test_node.py test/test_publisher.py + test/test_service.py + test/test_service_callback.py test/test_subscription.py test/test_subscription_callback.py test/test_timer.py - test/test_service.py - test/test_service_callback.py ) foreach(_test_path ${_tracetools_test_pytest_tests}) diff --git a/tracetools_test/src/test_intra.cpp b/tracetools_test/src/test_intra.cpp new file mode 100644 index 0000000..d11f81c --- /dev/null +++ b/tracetools_test/src/test_intra.cpp @@ -0,0 +1,92 @@ +// Copyright 2019 Robert Bosch GmbH +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include + +#include "rclcpp/rclcpp.hpp" +#include "std_msgs/msg/string.hpp" + +using namespace std::chrono_literals; + +#define PUB_NODE_NAME "test_intra_pub" +#define SUB_NODE_NAME "test_intra_sub" +#define TOPIC_NAME "the_topic" + +class PubIntraNode : public rclcpp::Node +{ +public: + explicit PubIntraNode(rclcpp::NodeOptions options) + : Node(PUB_NODE_NAME, options) + { + pub_ = this->create_publisher( + TOPIC_NAME, + rclcpp::QoS(10)); + timer_ = this->create_wall_timer( + 500ms, + std::bind(&PubIntraNode::timer_callback, this)); + } + +private: + void timer_callback() + { + auto msg = std::make_shared(); + msg->data = "some random intraprocess string"; + pub_->publish(*msg); + } + + rclcpp::Publisher::SharedPtr pub_; + rclcpp::TimerBase::SharedPtr timer_; +}; + +class SubIntraNode : public rclcpp::Node +{ +public: + explicit SubIntraNode(rclcpp::NodeOptions options) + : Node(SUB_NODE_NAME, options) + { + sub_ = this->create_subscription( + "the_topic", + rclcpp::QoS(10), + std::bind(&SubIntraNode::callback, this, std::placeholders::_1)); + } + +private: + void callback(const std_msgs::msg::String::SharedPtr msg) + { + RCLCPP_INFO(this->get_logger(), "[output] %s", msg->data.c_str()); + rclcpp::shutdown(); + } + + rclcpp::Subscription::SharedPtr sub_; +}; + +int main(int argc, char * argv[]) +{ + rclcpp::init(argc, argv); + + rclcpp::executors::SingleThreadedExecutor exec; + auto pub_intra_node = std::make_shared( + rclcpp::NodeOptions().use_intra_process_comms(true)); + auto sub_intra_node = std::make_shared( + rclcpp::NodeOptions().use_intra_process_comms(true)); + exec.add_node(pub_intra_node); + exec.add_node(sub_intra_node); + + printf("spinning\n"); + exec.spin(); + + // Will actually be called inside the node's callback + rclcpp::shutdown(); + return 0; +} diff --git a/tracetools_test/test/test_intra.py b/tracetools_test/test/test_intra.py new file mode 100644 index 0000000..9733b29 --- /dev/null +++ b/tracetools_test/test/test_intra.py @@ -0,0 +1,127 @@ +# Copyright 2019 Robert Bosch GmbH +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +from tracetools_test.case import TraceTestCase + + +class TestIntra(TraceTestCase): + + def __init__(self, *args) -> None: + super().__init__( + *args, + session_name_prefix='session-test-intra', + events_ros=[ + 'ros2:rcl_subscription_init', + 'ros2:rclcpp_subscription_callback_added', + 'ros2:callback_start', + 'ros2:callback_end', + ], + nodes=['test_intra'] + ) + + def test_all(self): + # Check events order as set (e.g. node_init before pub_init) + self.assertEventsOrderSet(self._events_ros) + + # Check sub_init for normal and intraprocess events + sub_init_events = self.get_events_with_name('ros2:rcl_subscription_init') + sub_init_normal_events = self.get_events_with_field_value( + 'topic_name', + '/the_topic', + sub_init_events) + sub_init_intra_events = self.get_events_with_field_value( + 'topic_name', + '/the_topic/_intra', + sub_init_events) + self.assertEqual( + len(sub_init_normal_events), + 1, + 'none or more than 1 sub init event for normal sub') + self.assertEqual( + len(sub_init_intra_events), + 1, + 'none or more than 1 sub init event for intra sub') + + # Get subscription handles for normal & intra subscriptions + sub_init_normal_event = sub_init_normal_events[0] + # Note: sub handle linked to "normal" topic is the one actually linked to intra callback + sub_handle_intra = self.get_field(sub_init_normal_event, 'subscription_handle') + print(f'sub_handle_intra: {sub_handle_intra}') + + # Get corresponding callback handle + # Callback handle + callback_added_events = self.get_events_with_field_value( + 'subscription_handle', + sub_handle_intra, + self.get_events_with_name( + 'ros2:rclcpp_subscription_callback_added')) + self.assertEqual( + len(callback_added_events), + 1, + 'none or more than 1 callback added event') + callback_added_event = callback_added_events[0] + callback_handle_intra = self.get_field(callback_added_event, 'callback') + + # Get corresponding callback start/end pairs + start_events = self.get_events_with_name('ros2:callback_start') + end_events = self.get_events_with_name('ros2:callback_end') + # Should still have two start:end pairs (1 normal + 1 intra) + self.assertEqual(len(start_events), 2, 'does not have 2 callback start events') + self.assertEqual(len(end_events), 2, 'does not have 2 callback end events') + start_events_intra = self.get_events_with_field_value( + 'callback', + callback_handle_intra, + start_events) + end_events_intra = self.get_events_with_field_value( + 'callback', + callback_handle_intra, + end_events) + self.assertEqual( + len(start_events_intra), + 1, + 'none or more than one intra start event') + self.assertEqual( + len(end_events_intra), + 1, + 'none or more than one intra end event') + + # Check is_intra_process field value + start_event_intra = start_events_intra[0] + is_intra_value_intra = self.get_field(start_event_intra, 'is_intra_process') + self.assertEqual( + is_intra_value_intra, + 1, + 'is_intra_process field value not valid for intra callback') + + # Also check that the other callback_start event (normal one) has the right field value + start_events_not_intra = self.get_events_with_field_not_value( + 'callback', + callback_handle_intra, + start_events) + self.assertEqual( + len(start_events_not_intra), + 1, + 'none or more than one normal start event') + start_event_not_intra = start_events_not_intra[0] + is_intra_value_not_intra = self.get_field(start_event_not_intra, 'is_intra_process') + self.assertEqual( + is_intra_value_not_intra, + 0, + 'is_intra_process field value not valid for normal callback') + + +if __name__ == '__main__': + unittest.main() diff --git a/tracetools_test/tracetools_test/case.py b/tracetools_test/tracetools_test/case.py index 09d980f..100d90f 100644 --- a/tracetools_test/tracetools_test/case.py +++ b/tracetools_test/tracetools_test/case.py @@ -286,6 +286,24 @@ class TraceTestCase(unittest.TestCase): events = self._events return [e for e in events if get_field(e, field_name, None) == field_value] + def get_events_with_field_not_value( + self, + field_name: str, + field_value: Any, + events: List[DictEvent] = None + ) -> List[DictEvent]: + """ + Get all events with the given field but not the value. + + :param field_name: the name of the field to check + :param field_value: the value of the field to check + :param events: the events to check (or None to check all events) + :return: the events with the given field:value pair + """ + if events is None: + events = self._events + return [e for e in events if get_field(e, field_name, None) != field_value] + def are_events_ordered(self, first_event: DictEvent, second_event: DictEvent): """ Check that the first event was generated before the second event.