Mathieu De Coster

Compile time CUDA device checking in Rust

24 July 2020

Stable Rust is soon getting const generics. Let's look at how const generics can be used to avoid a certain bug at compile time. The bug? Trying to perform operations on data on separate CUDA devices in PyTorch.

The deep learning framework PyTorch is becoming more and more popular. It's a great framework, and my current personal favourite. When writing PyTorch code, however, it is easy to make small mistakes that cause runtime errors. You could for example forget a .to(device) call1. This bug is usually quickly found during testing and little more than an annoyance. However, you can actually find and avoid these bugs at compile time with const generics.

For this experiment, I am going to use the Rust programming language. Specifically, I will be using the tch-rs bindings to libtorch. This allows writing code that is quite similar to PyTorch, but in Rust.

The Problem

In deep learning, we work with high dimensional data represented as tensors. PyTorch allows you to easily perform calculations on these tensors on multiple GPUs in parallel. You can explicitly move data to a different device by calling .to(device) on a tensor. device is a string: "cpu" or "cuda", or "cuda:X" for a specific CUDA device at index X.

For example, the following code adds two tensors with random values on the CPU:

tensor_1 = torch.randn(2, 3).float()
tensor_2 = torch.randn(2, 3).float()
result = tensor_1 + tensor_2

And this code runs on the first CUDA device:

device = 'cuda:0'
tensor_1 = torch.randn(2, 3).float().to(device)
tensor_2 = torch.randn(2, 3).float().to(device)
result = tensor_1 + tensor_2

But what if we do this?

device = 'cuda:0'
tensor_1 = torch.randn(2, 3).float()  # On CPU
tensor_2 = torch.randn(2, 3).float().to(device)  # On CUDA:0
result = tensor_1 + tensor_2  # ???

The result is a runtime error:

RuntimeError: expected backend CPU and dtype Float but got backend CUDA and dtype Float

Or if we try to add a tensor on CUDA:0 and another on CUDA:1, we get:

RuntimeError: binary_op(): expected both inputs to be on same device, but input a is on cuda:0 and input b is on cuda:1

Equivalent code in Rust

Here is the equivalent code in Rust (using tch-rs) for the failing example:

let tensor_1 = tch::Tensor::randn(&[2, 3], (Kind::Float, Device::Cpu)); // On CPU
let tensor_2 = tch::Tensor::randn(&[2, 3], (Kind::Float, Device::Cuda(0)); // On CUDA:0
let result = &tensor_1 + &tensor_2; // ???

Note how we do mark the device in the code, but it is not checked at compile time. We are going to use const generics to do this.

Const Generics

Const generics (RFC) are a feature of the Rust programming language which allows using constant values in type signatures. Why do we need these? Well, let's say we are writing a wrapper DeviceTensor around the Tensor struct from tch-rs, with an added generic parameter D, which indicates the device.

We want to enforce operations between DeviceTensors to be possible only if those tensors live on the same device. We could write DeviceTensor<Cpu> and DeviceTensor<Cuda> using "regular" generics. However, this only works for the "cpu" and "cuda" device strings, but will not work for "cuda:X". An obvious workaround is to introduce additional types, Cuda0, Cuda1, etcetera. However, that does not scale.

The const generics feature allows us to use const (constant) parameters in our type annotations. Since integer constants match that description, we can express our device as Cuda<N>, where N is an integer.

The Solution

To solve the above problem, we will create a struct that is generic over the device on which the internal data in that struct lives.

pub struct DeviceTensor<D> where D: Device {
    internal: tch::Tensor,
    _device_marker: std::marker::PhantomData<D>,
}

Note two things: the where clause, and the _device_marker. The where clause indicates that our D parameter must derive from a Device trait (defined below). The _device_marker is phantom data, which is a zero-size field. This is needed because otherwise the compiler would notify us that our D parameter is unused, which is an error.

The Device trait is quite simple. It has a single function tch_device; this is used above to move the data to the correct device when constructing it (see above in randn). We implement two device variants. The first is a CpuDevice. The second is CudaDevice, which makes use of const generics. The N parameter indicates the device on which the tensor lives, and in tch_device, we use N to return the correct device.

pub trait Device {
    fn tch_device() -> tch::Device;
}

pub struct CpuDevice;
impl Device for CpuDevice {
    fn tch_device() -> tch::Device {
        tch::Device::Cpu
    }
}

pub struct CudaDevice<const N: usize>; // Const generics!
impl<const N: usize> Device for CudaDevice<N> {
    fn tch_device() -> tch::Device {
        tch::Device::Cuda(N) // We use `N` here
    }
}

We can construct a DeviceTensor on the CPU and GPU as follows:

let tensor_1: DeviceTensor<CpuDevice> = DeviceTensor::randn(&[2, 3]);
let tensor_2: DeviceTensor<CudaDevice<0>> = DeviceTensor::randn(&[2, 3]);

In the implementation of randn, we will move the tensor to the device passed as a type parameter using tch-rs.

impl<D: Device> DeviceTensor<D> {
    pub fn randn(size: &[i64]) -> TypedTensor<D> {
        DeviceTensor {
            internal: tch::Tensor::randn(size, (Kind::Float, D::tch_device())),
            _device_marker: Default::default(),
        }
    }
}

We will now define an implementation for the + operator so that we can add two tensors.

impl<'a, 'b, D: Device> std::ops::Add<&'b DeviceTensor<D>> for &'a DeviceTensor<D> {
    type Output = DeviceTensor<D>;

    fn add(self, other: &'b DeviceTensor<D>) -> DeviceTensor<D> {
        DeviceTensor {
            internal: &self.internal + &other.internal,
            _device_marker: Default::default(),
        }
    }
}

Note how this simply adds the two tch-rs Tensors. So where does the compile time check happen? Well, it's quite simple: The addition operator is only defined for two DeviceTensor instances with the same type parameter D. Try to add two instances with a different D, and you'll get a compile time error.

We can now add two tensors of the same type on the same device, but as soon as the type or device is different, we get a compilation error.

fn will_compile() {
    let tensor_1: DeviceTensor<CpuDevice> = DeviceTensor::randn(&[2, 3]);
    let tensor_2: DeviceTensor<CpuDevice> = DeviceTensor::randn(&[2, 3]);

    let result = &tensor_1 + &tensor_2;
}

fn wont_compile() {
    let tensor_1: DeviceTensor<CpuDevice> = DeviceTensor::randn(&[2, 3]);
    let tensor_2: DeviceTensor<CudaDevice<0>> = DeviceTensor::randn(&[2, 3]);

    let result = &tensor_1 + &tensor_2; // Error!
}

The static checking of the device does not mean that the device needs to be static. Moving a tensor to a different device is easy to implement. We simply consume the current tensor and construct a new tensor with a different type, but the same internal data. The data is moved using tch-rs's to_device function to the device derived from the required device D2.

// (In the implementation of DeviceTensor)
pub fn to_device<D2: Device>(self) -> DeviceTensor<D2> {
    DeviceTensor {
        internal: self.internal.to_device(D2::tch_device()),
        _device_marker: Default::default(),
    }
}

This allows us to do things like

fn main() {
    let tensor_1: DeviceTensor<CpuDevice> = DeviceTensor::randn(&[2, 3]);
    let tensor_2: DeviceTensor<CudaDevice<0>> = DeviceTensor::randn(&[2, 3]);
    let tensor_2_cpu = tensor_2.to_device::<CpuDevice>();

    let result = &tensor_1 + &tensor_2_cpu; // Compiles and runs without error!
}

Note how to_device takes a self parameter. This means that the ownership of the tensor tensor_2 is passed to the function. At the end of the function, self goes out of scope and is dropped. This means that tensor_2 can no longer be used after calling to_device on it. Perfect, because its type parameter doesn't match the device anymore. This extra guarantee of our model for DeviceTensors we get for free thanks to Rust's ownership model.

Conclusion

We've seen an example use case for const generics. We've avoided a small bug in PyTorch code using const generics, Rust's type system, and its ownership model.

Are these compile-time guarantees great enough that you should drop Python for deep learning and move straight to Rust? No. The above code is far more verbose than any PyTorch code you will ever write to do the same. Deep learning research is highly iterative, and Python makes it far easier to do than Rust. Python is and will remain the primary language for deep learning for quite some time.

To quote Are We Learning Yet,

It's ripe for experimentation, but the ecosystem isn't very complete yet.

Your main PyTorch code will likely see little to no performance benefit from using Rust instead of Python2, and experimentation in Python is simply far faster and easier. The eco-system is built around Python.

But let's say this piqued your interest in Rust. Where could you implement Rust into your deep learning pipeline? In your pre-processing code, for example. I could also see Rust being used to deploy machine learning models, which is currently often done using C++.

The complete Rust code for this blog post is available here. You can discuss this blog post on Reddit.


(1) .to(device) is easy to forget when you're developing on a device without CUDA, and then deploying to a GPU server. [Back to text]

(2) When you're writing deep learning code, most Python is just glue between libraries implemented in C(++), FORTRAN, CUDA, or some other low level language. [Back to text]