Garry's Blog

About machine learning, debugging, Python, C++ and other interesting stuff.
Carefully documenting everything I screwed up for future generations.

Advanced libtorch

Part 3 of 3 - Bringing your Deep Learning Model to Production with libtorch

This is part 3 of a 3-part series on libtorch. Part 1 covers the rationale for PyTorch and using libtorch in production. Part 2 covers the basics of getting your model up-and-running in libtorch. This part discusses some more advanced topics.

The Abstract Syntax Tree

Before I show you the code, I want to quickly go over how PyTorch converts the model to a C++-usable one. Basically, it parses your Python code, in the same way the interpreter would, and builds an Abstract Syntax Tree representation of your code. This is what you see when you do print(traced.code) in Python. ASTs are commonplace in interpreters and source code parsers. In libtorch, the AST is loaded in and used to correctly execute the model when calling model.forward(). But libtorch also provides a bunch of other functions for interacting with the Torch Script model such as attr, set_attr and run_method.

The same Torch Script based approach is also used for all the other libtorch functionality. Whilst this blog series focuses mostly on using an existing PyTorch model in C++, the libtorch API also allows for creating models, training and inference all from C++. I would say close to 90% of Python functionality is included. There’s a great blog series on how to do all this by Kushashwa Ravi Shrimali here.

Lastly, some of you C++ programmer might be worried about the copy semantics of torch::Tensor (I was). Don’t worry though, the memory of those objects is fully managed as if it were wrapped by a shared_ptr. You cannot accidentally copy the tensor except by calling torch::Tensor::clone. Now, internally a lot of things might happen when doing operations on the tensor (especially the more complicated ones like permute, more on that later) including copying around. So if you are worried about performance - and you probably are, why else use C++? - you should grab a profiler to measure what operations are taking the most time. The whole performance thing is a subject for another post, especially since it is not libtorch-specific. Now it’s time for the code!

Setting class attributes

Let’s start of by modifying some of the class attributes of the PyTorch model from the C++ side. We modify the model so that the initialization function can be changed by setting the init_func attribute on Sequence:

def __init__(self):
    super(Sequence, self).__init__()
    self.lstm1 = nn.LSTMCell(1, 51)
    self.lstm2 = nn.LSTMCell(51, 51)
    self.linear = nn.Linear(51, 1)
+
+    self.init_func = 'zeros'
+if self.init_func == 'ones':
+    h_t = torch.ones(input.size(0), 51, dtype=torch.float)
+    c_t = torch.ones(input.size(0), 51, dtype=torch.float)
+    h_t2 = torch.ones(input.size(0), 51, dtype=torch.float)
+    c_t2 = torch.ones(input.size(0), 51, dtype=torch.float)
+    output = torch.ones((1, 1))
+else:
    h_t = torch.zeros(input.size(0), 51, dtype=torch.float)
    c_t = torch.zeros(input.size(0), 51, dtype=torch.float)
    h_t2 = torch.zeros(input.size(0), 51, dtype=torch.float)
    c_t2 = torch.zeros(input.size(0), 51, dtype=torch.float)
    output = torch.zeros((1, 1))

Now we can set it to either zeros (the default) or ones to change how the LSTM state is initialized.

We also add a line to make sure we produce deterministic results at the top in both Python as well as C++:

import torch
from torch import nn

+torch.manual_seed(0)
+torch::manual_seed(0);
+
torch::jit::script::Module model;

If we run the same C++ code as before a couple of times, we will get the same output each time:

$ ./build-and-run.sh
0.102669, 0.103464, 0.105305, 0.106987, 0.108209, 0.109003, 0.109485, 0.109763, 0.109919, 0.110003, 0.110046, 0.110068, 0.110078, 0.110083, 0.110084, 0.110084, 0.110084, 0.110083, 0.110083, 0.110082, 0.109766, 0.109395, 0.109072, 0.108825, 0.10865, 0.108532, 0.108457, 0.108409, 0.10838, 0.108363

We can list the model attributes in C++ like so:

for (const auto& attr : model.named_attributes())
  std::cout << attr.name << std::endl;

The list includes the three layers, as well as its sub attributes (nice!) and the init_func attribute we added is also there!

Now, let’s see if we can switch to using the ones initialization function from C++. We can use attr to read an attribute value and setattr to change it.

// Get current value of `init_func`
auto init_func = model.attr("init_func");

// Set value of `init_func`
model.setattr("init_func", "ones");

In main.cpp, we add a line here to change the init_func:

model = torch::jit::load(std::string(argv[1]));

// Explicitly load model onto CPU, you can use kGPU if you are on Linux
// and have libtorch version with CUDA support (and a GPU)
model.to(torch::kCPU);

// Set to `eval` model (just like Python)
model.eval();

+// Set initialization function to ones
+model.setattr("init_func", "ones");

If you run the program again, you should see output like this:

$ ./build-and-run.sh
-0.373218, -0.187147, -0.0826549, -0.0266808, 0.00465929, 0.0232547, 0.0347184, 0.0419016, 0.0464152, 0.0492396, 0.0509939, 0.0520746, 0.0527348, 0.0531352, 0.0533768, 0.0535219, 0.0536091, 0.0536615, 0.0536932, 0.0537126, 0.0537104, 0.0537147, 0.0537257, 0.0537386, 0.0537503, 0.0537595, 0.0537662, 0.0537708, 0.0537738, 0.0537758

Perfect, we got some different output, so this is working! There are some related functions to do similar things that I want to quickly go over:

// Get a list of methods in the model class
for (const auto& method : model.get_methods())
  std::cout << method.name() << "\n";
  
// Run a model class member function with parameters other than `forward`:
// (Note that `model.forward` is just a shortcut for `model.run_method("forward", ...)`)
auto result = model.run_method("some_func", ...);

// See parameters (similar ot Python API)
std::cout << model.parameters() << std::endl;

Some of this stuff is hardly documented, but you can find some information in the class reference documentation of torch::Module.

Converting between raw data and Tensor and back

At some point, you will have to convert between raw data (for example: images) and a proper torch::Tensor and back. To do this, you can create an empty Tensor, acquire a pointer to the Tensor contents, and then copy over the data. You have to make sure that the data is in the right format to avoid any memory issues.

For example, let’s look at some code to convert an OpenCV matrix image ( cv::Mat) to Torch. First, we need to make sure that the raw memory inside the cv::Mat is continuous. Sometimes, instead of applying an actual operation on the data, OpenCV will simply set some internal attributes to change the way the data is read instead of copying to preserve performance. (I believe this in only done when extracting submatrices, but I am not sure.) Copying raw data from such a non-continuous cv::Mat would cause the resulting Tensor to be messed up. Use cv::Mat::clone to make sure it is continuous:

if (!mat.isContinuous()) {
  mat = mat.clone();
}

Now we can pre-allocate an empty tensor with the same dimensions as the cv::Mat and copy the raw memory into it!

// (We use torch::empty here since it can be somewhat faster than `zeros`
//  or `ones` since it is allowed to fill the tensor with garbage.)
auto tensor = torch::empty(
  { mat.rows, mat.cols, mat.channels() },
  // Set dtype=byte and place on CPU, you can change these to whatever
  // suits your use-case.
  torch::TensorOptions()
    .dtype(torch::kByte)
    .device(torch::kCPU));
    
// Copy over the data
std::memcpy(tensor.data_ptr(), reinterpret_cast<void*>(mat.data), tensor.numel() * sizeof(at::kByte));

We get a reference to the tensor internal memory through tensor.data_ptr(). The cv::Mat data is in mat.data. To calculate the size of the tensor, we multiply the total number of elements with the size of each element with tensor.numel() * sizeof(at::kByte). Make sure that you use the same type here as you did in the tensor options before!

The above code creates an empty tensor in channels-last format of the original image. Because most PyTorch models accept channels-first format, we need to convert it. The torch::Tensor::permute function is perfect for this:

// You should read this as "permute the tensor in such a way that the dimensions of the new tensor are:
//  - First dimension: third dimension of old tensor
//  - Second dimension: first dimension of old tensor
//  - Third dimension: second dimension of old tensor
auto tensor_cf = tensor.permute({ 2, 0, 1 });

This would convert a tensor of shape (300, 200, 3) to (3, 300, 200).

Now, the conversion back to cv::Mat (or whatever other data structure you might use) is quite straightforward, I’ll leave it as an excersise for the reader. There is just one important thing to note: Just like OpenCV, torch also uses strided memory in some places and the memory layout is not guaranteed to be continuous! For example, the torch::Tensor::permute method above is really fast because it does not actually permute the inner data, it just applies some memory tricks so that subsequent interaction will read the tensor as if the memory was permuted. Therefore, use the torch::Tensor::contiguous method to rearrange the memory correctly:

// Do this before using the raw data somewhere else!
auto tensor_raw_data_ptr = tensor.contiguous().data_ptr();

I can tell you from first-hand experience that this little piece of advice might save you a couple of hours debugging :)

Other tidbits that might come in handy

This section is just an information-dense collection of general tips and tricks; hopefully they will save you some time if you run into similar problems.

  1. To check whether a GPU is available, you can use the following code, which is very similar to the Python version:
auto device = torch::cuda::is_available()
  ? torch::kCUDA
  : torch::kCPU;

// You can use `.to(device)` from here on modules and
// Tensors to make sure they are loaded in the correct
// memory location!
  1. You can convert between Python and C++ tuples by using the torch::Tensor::toTuple() function as such:
const auto& tuple_elements = module.attr("my_attribute").toTuple()->elements();
  1. libtorch throws exceptions of type c10::Error, you can catch them like this:
try {
  module_ = torch::jit::load(filename);
  module_.eval();
}
catch (const c10::Error& e) {
  throw std::invalid_argument("Failed to load model: " + e.msg());
}
  1. Native calls in libtorch are often asynchronous. If you are profiling a traced model you might run into weird timings like this:
// This call takes 200 milliseconds.
auto y = module.forward(x).toTensor();
// This call takes 300 milliseconds (???).
auto y_activation = torch::softmax(y, 0).to(torch::kCPU);

How could a simple softmax invocation be more time-consuming than the entire model inference iteration? Well, it’s not. But because of the async nature of libtorch, inference is still running after module.forward(). Only once to(torch::kCPU) is called does it block, since the output of the previous call must be available for the data to be copied over to the CPU.

You can use cudaDeviceSynchronize() (located in the CUDA header file cuda.h) to block at any point in your code.

// This call takes 200 milliseconds.
auto y = module.forward(x).toTensor();

// This call takes 300 milliseconds.
cudaDeviceSynchronize();

// This call now takes just a few msecs.
auto y_activation = torch::softmax(y, 0).to(torch::kCPU);
  1. To turn off cudnn benchmarking, use the following code (which is a little different from the Python version):
at::globalContext().setBenchmarkCuDNN(false);
  1. Lastly, keep in mind that the JIT compiler will do a lot of caching on the first inference run to improve performance later on. For me, the first run takes over 45 seconds. Subsequent runs take 250 milliseconds. I was a little worried when running my model in libtorch the first time, but this is normal!

Conclusions

The PyTorch team did a great job building a native counterpart of their library that is as powerful as the Python version! There are a few gotchas here and there which I tried to document as much as possible throughout this blog series. My own project with libtorch has resulted in significant speedup as well as more stability and a more predictable memory-footprint. As I said in the beginning, using the C++ library is not necessary for most use cases, but it’s a nice to have in case you are one of the few people who do.

Thanks for reading!