Skip to content

Commit

Permalink
[onnxruntime] register providers
Browse files Browse the repository at this point in the history
  • Loading branch information
jcelerier committed Aug 9, 2024
1 parent 9aaa54b commit a12a99c
Showing 1 changed file with 112 additions and 21 deletions.
133 changes: 112 additions & 21 deletions Onnx/helpers/OnnxContext.hpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
#pragma once
#include <ossia/detail/algorithms.hpp>

#include <QFile>
#include <QImage>

Expand All @@ -12,41 +14,130 @@
#include <iostream>
#include <string>
#include <vector>

namespace Onnx
{
struct OnnxRunContext
struct Options
{
Ort::Env env;
std::string provider;
int device_id;
};

Ort::SessionOptions session_options = []
static Ort::SessionOptions create_session_options(const Options& opts)
{
Ort::SessionOptions session_options;

const OrtApi& api = Ort::GetApi();
const auto& p = Ort::GetAvailableProviders();
if (opts.provider == "CUDA" && ossia::contains(p, "CUDA"))
{
Ort::SessionOptions session_options;
const auto& api = Ort::GetApi();
Ort::ThrowOnError(
OrtSessionOptionsAppendExecutionProvider_CUDA(session_options, 0));
using namespace Ort;

OrtCUDAProviderOptionsV2* cuda_option_v2 = nullptr;
Ort::ThrowOnError(api.CreateCUDAProviderOptions(&cuda_option_v2));
const std::vector keys{
"device_id",
"gpu_mem_limit",
"arena_extend_strategy",
"cudnn_conv_algo_search",
"do_copy_in_default_stream",
"cudnn_conv_use_max_workspace",
"cudnn_conv1d_pad_to_nc1d",
"enable_cuda_graph",
"enable_skip_layer_norm_strict_mode"};
const std::vector values{
"0",
"2147483648",
"kNextPowerOfTwo",
"EXHAUSTIVE",
"1",
"1",
"1",
"0",
"0"};
Ort::ThrowOnError(api.UpdateCUDAProviderOptions(
cuda_option_v2, keys.data(), values.data(), keys.size()));
// FIXME release options
session_options.AppendExecutionProvider_CUDA_V2(*cuda_option_v2);
}

if (opts.provider == "TensorRT" && ossia::contains(p, "TensorRT"))
{
using namespace Ort;
const std::vector keys{
"device_id",
"trt_engine_cache_enable",
"trt_timing_cache_enable",
};
const std::vector values{"0", "1", "1"};

// https://onnxruntime.ai/docs/execution-providers/TensorRT-ExecutionProvider.html#shape-inference-for-tensorrt-subgraphs
/*
OrtTensorRTProviderOptionsV2* tensorrt_options;
Ort::ThrowOnError(api.CreateTensorRTProviderOptions(&tensorrt_options));
std::unique_ptr<
OrtTensorRTProviderOptionsV2,
decltype(api.ReleaseTensorRTProviderOptions)>
rel_trt_options(tensorrt_options, api.ReleaseTensorRTProviderOptions);
Ort::ThrowOnError(api.SessionOptionsAppendExecutionProvider_TensorRT_V2(
static_cast<OrtSessionOptions*>(session_options),
rel_trt_options.get()));
*/
return session_options;
}();
OrtTensorRTProviderOptionsV2* options{};
Ort::ThrowOnError(api.CreateTensorRTProviderOptions(&options));
Ort::ThrowOnError(api.UpdateTensorRTProviderOptions(
options, keys.data(), values.data(), keys.size()));
session_options.AppendExecutionProvider_TensorRT_V2(*options);
// FIXME release options
}

if (opts.provider == "ROCM" && ossia::contains(p, "ROCM"))
{
using namespace Ort;
OrtROCMProviderOptions* options{};
Ort::ThrowOnError(api.CreateROCMProviderOptions(&options));
options->device_id = 0;
session_options.AppendExecutionProvider_ROCM(*options);
// FIXME release options
}

if (opts.provider == "OpenVINO" && ossia::contains(p, "OpenVINO"))
{
using namespace Ort;

std::unordered_map<std::string, std::string> options;
options["device_type"] = "GPU";
options["precision"] = "FP32";
session_options.AppendExecutionProvider("OpenVINO", options);

// https://onnxruntime.ai/docs/execution-providers/OpenVINO-ExecutionProvider.html#onnxruntime-graph-level-optimization
session_options.SetGraphOptimizationLevel(ORT_DISABLE_ALL);
}

#if _WIN32
if (opts.provider == "DirectML" && ossia::contains(p, "DML"))
{
using namespace Ort;

std::unordered_map<std::string, std::string> options;
session_options.AppendExecutionProvider("DML", options);
}
#endif

#if __APPLE__
if (opts.provider == "CoreML" && ossia::contains(p, "CoreML"))
{
using namespace Ort;

uint32_t coreml_flags = 0;
Ort::ThrowOnError(OrtSessionOptionsAppendExecutionProvider_CoreML(
session_options, coreml_flags));
}
#endif
return session_options;
}
struct OnnxRunContext
{
Options opts;
Ort::Env env;

Ort::SessionOptions session_options;
Ort::Session session;

// print name/shape of inputs
Ort::AllocatorWithDefaultOptions allocator;

explicit OnnxRunContext(std::string_view name)
: env(ORT_LOGGING_LEVEL_WARNING, "example")
, session_options(create_session_options(opts))
, session(env, name.data(), session_options)
{
}
Expand Down

0 comments on commit a12a99c

Please sign in to comment.