ONNX Runtime with DirectML: An Introductory Tutorial
This tutorial will guide you through using ONNX Runtime with DirectML to accelerate your machine learning models on Windows. DirectML provides a high-performance, hardware-accelerated DirectX API for machine learning inference, enabling you to leverage the power of your GPU across a wide range of Windows devices.
What is ONNX Runtime?
ONNX Runtime is an open-source inference engine for machine learning models. It supports models from various frameworks like PyTorch, TensorFlow, and scikit-learn, as long as they are exported to the ONNX (Open Neural Network Exchange) format. By integrating with DirectML, ONNX Runtime can utilize hardware acceleration on supported GPUs.
Prerequisites
- Windows 10 (version 1903 or later) or Windows 11
- DirectX 12 capable GPU
- Visual Studio 2019 or later with the ".NET desktop development" and "C++ build tools" workloads
- Git
- Python (recommended for model conversion)
Step 1: Set Up Your Development Environment
Ensure you have all the prerequisites installed. You can verify your DirectX 12 support by running dxdiag
in the command prompt.
For this tutorial, we'll be using a pre-trained ONNX model. If you need to convert your own model, you'll typically use Python scripts with libraries like onnxruntime-tools
and framework-specific exporters.
Step 2: Obtain a Sample ONNX Model
We'll use a simple image classification model (e.g., a pre-trained MobileNet) for this example. You can often find pre-trained ONNX models on model zoos or convert them yourself.
For demonstration, let's assume you have a model named mobilenet.onnx
.
Step 3: Create a C++ Project
Create a new C++ Console Application project in Visual Studio.
- Open Visual Studio.
- Select "Create a new project".
- Search for "Console App" and select the C++ template.
- Name your project (e.g.,
DirectML_ONNX_Demo
) and choose a location. - Click "Create".
Step 4: Install ONNX Runtime (DirectML Execution Provider)
You need to add the ONNX Runtime NuGet package with the DirectML execution provider to your project.
- In Visual Studio, right-click on your project in the Solution Explorer and select "Manage NuGet Packages...".
- Go to the "Browse" tab.
- Search for
Microsoft.ML.OnnxRuntime.DirectML
. - Select the package and click "Install".
This will add the necessary libraries and headers to your project.
Step 5: Write the Inference Code
Replace the contents of your main C++ file (e.g., DirectML_ONNX_Demo.cpp
) with the following code. This code loads the ONNX model, sets up the DirectML execution provider, runs inference, and processes the output.
#include <iostream>
#include <vector>
#include <onnxruntime_cxx_api.h>
#include <windows.h> // For GetModuleFileName, LoadLibrary
// Helper function to get the path to the current executable
std::wstring GetCurrentExecutableDirectory() {
wchar_t path[MAX_PATH];
GetModuleFileNameW(NULL, path, MAX_PATH);
std::wstring wstrPath = path;
size_t pos = wstrPath.find_last_of(L"\\/");
return wstrPath.substr(0, pos);
}
int main() {
OrtEnv* env = nullptr;
OrtSessionOptions* session_options = nullptr;
OrtSession* session = nullptr;
// 1. Initialize Ort environment
// Use OrtLoggingLevel::ORT_LOGGING_LEVEL_WARNING for less verbose logging
auto status = OrtCreateEnv(ORT_LOGGING_LEVEL_WARNING, "DirectML_ONNX_Tutorial", &env);
if (status) {
std::cerr << "Failed to create Ort environment: " << OrtGetErrorMessageString(status) << std::endl;
return -1;
}
// 2. Set up session options
status = OrtCreateSessionOptions(&session_options);
if (status) {
std::cerr << "Failed to create session options: " << OrtGetErrorMessageString(status) << std::endl;
OrtReleaseEnv(env);
return -1;
}
// Enable DirectML Execution Provider
// You can configure provider options here if needed, e.g., specific GPU device
status = OrtSessionOptionsAppendExecutionProvider_DML(session_options, 0); // 0 means use default GPU
if (status) {
std::cerr << "Failed to append DirectML EP: " << OrtGetErrorMessageString(status) << std::endl;
// Fallback or error handling - DirectML EP might not be available/supported
// For a robust app, you might try CPU EP as a fallback.
OrtReleaseSessionOptions(session_options);
OrtReleaseEnv(env);
return -1;
}
// Set graph optimization level
OrtSessionOptionsSetGraphOptimizationLevel(session_options, ORT_ENABLE_ALL);
// 3. Load the ONNX model
std::wstring model_path_w = GetCurrentExecutableDirectory() + L"\\mobilenet.onnx"; // Assumes model is in the same dir as executable
std::string model_path(model_path_w.begin(), model_path_w.end());
status = OrtCreateSession(env, model_path.c_str(), session_options, &session);
if (status) {
std::cerr << "Failed to create ONNX session: " << OrtGetErrorMessageString(status) << std::endl;
std::cerr << "Make sure '" << model_path << "' exists." << std::endl;
OrtReleaseSessionOptions(session_options);
OrtReleaseEnv(env);
return -1;
}
std::cout << "ONNX Runtime session created successfully with DirectML." << std::endl;
// 4. Prepare input data (Placeholder - adapt to your model's input)
// For image classification, this typically involves loading an image,
// resizing, normalizing, and converting to a tensor format.
// This example assumes a float32 tensor input named "input_tensor".
// You would need to populate `input_values` with your preprocessed data.
// Get input shape and type
char** input_names;
OrtTypeInfo** input_type_info;
size_t num_input_nodes;
status = OrtGetInputCount(session, &num_input_nodes);
if (status) {
std::cerr << "Failed to get input count: " << OrtGetErrorMessageString(status) << std::endl;
OrtReleaseSession(session);
OrtReleaseSessionOptions(session_options);
OrtReleaseEnv(env);
return -1;
}
OrtGetInputName(session, 0, &input_names); // Assuming one input
// --- Simplified input preparation ---
// This is a placeholder. Real-world scenarios require careful handling of image preprocessing.
// Example: Input shape might be [1, 3, 224, 224] for a typical image model (Batch, Channels, Height, Width)
const int64_t input_dims[] = {1, 3, 224, 224}; // Example dimensions
const size_t input_len = sizeof(input_dims) / sizeof(input_dims[0]);
std::vector<float> input_values(1 * 3 * 224 * 224, 0.5f); // Fill with some dummy data
OrtMemoryInfo* memory_info;
OrtCreateMemoryInfo("CPU", OrtArenaAllocator, &memory_info);
OrtTensorTypeAndShapeInfo* input_type_info_ptr;
status = OrtCreateTensorWithValues(memory_info, input_values.data(), input_values.size() * sizeof(float),
input_dims, input_len, ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT, &input_type_info_ptr);
if (status) {
std::cerr << "Failed to create input tensor: " << OrtGetErrorMessageString(status) << std::endl;
OrtReleaseMemoryInfo(memory_info);
OrtReleaseSession(session);
OrtReleaseSessionOptions(session_options);
OrtReleaseEnv(env);
return -1;
}
// --- End simplified input preparation ---
// 5. Run inference
const char* input_names_cstr[] = {input_names[0]}; // Use the actual input name from the model
OrtValue* input_tensor = input_type_info_ptr; // Use the prepared tensor value
char** output_names;
size_t num_output_nodes;
status = OrtGetOutputCount(session, &num_output_nodes);
if (status) {
std::cerr << "Failed to get output count: " << OrtGetErrorMessageString(status) << std::endl;
// Cleanup allocated input names and tensor
OrtFree(input_names);
OrtReleaseTensor(input_type_info_ptr);
OrtReleaseMemoryInfo(memory_info);
OrtReleaseSession(session);
OrtReleaseSessionOptions(session_options);
OrtReleaseEnv(env);
return -1;
}
OrtGetOutputName(session, 0, &output_names); // Assuming one output
const char* output_names_cstr[] = {output_names[0]}; // Use the actual output name
OrtValue* output_tensor = nullptr;
status = OrtRun(session, nullptr, input_names_cstr, (const OrtValue* const*)&input_tensor, 1, output_names_cstr, 1, &output_tensor);
if (status) {
std::cerr << "Failed to run inference: " << OrtGetErrorMessageString(status) << std::endl;
// Cleanup allocated input names and tensor
OrtFree(input_names);
OrtFree(output_names);
OrtReleaseTensor(input_type_info_ptr);
OrtReleaseMemoryInfo(memory_info);
OrtReleaseSession(session);
OrtReleaseSessionOptions(session_options);
OrtReleaseEnv(env);
return -1;
}
std::cout << "Inference completed successfully." << std::endl;
// 6. Process the output (Placeholder - adapt to your model's output)
// This typically involves converting the output tensor back into meaningful results
// (e.g., class probabilities, bounding boxes).
// Example: Get output data as float
float* output_values;
int output_tensor_rank;
const int64_t* output_dims;
status = OrtGetTensorMutableData(output_tensor, (void**)&output_values);
if (status) {
std::cerr << "Failed to get output tensor data: " << OrtGetErrorMessageString(status) << std::endl;
} else {
status = OrtGetTensorShapeAndTypeInfo(output_tensor, &output_dims, &output_tensor_rank);
if (status) {
std::cerr << "Failed to get output tensor shape: " << OrtGetErrorMessageString(status) << std::endl;
} else {
std::cout << "Output Tensor Rank: " << output_tensor_rank << std::endl;
std::cout << "Output Tensor Shape: [";
size_t total_output_elements = 1;
for (int i = 0; i < output_tensor_rank; ++i) {
std::cout << output_dims[i] << (i < output_tensor_rank - 1 ? ", " : "");
total_output_elements *= output_dims[i];
}
std::cout << "]" << std::endl;
std::cout << "Output Values (first 10):" << std::endl;
for (int i = 0; i < std::min((size_t)10, total_output_elements); ++i) {
std::cout << output_values[i] << " ";
}
std::cout << std::endl;
// Process `output_values` based on your model's output layer
}
}
// 7. Clean up
OrtReleaseValue(output_tensor);
OrtFree(input_names);
OrtFree(output_names);
OrtReleaseTensor(input_type_info_ptr);
OrtReleaseMemoryInfo(memory_info);
OrtReleaseSession(session);
OrtReleaseSessionOptions(session_options);
OrtReleaseEnv(env);
std::cout << "Resources cleaned up." << std::endl;
return 0;
}
Step 6: Build and Run
Place your mobilenet.onnx
file in the same directory as your executable (e.g., the x64/Debug
folder within your project directory). Build and run your C++ project.
You should see output indicating that the ONNX Runtime session was created and inference completed. The output section will display the processed results from your model.