diff --git a/larrecodnn/NuGraph/CMakeLists.txt b/larrecodnn/NuGraph/CMakeLists.txt index 252e7a3d..c17c882c 100644 --- a/larrecodnn/NuGraph/CMakeLists.txt +++ b/larrecodnn/NuGraph/CMakeLists.txt @@ -1,3 +1,5 @@ +add_subdirectory(Tools) + cet_set_compiler_flags(DIAGS VIGILANT DWARF_VER 4 WERROR NO_UNDEFINED diff --git a/larrecodnn/NuGraph/NuGraphInferenceSonicTriton_module.cc b/larrecodnn/NuGraph/NuGraphInferenceSonicTriton_module.cc index 0f9b8902..8001b984 100644 --- a/larrecodnn/NuGraph/NuGraphInferenceSonicTriton_module.cc +++ b/larrecodnn/NuGraph/NuGraphInferenceSonicTriton_module.cc @@ -10,107 +10,37 @@ #include "art/Framework/Core/EDProducer.h" #include "art/Framework/Core/ModuleMacros.h" #include "art/Framework/Principal/Event.h" -#include "art/Framework/Principal/Handle.h" -#include "art/Framework/Principal/Run.h" -#include "art/Framework/Principal/SubRun.h" -#include "art/Utilities/ToolMacros.h" -#include "canvas/Persistency/Common/FindManyP.h" -#include "canvas/Utilities/InputTag.h" +#include "canvas/Persistency/Common/Ptr.h" #include "fhiclcpp/ParameterSet.h" #include "fhiclcpp/types/Table.h" #include "messagefacility/MessageLogger/MessageLogger.h" -#include #include #include #include "lardataobj/AnalysisBase/MVAOutput.h" #include "lardataobj/RecoBase/Hit.h" -#include "lardataobj/RecoBase/SpacePoint.h" -#include "lardataobj/RecoBase/Vertex.h" //this creates a conflict with torch script if included before it... +#include "lardataobj/RecoBase/Vertex.h" #include "larrecodnn/ImagePatternAlgs/NuSonic/Triton/TritonClient.h" #include "larrecodnn/ImagePatternAlgs/NuSonic/Triton/TritonData.h" -#include -#include +#include "larrecodnn/NuGraph/Tools/DecoderToolBase.h" +#include "larrecodnn/NuGraph/Tools/LoaderToolBase.h" #include #include #include #include #include -#include #include -#include "grpc_client.h" - class NuGraphInferenceSonicTriton; using anab::FeatureVector; using anab::MVADescription; using recob::Hit; -using recob::SpacePoint; -using std::array; using std::vector; -#define FAIL_IF_ERR(X, MSG) \ - { \ - tc::Error err = (X); \ - if (!err.IsOk()) { \ - std::cerr << "error: " << (MSG) << ": " << err << std::endl; \ - exit(1); \ - } \ - } -namespace tc = triton::client; - -namespace { - - template - int arg_max(std::vector const& vec) - { - return static_cast(std::distance(vec.begin(), max_element(vec.begin(), vec.end()))); - } - - template - void softmax(std::array& arr) - { - T m = -std::numeric_limits::max(); - for (size_t i = 0; i < arr.size(); i++) { - if (arr[i] > m) { m = arr[i]; } - } - T sum = 0.0; - for (size_t i = 0; i < arr.size(); i++) { - sum += expf(arr[i] - m); - } - T offset = m + logf(sum); - for (size_t i = 0; i < arr.size(); i++) { - arr[i] = expf(arr[i] - offset); - } - return; - } -} - -// Function to convert string to integer -int stoi(const std::string& str) -{ - std::istringstream iss(str); - int num; - iss >> num; - return num; -} - -// Function to print elements of a vector -void printVector(const std::vector& vec) -{ - for (size_t i = 0; i < vec.size(); ++i) { - std::cout << vec[i]; - // Print space unless it's the last element - if (i != vec.size() - 1) { std::cout << " "; } - } - std::cout << std::endl; - std::cout << std::endl; -} - class NuGraphInferenceSonicTriton : public art::EDProducer { public: explicit NuGraphInferenceSonicTriton(fhicl::ParameterSet const& p); @@ -125,432 +55,120 @@ class NuGraphInferenceSonicTriton : public art::EDProducer { void produce(art::Event& e) override; private: - vector planes; - art::InputTag hitInput; - art::InputTag spsInput; size_t minHits; bool debug; - // vector> avgs; - // vector> devs; - bool filterDecoder; - bool semanticDecoder; - bool vertexDecoder; - std::string inference_url; - std::string inference_model_name; - bool inference_ssl; - std::string ssl_root_certificates; - std::string ssl_private_key; - std::string ssl_certificate_chain; + vector planes; + fhicl::ParameterSet tritonPset; + std::unique_ptr triton_client; + + // loader tool + std::unique_ptr _loaderTool; + // decoder tools + std::vector> _decoderToolsVec; + + template + void setShapeAndToServer(lartriton::TritonData& triton_input, + vector& vec, + size_t batchSize) + { + triton_input.setShape({static_cast(vec.size())}); + triton_input.toServer( + std::make_shared>(lartriton::TritonInput(batchSize, vec))); + } }; NuGraphInferenceSonicTriton::NuGraphInferenceSonicTriton(fhicl::ParameterSet const& p) : EDProducer{p} - , planes(p.get>("planes")) - , hitInput(p.get("hitInput")) - , spsInput(p.get("spsInput")) , minHits(p.get("minHits")) , debug(p.get("debug")) - , filterDecoder(p.get("filterDecoder")) - , semanticDecoder(p.get("semanticDecoder")) - , vertexDecoder(p.get("vertexDecoder")) - , inference_url(p.get("url")) - , inference_model_name(p.get("modelName")) - , inference_ssl(p.get("ssl")) - , ssl_root_certificates(p.get("sslRootCertificates", "")) - , ssl_private_key(p.get("sslPrivateKey", "")) - , ssl_certificate_chain(p.get("sslCertificateChain", "")) + , planes(p.get>("planes")) + , tritonPset(p.get("TritonConfig")) { - if (filterDecoder) { produces>>("filter"); } - // - if (semanticDecoder) { - produces>>("semantic"); - produces>("semantic"); + // ... Create the Triton inference client + if (debug) std::cout << "TritonConfig: " << tritonPset.to_string() << std::endl; + triton_client = std::make_unique(tritonPset); + + // Loader Tool + _loaderTool = art::make_tool(p.get("LoaderTool")); + _loaderTool->setDebugAndPlanes(debug, planes); + + // configure and construct Decoder Tools + auto const tool_psets = p.get("DecoderTools"); + for (auto const& tool_pset_labels : tool_psets.get_pset_names()) { + std::cout << "decoder lablel: " << tool_pset_labels << std::endl; + auto const tool_pset = tool_psets.get(tool_pset_labels); + _decoderToolsVec.push_back(art::make_tool(tool_pset)); + _decoderToolsVec.back()->setDebugAndPlanes(debug, planes); + _decoderToolsVec.back()->declareProducts(producesCollector()); } - // - if (vertexDecoder) { produces>("vertex"); } } void NuGraphInferenceSonicTriton::produce(art::Event& e) { - art::Handle> hitListHandle; + // + // Load the data and fill the graph inputs + // vector> hitlist; - if (e.getByLabel(hitInput, hitListHandle)) { art::fill_ptr_vector(hitlist, hitListHandle); } - - std::unique_ptr>> filtcol( - new vector>(hitlist.size(), FeatureVector<1>(std::array({-1.})))); - - std::unique_ptr>> semtcol(new vector>( - hitlist.size(), FeatureVector<5>(std::array({-1., -1., -1., -1., -1.})))); - std::unique_ptr> semtdes( - new MVADescription<5>(hitListHandle.provenance()->moduleLabel(), - "semantic", - {"MIP", "HIP", "shower", "michel", "diffuse"})); - - std::unique_ptr> vertcol(new vector()); + vector> idsmap; + vector graphinputs; + _loaderTool->loadData(e, hitlist, graphinputs, idsmap); if (debug) std::cout << "Hits size=" << hitlist.size() << std::endl; if (hitlist.size() < minHits) { - if (filterDecoder) { e.put(std::move(filtcol), "filter"); } - if (semanticDecoder) { - e.put(std::move(semtcol), "semantic"); - e.put(std::move(semtdes), "semantic"); + // Writing the empty outputs to the output root file + for (size_t i = 0; i < _decoderToolsVec.size(); i++) { + _decoderToolsVec[i]->writeEmptyToEvent(e, idsmap); } - if (vertexDecoder) { e.put(std::move(vertcol), "vertex"); } return; } - vector> idsmap(planes.size(), vector()); - vector idsmapRev(hitlist.size(), hitlist.size()); - for (auto h : hitlist) { - idsmap[h->View()].push_back(h.key()); - idsmapRev[h.key()] = idsmap[h->View()].size() - 1; - } - - // event id - int run = e.id().run(); - int subrun = e.id().subRun(); - int event = e.id().event(); - - array evtID; - evtID[0] = run; - evtID[1] = subrun; - evtID[2] = event; - - // hit table - vector hit_table_hit_id_data; - vector hit_table_local_plane_data; - vector hit_table_local_time_data; - vector hit_table_local_wire_data; - vector hit_table_integral_data; - vector hit_table_rms_data; - for (auto h : hitlist) { - hit_table_hit_id_data.push_back(h.key()); - hit_table_local_plane_data.push_back(h->View()); - hit_table_local_time_data.push_back(h->PeakTime()); - hit_table_local_wire_data.push_back(h->WireID().Wire); - hit_table_integral_data.push_back(h->Integral()); - hit_table_rms_data.push_back(h->RMS()); - } - - // Get spacepoints from the event record - art::Handle> spListHandle; - vector> splist; - if (e.getByLabel(spsInput, spListHandle)) { art::fill_ptr_vector(splist, spListHandle); } - // Get assocations from spacepoints to hits - vector>> sp2Hit(splist.size()); - if (splist.size() > 0) { - art::FindManyP fmp(spListHandle, e, "sps"); - for (size_t spIdx = 0; spIdx < sp2Hit.size(); ++spIdx) { - sp2Hit[spIdx] = fmp.at(spIdx); - } - } - - // space point table - vector spacepoint_table_spacepoint_id_data; - vector spacepoint_table_hit_id_u_data; - vector spacepoint_table_hit_id_v_data; - vector spacepoint_table_hit_id_y_data; - for (size_t i = 0; i < splist.size(); ++i) { - spacepoint_table_spacepoint_id_data.push_back(i); - spacepoint_table_hit_id_u_data.push_back(-1); - spacepoint_table_hit_id_v_data.push_back(-1); - spacepoint_table_hit_id_y_data.push_back(-1); - for (size_t j = 0; j < sp2Hit[i].size(); ++j) { - if (sp2Hit[i][j]->View() == 0) spacepoint_table_hit_id_u_data.back() = sp2Hit[i][j].key(); - if (sp2Hit[i][j]->View() == 1) spacepoint_table_hit_id_v_data.back() = sp2Hit[i][j].key(); - if (sp2Hit[i][j]->View() == 2) spacepoint_table_hit_id_y_data.back() = sp2Hit[i][j].key(); - } - } + // + // NuSonic Triton Server section + // + auto start = std::chrono::high_resolution_clock::now(); + // //Here the input should be sent to Triton - bool fTritonVerbose = false; - std::string fTritonModelVersion = ""; - unsigned fTritonTimeout = 0; - unsigned fTritonAllowedTries = 1; - std::unique_ptr triton_client; - - // ... Create parameter set for Triton inference client - fhicl::ParameterSet TritonPset; - TritonPset.put("serverURL", inference_url); - TritonPset.put("verbose", fTritonVerbose); - TritonPset.put("ssl", inference_ssl); - TritonPset.put("sslRootCertificates", ssl_root_certificates); - TritonPset.put("sslPrivateKey", ssl_private_key); - TritonPset.put("sslCertificateChain", ssl_certificate_chain); - TritonPset.put("modelName", inference_model_name); - TritonPset.put("modelVersion", fTritonModelVersion); - TritonPset.put("timeout", fTritonTimeout); - TritonPset.put("allowedTries", fTritonAllowedTries); - TritonPset.put("outputs", "[]"); - - // ... Create the Triton inference client - triton_client = std::make_unique(TritonPset); - - triton_client->setBatchSize(1); // set batch size - - auto hit_table_hit_id_ptr = std::make_shared>(); - auto hit_table_local_plane_ptr = std::make_shared>(); - auto hit_table_local_time_ptr = std::make_shared>(); - auto hit_table_local_wire_ptr = std::make_shared>(); - auto hit_table_integral_ptr = std::make_shared>(); - auto hit_table_rms_ptr = std::make_shared>(); - auto spacepoint_table_spacepoint_id_ptr = std::make_shared>(); - auto spacepoint_table_hit_id_u_ptr = std::make_shared>(); - auto spacepoint_table_hit_id_v_ptr = std::make_shared>(); - auto spacepoint_table_hit_id_y_ptr = std::make_shared>(); - - hit_table_hit_id_ptr->reserve(1); - hit_table_local_plane_ptr->reserve(1); - hit_table_local_time_ptr->reserve(1); - hit_table_local_wire_ptr->reserve(1); - hit_table_integral_ptr->reserve(1); - hit_table_rms_ptr->reserve(1); - spacepoint_table_spacepoint_id_ptr->reserve(1); - spacepoint_table_hit_id_u_ptr->reserve(1); - spacepoint_table_hit_id_v_ptr->reserve(1); - spacepoint_table_hit_id_y_ptr->reserve(1); - - auto& hit_table_hit_id = hit_table_hit_id_ptr->emplace_back(); - auto& hit_table_local_plane = hit_table_local_plane_ptr->emplace_back(); - auto& hit_table_local_time = hit_table_local_time_ptr->emplace_back(); - auto& hit_table_local_wire = hit_table_local_wire_ptr->emplace_back(); - auto& hit_table_integral = hit_table_integral_ptr->emplace_back(); - auto& hit_table_rms = hit_table_rms_ptr->emplace_back(); - auto& spacepoint_table_spacepoint_id = spacepoint_table_spacepoint_id_ptr->emplace_back(); - auto& spacepoint_table_hit_id_u = spacepoint_table_hit_id_u_ptr->emplace_back(); - auto& spacepoint_table_hit_id_v = spacepoint_table_hit_id_v_ptr->emplace_back(); - auto& spacepoint_table_hit_id_y = spacepoint_table_hit_id_y_ptr->emplace_back(); - + triton_client->reset(); + size_t batchSize = 1; //the code below assumes/has only been tested for batch size = 1 + triton_client->setBatchSize(batchSize); // set batch size + // auto& inputs = triton_client->input(); for (auto& input_pair : inputs) { const std::string& key = input_pair.first; auto& triton_input = input_pair.second; - - if (key == "hit_table_hit_id") { - for (size_t i = 0; i < hit_table_hit_id_data.size(); ++i) { - hit_table_hit_id.push_back(hit_table_hit_id_data[i]); - } - triton_input.setShape({static_cast(hit_table_hit_id_data.size())}); - triton_input.toServer(hit_table_hit_id_ptr); - } - else if (key == "hit_table_local_plane") { - for (size_t i = 0; i < hit_table_local_plane_data.size(); ++i) { - hit_table_local_plane.push_back(hit_table_local_plane_data[i]); - } - triton_input.setShape({static_cast(hit_table_local_plane_data.size())}); - triton_input.toServer(hit_table_local_plane_ptr); - } - else if (key == "hit_table_local_time") { - for (size_t i = 0; i < hit_table_local_time_data.size(); ++i) { - hit_table_local_time.push_back(hit_table_local_time_data[i]); - } - triton_input.setShape({static_cast(hit_table_local_time_data.size())}); - triton_input.toServer(hit_table_local_time_ptr); - } - else if (key == "hit_table_local_wire") { - for (size_t i = 0; i < hit_table_local_wire_data.size(); ++i) { - hit_table_local_wire.push_back(hit_table_local_wire_data[i]); - } - triton_input.setShape({static_cast(hit_table_local_wire_data.size())}); - triton_input.toServer(hit_table_local_wire_ptr); - } - else if (key == "hit_table_integral") { - for (size_t i = 0; i < hit_table_integral_data.size(); ++i) { - hit_table_integral.push_back(hit_table_integral_data[i]); - } - triton_input.setShape({static_cast(hit_table_integral_data.size())}); - triton_input.toServer(hit_table_integral_ptr); - } - else if (key == "hit_table_rms") { - for (size_t i = 0; i < hit_table_rms_data.size(); ++i) { - hit_table_rms.push_back(hit_table_rms_data[i]); - } - triton_input.setShape({static_cast(hit_table_rms_data.size())}); - triton_input.toServer(hit_table_rms_ptr); - } - else if (key == "spacepoint_table_spacepoint_id") { - for (size_t i = 0; i < spacepoint_table_spacepoint_id_data.size(); ++i) { - spacepoint_table_spacepoint_id.push_back(spacepoint_table_spacepoint_id_data[i]); - } - triton_input.setShape({static_cast(spacepoint_table_spacepoint_id_data.size())}); - triton_input.toServer(spacepoint_table_spacepoint_id_ptr); - } - else if (key == "spacepoint_table_hit_id_u") { - for (size_t i = 0; i < spacepoint_table_hit_id_u_data.size(); ++i) { - spacepoint_table_hit_id_u.push_back(spacepoint_table_hit_id_u_data[i]); - } - triton_input.setShape({static_cast(spacepoint_table_hit_id_u_data.size())}); - triton_input.toServer(spacepoint_table_hit_id_u_ptr); - } - else if (key == "spacepoint_table_hit_id_v") { - for (size_t i = 0; i < spacepoint_table_hit_id_v_data.size(); ++i) { - spacepoint_table_hit_id_v.push_back(spacepoint_table_hit_id_v_data[i]); - } - triton_input.setShape({static_cast(spacepoint_table_hit_id_v_data.size())}); - triton_input.toServer(spacepoint_table_hit_id_v_ptr); - } - else if (key == "spacepoint_table_hit_id_y") { - for (size_t i = 0; i < spacepoint_table_hit_id_y_data.size(); ++i) { - spacepoint_table_hit_id_y.push_back(spacepoint_table_hit_id_y_data[i]); - } - triton_input.setShape({static_cast(spacepoint_table_hit_id_y_data.size())}); - triton_input.toServer(spacepoint_table_hit_id_y_ptr); + // + for (auto& gi : graphinputs) { + if (key != gi.input_name) continue; + if (gi.isInt) + setShapeAndToServer(triton_input, gi.input_int32_vec, batchSize); + else + setShapeAndToServer(triton_input, gi.input_float_vec, batchSize); } } - - auto start = std::chrono::high_resolution_clock::now(); // ~~~~ Send inference request triton_client->dispatch(); + // ~~~~ Retrieve inference results + auto& infer_result = triton_client->output(); auto end = std::chrono::high_resolution_clock::now(); std::chrono::duration elapsed = end - start; std::cout << "Time taken for inference: " << elapsed.count() << " seconds" << std::endl; - // ~~~~ Retrieve inference results - const auto& triton_output0 = triton_client->output().at("x_semantic_u"); - const auto& prob0 = triton_output0.fromServer(); - size_t triton_input0_elements = std::distance(prob0[0].begin(), prob0[0].end()); - - const auto& triton_output1 = triton_client->output().at("x_semantic_v"); - const auto& prob1 = triton_output1.fromServer(); - size_t triton_input1_elements = std::distance(prob1[0].begin(), prob1[0].end()); - - const auto& triton_output2 = triton_client->output().at("x_semantic_y"); - const auto& prob2 = triton_output2.fromServer(); - size_t triton_input2_elements = std::distance(prob2[0].begin(), prob2[0].end()); - - const auto& triton_output3 = triton_client->output().at("x_filter_u"); - const auto& prob3 = triton_output3.fromServer(); - size_t triton_input3_elements = std::distance(prob3[0].begin(), prob3[0].end()); - - const auto& triton_output4 = triton_client->output().at("x_filter_v"); - const auto& prob4 = triton_output4.fromServer(); - size_t triton_input4_elements = std::distance(prob4[0].begin(), prob4[0].end()); - - const auto& triton_output5 = triton_client->output().at("x_filter_y"); - const auto& prob5 = triton_output5.fromServer(); - size_t triton_input5_elements = std::distance(prob5[0].begin(), prob5[0].end()); - - // putting in the resp output vectors - std::vector x_semantic_u_data; - x_semantic_u_data.reserve(triton_input0_elements); - x_semantic_u_data.insert(x_semantic_u_data.end(), prob0[0].begin(), prob0[0].end()); - - std::vector x_semantic_v_data; - x_semantic_v_data.reserve(triton_input1_elements); - x_semantic_v_data.insert(x_semantic_v_data.end(), prob1[0].begin(), prob1[0].end()); - - std::vector x_semantic_y_data; - x_semantic_y_data.reserve(triton_input2_elements); - x_semantic_y_data.insert(x_semantic_y_data.end(), prob2[0].begin(), prob2[0].end()); - - std::vector x_filter_u_data; - x_filter_u_data.reserve(triton_input3_elements); - x_filter_u_data.insert(x_filter_u_data.end(), prob3[0].begin(), prob3[0].end()); - - std::vector x_filter_v_data; - x_filter_v_data.reserve(triton_input4_elements); - x_filter_v_data.insert(x_filter_v_data.end(), prob4[0].begin(), prob4[0].end()); - - std::vector x_filter_y_data; - x_filter_y_data.reserve(triton_input5_elements); - x_filter_y_data.insert(x_filter_y_data.end(), prob5[0].begin(), prob5[0].end()); - - std::cout << "Triton Input: " << std::endl; - - std::cout << "x_semantic_u: " << std::endl; - printVector(x_semantic_u_data); - - std::cout << "x_semantic_v: " << std::endl; - printVector(x_semantic_v_data); - - std::cout << "x_semantic_y: " << std::endl; - printVector(x_semantic_y_data); - - std::cout << "x_filter_u: " << std::endl; - printVector(x_filter_u_data); - - std::cout << "x_filter_v: " << std::endl; - printVector(x_filter_v_data); - - std::cout << "x_filter_y: " << std::endl; - printVector(x_filter_y_data); - - // writing the outputs to the output root file - if (semanticDecoder) { - size_t n_cols = 5; - for (size_t p = 0; p < planes.size(); p++) { - torch::Tensor s; - torch::TensorOptions options = torch::TensorOptions().dtype(torch::kFloat32); - if (planes[p] == "u") { - size_t n_rows = x_semantic_u_data.size() / n_cols; - s = torch::from_blob(x_semantic_u_data.data(), - {static_cast(n_rows), static_cast(n_cols)}, - options); - } - else if (planes[p] == "v") { - size_t n_rows = x_semantic_v_data.size() / n_cols; - s = torch::from_blob(x_semantic_v_data.data(), - {static_cast(n_rows), static_cast(n_cols)}, - options); - } - else if (planes[p] == "y") { - size_t n_rows = x_semantic_y_data.size() / n_cols; - s = torch::from_blob(x_semantic_y_data.data(), - {static_cast(n_rows), static_cast(n_cols)}, - options); - } - else { - std::cout << "Error!!" << std::endl; - } - - for (int i = 0; i < s.sizes()[0]; ++i) { - size_t idx = idsmap[p][i]; - std::array input({s[i][0].item(), - s[i][1].item(), - s[i][2].item(), - s[i][3].item(), - s[i][4].item()}); - softmax(input); - FeatureVector<5> semt = FeatureVector<5>(input); - (*semtcol)[idx] = semt; - } - } - } - if (filterDecoder) { - for (size_t p = 0; p < planes.size(); p++) { - torch::Tensor f; - torch::TensorOptions options = torch::TensorOptions().dtype(torch::kFloat32); - if (planes[p] == "u") { - int64_t num_elements = x_filter_u_data.size(); - f = torch::from_blob(x_filter_u_data.data(), {num_elements}, options); - } - else if (planes[p] == "v") { - int64_t num_elements = x_filter_v_data.size(); - f = torch::from_blob(x_filter_v_data.data(), {num_elements}, options); - } - else if (planes[p] == "y") { - int64_t num_elements = x_filter_y_data.size(); - f = torch::from_blob(x_filter_y_data.data(), {num_elements}, options); - } - else { - std::cout << "error!" << std::endl; - } - - for (int i = 0; i < f.numel(); ++i) { - size_t idx = idsmap[p][i]; - std::array input({f[i].item()}); - (*filtcol)[idx] = FeatureVector<1>(input); - } - } + // + // Get pointers to the result returned and write to the event + // + vector infer_output; + for (const auto& [name, data] : infer_result) { + const auto& prob = data.fromServer(); + std::vector out_data(prob[0].begin(), prob[0].end()); + infer_output.emplace_back(name, std::move(out_data)); } - if (filterDecoder) { e.put(std::move(filtcol), "filter"); } - if (semanticDecoder) { - e.put(std::move(semtcol), "semantic"); - e.put(std::move(semtdes), "semantic"); + // Write the outputs to the output root file + for (size_t i = 0; i < _decoderToolsVec.size(); i++) { + _decoderToolsVec[i]->writeToEvent(e, idsmap, infer_output); } - if (vertexDecoder) { e.put(std::move(vertcol), "vertex"); } } DEFINE_ART_MODULE(NuGraphInferenceSonicTriton) diff --git a/larrecodnn/NuGraph/NuGraphInferenceTriton_module.cc b/larrecodnn/NuGraph/NuGraphInferenceTriton_module.cc index c3786524..e0dfb28c 100644 --- a/larrecodnn/NuGraph/NuGraphInferenceTriton_module.cc +++ b/larrecodnn/NuGraph/NuGraphInferenceTriton_module.cc @@ -10,12 +10,9 @@ #include "art/Framework/Core/EDProducer.h" #include "art/Framework/Core/ModuleMacros.h" #include "art/Framework/Principal/Event.h" -#include "art/Framework/Principal/Handle.h" -#include "art/Framework/Principal/Run.h" -#include "art/Framework/Principal/SubRun.h" -#include "canvas/Persistency/Common/FindManyP.h" -#include "canvas/Utilities/InputTag.h" +#include "canvas/Persistency/Common/Ptr.h" #include "fhiclcpp/ParameterSet.h" +#include "fhiclcpp/types/Table.h" #include "messagefacility/MessageLogger/MessageLogger.h" #include @@ -24,18 +21,15 @@ #include "lardataobj/AnalysisBase/MVAOutput.h" #include "lardataobj/RecoBase/Hit.h" -#include "lardataobj/RecoBase/SpacePoint.h" -#include "lardataobj/RecoBase/Vertex.h" //this creates a conflict with torch script if included before it... - -#include -#include +#include "lardataobj/RecoBase/Vertex.h" +#include "larrecodnn/NuGraph/Tools/DecoderToolBase.h" +#include "larrecodnn/NuGraph/Tools/LoaderToolBase.h" #include #include #include #include #include -#include #include #include "grpc_client.h" @@ -45,7 +39,6 @@ class NuGraphInferenceTriton; using anab::FeatureVector; using anab::MVADescription; using recob::Hit; -using recob::SpacePoint; using std::array; using std::vector; @@ -59,51 +52,6 @@ using std::vector; } namespace tc = triton::client; -namespace { - - template - int arg_max(std::vector const& vec) - { - return static_cast(std::distance(vec.begin(), max_element(vec.begin(), vec.end()))); - } - - template - void softmax(std::array& arr) - { - T m = -std::numeric_limits::max(); - for (size_t i = 0; i < arr.size(); i++) { - if (arr[i] > m) { m = arr[i]; } - } - T sum = 0.0; - for (size_t i = 0; i < arr.size(); i++) { - sum += expf(arr[i] - m); - } - T offset = m + logf(sum); - for (size_t i = 0; i < arr.size(); i++) { - arr[i] = expf(arr[i] - offset); - } - return; - } -} - -// Function to convert string to integer -int stoi(const std::string& str) -{ - std::istringstream iss(str); - int num; - iss >> num; - return num; -} - -void printFloatArray(const float* data, size_t num_elements) -{ - for (size_t i = 0; i < num_elements; ++i) { - std::cout << data[i]; - if (i < num_elements - 1) { std::cout << " "; } - } - std::cout << std::endl; -} - class NuGraphInferenceTriton : public art::EDProducer { public: explicit NuGraphInferenceTriton(fhicl::ParameterSet const& p); @@ -118,164 +66,126 @@ class NuGraphInferenceTriton : public art::EDProducer { void produce(art::Event& e) override; private: - vector planes; - art::InputTag hitInput; - art::InputTag spsInput; size_t minHits; bool debug; - // vector> avgs; - // vector> devs; - bool filterDecoder; - bool semanticDecoder; - bool vertexDecoder; + vector planes; std::string inference_url; std::string inference_model_name; + std::string model_version; bool inference_ssl; std::string ssl_root_certificates; std::string ssl_private_key; std::string ssl_certificate_chain; + bool verbose; + uint32_t client_timeout; + + // loader tool + std::unique_ptr _loaderTool; + // decoder tools + std::vector> _decoderToolsVec; }; NuGraphInferenceTriton::NuGraphInferenceTriton(fhicl::ParameterSet const& p) : EDProducer{p} - , planes(p.get>("planes")) - , hitInput(p.get("hitInput")) - , spsInput(p.get("spsInput")) , minHits(p.get("minHits")) , debug(p.get("debug")) - , filterDecoder(p.get("filterDecoder")) - , semanticDecoder(p.get("semanticDecoder")) - , vertexDecoder(p.get("vertexDecoder")) - , inference_url(p.get("url")) - , inference_model_name(p.get("modelName")) - , inference_ssl(p.get("ssl")) - , ssl_root_certificates(p.get("sslRootCertificates", "")) - , ssl_private_key(p.get("sslPrivateKey", "")) - , ssl_certificate_chain(p.get("sslCertificateChain", "")) + , planes(p.get>("planes")) { - if (filterDecoder) { produces>>("filter"); } - // - if (semanticDecoder) { - produces>>("semantic"); - produces>("semantic"); + fhicl::ParameterSet tritonPset = p.get("TritonConfig"); + inference_url = tritonPset.get("serverURL"); + inference_model_name = tritonPset.get("modelName"); + inference_ssl = tritonPset.get("ssl"); + ssl_root_certificates = tritonPset.get("sslRootCertificates", ""); + ssl_private_key = tritonPset.get("sslPrivateKey", ""); + ssl_certificate_chain = tritonPset.get("sslCertificateChain", ""); + verbose = tritonPset.get("verbose", "false"); + model_version = tritonPset.get("modelVersion", ""); + client_timeout = tritonPset.get("timeout", 0); + + // Loader Tool + _loaderTool = art::make_tool(p.get("LoaderTool")); + _loaderTool->setDebugAndPlanes(debug, planes); + + // configure and construct Decoder Tools + auto const tool_psets = p.get("DecoderTools"); + for (auto const& tool_pset_labels : tool_psets.get_pset_names()) { + std::cout << "decoder lablel: " << tool_pset_labels << std::endl; + auto const tool_pset = tool_psets.get(tool_pset_labels); + _decoderToolsVec.push_back(art::make_tool(tool_pset)); + _decoderToolsVec.back()->setDebugAndPlanes(debug, planes); + _decoderToolsVec.back()->declareProducts(producesCollector()); } - // - if (vertexDecoder) { produces>("vertex"); } } void NuGraphInferenceTriton::produce(art::Event& e) { - art::Handle> hitListHandle; + // + // Load the data and fill the graph inputs + // vector> hitlist; - if (e.getByLabel(hitInput, hitListHandle)) { art::fill_ptr_vector(hitlist, hitListHandle); } - - std::unique_ptr>> filtcol( - new vector>(hitlist.size(), FeatureVector<1>(std::array({-1.})))); - - std::unique_ptr>> semtcol(new vector>( - hitlist.size(), FeatureVector<5>(std::array({-1., -1., -1., -1., -1.})))); - std::unique_ptr> semtdes( - new MVADescription<5>(hitListHandle.provenance()->moduleLabel(), - "semantic", - {"MIP", "HIP", "shower", "michel", "diffuse"})); - - std::unique_ptr> vertcol(new vector()); + vector> idsmap; + vector graphinputs; + _loaderTool->loadData(e, hitlist, graphinputs, idsmap); if (debug) std::cout << "Hits size=" << hitlist.size() << std::endl; if (hitlist.size() < minHits) { - if (filterDecoder) { e.put(std::move(filtcol), "filter"); } - if (semanticDecoder) { - e.put(std::move(semtcol), "semantic"); - e.put(std::move(semtdes), "semantic"); + // Writing the empty outputs to the output root file + for (size_t i = 0; i < _decoderToolsVec.size(); i++) { + _decoderToolsVec[i]->writeEmptyToEvent(e, idsmap); } - if (vertexDecoder) { e.put(std::move(vertcol), "vertex"); } return; } - vector> idsmap(planes.size(), vector()); - vector idsmapRev(hitlist.size(), hitlist.size()); - for (auto h : hitlist) { - idsmap[h->View()].push_back(h.key()); - idsmapRev[h.key()] = idsmap[h->View()].size() - 1; - } - - // event id - int run = e.id().run(); - int subrun = e.id().subRun(); - int event = e.id().event(); - - array evtID; - evtID[0] = run; - evtID[1] = subrun; - evtID[2] = event; - - // hit table - vector hit_table_hit_id_data; - vector hit_table_local_plane_data; - vector hit_table_local_time_data; - vector hit_table_local_wire_data; - vector hit_table_integral_data; - vector hit_table_rms_data; - for (auto h : hitlist) { - hit_table_hit_id_data.push_back(h.key()); - hit_table_local_plane_data.push_back(h->View()); - hit_table_local_time_data.push_back(h->PeakTime()); - hit_table_local_wire_data.push_back(h->WireID().Wire); - hit_table_integral_data.push_back(h->Integral()); - hit_table_rms_data.push_back(h->RMS()); - } - - // Get spacepoints from the event record - art::Handle> spListHandle; - vector> splist; - if (e.getByLabel(spsInput, spListHandle)) { art::fill_ptr_vector(splist, spListHandle); } - // Get assocations from spacepoints to hits - vector>> sp2Hit(splist.size()); - if (splist.size() > 0) { - art::FindManyP fmp(spListHandle, e, "sps"); - for (size_t spIdx = 0; spIdx < sp2Hit.size(); ++spIdx) { - sp2Hit[spIdx] = fmp.at(spIdx); - } - } - // space point table - vector spacepoint_table_spacepoint_id_data; - vector spacepoint_table_hit_id_u_data; - vector spacepoint_table_hit_id_v_data; - vector spacepoint_table_hit_id_y_data; - for (size_t i = 0; i < splist.size(); ++i) { - spacepoint_table_spacepoint_id_data.push_back(i); - spacepoint_table_hit_id_u_data.push_back(-1); - spacepoint_table_hit_id_v_data.push_back(-1); - spacepoint_table_hit_id_y_data.push_back(-1); - for (size_t j = 0; j < sp2Hit[i].size(); ++j) { - if (sp2Hit[i][j]->View() == 0) spacepoint_table_hit_id_u_data.back() = sp2Hit[i][j].key(); - if (sp2Hit[i][j]->View() == 1) spacepoint_table_hit_id_v_data.back() = sp2Hit[i][j].key(); - if (sp2Hit[i][j]->View() == 2) spacepoint_table_hit_id_y_data.back() = sp2Hit[i][j].key(); - } + // + // Triton-specific section + // + const vector* hit_table_hit_id_data = nullptr; + const vector* hit_table_local_plane_data = nullptr; + const vector* hit_table_local_time_data = nullptr; + const vector* hit_table_local_wire_data = nullptr; + const vector* hit_table_integral_data = nullptr; + const vector* hit_table_rms_data = nullptr; + const vector* spacepoint_table_spacepoint_id_data = nullptr; + const vector* spacepoint_table_hit_id_u_data = nullptr; + const vector* spacepoint_table_hit_id_v_data = nullptr; + const vector* spacepoint_table_hit_id_y_data = nullptr; + for (const auto& gi : graphinputs) { + if (gi.input_name == "hit_table_hit_id") + hit_table_hit_id_data = &gi.input_int32_vec; + else if (gi.input_name == "hit_table_local_plane") + hit_table_local_plane_data = &gi.input_int32_vec; + else if (gi.input_name == "hit_table_local_time") + hit_table_local_time_data = &gi.input_float_vec; + else if (gi.input_name == "hit_table_local_wire") + hit_table_local_wire_data = &gi.input_int32_vec; + else if (gi.input_name == "hit_table_integral") + hit_table_integral_data = &gi.input_float_vec; + else if (gi.input_name == "hit_table_rms") + hit_table_rms_data = &gi.input_float_vec; + else if (gi.input_name == "spacepoint_table_spacepoint_id") + spacepoint_table_spacepoint_id_data = &gi.input_int32_vec; + else if (gi.input_name == "spacepoint_table_hit_id_u") + spacepoint_table_hit_id_u_data = &gi.input_int32_vec; + else if (gi.input_name == "spacepoint_table_hit_id_v") + spacepoint_table_hit_id_v_data = &gi.input_int32_vec; + else if (gi.input_name == "spacepoint_table_hit_id_y") + spacepoint_table_hit_id_y_data = &gi.input_int32_vec; } //Here the input should be sent to Triton - bool verbose = false; - std::string url(inference_url); tc::Headers http_headers; - uint32_t client_timeout = 0; - bool use_ssl = inference_ssl; grpc_compression_algorithm compression_algorithm = grpc_compression_algorithm::GRPC_COMPRESS_NONE; bool test_use_cached_channel = false; bool use_cached_channel = true; - // the element-wise difference. - std::string model_name = inference_model_name; - std::string model_version = ""; - // Create a InferenceServerGrpcClient instance to communicate with the // server using gRPC protocol. std::unique_ptr client; tc::SslOptions ssl_options = tc::SslOptions(); std::string err; - if (use_ssl) { + if (inference_ssl) { ssl_options.root_certificates = ssl_root_certificates; ssl_options.private_key = ssl_private_key; ssl_options.certificate_chain = ssl_certificate_chain; @@ -287,14 +197,18 @@ void NuGraphInferenceTriton::produce(art::Event& e) // Run with the same name to ensure cached channel is not used int numRuns = test_use_cached_channel ? 2 : 1; for (int i = 0; i < numRuns; ++i) { - FAIL_IF_ERR( - tc::InferenceServerGrpcClient::Create( - &client, url, verbose, use_ssl, ssl_options, tc::KeepAliveOptions(), use_cached_channel), - err); - - std::vector hit_table_shape{int64_t(hit_table_hit_id_data.size())}; + FAIL_IF_ERR(tc::InferenceServerGrpcClient::Create(&client, + inference_url, + verbose, + inference_ssl, + ssl_options, + tc::KeepAliveOptions(), + use_cached_channel), + err); + + std::vector hit_table_shape{int64_t(hit_table_hit_id_data->size())}; std::vector spacepoint_table_shape{ - int64_t(spacepoint_table_spacepoint_id_data.size())}; + int64_t(spacepoint_table_spacepoint_id_data->size())}; // Initialize the inputs with the data. tc::InferInput* hit_table_hit_id; @@ -373,53 +287,54 @@ void NuGraphInferenceTriton::produce(art::Event& e) std::shared_ptr spacepoint_table_hit_id_y_ptr; spacepoint_table_hit_id_y_ptr.reset(spacepoint_table_hit_id_y); - FAIL_IF_ERR( - hit_table_hit_id_ptr->AppendRaw(reinterpret_cast(&hit_table_hit_id_data[0]), - hit_table_hit_id_data.size() * sizeof(float)), - "unable to set data for hit_table_hit_id"); + FAIL_IF_ERR(hit_table_hit_id_ptr->AppendRaw( + reinterpret_cast(hit_table_hit_id_data->data()), + hit_table_hit_id_data->size() * sizeof(int32_t)), + "unable to set data for hit_table_hit_id"); FAIL_IF_ERR(hit_table_local_plane_ptr->AppendRaw( - reinterpret_cast(&hit_table_local_plane_data[0]), - hit_table_local_plane_data.size() * sizeof(float)), + reinterpret_cast(hit_table_local_plane_data->data()), + hit_table_local_plane_data->size() * sizeof(int32_t)), "unable to set data for hit_table_local_plane"); - FAIL_IF_ERR( - hit_table_local_time_ptr->AppendRaw(reinterpret_cast(&hit_table_local_time_data[0]), - hit_table_local_time_data.size() * sizeof(float)), - "unable to set data for hit_table_local_time"); + FAIL_IF_ERR(hit_table_local_time_ptr->AppendRaw( + reinterpret_cast(hit_table_local_time_data->data()), + hit_table_local_time_data->size() * sizeof(float)), + "unable to set data for hit_table_local_time"); - FAIL_IF_ERR( - hit_table_local_wire_ptr->AppendRaw(reinterpret_cast(&hit_table_local_wire_data[0]), - hit_table_local_wire_data.size() * sizeof(float)), - "unable to set data for hit_table_local_wire"); + FAIL_IF_ERR(hit_table_local_wire_ptr->AppendRaw( + reinterpret_cast(hit_table_local_wire_data->data()), + hit_table_local_wire_data->size() * sizeof(int32_t)), + "unable to set data for hit_table_local_wire"); - FAIL_IF_ERR( - hit_table_integral_ptr->AppendRaw(reinterpret_cast(&hit_table_integral_data[0]), - hit_table_integral_data.size() * sizeof(float)), - "unable to set data for hit_table_integral"); + FAIL_IF_ERR(hit_table_integral_ptr->AppendRaw( + reinterpret_cast(hit_table_integral_data->data()), + hit_table_integral_data->size() * sizeof(float)), + "unable to set data for hit_table_integral"); - FAIL_IF_ERR(hit_table_rms_ptr->AppendRaw(reinterpret_cast(&hit_table_rms_data[0]), - hit_table_rms_data.size() * sizeof(float)), - "unable to set data for hit_table_rms"); + FAIL_IF_ERR( + hit_table_rms_ptr->AppendRaw(reinterpret_cast(hit_table_rms_data->data()), + hit_table_rms_data->size() * sizeof(float)), + "unable to set data for hit_table_rms"); FAIL_IF_ERR(spacepoint_table_spacepoint_id_ptr->AppendRaw( - reinterpret_cast(&spacepoint_table_spacepoint_id_data[0]), - spacepoint_table_spacepoint_id_data.size() * sizeof(float)), + reinterpret_cast(spacepoint_table_spacepoint_id_data->data()), + spacepoint_table_spacepoint_id_data->size() * sizeof(int32_t)), "unable to set data for spacepoint_table_spacepoint_id"); FAIL_IF_ERR(spacepoint_table_hit_id_u_ptr->AppendRaw( - reinterpret_cast(&spacepoint_table_hit_id_u_data[0]), - spacepoint_table_hit_id_u_data.size() * sizeof(float)), + reinterpret_cast(spacepoint_table_hit_id_u_data->data()), + spacepoint_table_hit_id_u_data->size() * sizeof(int32_t)), "unable to set data for spacepoint_table_hit_id_u"); FAIL_IF_ERR(spacepoint_table_hit_id_v_ptr->AppendRaw( - reinterpret_cast(&spacepoint_table_hit_id_v_data[0]), - spacepoint_table_hit_id_v_data.size() * sizeof(float)), + reinterpret_cast(spacepoint_table_hit_id_v_data->data()), + spacepoint_table_hit_id_v_data->size() * sizeof(int32_t)), "unable to set data for spacepoint_table_hit_id_v"); FAIL_IF_ERR(spacepoint_table_hit_id_y_ptr->AppendRaw( - reinterpret_cast(&spacepoint_table_hit_id_y_data[0]), - spacepoint_table_hit_id_y_data.size() * sizeof(float)), + reinterpret_cast(spacepoint_table_hit_id_y_data->data()), + spacepoint_table_hit_id_y_data->size() * sizeof(int32_t)), "unable to set data for spacepoint_table_hit_id_y"); // Generate the outputs to be requested. @@ -461,7 +376,7 @@ void NuGraphInferenceTriton::produce(art::Event& e) x_filter_y_ptr.reset(x_filter_y); // The inference settings. Will be using default for now. - tc::InferOptions options(model_name); + tc::InferOptions options(inference_model_name); options.model_version_ = model_version; options.client_timeout_ = client_timeout; @@ -494,144 +409,26 @@ void NuGraphInferenceTriton::produce(art::Event& e) std::shared_ptr results_ptr; results_ptr.reset(results); - // Get pointers to the result returned... - - const float* x_semantic_u_data; - size_t x_semantic_u_byte_size; - FAIL_IF_ERR(results_ptr->RawData( - "x_semantic_u", (const uint8_t**)&x_semantic_u_data, &x_semantic_u_byte_size), - "unable to get result data for 'x_semantic_u'"); - - const float* x_semantic_v_data; - size_t x_semantic_v_byte_size; - FAIL_IF_ERR(results_ptr->RawData( - "x_semantic_v", (const uint8_t**)&x_semantic_v_data, &x_semantic_v_byte_size), - "unable to get result data for 'x_semantic_v'"); - - const float* x_semantic_y_data; - size_t x_semantic_y_byte_size; - FAIL_IF_ERR(results_ptr->RawData( - "x_semantic_y", (const uint8_t**)&x_semantic_y_data, &x_semantic_y_byte_size), - "unable to get result data for 'x_semantic_y'"); - - const float* x_filter_u_data; - size_t x_filter_u_byte_size; - FAIL_IF_ERR( - results_ptr->RawData("x_filter_u", (const uint8_t**)&x_filter_u_data, &x_filter_u_byte_size), - "unable to get result data for 'x_filter_u'"); - - const float* x_filter_v_data; - size_t x_filter_v_byte_size; - FAIL_IF_ERR( - results_ptr->RawData("x_filter_v", (const uint8_t**)&x_filter_v_data, &x_filter_v_byte_size), - "unable to get result data for 'x_filter_v'"); - - const float* x_filter_y_data; - size_t x_filter_y_byte_size; - FAIL_IF_ERR( - results_ptr->RawData("x_filter_y", (const uint8_t**)&x_filter_y_data, &x_filter_y_byte_size), - "unable to get result data for 'x_filter_y'"); - - std::cout << "Trition output: " << std::endl; - - std::cout << "x_semantic_u: " << std::endl; - printFloatArray(x_semantic_u_data, x_semantic_u_byte_size / sizeof(float)); - std::cout << std::endl; - - std::cout << "x_semantic_v: " << std::endl; - printFloatArray(x_semantic_v_data, x_semantic_v_byte_size / sizeof(float)); - std::cout << std::endl; - - std::cout << "x_semantic_y: " << std::endl; - printFloatArray(x_semantic_y_data, x_semantic_y_byte_size / sizeof(float)); - std::cout << std::endl; - - std::cout << "x_filter_u: " << std::endl; - printFloatArray(x_filter_u_data, x_filter_u_byte_size / sizeof(float)); - std::cout << std::endl; - - std::cout << "x_filter_v: " << std::endl; - printFloatArray(x_filter_v_data, x_filter_v_byte_size / sizeof(float)); - std::cout << std::endl; - - std::cout << "x_filter_y: " << std::endl; - printFloatArray(x_filter_y_data, x_filter_y_byte_size / sizeof(float)); - std::cout << std::endl; - - if (semanticDecoder) { - size_t n_cols = 5; - for (size_t p = 0; p < planes.size(); p++) { - torch::Tensor s; - torch::TensorOptions options = torch::TensorOptions().dtype(torch::kFloat32); - if (planes[p] == "u") { - size_t n_rows = x_semantic_u_byte_size / (n_cols * sizeof(float)); - s = torch::from_blob((void*)x_semantic_u_data, - {static_cast(n_rows), static_cast(n_cols)}, - options); - } - else if (planes[p] == "v") { - size_t n_rows = x_semantic_v_byte_size / (n_cols * sizeof(float)); - s = torch::from_blob((void*)x_semantic_v_data, - {static_cast(n_rows), static_cast(n_cols)}, - options); - } - else if (planes[p] == "y") { - size_t n_rows = x_semantic_y_byte_size / (n_cols * sizeof(float)); - s = torch::from_blob((void*)x_semantic_y_data, - {static_cast(n_rows), static_cast(n_cols)}, - options); - } - else { - std::cout << "Error!!" << std::endl; - } - - for (int i = 0; i < s.sizes()[0]; ++i) { - size_t idx = idsmap[p][i]; - std::array input({s[i][0].item(), - s[i][1].item(), - s[i][2].item(), - s[i][3].item(), - s[i][4].item()}); - softmax(input); - FeatureVector<5> semt = FeatureVector<5>(input); - (*semtcol)[idx] = semt; - } - } + // + // Get pointers to the result returned and write to the event + // + vector infer_output; + vector outnames = { + "x_semantic_u", "x_semantic_v", "x_semantic_y", "x_filter_u", "x_filter_v", "x_filter_y"}; + for (const auto& name : outnames) { + const float* _data; + size_t _byte_size; + FAIL_IF_ERR(results_ptr->RawData(name, (const uint8_t**)&_data, &_byte_size), + "unable to get result data for " + name); + size_t n_elements = _byte_size / sizeof(float); + std::vector out_data(_data, _data + n_elements); + infer_output.push_back(NuGraphOutput(name, out_data)); } - if (filterDecoder) { - for (size_t p = 0; p < planes.size(); p++) { - torch::Tensor f; - torch::TensorOptions options = torch::TensorOptions().dtype(torch::kFloat32); - if (planes[p] == "u") { - int64_t num_elements = x_filter_u_byte_size / sizeof(float); - f = torch::from_blob((void*)x_filter_u_data, {num_elements}, options); - } - else if (planes[p] == "v") { - int64_t num_elements = x_filter_v_byte_size / sizeof(float); - f = torch::from_blob((void*)x_filter_v_data, {num_elements}, options); - } - else if (planes[p] == "y") { - int64_t num_elements = x_filter_y_byte_size / sizeof(float); - f = torch::from_blob((void*)x_filter_y_data, {num_elements}, options); - } - else { - std::cout << "error!" << std::endl; - } - - for (int i = 0; i < f.numel(); ++i) { - size_t idx = idsmap[p][i]; - std::array input({f[i].item()}); - (*filtcol)[idx] = FeatureVector<1>(input); - } - } + + // Write the outputs + for (size_t i = 0; i < _decoderToolsVec.size(); i++) { + _decoderToolsVec[i]->writeToEvent(e, idsmap, infer_output); } } - if (filterDecoder) { e.put(std::move(filtcol), "filter"); } - if (semanticDecoder) { - e.put(std::move(semtcol), "semantic"); - e.put(std::move(semtdes), "semantic"); - } - if (vertexDecoder) { e.put(std::move(vertcol), "vertex"); } } - DEFINE_ART_MODULE(NuGraphInferenceTriton) diff --git a/larrecodnn/NuGraph/NuGraphInference_module.cc b/larrecodnn/NuGraph/NuGraphInference_module.cc index 1f9214a2..f9ac3c1b 100644 --- a/larrecodnn/NuGraph/NuGraphInference_module.cc +++ b/larrecodnn/NuGraph/NuGraphInference_module.cc @@ -30,6 +30,9 @@ #include "lardataobj/RecoBase/SpacePoint.h" #include "lardataobj/RecoBase/Vertex.h" //this creates a conflict with torch script if included before it... +#include "larrecodnn/NuGraph/Tools/DecoderToolBase.h" +#include "larrecodnn/NuGraph/Tools/LoaderToolBase.h" + class NuGraphInference; using anab::FeatureVector; @@ -80,28 +83,24 @@ class NuGraphInference : public art::EDProducer { private: vector planes; - art::InputTag hitInput; - art::InputTag spsInput; size_t minHits; bool debug; vector> avgs; vector> devs; - bool filterDecoder; - bool semanticDecoder; - bool vertexDecoder; + vector pos_norm; torch::jit::script::Module model; + // loader tool + std::unique_ptr _loaderTool; + // decoder tools + std::vector> _decoderToolsVec; }; NuGraphInference::NuGraphInference(fhicl::ParameterSet const& p) : EDProducer{p} , planes(p.get>("planes")) - , hitInput(p.get("hitInput")) - , spsInput(p.get("spsInput")) , minHits(p.get("minHits")) , debug(p.get("debug")) - , filterDecoder(p.get("filterDecoder")) - , semanticDecoder(p.get("semanticDecoder")) - , vertexDecoder(p.get("vertexDecoder")) + , pos_norm(p.get>("pos_norm")) { for (size_t ip = 0; ip < planes.size(); ++ip) { @@ -109,14 +108,19 @@ NuGraphInference::NuGraphInference(fhicl::ParameterSet const& p) devs.push_back(p.get>("devs_" + planes[ip])); } - if (filterDecoder) { produces>>("filter"); } - // - if (semanticDecoder) { - produces>>("semantic"); - produces>("semantic"); + // Loader Tool + _loaderTool = art::make_tool(p.get("LoaderTool")); + _loaderTool->setDebugAndPlanes(debug, planes); + + // configure and construct Decoder Tools + auto const tool_psets = p.get("DecoderTools"); + for (auto const& tool_pset_labels : tool_psets.get_pset_names()) { + std::cout << "decoder lablel: " << tool_pset_labels << std::endl; + auto const tool_pset = tool_psets.get(tool_pset_labels); + _decoderToolsVec.push_back(art::make_tool(tool_pset)); + _decoderToolsVec.back()->setDebugAndPlanes(debug, planes); + _decoderToolsVec.back()->declareProducts(producesCollector()); } - // - if (vertexDecoder) { produces>("vertex"); } cet::search_path sp("FW_SEARCH_PATH"); model = torch::jit::load(sp.find_file(p.get("modelFileName"))); @@ -124,51 +128,63 @@ NuGraphInference::NuGraphInference(fhicl::ParameterSet const& p) void NuGraphInference::produce(art::Event& e) { - art::Handle> hitListHandle; - vector> hitlist; - if (e.getByLabel(hitInput, hitListHandle)) { art::fill_ptr_vector(hitlist, hitListHandle); } - - std::unique_ptr>> filtcol( - new vector>(hitlist.size(), FeatureVector<1>(std::array({-1.})))); - std::unique_ptr>> semtcol(new vector>( - hitlist.size(), FeatureVector<5>(std::array({-1., -1., -1., -1., -1.})))); - std::unique_ptr> semtdes( - new MVADescription<5>(hitListHandle.provenance()->moduleLabel(), - "semantic", - {"MIP", "HIP", "shower", "michel", "diffuse"})); - - std::unique_ptr> vertcol(new vector()); + // + // Load the data and fill the graph inputs + // + vector> hitlist; + vector> idsmap; + vector graphinputs; + _loaderTool->loadData(e, hitlist, graphinputs, idsmap); if (debug) std::cout << "Hits size=" << hitlist.size() << std::endl; if (hitlist.size() < minHits) { - if (filterDecoder) { e.put(std::move(filtcol), "filter"); } - if (semanticDecoder) { - e.put(std::move(semtcol), "semantic"); - e.put(std::move(semtdes), "semantic"); + // Writing the empty outputs to the output root file + for (size_t i = 0; i < _decoderToolsVec.size(); i++) { + _decoderToolsVec[i]->writeEmptyToEvent(e, idsmap); } - if (vertexDecoder) { e.put(std::move(vertcol), "vertex"); } return; } - vector> nodeft_bare(planes.size(), vector()); - vector> nodeft(planes.size(), vector()); - vector> coords(planes.size(), vector()); - vector> idsmap(planes.size(), vector()); + // + // libTorch-specific section: requires extracting inputs, create graph, run inference + // + const vector* spids = nullptr; + const vector* hitids_u = nullptr; + const vector* hitids_v = nullptr; + const vector* hitids_y = nullptr; + const vector* hit_plane = nullptr; + const vector* hit_time = nullptr; + const vector* hit_wire = nullptr; + const vector* hit_integral = nullptr; + const vector* hit_rms = nullptr; + for (const auto& gi : graphinputs) { + if (gi.input_name == "spacepoint_table_spacepoint_id") + spids = &gi.input_int32_vec; + else if (gi.input_name == "spacepoint_table_hit_id_u") + hitids_u = &gi.input_int32_vec; + else if (gi.input_name == "spacepoint_table_hit_id_v") + hitids_v = &gi.input_int32_vec; + else if (gi.input_name == "spacepoint_table_hit_id_y") + hitids_y = &gi.input_int32_vec; + else if (gi.input_name == "hit_table_local_plane") + hit_plane = &gi.input_int32_vec; + else if (gi.input_name == "hit_table_local_time") + hit_time = &gi.input_float_vec; + else if (gi.input_name == "hit_table_local_wire") + hit_wire = &gi.input_int32_vec; + else if (gi.input_name == "hit_table_integral") + hit_integral = &gi.input_float_vec; + else if (gi.input_name == "hit_table_rms") + hit_rms = &gi.input_float_vec; + } + + // Reverse lookup from key to index in plane index vector idsmapRev(hitlist.size(), hitlist.size()); - for (auto h : hitlist) { - idsmap[h->View()].push_back(h.key()); - idsmapRev[h.key()] = idsmap[h->View()].size() - 1; - coords[h->View()].push_back(h->PeakTime() * 0.055); - coords[h->View()].push_back(h->WireID().Wire * 0.3); - nodeft[h->View()].push_back((h->WireID().Wire * 0.3 - avgs[h->View()][0]) / devs[h->View()][0]); - nodeft[h->View()].push_back((h->PeakTime() * 0.055 - avgs[h->View()][1]) / devs[h->View()][1]); - nodeft[h->View()].push_back((h->Integral() - avgs[h->View()][2]) / devs[h->View()][2]); - nodeft[h->View()].push_back((h->RMS() - avgs[h->View()][3]) / devs[h->View()][3]); - nodeft_bare[h->View()].push_back(h->WireID().Wire * 0.3); - nodeft_bare[h->View()].push_back(h->PeakTime() * 0.055); - nodeft_bare[h->View()].push_back(h->Integral()); - nodeft_bare[h->View()].push_back(h->RMS()); + for (const auto& ipv : idsmap) { + for (size_t ih = 0; ih < ipv.size(); ih++) { + idsmapRev[ipv[ih]] = ih; + } } struct Edge { @@ -183,12 +199,19 @@ void NuGraphInference::produce(art::Event& e) }; }; + // Delauney graph construction auto start_preprocess1 = std::chrono::high_resolution_clock::now(); vector> edge2d(planes.size(), vector()); for (size_t p = 0; p < planes.size(); p++) { - if (debug) std::cout << "Plane " << p << " has N hits=" << coords[p].size() / 2 << std::endl; - if (coords[p].size() / 2 < 3) { continue; } - delaunator::Delaunator d(coords[p]); + vector coords; + for (size_t i = 0; i < hit_plane->size(); ++i) { + if (size_t(hit_plane->at(i)) != p) continue; + coords.push_back(hit_time->at(i) * pos_norm[1]); + coords.push_back(hit_wire->at(i) * pos_norm[0]); + } + if (debug) std::cout << "Plane " << p << " has N hits=" << coords.size() / 2 << std::endl; + if (coords.size() / 2 < 3) { continue; } + delaunator::Delaunator d(coords); if (debug) std::cout << "Found N triangles=" << d.triangles.size() / 3 << std::endl; for (std::size_t i = 0; i < d.triangles.size(); i += 3) { //create edges in both directions @@ -237,58 +260,60 @@ void NuGraphInference::produce(art::Event& e) auto end_preprocess1 = std::chrono::high_resolution_clock::now(); std::chrono::duration elapsed_preprocess1 = end_preprocess1 - start_preprocess1; - // Get spacepoints from the event record - art::Handle> spListHandle; - vector> splist; - if (e.getByLabel(spsInput, spListHandle)) { art::fill_ptr_vector(splist, spListHandle); } - // Get assocations from spacepoints to hits - vector>> sp2Hit(splist.size()); - if (splist.size() > 0) { - art::FindManyP fmp(spListHandle, e, "sps"); - for (size_t spIdx = 0; spIdx < sp2Hit.size(); ++spIdx) { - sp2Hit[spIdx] = fmp.at(spIdx); - } - } - - //Edges are the same as in pyg, but order is not identical. - //It should not matter but better verify that output is indeed the same. + // Nexus edges auto start_preprocess2 = std::chrono::high_resolution_clock::now(); vector> edge3d(planes.size(), vector()); - for (size_t i = 0; i < splist.size(); ++i) { - for (size_t j = 0; j < sp2Hit[i].size(); ++j) { + for (size_t i = 0; i < spids->size(); ++i) { + if (hitids_u->at(i) >= 0) { + Edge e; + e.n1 = idsmapRev[hitids_u->at(i)]; + e.n2 = spids->at(i); + edge3d[0].push_back(e); + } + if (hitids_v->at(i) >= 0) { + Edge e; + e.n1 = idsmapRev[hitids_v->at(i)]; + e.n2 = spids->at(i); + edge3d[1].push_back(e); + } + if (hitids_y->at(i) >= 0) { Edge e; - e.n1 = idsmapRev[sp2Hit[i][j].key()]; - e.n2 = i; - edge3d[sp2Hit[i][j]->View()].push_back(e); + e.n1 = idsmapRev[hitids_y->at(i)]; + e.n2 = spids->at(i); + edge3d[2].push_back(e); } } + // Prepare inputs auto x = torch::Dict(); auto batch = torch::Dict(); for (size_t p = 0; p < planes.size(); p++) { - long int dim = nodeft[p].size() / 4; + vector nodeft; + for (size_t i = 0; i < hit_plane->size(); ++i) { + if (size_t(hit_plane->at(i)) != p) continue; + nodeft.push_back((hit_wire->at(i) * pos_norm[0] - avgs[hit_plane->at(i)][0]) / + devs[hit_plane->at(i)][0]); + nodeft.push_back((hit_time->at(i) * pos_norm[1] - avgs[hit_plane->at(i)][1]) / + devs[hit_plane->at(i)][1]); + nodeft.push_back((hit_integral->at(i) - avgs[hit_plane->at(i)][2]) / + devs[hit_plane->at(i)][2]); + nodeft.push_back((hit_rms->at(i) - avgs[hit_plane->at(i)][3]) / devs[hit_plane->at(i)][3]); + } + long int dim = nodeft.size() / 4; torch::Tensor ix = torch::zeros({dim, 4}, torch::dtype(torch::kFloat32)); if (debug) { std::cout << "plane=" << p << std::endl; - std::cout << std::fixed; - std::cout << std::setprecision(4); - std::cout << "before, plane=" << planes[p] << std::endl; - for (size_t n = 0; n < nodeft_bare[p].size(); n = n + 4) { - std::cout << nodeft_bare[p][n] << " " << nodeft_bare[p][n + 1] << " " - << nodeft_bare[p][n + 2] << " " << nodeft_bare[p][n + 3] << " " << std::endl; - } std::cout << std::scientific; - std::cout << "after, plane=" << planes[p] << std::endl; - for (size_t n = 0; n < nodeft[p].size(); n = n + 4) { - std::cout << nodeft[p][n] << " " << nodeft[p][n + 1] << " " << nodeft[p][n + 2] << " " - << nodeft[p][n + 3] << " " << std::endl; + for (size_t n = 0; n < nodeft.size(); n = n + 4) { + std::cout << nodeft[n] << " " << nodeft[n + 1] << " " << nodeft[n + 2] << " " + << nodeft[n + 3] << " " << std::endl; } } - for (size_t n = 0; n < nodeft[p].size(); n = n + 4) { - ix[n / 4][0] = nodeft[p][n]; - ix[n / 4][1] = nodeft[p][n + 1]; - ix[n / 4][2] = nodeft[p][n + 2]; - ix[n / 4][3] = nodeft[p][n + 3]; + for (size_t n = 0; n < nodeft.size(); n = n + 4) { + ix[n / 4][0] = nodeft[n]; + ix[n / 4][1] = nodeft[n + 1]; + ix[n / 4][2] = nodeft[n + 2]; + ix[n / 4][3] = nodeft[n + 3]; } x.insert(planes[p], ix); torch::Tensor ib = torch::zeros({dim}, torch::dtype(torch::kInt64)); @@ -341,7 +366,7 @@ void NuGraphInference::produce(art::Event& e) } } - long int spdim = splist.size(); + long int spdim = spids->size(); auto nexus = torch::empty({spdim, 0}, torch::dtype(torch::kFloat32)); std::vector inputs; @@ -350,6 +375,8 @@ void NuGraphInference::produce(art::Event& e) inputs.push_back(edge_index_nexus); inputs.push_back(nexus); inputs.push_back(batch); + + // Run inference auto end_preprocess2 = std::chrono::high_resolution_clock::now(); std::chrono::duration elapsed_preprocess2 = end_preprocess2 - start_preprocess2; if (debug) std::cout << "FORWARD!" << std::endl; @@ -363,67 +390,31 @@ void NuGraphInference::produce(art::Event& e) << " seconds" << std::endl; std::cout << "output =" << outputs << std::endl; } - if (semanticDecoder) { - for (size_t p = 0; p < planes.size(); p++) { - torch::Tensor s = outputs.at("x_semantic").toGenericDict().at(planes[p]).toTensor(); - for (int i = 0; i < s.sizes()[0]; ++i) { - size_t idx = idsmap[p][i]; - std::array input({s[i][0].item(), - s[i][1].item(), - s[i][2].item(), - s[i][3].item(), - s[i][4].item()}); - softmax(input); - FeatureVector<5> semt = FeatureVector<5>(input); - (*semtcol)[idx] = semt; - } - if (debug) { - for (int j = 0; j < 5; j++) { - std::cout << "x_semantic category=" << j << " : "; - for (size_t p = 0; p < planes.size(); p++) { - torch::Tensor s = outputs.at("x_semantic").toGenericDict().at(planes[p]).toTensor(); - for (int i = 0; i < s.sizes()[0]; ++i) - std::cout << s[i][j].item() << ", "; - } - std::cout << std::endl; - } - } - } - } - if (filterDecoder) { - for (size_t p = 0; p < planes.size(); p++) { - torch::Tensor f = outputs.at("x_filter").toGenericDict().at(planes[p]).toTensor(); - for (int i = 0; i < f.numel(); ++i) { - size_t idx = idsmap[p][i]; - std::array input({f[i].item()}); - (*filtcol)[idx] = FeatureVector<1>(input); - } + + // + // Get pointers to the result returned and write to the event + // + vector infer_output; + for (const auto& elem1 : outputs) { + if (elem1.value().isTensor()) { + torch::Tensor tensor = elem1.value().toTensor(); + std::vector vec(tensor.data_ptr(), tensor.data_ptr() + tensor.numel()); + infer_output.push_back(NuGraphOutput(elem1.key().to(), vec)); } - if (debug) { - std::cout << "x_filter : "; - for (size_t p = 0; p < planes.size(); p++) { - torch::Tensor f = outputs.at("x_filter").toGenericDict().at(planes[p]).toTensor(); - for (int i = 0; i < f.numel(); ++i) - std::cout << f[i].item() << ", "; + else if (elem1.value().isGenericDict()) { + for (const auto& elem2 : elem1.value().toGenericDict()) { + torch::Tensor tensor = elem2.value().toTensor(); + std::vector vec(tensor.data_ptr(), tensor.data_ptr() + tensor.numel()); + infer_output.push_back( + NuGraphOutput(elem1.key().to() + "_" + elem2.key().to(), vec)); } - std::cout << std::endl; } } - if (vertexDecoder) { - torch::Tensor v = outputs.at("x_vertex").toGenericDict().at(0).toTensor(); - double vpos[3]; - vpos[0] = v[0].item(); - vpos[1] = v[1].item(); - vpos[2] = v[2].item(); - vertcol->push_back(recob::Vertex(vpos)); - } - if (filterDecoder) { e.put(std::move(filtcol), "filter"); } - if (semanticDecoder) { - e.put(std::move(semtcol), "semantic"); - e.put(std::move(semtdes), "semantic"); + // Write the outputs to the output root file + for (size_t i = 0; i < _decoderToolsVec.size(); i++) { + _decoderToolsVec[i]->writeToEvent(e, idsmap, infer_output); } - if (vertexDecoder) { e.put(std::move(vertcol), "vertex"); } } DEFINE_ART_MODULE(NuGraphInference) diff --git a/larrecodnn/NuGraph/Tools/CMakeLists.txt b/larrecodnn/NuGraph/Tools/CMakeLists.txt new file mode 100644 index 00000000..d484b733 --- /dev/null +++ b/larrecodnn/NuGraph/Tools/CMakeLists.txt @@ -0,0 +1,16 @@ +cet_enable_asserts() + +set( nugraph_tool_lib_list + lardataobj::RecoBase + lardataobj::AnalysisBase + larrecodnn_ImagePatternAlgs_NuSonic_Triton + TorchScatter::TorchScatter + art::Framework_Core + hep_concurrency::hep_concurrency +) +art_make(TOOL_LIBRARIES ${nugraph_tool_lib_list} ) + +install_headers() +install_fhicl() +install_source() + diff --git a/larrecodnn/NuGraph/Tools/DecoderToolBase.h b/larrecodnn/NuGraph/Tools/DecoderToolBase.h new file mode 100644 index 00000000..f4c3d003 --- /dev/null +++ b/larrecodnn/NuGraph/Tools/DecoderToolBase.h @@ -0,0 +1,111 @@ +#ifndef DECODERTOOLBASE_H +#define DECODERTOOLBASE_H + +// art TOOL +#include "art/Utilities/ToolMacros.h" +#include "art/Utilities/make_tool.h" + +#include "art/Framework/Core/EDProducer.h" +#include "art/Framework/Principal/Event.h" +#include "fhiclcpp/ParameterSet.h" + +#include +#include +#include + +using std::string; +using std::vector; + +struct NuGraphOutput { + NuGraphOutput(string s, vector vf) : output_name(s), output_vec(std::move(vf)) {} + string output_name; + vector output_vec; +}; + +class DecoderToolBase { + +public: + /** + * @brief Virtual Destructor + */ + virtual ~DecoderToolBase() noexcept = default; + + /** + * @brief Construcutor + * + * @param ParameterSet The input set of parameters for configuration + */ + DecoderToolBase(fhicl::ParameterSet const& p) + : instancename{p.get("instanceName", "filter")} + , outputname{p.get("outputName", "x_filter_")} + {} + + /** + * @brief declareProducts function + * + * @param art::ProducesCollector + */ + virtual void declareProducts(art::ProducesCollector&) = 0; + + /** + * @brief writeEmptyToEvent function + * + * @param art::Event event record + */ + virtual void writeEmptyToEvent(art::Event& e, const vector>& idsmap) = 0; + + /** + * @brief writeToEvent function + * + * @param art::Event event record + * @param idsmap + * @param infer_output + */ + virtual void writeToEvent(art::Event& e, + const vector>& idsmap, + const vector& infer_output) = 0; + + // Function to print elements of a vector + void printVector(const std::vector& vec) + { + for (size_t i = 0; i < vec.size(); ++i) { + std::cout << vec[i]; + // Print space unless it's the last element + if (i != vec.size() - 1) { std::cout << " "; } + } + std::cout << std::endl; + std::cout << std::endl; + } + + template + void softmax(std::array& arr) + { + T m = -std::numeric_limits::max(); + for (size_t i = 0; i < arr.size(); i++) { + if (arr[i] > m) { m = arr[i]; } + } + T sum = 0.0; + for (size_t i = 0; i < arr.size(); i++) { + sum += expf(arr[i] - m); + } + T offset = m + logf(sum); + for (size_t i = 0; i < arr.size(); i++) { + arr[i] = expf(arr[i] - offset); + } + return; + } + + void setDebugAndPlanes(bool d, vector& p) + { + debug = d; + planes = p; + } + +protected: + bool debug; + vector planes; + std::string instancename; + std::string outputname; +}; + +#endif diff --git a/larrecodnn/NuGraph/Tools/FilterDecoder_tool.cc b/larrecodnn/NuGraph/Tools/FilterDecoder_tool.cc new file mode 100644 index 00000000..bb97004c --- /dev/null +++ b/larrecodnn/NuGraph/Tools/FilterDecoder_tool.cc @@ -0,0 +1,101 @@ +#include "DecoderToolBase.h" + +#include "lardataobj/AnalysisBase/MVAOutput.h" +#include + +using anab::FeatureVector; +using anab::MVADescription; + +class FilterDecoder : public DecoderToolBase { + +public: + /** + * @brief Constructor + * + * @param pset + */ + FilterDecoder(const fhicl::ParameterSet& pset); + + /** + * @brief Virtual Destructor + */ + virtual ~FilterDecoder() noexcept = default; + + /** + * @brief declareProducts function + * + * @param art::ProducesCollector + */ + void declareProducts(art::ProducesCollector& collector) override + { + collector.produces>>(instancename); + } + + /** + * @brief writeEmptyToEvent function + * + * @param art::Event event record + */ + void writeEmptyToEvent(art::Event& e, const vector>& idsmap) override; + + /** + * @brief Decoder function + * + * @param art::Event event record for decoder + */ + void writeToEvent(art::Event& e, + const vector>& idsmap, + const vector& infer_output) override; +}; + +FilterDecoder::FilterDecoder(const fhicl::ParameterSet& p) : DecoderToolBase{p} {} + +void FilterDecoder::writeEmptyToEvent(art::Event& e, const vector>& idsmap) +{ + // + size_t size = 0; + for (auto& v : idsmap) + size += v.size(); + auto filtcol = + std::make_unique>>(size, FeatureVector<1>(std::array({-1.}))); + e.put(std::move(filtcol), instancename); + // +} + +void FilterDecoder::writeToEvent(art::Event& e, + const vector>& idsmap, + const vector& infer_output) +{ + // + size_t size = 0; + for (auto& v : idsmap) + size += v.size(); + auto filtcol = + std::make_unique>>(size, FeatureVector<1>(std::array({-1.}))); + // + for (size_t p = 0; p < planes.size(); p++) { + // + const std::vector* x_filter_data = 0; + for (auto& io : infer_output) { + if (io.output_name == outputname + planes[p]) x_filter_data = &io.output_vec; + } + if (debug) { + std::cout << outputname + planes[p] << std::endl; + printVector(*x_filter_data); + } + // + torch::TensorOptions options = torch::TensorOptions().dtype(torch::kFloat32); + int64_t num_elements = x_filter_data->size(); + const torch::Tensor f = + torch::from_blob(const_cast(x_filter_data->data()), {num_elements}, options); + // + for (int i = 0; i < f.numel(); ++i) { + size_t idx = idsmap[p][i]; + std::array input({f[i].item()}); + (*filtcol)[idx] = FeatureVector<1>(input); + } + } + e.put(std::move(filtcol), instancename); +} + +DEFINE_ART_CLASS_TOOL(FilterDecoder) diff --git a/larrecodnn/NuGraph/Tools/LoaderToolBase.h b/larrecodnn/NuGraph/Tools/LoaderToolBase.h new file mode 100644 index 00000000..5c2731cb --- /dev/null +++ b/larrecodnn/NuGraph/Tools/LoaderToolBase.h @@ -0,0 +1,63 @@ +#ifndef LOADERTOOLBASE_H +#define LOADERTOOLBASE_H + +// art TOOL +#include "art/Utilities/ToolMacros.h" +#include "art/Utilities/make_tool.h" + +#include "art/Framework/Core/EDProducer.h" +#include "art/Framework/Principal/Event.h" +#include "fhiclcpp/ParameterSet.h" + +#include "lardataobj/RecoBase/Hit.h" + +#include +#include +#include + +using std::string; +using std::vector; + +struct NuGraphInput { + NuGraphInput(string s, vector vi) + : input_name(s), isInt(true), input_int32_vec(std::move(vi)) + {} + NuGraphInput(string s, vector vf) + : input_name(s), isInt(false), input_float_vec(std::move(vf)) + {} + string input_name; + bool isInt; + vector input_int32_vec; + vector input_float_vec; +}; + +class LoaderToolBase { + +public: + /** + * @brief Virtual Destructor + */ + virtual ~LoaderToolBase() noexcept = default; + + /** + * @brief loadData virtual function + * + * @param art::Event event record, list of input, idsmap + */ + virtual void loadData(art::Event& e, + vector>& hitlist, + vector& inputs, + vector>& idsmap) = 0; + + void setDebugAndPlanes(bool d, vector& p) + { + debug = d; + planes = p; + } + +protected: + bool debug; + vector planes; +}; + +#endif diff --git a/larrecodnn/NuGraph/Tools/SemanticDecoder_tool.cc b/larrecodnn/NuGraph/Tools/SemanticDecoder_tool.cc new file mode 100644 index 00000000..b4fb6e81 --- /dev/null +++ b/larrecodnn/NuGraph/Tools/SemanticDecoder_tool.cc @@ -0,0 +1,126 @@ +#include "DecoderToolBase.h" + +#include "lardataobj/AnalysisBase/MVAOutput.h" +#include + +using anab::FeatureVector; +using anab::MVADescription; + +// fixme: this only works for 5 categories and should be extended to different sizes. This may require making the class templated. +class SemanticDecoder : public DecoderToolBase { + +public: + /** + * @brief Constructor + * + * @param pset + */ + SemanticDecoder(const fhicl::ParameterSet& pset); + + /** + * @brief Virtual Destructor + */ + virtual ~SemanticDecoder() noexcept = default; + + /** + * @brief declareProducts function + * + * @param art::ProducesCollector + */ + void declareProducts(art::ProducesCollector& collector) override + { + collector.produces>>(instancename); + collector.produces>(instancename); + } + + /** + * @brief writeEmptyToEvent function + * + * @param art::Event event record + */ + void writeEmptyToEvent(art::Event& e, const vector>& idsmap) override; + + /** + * @brief Decoder function + * + * @param art::Event event record for decoder + */ + void writeToEvent(art::Event& e, + const vector>& idsmap, + const vector& infer_output) override; + +private: + std::vector categories; + art::InputTag hitInput; +}; + +SemanticDecoder::SemanticDecoder(const fhicl::ParameterSet& p) + : DecoderToolBase(p) + , categories{p.get>("categories")} + , hitInput{p.get("hitInput")} +{} + +void SemanticDecoder::writeEmptyToEvent(art::Event& e, const vector>& idsmap) +{ + // + auto semtdes = std::make_unique>(hitInput.label(), instancename, categories); + e.put(std::move(semtdes), instancename); + // + size_t size = 0; + for (auto& v : idsmap) + size += v.size(); + std::array arr; + std::fill(arr.begin(), arr.end(), -1.); + auto semtcol = std::make_unique>>(size, FeatureVector<5>(arr)); + e.put(std::move(semtcol), instancename); + // +} + +void SemanticDecoder::writeToEvent(art::Event& e, + const vector>& idsmap, + const vector& infer_output) +{ + // + auto semtdes = std::make_unique>(hitInput.label(), instancename, categories); + e.put(std::move(semtdes), instancename); + // + size_t size = 0; + for (auto& v : idsmap) + size += v.size(); + std::array arr; + std::fill(arr.begin(), arr.end(), -1.); + auto semtcol = std::make_unique>>(size, FeatureVector<5>(arr)); + + size_t n_cols = categories.size(); + for (size_t p = 0; p < planes.size(); p++) { + // + const std::vector* x_semantic_data = 0; + for (auto& io : infer_output) { + if (io.output_name == outputname + planes[p]) x_semantic_data = &io.output_vec; + } + if (debug) { + std::cout << outputname + planes[p] << std::endl; + printVector(*x_semantic_data); + } + + torch::TensorOptions options = torch::TensorOptions().dtype(torch::kFloat32); + size_t n_rows = x_semantic_data->size() / n_cols; + const torch::Tensor s = + torch::from_blob(const_cast(x_semantic_data->data()), + {static_cast(n_rows), static_cast(n_cols)}, + options); + + for (int i = 0; i < s.sizes()[0]; ++i) { + size_t idx = idsmap[p][i]; + std::array input; + for (size_t j = 0; j < n_cols; ++j) + input[j] = s[i][j].item(); + softmax(input); + FeatureVector<5> semt = FeatureVector<5>(input); + (*semtcol)[idx] = semt; + } + } + e.put(std::move(semtcol), instancename); +} + +DEFINE_ART_CLASS_TOOL(SemanticDecoder) diff --git a/larrecodnn/NuGraph/Tools/StandardLoader_tool.cc b/larrecodnn/NuGraph/Tools/StandardLoader_tool.cc new file mode 100644 index 00000000..408f2b6a --- /dev/null +++ b/larrecodnn/NuGraph/Tools/StandardLoader_tool.cc @@ -0,0 +1,116 @@ +#include "LoaderToolBase.h" + +#include "canvas/Persistency/Common/FindManyP.h" +#include "canvas/Persistency/Common/Ptr.h" +#include "canvas/Utilities/InputTag.h" +#include "lardataobj/RecoBase/SpacePoint.h" +#include + +class StandardLoader : public LoaderToolBase { + +public: + /** + * @brief Constructor + * + * @param pset + */ + StandardLoader(const fhicl::ParameterSet& pset); + + /** + * @brief Virtual Destructor + */ + virtual ~StandardLoader() noexcept = default; + + /** + * @brief loadData function + * + * @param art::Event event record, list of input, idsmap + */ + void loadData(art::Event& e, + vector>& hitlist, + vector& inputs, + vector>& idsmap) override; + +private: + art::InputTag hitInput; + art::InputTag spsInput; +}; + +StandardLoader::StandardLoader(const fhicl::ParameterSet& p) + : hitInput{p.get("hitInput")}, spsInput{p.get("spsInput")} +{} + +void StandardLoader::loadData(art::Event& e, + vector>& hitlist, + vector& inputs, + vector>& idsmap) +{ + // + art::Handle> hitListHandle; + if (e.getByLabel(hitInput, hitListHandle)) { art::fill_ptr_vector(hitlist, hitListHandle); } + // + idsmap = std::vector>(planes.size(), std::vector()); + for (auto h : hitlist) { + idsmap[h->View()].push_back(h.key()); + } + + vector hit_table_hit_id_data; + vector hit_table_local_plane_data; + vector hit_table_local_time_data; + vector hit_table_local_wire_data; + vector hit_table_integral_data; + vector hit_table_rms_data; + vector spacepoint_table_spacepoint_id_data; + vector spacepoint_table_hit_id_u_data; + vector spacepoint_table_hit_id_v_data; + vector spacepoint_table_hit_id_y_data; + + // hit table + for (auto h : hitlist) { + hit_table_hit_id_data.push_back(h.key()); + hit_table_local_plane_data.push_back(h->View()); + hit_table_local_time_data.push_back(h->PeakTime()); + hit_table_local_wire_data.push_back(h->WireID().Wire); + hit_table_integral_data.push_back(h->Integral()); + hit_table_rms_data.push_back(h->RMS()); + } + + // Get spacepoints from the event record + art::Handle> spListHandle; + vector> splist; + if (e.getByLabel(spsInput, spListHandle)) { art::fill_ptr_vector(splist, spListHandle); } + // Get assocations from spacepoints to hits + vector>> sp2Hit(splist.size()); + if (splist.size() > 0) { + art::FindManyP fmp(spListHandle, e, spsInput); + for (size_t spIdx = 0; spIdx < sp2Hit.size(); ++spIdx) { + sp2Hit[spIdx] = fmp.at(spIdx); + } + } + + // space point table + for (size_t i = 0; i < splist.size(); ++i) { + spacepoint_table_spacepoint_id_data.push_back(i); + spacepoint_table_hit_id_u_data.push_back(-1); + spacepoint_table_hit_id_v_data.push_back(-1); + spacepoint_table_hit_id_y_data.push_back(-1); + for (size_t j = 0; j < sp2Hit[i].size(); ++j) { + if (sp2Hit[i][j]->View() == 0) spacepoint_table_hit_id_u_data.back() = sp2Hit[i][j].key(); + if (sp2Hit[i][j]->View() == 1) spacepoint_table_hit_id_v_data.back() = sp2Hit[i][j].key(); + if (sp2Hit[i][j]->View() == 2) spacepoint_table_hit_id_y_data.back() = sp2Hit[i][j].key(); + } + } + + inputs.emplace_back("hit_table_hit_id", hit_table_hit_id_data); + inputs.emplace_back("hit_table_local_plane", hit_table_local_plane_data); + inputs.emplace_back("hit_table_local_time", hit_table_local_time_data); + inputs.emplace_back("hit_table_local_wire", hit_table_local_wire_data); + inputs.emplace_back("hit_table_integral", hit_table_integral_data); + inputs.emplace_back("hit_table_rms", hit_table_rms_data); + + inputs.emplace_back("spacepoint_table_spacepoint_id", spacepoint_table_spacepoint_id_data); + inputs.emplace_back("spacepoint_table_hit_id_u", spacepoint_table_hit_id_u_data); + inputs.emplace_back("spacepoint_table_hit_id_v", spacepoint_table_hit_id_v_data); + inputs.emplace_back("spacepoint_table_hit_id_y", spacepoint_table_hit_id_y_data); +} +DEFINE_ART_CLASS_TOOL(StandardLoader) diff --git a/larrecodnn/NuGraph/Tools/VertexDecoder_tool.cc b/larrecodnn/NuGraph/Tools/VertexDecoder_tool.cc new file mode 100644 index 00000000..de42ca54 --- /dev/null +++ b/larrecodnn/NuGraph/Tools/VertexDecoder_tool.cc @@ -0,0 +1,88 @@ +#include "DecoderToolBase.h" + +#include "lardataobj/RecoBase/Vertex.h" +#include + +class VertexDecoder : public DecoderToolBase { + +public: + /** + * @brief Constructor + * + * @param pset + */ + VertexDecoder(const fhicl::ParameterSet& pset); + + /** + * @brief Virtual Destructor + */ + virtual ~VertexDecoder() noexcept = default; + + /** + * @brief declareProducts function + * + * @param art::ProducesCollector + */ + void declareProducts(art::ProducesCollector& collector) override + { + collector.produces>(instancename); + } + + /** + * @brief writeEmptyToEvent function + * + * @param art::Event event record + */ + void writeEmptyToEvent(art::Event& e, const vector>& idsmap) override; + + /** + * @brief Decoder function + * + * @param art::Event event record for decoder + */ + void writeToEvent(art::Event& e, + const vector>& idsmap, + const vector& infer_output) override; + +private: + string outputDictElem; +}; + +VertexDecoder::VertexDecoder(const fhicl::ParameterSet& p) + : DecoderToolBase(p), outputDictElem{p.get("outputDictElem")} +{} + +void VertexDecoder::writeEmptyToEvent(art::Event& e, const vector>& idsmap) +{ + // + auto vertcol = std::make_unique>(); + e.put(std::move(vertcol), instancename); + // +} + +void VertexDecoder::writeToEvent(art::Event& e, + const vector>& idsmap, + const vector& infer_output) +{ + // + auto vertcol = std::make_unique>(); + + const std::vector* x_vertex_data = nullptr; + for (auto& io : infer_output) { + if (io.output_name == outputDictElem) x_vertex_data = &io.output_vec; + } + if (x_vertex_data->size() == 3) { + double vpos[3] = {(*x_vertex_data)[0], (*x_vertex_data)[1], (*x_vertex_data)[2]}; + vertcol->push_back(recob::Vertex(vpos)); + if (debug) + std::cout << "NuGraph vertex pos=" << vpos[0] << ", " << vpos[1] << ", " << vpos[2] + << std::endl; + } + else { + std::cout << "ERROR -- Wrong size returned by NuGraph vertex decoder" << std::endl; + } + e.put(std::move(vertcol), instancename); + // +} + +DEFINE_ART_CLASS_TOOL(VertexDecoder) diff --git a/larrecodnn/NuGraph/Tools/nugraph_decoders.fcl b/larrecodnn/NuGraph/Tools/nugraph_decoders.fcl new file mode 100644 index 00000000..46e88fce --- /dev/null +++ b/larrecodnn/NuGraph/Tools/nugraph_decoders.fcl @@ -0,0 +1,23 @@ +BEGIN_PROLOG + +FilterDecoderTool: { + instanceName: "filter" + outputName: "x_filter_" + tool_type: "FilterDecoder" +} + +SemanticDecoderTool: { + instanceName: "semantic" + outputName: "x_semantic_" + categories: ["MIP", "HIP", "shower", "michel", "diffuse"] + hitInput: @nil + tool_type: "SemanticDecoder" +} + +VertexDecoderTool: { + instanceName: "vertex" + outputDictElem: "x_vertex" + tool_type: "VertexDecoder" +} + +END_PROLOG diff --git a/larrecodnn/NuGraph/Tools/nugraph_loaders.fcl b/larrecodnn/NuGraph/Tools/nugraph_loaders.fcl new file mode 100644 index 00000000..d27f3d5d --- /dev/null +++ b/larrecodnn/NuGraph/Tools/nugraph_loaders.fcl @@ -0,0 +1,9 @@ +BEGIN_PROLOG + +StandardLoader: { + hitInput: "nuslhits" + spsInput: "sps" + tool_type: "StandardLoader" +} + +END_PROLOG diff --git a/larrecodnn/NuGraph/nugraph.fcl b/larrecodnn/NuGraph/nugraph.fcl index 7a59f97e..a5bc8ac8 100644 --- a/larrecodnn/NuGraph/nugraph.fcl +++ b/larrecodnn/NuGraph/nugraph.fcl @@ -1,74 +1,72 @@ +#include "nugraph_loaders.fcl" +#include "nugraph_decoders.fcl" + BEGIN_PROLOG NuGraphCommon: { - planes: ["u","v","y"] - hitInput: "nuslhits" - spsInput: "sps" minHits: 10 debug: false - filterDecoder: true - semanticDecoder: true - vertexDecoder: false + planes: ["u","v","y"] + LoaderTool: @local::StandardLoader + DecoderTools: { + FilterDecoderTool: @local::FilterDecoderTool + SemanticDecoderTool: @local::SemanticDecoderTool + } } +NuGraphCommon.DecoderTools.SemanticDecoderTool.hitInput: @local::NuGraphCommon.LoaderTool.hitInput -NuGraph: { - @table::NuGraphCommon +TritonConfig: { + serverURL: @nil #"test-1-eaf.fnal.gov:443" + verbose: false + ssl: true + sslRootCertificates: "" + sslPrivateKey: "" + sslCertificateChain: "" + modelName: "nugraph2" + modelVersion: "" + timeout: 0 + allowedTries: 1 + outputs: [] +} + +NuGraphLibTorch: { + @table::NuGraphCommon avgs_u: [389.00632, 173.42017, 144.42065, 4.5582113] avgs_v: [3.6914261e+02, 1.7347592e+02, 8.5748262e+08, 4.4525051e+00] avgs_y: [547.38995, 173.13017, 109.57691, 4.1024675] devs_u: [148.02893, 78.83508, 223.89404, 2.2621224] devs_v: [1.4524565e+02, 8.1395981e+01, 1.0625440e+13, 1.9223815e+00] devs_y: [284.20657, 74.47823, 108.93791, 1.4318414] + pos_norm: [0.3, 0.055] modelFileName: "model.pt" module_type: "NuGraphInference" } +NuGraph: @local::NuGraphLibTorch NuGraphTriton: { - @table::NuGraphCommon - url: @nil # "test-1-eaf.fnal.gov:443" - ssl: true - modelName: "nugraph2" + @table::NuGraphCommon + TritonConfig: @local::TritonConfig module_type: "NuGraphInferenceTriton" } -CPUNuGraphTriton: { - @table::NuGraphCommon - url: @nil # "test-1-eaf.fnal.gov:443" - ssl: true - modelName: "nugraph2_cpu" - module_type: "NuGraphInferenceTriton" -} +CPUNuGraphTriton: @local::NuGraphTriton +CPUNuGraphTriton.TritonConfig.modelName: "nugraph2_cpu" -ApptainerNuGraphTriton: { - @table::NuGraphCommon - url: @nil # "localhost:8001" - ssl: false - modelName: "nugraph2" - module_type: "NuGraphInferenceTriton" -} +ApptainerNuGraphTriton: @local::NuGraphTriton +ApptainerNuGraphTriton.TritonConfig.serverURL: "localhost:8001" +ApptainerNuGraphTriton.TritonConfig.ssl: false NuGraphNuSonicTriton: { - @table::NuGraphCommon - url: @nil # "test-1-eaf.fnal.gov:443" - ssl: true - modelName: "nugraph2" + @table::NuGraphCommon + TritonConfig: @local::TritonConfig module_type: "NuGraphInferenceSonicTriton" } -CPUNuGraphNuSonicTriton: { - @table::NuGraphCommon - url: @nil # "test-1-eaf.fnal.gov:443" - ssl: true - modelName: "nugraph2_cpu" - module_type: "NuGraphInferenceSonicTriton" -} +CPUNuGraphNuSonicTriton: @local::NuGraphNuSonicTriton +CPUNuGraphNuSonicTriton.TritonConfig.modelName: "nugraph2_cpu" -ApptainerNuGraphNuSonicTriton: { - @table::NuGraphCommon - url: @nil # "localhost:8001" - ssl: false - modelName: "nugraph2" - module_type: "NuGraphInferenceSonicTriton" -} +ApptainerNuGraphNuSonicTriton: @local::NuGraphNuSonicTriton +ApptainerNuGraphNuSonicTriton.TritonConfig.serverURL: "localhost:8001" +ApptainerNuGraphNuSonicTriton.TritonConfig.ssl: false END_PROLOG