diff --git a/rclcpp/include/rclcpp/callback_group.hpp b/rclcpp/include/rclcpp/callback_group.hpp index 63ebca8..47ea417 100644 --- a/rclcpp/include/rclcpp/callback_group.hpp +++ b/rclcpp/include/rclcpp/callback_group.hpp @@ -66,7 +66,7 @@ public: get_service_ptrs() const; RCLCPP_PUBLIC - const std::vector & + const std::vector & get_client_ptrs() const; RCLCPP_PUBLIC @@ -100,7 +100,7 @@ private: std::vector subscription_ptrs_; std::vector timer_ptrs_; std::vector service_ptrs_; - std::vector client_ptrs_; + std::vector client_ptrs_; std::atomic_bool can_be_taken_from_; }; diff --git a/rclcpp/include/rclcpp/client.hpp b/rclcpp/include/rclcpp/client.hpp index 2eaed00..3d27cb8 100644 --- a/rclcpp/include/rclcpp/client.hpp +++ b/rclcpp/include/rclcpp/client.hpp @@ -60,7 +60,7 @@ public: virtual std::shared_ptr create_response() = 0; virtual std::shared_ptr create_request_header() = 0; virtual void handle_response( - std::shared_ptr & request_header, std::shared_ptr & response) = 0; + std::shared_ptr request_header, std::shared_ptr response) = 0; private: RCLCPP_DISABLE_COPY(ClientBase); @@ -111,13 +111,17 @@ public: return std::shared_ptr(new rmw_request_id_t); } - void handle_response(std::shared_ptr & request_header, std::shared_ptr & response) + void handle_response(std::shared_ptr request_header, std::shared_ptr response) { + std::lock_guard lock(pending_requests_mutex_); auto typed_request_header = std::static_pointer_cast(request_header); auto typed_response = std::static_pointer_cast(response); int64_t sequence_number = typed_request_header->sequence_number; - // TODO(esteve) this must check if the sequence_number is valid otherwise the - // call_promise will be null + // TODO(esteve) this should throw instead since it is not expected to happen in the first place + if (this->pending_requests_.count(sequence_number) == 0) { + fprintf(stderr, "Received invalid sequence number. Ignoring...\n"); + return; + } auto tuple = this->pending_requests_[sequence_number]; auto call_promise = std::get<0>(tuple); auto callback = std::get<1>(tuple); @@ -143,6 +147,7 @@ public: > SharedFuture async_send_request(SharedRequest request, CallbackT && cb) { + std::lock_guard lock(pending_requests_mutex_); int64_t sequence_number; if (RMW_RET_OK != rmw_send_request(get_client_handle(), request.get(), &sequence_number)) { // *INDENT-OFF* (prevent uncrustify from making unecessary indents here) @@ -187,6 +192,7 @@ private: RCLCPP_DISABLE_COPY(Client); std::map> pending_requests_; + std::mutex pending_requests_mutex_; }; } // namespace client diff --git a/rclcpp/include/rclcpp/strategies/allocator_memory_strategy.hpp b/rclcpp/include/rclcpp/strategies/allocator_memory_strategy.hpp index 6d14aca..d4d4975 100644 --- a/rclcpp/include/rclcpp/strategies/allocator_memory_strategy.hpp +++ b/rclcpp/include/rclcpp/strategies/allocator_memory_strategy.hpp @@ -151,7 +151,8 @@ public: services_.push_back(service); } } - for (auto & client : group->get_client_ptrs()) { + for (auto & weak_client : group->get_client_ptrs()) { + auto client = weak_client.lock(); if (client) { clients_.push_back(client); } diff --git a/rclcpp/src/rclcpp/callback_group.cpp b/rclcpp/src/rclcpp/callback_group.cpp index 2555973..6d728a6 100644 --- a/rclcpp/src/rclcpp/callback_group.cpp +++ b/rclcpp/src/rclcpp/callback_group.cpp @@ -41,7 +41,7 @@ CallbackGroup::get_service_ptrs() const return service_ptrs_; } -const std::vector & +const std::vector & CallbackGroup::get_client_ptrs() const { return client_ptrs_; diff --git a/rclcpp/src/rclcpp/memory_strategy.cpp b/rclcpp/src/rclcpp/memory_strategy.cpp index 28c4bfd..a196fa3 100644 --- a/rclcpp/src/rclcpp/memory_strategy.cpp +++ b/rclcpp/src/rclcpp/memory_strategy.cpp @@ -84,8 +84,9 @@ MemoryStrategy::get_client_by_handle(void * client_handle, const WeakNodeVector if (!group) { continue; } - for (auto & client : group->get_client_ptrs()) { - if (client->get_client_handle()->data == client_handle) { + for (auto & weak_client : group->get_client_ptrs()) { + auto client = weak_client.lock(); + if (client && client->get_client_handle()->data == client_handle) { return client; } } @@ -182,8 +183,9 @@ MemoryStrategy::get_group_by_client( if (!group) { continue; } - for (auto & cli : group->get_client_ptrs()) { - if (cli == client) { + for (auto & weak_client : group->get_client_ptrs()) { + auto cli = weak_client.lock(); + if (cli && cli == client) { return group; } }