Garry's Blog

About machine learning, debugging, Python, C++ and other interesting stuff.
Carefully documenting everything I screwed up for future generations.

Tracing your PyTorch model

Part 2 of 3 - Bringing your Deep Learning Model to Production with libtorch

This is part 2 of a 3-part series on libtorch. Part 1 covers the rationale for PyTorch and using libtorch in production. This part covers the basics of getting your model up-and-running in libtorch. Part 3 discusses some more advanced topics.

Porting your PyTorch Model to Torch Script with the JIT

Before we start converting our model to something we can use with libtorch, we need to talk about the JIT and TorchScript a little bit. Torch Script is an intermediate format used to store your models so that they are portable between PyTorch and libtorch. A JIT (Just-In-Time compiler) is included to allow for exporting and importing Torch Script files.

There are two ways to convert your PyTorch model to a Torch Script one:

  1. Tracing. The JIT traces your model over one inference iteration to extract the model.
  2. Scripting. The JIT parsers your Python code and converts it to Torch Script directly.

Tracing requires no changes to your Python code, but it doesn’t deal well with complicated models - for example: if you have a model of which the behavior might change depending on some internal logic in-between inference runs, this is not picked up by a single trace.

Scripting always produces a correct Torch Script model, if it works. It often requires some changes to the Python code because the compiler is quite picky. I’ll quickly go through the most basic example (you can find more on those in the libtorch docs) and move onto the more complicated stuff after that.

Converting a sine estimation model to Torch Script (taken from the PyTorch repo examples):

import torch
from torch import nn

"""
 class Sequence;

 Model capable of doing sine wave prediction. We don't care about the model itself for now,
 just assume it works!
 
 Source:
    https://github.com/pytorch/examples/blob/master/time_sequence_prediction/train.py

"""
class Sequence(nn.Module):

    def __init__(self):
        super(Sequence, self).__init__()
        self.lstm1 = nn.LSTMCell(1, 51)
        self.lstm2 = nn.LSTMCell(51, 51)
        self.linear = nn.Linear(51, 1)

    def forward(self, input, future=0):
        outputs = []

        # (I changed `torch.double` to `torch.float` here because otherwise
        #  the LSTM layer will complain!)
        h_t = torch.zeros(input.size(0), 51, dtype=torch.float)
        c_t = torch.zeros(input.size(0), 51, dtype=torch.float)
        h_t2 = torch.zeros(input.size(0), 51, dtype=torch.float)
        c_t2 = torch.zeros(input.size(0), 51, dtype=torch.float)

        for i, input_t in enumerate(input.chunk(input.size(1), dim=1)):
            h_t, c_t = self.lstm1(input_t, (h_t, c_t))
            h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
            output = self.linear(h_t2)
            outputs += [output]

        for i in range(future):
            h_t, c_t = self.lstm1(output, (h_t, c_t))
            h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
            output = self.linear(h_t2)
            outputs += [output]

        outputs = torch.stack(outputs, 1).squeeze(2)

        return outputs

model = Sequence()

with torch.no_grad():
    # Generate a bunch of fake data to feed the model when we `torch.jit.script` it
    # since it is needed by the JIT (not sure why?).
    fake_input = torch.zeros((10, 100))

    # Trace the model using `torch.jit.script`
    traced = torch.jit.script(model, fake_input)

    # Print the Torch Script code
    print(traced.code)

    # We can also store the model like usual:
    traced.save('traced.ptc')

That last part is important, we use torch.jit.script to script the model. Even this relatively simple model cannot be converted in its current form. Now, what happens if the JIT cannot compile a model is one of those things that’s not very clear from the docs. Fortunately, the error messages are quite helpful. Anyway, I wanted to run through it anyway because that’s how we learn.

Here’s the first error:

RuntimeError: 
all inputs of range must be ints, found Tensor in argument 0:
  File "sine.py", line 26
            outputs += [output]

        for i in range(future):
                 ~~~~~~~~~~~~ <--- HERE
            h_t, c_t = self.lstm1(output, (h_t, c_t))
            h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))

So, the range function obviously needs an integer. When executing the model without the JIT, it receives an integer for the future parameter through the forward function and everything is fine. Now, with the JIT, suddenly it’s a Tensor? This is actually not some kind of weird bug: In a traced Torch Script model, all inputs must be Tensors, even inputs that are supposed to be integers (we can just have a Tensor with shape (1,)). In this case, the JIT blindly assumes that (even though the future parameter has a default value 0) it is a Tensor and suddenly the range function does not work anymore.

We’ll just make the conversion explicit and that gets rid of the error:

- for i in range(future):
+ for i in range(int(future)):

Next error:

RuntimeError: 
undefined value output:  
  File "sine.py", line 37

        for i in range(int(future)):
            h_t, c_t = self.lstm1(output, (h_t, c_t))
                                  ~~~~~~ <--- HERE
            h_t2, c_t2 = self.lstm2(h_t, (h_t2, c_t2))
            output = self.linear(h_t2)

This piece of code makes use of the output variable declared inside the previous loop, which is valid in Python (not very nice though) but not in Torch Script, we can predefine the variable as a tensor earlier:

h_t = torch.zeros(input.size(0), 51, dtype=torch.double)
c_t = torch.zeros(input.size(0), 51, dtype=torch.double)
h_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)
c_t2 = torch.zeros(input.size(0), 51, dtype=torch.double)
+ output = torch.zeros((1, 1))

Finally:

RuntimeError: 
Expected a default value of type Tensor on parameter "future":
  File "sine.py", line 22
    def forward(self, input, future=0):
    ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~...  <--- HERE
        outputs = []

Yes, we mentioned this earlier. It should be a Tensor. Instead of making the default value a tensor we simply remove the default value because we’ll always want to predict the future in inference mode.


-def forward(self, input, future = 0):
+def forward(self, input, future):

A quick note about that: Most people that are looking to convert their model to Torch Script will be doing this for the sake of inference optimization. You can do training in C++ as well, but in most commercial settings inference optimization is more important than training optimization (the latter can be solved by throwing more money at the problem anyway). Because of this, you can often remove large parts of your model that are only relevant for the training phase. However, this causes your model code to branch out which adds maintenance overhead. You can use something like the Strategy pattern to cope with this architecturally.

Even with a simple model (just three LSTM modules), we have already had to change quite a bit. You will notice that the work required to make your model JIT-proof grows linearly with the number of lines your model comprises. For large models it can be quite a lot of work, but it is doable.

The code will print a traced version of the model in Torch Script that looks a lot like Python code. You can see some of the statements have been made more explicit, some have been rewritten or split out and all the variable names have been mangled.

The TorchScript model was saved as traced.ptc. There’s some discussion about which extension to use for Torch Script models. .pt is recommended for “normal” PyTorch models, but it’s unclear what to use for compiled models. I like .ptc so I used it consistently here.

Get up-and-running with libtorch

Just a heads up: I’ll be using C++17 for the remainder of this post. It is not a requirement though, just replace C++17 features with their older (and more verbose) counterpart version.

First, we need to install libtorch. It comes prepackaged as a zip file, you can download it on the PyTorch website. You need to select LibTorch and C++/Java to get the correct one.

Here’s a quick link to libtorch version 1.5 for macOS (no CUDA support).

Let’s install it to /usr/local so it gets picked up by CMake. Here’s a script you can use for that:

cd /tmp
echo "Downloading libtorch 1.5..."
wget https://download.pytorch.org/libtorch/cpu/libtorch-macos-1.5.0.zip && \
unzip libtorch-macos-1.5.0.zip && \
cd libtorch/
echo "Installing libtorch..."
sudo mv include/* /usr/local/include/ && \
sudo mv lib/* /usr/local/lib/ && \
sudo mv share/* /usr/local/share/

This will work on Linux and macOS (it will not work on Windows). You can replace the download link with your respective version.

We create a minimum working CMakeLists.txt like this:

cmake_minimum_required(VERSION 3.10)

project(tutorial-libtorch VERSION 1.0.0)

# This sets the C++ version to C++17
set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CXX_STANDARD_REQUIRED ON)

# Import Torch C++ (this looks for the cmake-files is
# a number of paths including `/usr/local/share/cmake`
# where we installed it)
find_package(Torch REQUIRED)

# Make target
add_executable(tutorial-libtorch main.cpp)

# Link Torch C++ libs
target_link_libraries(tutorial-libtorch "${TORCH_LIBRARIES}")

Since /usr/local is in CMake’s search path, this should pick up libtorch automatically!

Let’s use a basic example I took from the PyTorch website to test whether our build setup works.

main.cpp

#include <iostream>

#include <torch/torch.h>

int main() {
  torch::Tensor tensor = torch::eye(3);
  std::cout << tensor << std::endl;
}

Using CMake, (out of source) compiling is quite easy.

mkdir -p build && \
cd build && \
cmake .. && \
make && \
echo "\nRunning the program...\n" && \
./tutorial-libtorch && \
cd ..

There you go! If everything went well, you should see:


 1  0  0
 0  1  0
 0  0  1
[ CPUFloatType{3,3} ]

If you are not familiar with CMake and the likes, I recommend reading up on the basics. Either way, your project structure (at this point) should look something like this:

`- tutorial
    `- CMakeLists.txt
`- main.cpp
`- build
`- tutorial-libtorch (where the built binary will live)
`- ... (more CMake files and folders)

Loading and running our model

The libtorch API is designed such that it almost feels like you’re coding in Python (you are not). We can load and run the model like so:

#include <iostream>

#include <torch/script.h>

int main(int argc, char** argv) {
  if (argc < 2) {
    std::cerr << "Please provide model name." << std::endl;
    return 1;
  }

  torch::jit::script::Module model;

  try {
    model = torch::jit::load(std::string(argv[1]));

    // Explicitly load model onto CPU, you can use kGPU if you are on Linux
    // and have libtorch version with CUDA support (and a GPU)
    model.to(torch::kCPU);

    // Set to `eval` model (just like Python)
    model.eval();

    // Within this scope/thread, don't use gradients (again, like in Python)
    torch::NoGradGuard no_grad_;

    // Input to the model is a vector of "IValues" (tensors)
    std::vector<torch::jit::IValue> input = {
      // Corresponds to `input`
      torch::zeros({ 1, 20 }, torch::dtype(torch::kFloat)),
      // Corresponds to the `future` parameter
      torch::full({ 1 }, 10, torch::dtype(torch::kInt))
    };

    // `model.forward` does what you think it does; it returns an IValue
    // which we convert back to a Tensor
    auto output = model
      .forward(input)
      .toTensor();

    // Extract size of output (of the first and only batch) and preallocate
    // a vector with that size
    auto output_size = output.sizes()[1];
    auto output_vector = std::vector<float>(output_size);

    // Fill result vector with tensor items using `Tensor::item`
    for (int i = 0; i < output_size; i++) {
      output_vector[i] = output[0][i].item<float>();
    }

    // Print the vector here
    for (float x : output_vector)
      std::cout << x << ", ";
    std::cout << std::endl;
  }
  catch (const c10::Error& e) {
    std::cerr << "An error ocurred: " << e.what() << std::endl;
    return 1;
  }

  return 0;
}

If you run this code, you should see some meaningless output (we haven’t trained the model so nothing to see there). Most of the code is self-explanatory, but there are a few bits I want to highlight.

First, you will notice that we copy the model to the CPU explicitly here. I found that doing that consistently (to CPU or GPU, whatever you are using) can prevent some errors later on. This is already the case in Python PyTorch, but even more so in libtorch because issues can arise if you trace the model on CPU but try to load it onto a GPU or vice-versa. Just remember that traced models are tied to the device they were originally traced on (this is not always the case, but for more complicated models is usually is).

model.to(torch::kCPU);

We use a guard where we would use a with-block in Python. The guard ends when the object goes out of scope.

// Within this scope/thread, don't use gradients (again, like in Python)
torch::NoGradGuard no_grad_;

There are a number of ways to initialize tensors. We use the TensorOptions in the constructor of the torch::zeros and torch::full factory methods.

// Input to the model is a vector of "IValues" (tensors)
std::vector<torch::jit::IValue> input = {
  // Corresponds to `input`
  torch::zeros({ 1, 20 }, torch::dtype(torch::kFloat)),
  // Corresponds to the `future` parameter
  torch::full({ 1 }, 10, torch::dtype(torch::kInt))
};

Likewise, you could do it like this:

torch::zeros({ 1, 20 }).to(torch::kFloat);
torch::full({ 1 }, 10).to(torch::kInt);

// If you want to make the device explicit as well:
torch::zeros({ 1, 20 }).to(torch::kFloat).to(torch::kCPU);
torch::full({ 1 }, 10).to(torch::kInt).to(torch::kCPU);

Lastly, we use the Tensor::item call to extract the data from the tensor and put it in a vector. We could have just printed the vector itself, but in a more realistic use-case you would want to do something with the output.

// Extract size of output (of the first and only batch) and preallocate
// a vector with that size
auto output_size = output.sizes()[1];
auto output_vector = std::vector<float>(output_size);

// Fill result vector with tensor items using `Tensor::item`
for (int i = 0; i < output_size; i++) {
  output_vector[i] = output[0][i].item<float>();
}

If everything went well, you should see a comma separated list of outputs on your screen!

Fortunately, the Tensor Creation API is very well documented on pytorch.org. The documentation is relatively sparse beyond this though. You can find an API overview here and a bunch of great examples in the PyTorch repo here. Getting up and running is quite easy, but we want to do some more complicated stuff as well. Part 3 covers some of the more advanced functionality in libtorch that is hardly documented anywhere!