diff --git a/README.md b/README.md index 9058b1e..b8ebb7c 100644 --- a/README.md +++ b/README.md @@ -134,6 +134,14 @@ The `inference_pkg_launch.py`, included in this package, provides an example dem |`load_model`|`LoadModelSrv`|Service that is responsible for setting pre-processing algorithm and inference tasks for the specific type of model loaded.| |`inference_state`|`InferenceStateSrv`|Service that is responsible for starting and stopping inference tasks.| + +### Parameters + +| Parameter name | Description | +| ---------------- | ----------- | +| `device` | String that is either `CPU`, `GPU` or `MYRIAD`. Default is `CPU`. `MYRIAD` is the Intel Compute Stick 2. | + + ## Resources * [Getting started with AWS DeepRacer OpenSource](https://github.com/aws-deepracer/aws-deepracer-launcher/blob/main/getting-started.md) diff --git a/inference_pkg/CMakeLists.txt b/inference_pkg/CMakeLists.txt index 72aca35..2991e4d 100644 --- a/inference_pkg/CMakeLists.txt +++ b/inference_pkg/CMakeLists.txt @@ -1,5 +1,8 @@ cmake_minimum_required(VERSION 3.5) project(inference_pkg) +include(FetchContent) + +set(ABSL_PROPAGATE_CXX_STD ON) # Default to C99 if(NOT CMAKE_C_STANDARD) @@ -12,9 +15,20 @@ if(NOT CMAKE_CXX_STANDARD) endif() if(CMAKE_COMPILER_IS_GNUCXX OR CMAKE_CXX_COMPILER_ID MATCHES "Clang") - add_compile_options(-Wall -Wextra -Wpedantic) + add_compile_options(-Wno-deprecated-declarations -Wno-ignored-attributes) endif() +FetchContent_Declare(tensorflow-lite + GIT_REPOSITORY https://github.com/tensorflow/tensorflow.git +) +FetchContent_Populate(tensorflow-lite) + +add_subdirectory( + ${tensorflow-lite_SOURCE_DIR}/tensorflow/lite + ${tensorflow-lite_BINARY_DIR} + EXCLUDE_FROM_ALL +) + # find dependencies find_package(ament_cmake REQUIRED) find_package(rclcpp REQUIRED) @@ -44,6 +58,7 @@ endif() add_executable(inference_node src/inference_node.cpp + src/tflite_inference_eng.cpp src/intel_inference_eng.cpp src/image_process.cpp ) @@ -52,12 +67,16 @@ target_include_directories(inference_node PRIVATE include ${OpenCV_INCLUDE_DIRS} ${InferenceEngine_INCLUDE_DIRS} + ${CMAKE_CURRENT_BINARY_DIR}/flatbuffers/include + $ ) target_link_libraries(inference_node -lm -ldl ${OpenCV_LIBRARIES} + tensorflow-lite ${InferenceEngine_LIBRARIES} - ${NGRAPH_LIBRARIES}) + ${NGRAPH_LIBRARIES} +) ament_target_dependencies(inference_node rclcpp deepracer_interfaces_pkg sensor_msgs std_msgs cv_bridge image_transport OpenCV InferenceEngine ngraph) diff --git a/inference_pkg/include/inference_pkg/image_process.hpp b/inference_pkg/include/inference_pkg/image_process.hpp index 33bce8d..b199b4c 100644 --- a/inference_pkg/include/inference_pkg/image_process.hpp +++ b/inference_pkg/include/inference_pkg/image_process.hpp @@ -19,6 +19,7 @@ #include "rclcpp/rclcpp.hpp" #include "sensor_msgs/msg/image.hpp" +#include "sensor_msgs/msg/compressed_image.hpp" #include "cv_bridge/cv_bridge.h" #include @@ -33,9 +34,9 @@ namespace InferTask { /// @param frameData ROS message containing the image data. /// @param retImg Open CV Mat object that will be used to store the post processed image /// @param params Hash map containing relevant pre-processing parameters - virtual void processImage(const sensor_msgs::msg::Image &frameData, cv::Mat& retImg, + virtual void processImage(const sensor_msgs::msg::CompressedImage &frameData, cv::Mat& retImg, const std::unordered_map ¶ms) = 0; - virtual void processImageVec(const std::vector &frameDataArr, cv::Mat& retImg, + virtual void processImageVec(const std::vector &frameDataArr, cv::Mat& retImg, const std::unordered_map ¶ms) = 0; /// Resets the image processing algorithms data if any. virtual void reset() = 0; @@ -49,9 +50,9 @@ namespace InferTask { public: RGB() = default; virtual ~RGB() = default; - virtual void processImage(const sensor_msgs::msg::Image &frameData, cv::Mat& retImg, + virtual void processImage(const sensor_msgs::msg::CompressedImage &frameData, cv::Mat& retImg, const std::unordered_map ¶ms) override; - virtual void processImageVec(const std::vector &frameDataArr, cv::Mat& retImg, + virtual void processImageVec(const std::vector &frameDataArr, cv::Mat& retImg, const std::unordered_map ¶ms) override {(void)frameDataArr;(void)retImg;(void)params;} virtual void reset() override {} virtual const std::string getEncode() const; @@ -67,9 +68,9 @@ namespace InferTask { /// @param isMask True if background masking should be performed on the image. Grey(bool isThreshold, bool isMask); virtual ~Grey() = default; - virtual void processImage(const sensor_msgs::msg::Image &frameData, cv::Mat& retImg, + virtual void processImage(const sensor_msgs::msg::CompressedImage &frameData, cv::Mat& retImg, const std::unordered_map ¶ms) override; - virtual void processImageVec(const std::vector &frameDataArr, cv::Mat& retImg, + virtual void processImageVec(const std::vector &frameDataArr, cv::Mat& retImg, const std::unordered_map ¶ms); virtual void reset() override; virtual const std::string getEncode() const; @@ -91,9 +92,9 @@ namespace InferTask { public: GreyDiff() = default; virtual ~GreyDiff() = default; - virtual void processImage(const sensor_msgs::msg::Image &frameData, cv::Mat& retImg, + virtual void processImage(const sensor_msgs::msg::CompressedImage &frameData, cv::Mat& retImg, const std::unordered_map ¶ms) override; - virtual void processImageVec(const std::vector &frameDataArr, cv::Mat& retImg, + virtual void processImageVec(const std::vector &frameDataArr, cv::Mat& retImg, const std::unordered_map ¶ms) override {(void)frameDataArr;(void)retImg;(void)params;} virtual void reset() override; virtual const std::string getEncode() const; diff --git a/inference_pkg/include/inference_pkg/inference_base.hpp b/inference_pkg/include/inference_pkg/inference_base.hpp index 0dd4208..6cddcb1 100644 --- a/inference_pkg/include/inference_pkg/inference_base.hpp +++ b/inference_pkg/include/inference_pkg/inference_base.hpp @@ -32,8 +32,10 @@ namespace InferTask { /// @returns True if model loaded successfully, false otherwise /// @param artifactPath Path to the model artifact. /// @param imgProcess Pointer to the image processing algorithm + /// @param device Reference to the compute device (CPU, GPU, MYRIAD) virtual bool loadModel(const char* artifactPath, - std::shared_ptr imgProcess) = 0; + std::shared_ptr imgProcess, + std::string device) = 0; /// Starts the inference task until stopped. virtual void startInference() = 0; /// Stops the inference task if running. diff --git a/inference_pkg/include/inference_pkg/intel_inference_eng.hpp b/inference_pkg/include/inference_pkg/intel_inference_eng.hpp index 47738f0..4295950 100644 --- a/inference_pkg/include/inference_pkg/intel_inference_eng.hpp +++ b/inference_pkg/include/inference_pkg/intel_inference_eng.hpp @@ -34,7 +34,8 @@ namespace IntelInferenceEngine { RLInferenceModel(std::shared_ptr inferenceNodePtr, const std::string &sensorSubName); virtual ~RLInferenceModel(); virtual bool loadModel(const char* artifactPath, - std::shared_ptr imgProcess) override; + std::shared_ptr imgProcess, + std::string device) override; virtual void startInference() override; virtual void stopInference() override; /// Callback method to retrieve sensor data. diff --git a/inference_pkg/include/inference_pkg/tflite_inference_eng.hpp b/inference_pkg/include/inference_pkg/tflite_inference_eng.hpp new file mode 100644 index 0000000..1aa5106 --- /dev/null +++ b/inference_pkg/include/inference_pkg/tflite_inference_eng.hpp @@ -0,0 +1,73 @@ +/////////////////////////////////////////////////////////////////////////////////// +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // +// // +// Licensed under the Apache License, Version 2.0 (the "License"). // +// You may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +/////////////////////////////////////////////////////////////////////////////////// + +#ifndef TFLITE_INFERENCE_ENG_HPP +#define TFLITE_INFERENCE_ENG_HPP + +#include "inference_pkg/inference_base.hpp" +#include "tensorflow/lite/interpreter.h" +#include "tensorflow/lite/kernels/register.h" +#include "tensorflow/lite/model.h" +#include "deepracer_interfaces_pkg/msg/evo_sensor_msg.hpp" +#include "deepracer_interfaces_pkg/msg/infer_results_array.hpp" +#include + +namespace TFLiteInferenceEngine { + class RLInferenceModel : public InferTask::InferenceBase + { + /// Concrete inference task class for running reinforcement learning models + /// on the GPU. + public: + /// @param node_name Name of the node to be created. + /// @param subName Name of the topic to subscribe to for sensor data. + RLInferenceModel(std::shared_ptr inferenceNodePtr, const std::string &sensorSubName); + virtual ~RLInferenceModel(); + virtual bool loadModel(const char* artifactPath, + std::shared_ptr imgProcess, + std::string device) override; + virtual void startInference() override; + virtual void stopInference() override; + /// Callback method to retrieve sensor data. + /// @param msg Message returned by the ROS messaging system. + void sensorCB(const deepracer_interfaces_pkg::msg::EvoSensorMsg::SharedPtr msg); + + private: + /// Inference node object + std::shared_ptr inferenceNode; + /// ROS subscriber object to the desired sensor topic. + rclcpp::Subscription::SharedPtr sensorSub_; + /// ROS publisher object to the desired topic. + rclcpp::Publisher::SharedPtr resultPub_; + /// Pointer to image processing algorithm. + std::shared_ptr imgProcess_; + /// Inference state variable. + std::atomic doInference_; + /// Neural network Inference engine core object. + std::unique_ptr model_; + /// Inference request object + std::unique_ptr interpreter_; + /// Vector of hash map that stores all relevant pre-processing parameters for each input head. + std::vector> paramsArr_; + /// Vector of names of the input heads + std::vector inputNamesArr_; + /// Name of the output layer + std::string outputName_; + std::vector> outputDimsArr_; + std::vector output_tensors_; + + }; +} +#endif \ No newline at end of file diff --git a/inference_pkg/src/image_process.cpp b/inference_pkg/src/image_process.cpp index 814c8a7..74e4e9d 100644 --- a/inference_pkg/src/image_process.cpp +++ b/inference_pkg/src/image_process.cpp @@ -29,7 +29,7 @@ namespace { /// @param frameData ROS image message containing the image data. /// @param retImg Reference to CV object to be populated the with resized image. /// @param params Hash map containing resize information. - bool cvtToCVObjResize (const sensor_msgs::msg::Image &frameData, cv::Mat &retImg, + bool cvtToCVObjResize (const sensor_msgs::msg::CompressedImage &frameData, cv::Mat &retImg, const std::unordered_map ¶ms) { cv_bridge::CvImagePtr cvPtr; @@ -121,7 +121,7 @@ namespace { } namespace InferTask { - void RGB::processImage(const sensor_msgs::msg::Image &frameData, cv::Mat &retImg, + void RGB::processImage(const sensor_msgs::msg::CompressedImage &frameData, cv::Mat &retImg, const std::unordered_map ¶ms) { cvtToCVObjResize(frameData, retImg, params); } @@ -137,7 +137,7 @@ namespace InferTask { } - void Grey::processImage(const sensor_msgs::msg::Image &frameData, cv::Mat &retImg, + void Grey::processImage(const sensor_msgs::msg::CompressedImage &frameData, cv::Mat &retImg, const std::unordered_map ¶ms) { cv::Mat currImg; if (cvtToCVObjResize(frameData, currImg, params)) { @@ -160,7 +160,7 @@ namespace InferTask { } } - void Grey::processImageVec(const std::vector &frameDataArr, cv::Mat &retImg, + void Grey::processImageVec(const std::vector &frameDataArr, cv::Mat &retImg, const std::unordered_map ¶ms) { // Left camera image is sent as the top image and the right camera image is sent as second in the vector. // Stack operation replaces the beginning values as we loop through and hence we loop in decreasing order @@ -188,7 +188,7 @@ namespace InferTask { return sensor_msgs::image_encodings::MONO8; } - void GreyDiff::processImage(const sensor_msgs::msg::Image &frameData, cv::Mat &retImg, + void GreyDiff::processImage(const sensor_msgs::msg::CompressedImage &frameData, cv::Mat &retImg, const std::unordered_map ¶ms) { (void)retImg; cv::Mat currImg; diff --git a/inference_pkg/src/inference_node.cpp b/inference_pkg/src/inference_node.cpp index 56e71a5..d59636d 100644 --- a/inference_pkg/src/inference_node.cpp +++ b/inference_pkg/src/inference_node.cpp @@ -15,6 +15,8 @@ /////////////////////////////////////////////////////////////////////////////////// #include "inference_pkg/intel_inference_eng.hpp" +#include "inference_pkg/tflite_inference_eng.hpp" +#include "std_msgs/msg/string.hpp" #include "deepracer_interfaces_pkg/srv/inference_state_srv.hpp" #include "deepracer_interfaces_pkg/srv/load_model_srv.hpp" @@ -41,11 +43,25 @@ namespace InferTask { /// Class that will manage the inference task. In particular it will start and stop the /// inference tasks and feed the inference task the sensor data. /// @param nodeName Reference to the string containing name of the node. + /// @param device Reference to the compute device (CPU, GPU, MYRIAD) public: + const char* MODEL_ARTIFACT_TOPIC = "model_artifact"; + InferenceNodeMgr(const std::string & nodeName) - : Node(nodeName) + : Node(nodeName), + deviceName_("CPU"), + inferenceEngine_("TFLITE") { RCLCPP_INFO(this->get_logger(), "%s started", nodeName.c_str()); + + this->declare_parameter("device", deviceName_); + // Device name; OpenVINO supports CPU, GPU and MYRIAD + deviceName_ = this->get_parameter("device").as_string(); + + this->declare_parameter("inference_engine", inferenceEngine_); + // Inference Engine name; TFLITE or OPENVINO + inferenceEngine_ = this->get_parameter("inference_engine").as_string(); + loadModelServiceCbGrp_ = this->create_callback_group(rclcpp::callback_group::CallbackGroupType::MutuallyExclusive); loadModelService_ = this->create_service("load_model", std::bind(&InferTask::InferenceNodeMgr::LoadModelHdl, @@ -66,6 +82,9 @@ namespace InferTask { ::rmw_qos_profile_default, setInferenceStateServiceCbGrp_); + // Create a publisher to publish the images to run inference. + modelArtifactPub_ = this->create_publisher(MODEL_ARTIFACT_TOPIC, 1); + // Add all available task and algorithms to these hash maps. taskList_ = { {rlTask, nullptr} }; preProcessList_ = { {rgb, std::make_shared()}, @@ -119,7 +138,12 @@ namespace InferTask { if (itInferTask != taskList_.end() && itPreProcess != preProcessList_.end()) { switch(req->task_type) { case rlTask: - itInferTask->second.reset(new IntelInferenceEngine::RLInferenceModel(this->shared_from_this(), "/sensor_fusion_pkg/sensor_msg")); + if (inferenceEngine_.compare("TFLITE") == 0) { + itInferTask->second.reset(new TFLiteInferenceEngine::RLInferenceModel(this->shared_from_this(), "/sensor_fusion_pkg/sensor_msg")); + } else { + itInferTask->second.reset(new IntelInferenceEngine::RLInferenceModel(this->shared_from_this(), "/sensor_fusion_pkg/sensor_msg")); + } + break; case objDetectTask: //! TODO add onject detection when class is implemented. @@ -129,7 +153,14 @@ namespace InferTask { RCLCPP_ERROR(this->get_logger(), "Unknown inference task"); return; } - itInferTask->second->loadModel(req->artifact_path.c_str(), itPreProcess->second); + + itInferTask->second->loadModel(req->artifact_path.c_str(), itPreProcess->second, deviceName_); + + // Send a message to say we have loaded a model + std_msgs::msg::String modelArtifactMsg; + modelArtifactMsg.data = req->artifact_path; + modelArtifactPub_->publish(modelArtifactMsg); + res->error = 0; } } @@ -149,6 +180,14 @@ namespace InferTask { /// List of available pre-processing algorithms. std::unordered_map> preProcessList_; /// Reference to the node handler. + + /// ROS publisher object to publish the name of a new model. + rclcpp::Publisher::SharedPtr modelArtifactPub_; + + /// Compute device type. + std::string deviceName_; + /// Inference Engine parameter. + std::string inferenceEngine_; }; } diff --git a/inference_pkg/src/intel_inference_eng.cpp b/inference_pkg/src/intel_inference_eng.cpp index e9e2c40..1abf108 100644 --- a/inference_pkg/src/intel_inference_eng.cpp +++ b/inference_pkg/src/intel_inference_eng.cpp @@ -103,7 +103,7 @@ namespace { template void load1DImg(V *inputPtr, cv::Mat &retImg, std::shared_ptr imgProcessPtr, - const sensor_msgs::msg::Image &imgData, + const sensor_msgs::msg::CompressedImage &imgData, const std::unordered_map ¶ms) { imgProcessPtr->processImage(imgData, retImg, params); if (retImg.empty()) { @@ -127,7 +127,7 @@ namespace { template void loadStackImg(V *inputPtr, cv::Mat &retImg, std::shared_ptr imgProcessPtr, - const sensor_msgs::msg::Image &imgData, + const sensor_msgs::msg::CompressedImage &imgData, const std::unordered_map ¶ms) { imgProcessPtr->processImage(imgData, retImg, params); if (retImg.empty()) { @@ -150,7 +150,7 @@ namespace { template void loadStereoImg(V *inputPtr, cv::Mat &retImg, std::shared_ptr imgProcessPtr, - const std::vector &imgDataArr, + const std::vector &imgDataArr, const std::unordered_map ¶ms) { imgProcessPtr->processImageVec(imgDataArr, retImg, params); @@ -201,7 +201,8 @@ namespace IntelInferenceEngine { } bool RLInferenceModel::loadModel(const char* artifactPath, - std::shared_ptr imgProcess) { + std::shared_ptr imgProcess, + std::string device) { if (doInference_) { RCLCPP_ERROR(inferenceNode->get_logger(), "Please stop inference prior to loading a model"); return false; @@ -214,7 +215,7 @@ namespace IntelInferenceEngine { imgProcess_ = imgProcess; // Load the model try { - inferRequest_ = setMultiHeadModel(artifactPath, "CPU", core_, inputNamesArr_, + inferRequest_ = setMultiHeadModel(artifactPath, device, core_, inputNamesArr_, outputName_, InferenceEngine::Precision::FP32, InferenceEngine::Precision::FP32, inferenceNode); for(size_t i = 0; i != inputNamesArr_.size(); ++i) { diff --git a/inference_pkg/src/tflite_inference_eng.cpp b/inference_pkg/src/tflite_inference_eng.cpp new file mode 100644 index 0000000..500ba65 --- /dev/null +++ b/inference_pkg/src/tflite_inference_eng.cpp @@ -0,0 +1,306 @@ +/////////////////////////////////////////////////////////////////////////////////// +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // +// // +// Licensed under the Apache License, Version 2.0 (the "License"). // +// You may not use this file except in compliance with the License. // +// You may obtain a copy of the License at // +// // +// http://www.apache.org/licenses/LICENSE-2.0 // +// // +// Unless required by applicable law or agreed to in writing, software // +// distributed under the License is distributed on an "AS IS" BASIS, // +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // +// See the License for the specific language governing permissions and // +// limitations under the License. // +/////////////////////////////////////////////////////////////////////////////////// + +#include "inference_pkg/tflite_inference_eng.hpp" + +// ROS2 message headers +#include "deepracer_interfaces_pkg/msg/infer_results.hpp" +#include "deepracer_interfaces_pkg/msg/infer_results_array.hpp" + +#include +#define RAD2DEG(x) ((x)*180./M_PI) + +const std::string LIDAR = "LIDAR"; +const std::string STEREO = "STEREO_CAMERAS"; +const std::string FRONT = "FRONT_FACING_CAMERA"; +const std::string OBS = "observation"; +const std::string LEFT = "LEFT_CAMERA"; + + +namespace { + class InferenceExcept : public std::exception + { + /// Simple exception class that is used to send a message to the catch clause. + public: + /// @param msg Message to be logged + InferenceExcept(std::string msg) + : msg_(msg) + { + } + virtual const char* what() const throw() override { + return msg_.c_str(); + } + private: + /// Store message in class so that the what method can dump it when invoked. + const std::string msg_; + }; + + /// Helper method that loads grey images into the inference engine input + /// @param inputPtr Pointer to the input data. + /// @param imgProcessPtr Pointer to the image processing algorithm. + /// @param imgData ROS message containing the image data. + /// @param params Hash map of relevant parameters for image processing. + template void load1DImg(V *inputPtr, + cv::Mat &retImg, + std::shared_ptr imgProcessPtr, + const sensor_msgs::msg::CompressedImage &imgData, + const std::unordered_map ¶ms) { + imgProcessPtr->processImage(imgData, retImg, params); + if (retImg.empty()) { + throw InferenceExcept("No image after pre-process"); + } + int height = retImg.rows; + int width = retImg.cols; + + for (int h = 0; h < height; h++) { + for (int w = 0; w < width; w++) { + inputPtr[h * width + w] = retImg.at(h, w); + } + } + } + + /// Helper method that loads multi channel images into the inference engine input + /// @param inputPtr Pointer to the input data. + /// @param imgProcessPtr Pointer to the image processing algorithm. + /// @param imgData ROS message containing the image data. + /// @param params Hash map of relevant parameters for image processing. + template void loadStackImg(V *inputPtr, + cv::Mat &retImg, + std::shared_ptr imgProcessPtr, + const sensor_msgs::msg::CompressedImage &imgData, + const std::unordered_map ¶ms) { + imgProcessPtr->processImage(imgData, retImg, params); + if (retImg.empty()) { + throw InferenceExcept("No image after-pre process"); + } + const int channelSize = retImg.rows * retImg.cols; + + for (size_t pixelNum = 0; pixelNum < channelSize; ++pixelNum) { + for (size_t ch = 0; ch < retImg.channels(); ++ch) { + inputPtr[(ch*channelSize) + pixelNum] = retImg.at(pixelNum)[ch]; + } + } + } + + /// Helper method that loads multi channel images into the inference engine input + /// @param inputPtr Pointer to the input data. + /// @param imgProcessPtr Pointer to the image processing algorithm. + /// @param imgData ROS message containing the image data. + /// @param params Hash map of relevant parameters for image processing. + template void loadStereoImg(V *inputPtr, + cv::Mat &retImg, + std::shared_ptr imgProcessPtr, + const std::vector &imgDataArr, + const std::unordered_map ¶ms) { + + imgProcessPtr->processImageVec(imgDataArr, retImg, params); + if (retImg.empty()) { + throw InferenceExcept("No image after-pre process"); + } + + const int width = retImg.cols; + const int height = retImg.rows; + const int channel = retImg.channels(); + + for (int c = 0; c < channel; c++) { + for (int h = 0; h < height; h++) { + for (int w = 0; w < width; w++) { + inputPtr[c * width * height + h * width + w] = retImg.at(h, w)[c]; + } + } + } + } + + /// Helper method that loads 1D data into the inference engine input + /// @param inputPtr Pointer to the input data. + /// @param lidarData ROS message containing the lidar data. + void loadLidarData(float *inputPtr, + const std::vector &lidar_data) { + size_t pixelNum = 0; + for(const auto& lidar_value : lidar_data) { + inputPtr[pixelNum] = lidar_value; + ++pixelNum; + } + } +} + +namespace TFLiteInferenceEngine { + RLInferenceModel::RLInferenceModel(std::shared_ptr inferenceNodePtr, const std::string &sensorSubName) + : doInference_(false) + { + inferenceNode = inferenceNodePtr; + RCLCPP_INFO(inferenceNode->get_logger(), "Initializing RL Model"); + RCLCPP_INFO(inferenceNode->get_logger(), "%s", sensorSubName.c_str()); + // Subscribe to the sensor topic and set the call back + sensorSub_ = inferenceNode->create_subscription(sensorSubName, 10, std::bind(&TFLiteInferenceEngine::RLInferenceModel::sensorCB, this, std::placeholders::_1)); + resultPub_ = inferenceNode->create_publisher("rl_results", 1); + } + + RLInferenceModel::~RLInferenceModel() { + stopInference(); + } + + bool RLInferenceModel::loadModel(const char* artifactPath, + std::shared_ptr imgProcess, + std::string device) { + if (doInference_) { + RCLCPP_ERROR(inferenceNode->get_logger(), "Please stop inference prior to loading a model"); + return false; + } + if (!imgProcess) { + RCLCPP_ERROR(inferenceNode->get_logger(), "Invalid image processing algorithm"); + return false; + } + + // Validate the artifact path. + auto strIdx = ((std::string) artifactPath).rfind('.'); + if (strIdx == std::string::npos) { + throw InferenceExcept("Artifact missing file extension"); + } + if (((std::string) artifactPath).substr(strIdx+1) != "tflite") { + throw InferenceExcept("No tflite extension found"); + } + + // Set the image processing algorithms + imgProcess_ = imgProcess; + + // Load the model + try { + + model_ = tflite::FlatBufferModel::BuildFromFile(artifactPath); + + tflite::ops::builtin::BuiltinOpResolver resolver; + tflite::InterpreterBuilder(*model_, resolver)(&interpreter_); + + interpreter_->AllocateTensors(); + + // Determine input and output dimensions + for (auto i : interpreter_->inputs()) + { + auto const *input_tensor = interpreter_->tensor(i); + + auto dims = std::vector{}; + std::copy( + input_tensor->dims->data, input_tensor->dims->data + input_tensor->dims->size, + std::back_inserter(dims)); + + inputNamesArr_.push_back(interpreter_->GetInputName(i)); + + std::unordered_map params_ = {{"width", input_tensor->dims->data[2]}, + {"height", input_tensor->dims->data[1]}, + {"channels", input_tensor->dims->data[0]}}; + paramsArr_.push_back(params_); + + RCLCPP_INFO(inferenceNode->get_logger(), "Input name: %s", input_tensor->name); + RCLCPP_INFO(inferenceNode->get_logger(), "Input dimensions: %i x %i x %i", input_tensor->dims->data[2], input_tensor->dims->data[1], input_tensor->dims->data[0]); + } + + for (auto o : interpreter_->outputs()) + { + auto const *output_tensor = interpreter_->tensor(o); + output_tensors_.push_back(output_tensor); + + auto dims = std::vector{}; + std::copy( + output_tensor->dims->data, output_tensor->dims->data + output_tensor->dims->size, + std::back_inserter(dims)); + + RCLCPP_INFO(inferenceNode->get_logger(), "Output name: %s", output_tensor->name); + + outputDimsArr_.push_back(dims); + } + + } + catch (const std::exception &ex) { + RCLCPP_ERROR(inferenceNode->get_logger(), "Model failed to load: %s", ex.what()); + return false; + } + return true; + } + + void RLInferenceModel::startInference() { + // Reset the image processing algorithm. + if (imgProcess_) { + imgProcess_->reset(); + } + doInference_ = true; + } + + void RLInferenceModel::stopInference() { + doInference_ = false; + } + + void RLInferenceModel::sensorCB(const deepracer_interfaces_pkg::msg::EvoSensorMsg::SharedPtr msg) { + if(!doInference_) { + return; + } + try { + for(size_t i = 0; i < inputNamesArr_.size(); ++i) { + float* inputLayer = interpreter_->typed_input_tensor(i); + + // Object that will hold the data sent to the inference engine post processed. + cv::Mat retData; + if (inputNamesArr_[i].find(STEREO) != std::string::npos) + { + loadStereoImg(inputLayer, retData, imgProcess_, msg->images, paramsArr_[i]); + } + else if (inputNamesArr_[i].find(FRONT) != std::string::npos + || inputNamesArr_[i].find(LEFT) != std::string::npos + || inputNamesArr_[i].find(OBS) != std::string::npos) { + load1DImg(inputLayer, retData, imgProcess_, msg->images.front(), paramsArr_[i]); + } + else if (inputNamesArr_[i].find(LIDAR) != std::string::npos){ + loadLidarData(inputLayer, msg->lidar_data); + } + else { + RCLCPP_ERROR(inferenceNode->get_logger(), "Invalid input head"); + return; + } + imgProcess_->reset(); + } + // Do inference + interpreter_->Invoke(); + + // Last dimension of output is number of classes + auto nClasses = outputDimsArr_[0].back(); + auto * outputData = output_tensors_[0]->data.f; + + auto inferMsg = deepracer_interfaces_pkg::msg::InferResultsArray(); + for (size_t i = 0; i < msg->images.size(); ++i) { + // Send the image data over with the results + inferMsg.images.push_back(msg->images[i]) ; + } + + for (size_t label = 0; label < nClasses; ++label) { + auto inferData = deepracer_interfaces_pkg::msg::InferResults(); + inferData.class_label = label; + inferData.class_prob = outputData[label]; + // Set bounding box data to -1 to indicate to subscribers that this model offers no + // localization information. + inferData.x_min = -1.0; + inferData.y_min = -1.0; + inferData.x_max = -1.0; + inferData.y_max = -1.0; + inferMsg.results.push_back(inferData); + } + // Send results to all subscribers. + resultPub_->publish(inferMsg); + } + catch (const std::exception &ex) { + RCLCPP_ERROR(inferenceNode->get_logger(), "Inference failed %s", ex.what()); + } + } +}