Extending JAX with custom C++ and CUDA code

Overview

Extending JAX with custom C++ and CUDA code

Tests

This repository is meant as a tutorial demonstrating the infrastructure required to provide custom ops in JAX when you have an existing implementation in C++ and, optionally, CUDA. I originally wanted to write this as a blog post, but there's enough boilerplate code that I ended up deciding that it made more sense to just share it as a repo with the tutorial in the README, so here we are!

The motivation for this is that in my work I want to use libraries like JAX to fit models to data in astrophysics. In these models, there is often at least one part of the model specification that is physically motivated and while there are generally existing implementations of these model elements, it is often inefficient or impractical to re-implement these as a high-level JAX function. Instead, I want to expose a well-tested and optimized implementation in C directly to JAX. In my work, this often includes things like iterative algorithms or special functions that are not well suited to implementation using JAX directly.

So, as part of updating my exoplanet library to interface with JAX, I had to learn what infrastructure was required to support this use case, and since I couldn't find a tutorial that covered all the pieces that I needed in one place, I wanted to put this together. Pretty much everything that I'll talk about is covered in more detail somewhere else (even if that somewhere is just a comment in some source code), but hopefully this summary can point you in the right direction if you have a use case like this.

A warning: I'm writing this in January 2021 and much of what I'm talking about is based on essentially undocumented APIs that are likely to change. Furthermore, I'm not affiliated with the JAX project and I'm far from an expert so I'm sure there are wrong things that I say. I'll try to update this if I notice things changing or if I learn of issues, but no promises! So, MIT license and all that: use at your own risk.

Related reading

As I mentioned previously, this tutorial is built on a lot of existing literature and I won't reproduce all the details of those documents here, so I wanted to start by listing the key resources that I found useful:

  1. The How primitives work tutorial in the JAX documentation includes almost all the details about how to expose a custom op to JAX and spending some quality time with that tutorial is not wasted time. The only thing missing from that document is a description of how to use the XLA CustomCall interface.

  2. Which brings us to the XLA custom calls documentation. This page is pretty telegraphic, but it includes a description of the interface that your custom call functions need to support. In particular, this is where the differences in interface between the CPU and GPU are described, including things like the "opaque" parameter and how multiple outputs are handled.

  3. I originally learned how to write the pybind11 interface for an XLA custom call from the danieljtait/jax_xla_adventures repository by Dan Tait on GitHub. Again, this doesn't include very many details, but that's really a benefit here because it really distills the infrastructure to a place where I could understand what was going on.

  4. Finally, much of what I know about this topic, I learned from spelunking in the jaxlib source code on GitHub. That code is pretty readable and includes good comments most of the time so that's a good place to look if you get stuck since folks there might have already faced the issue.

What is an "op"

In frameworks like JAX (or Theano, or TensorFlow, or PyTorch, to name a few), models are defined as a collection of operations or "ops" that can be chained, fused, or differentiated in clever ways. For our purposes, an op defines a function that knows:

  1. how the input and output parameter shapes and types are related,
  2. how to compute the output from a set of inputs, and
  3. how to propagate derivatives using the chain rule.

There are a lot of choices about where you draw the lines around a single op and there will be tradeoffs in terms of performance, generality, ease of use, and other factors when making these decisions. In my experience, it is often best to define the minimal scope ops and then allow your framework of choice to combine it efficiently with the rest of your model, but there will always be counter examples.

Our example application: solving Kepler's equation

In this section I'll describe the application presented in this project. Feel free to skip this if you just want to get to the technical details.

This project exposes a single jit-able and differentiable JAX operation to solve Kepler's equation, a tool that is used for computing gravitational orbits in astronomy. This is basically the "hello world" example that I use whenever learning about something like this. For example, I have previously written about how to expose such an op when using Stan. The implementation used in that post and the one used here are not meant to be the most robust or efficient, but it is relatively simple and it exposes some of the interesting issues that one might face when writing custom JAX ops. If you're interested in the mathematical details, take a look at my blog post, but the key point for now is that this operation involves solving a transcendental equation, and in this tutorial we'll use a simple iterative method that you'll find in the kepler.h header file. Then, the derivatives of this operation can be evaluated using implicit differentiation. Unlike in the previously mentioned blog post, our operation will actually return the sine and cosine of the eccentric anomaly, since that's what most high performance versions of this function would return and because the way XLA handles ops with multiple outputs is a little funky.

The cost/benefit analysis

One important question to answer first is: "should I actually write a custom JAX extension?" If you're here, you've probably already thought about that, but I wanted to emphasize a few points to consider.

  1. Performance: The main reason why you might want to implement a custom op for JAX is performance. JAX's JIT compiler can get great performance in a broad range of applications, but for some of the problems I work on, finely-tuned C++ can be much faster. In my experience, iterative algorithms, other special functions, or code with complicated logic are all examples of places where a custom op might greatly improve performance. I'm not always good at doing this, but it's probably worth benchmarking performance of a version of your code implemented directly in high-level JAX against your custom op.

  2. Autodiff: One thing that is important to realize is that the extension that we write won't magically know how to propagate derivatives. Instead, we'll be required to provide a JAX interface for applying the chain rule to out op. In other words, if you're setting out to wrap that huge Fortran library that has been passed down through the generations, the payoff might not be as great as you hoped unless (a) the code already provides operations for propagating derivatives (in which case you JAX op probably won't support second and higher order differentiation), or (b) you can easily compute the differentiation rules using the algorithm that you already have (which is the case we have for our example application here). In my work, I try (sometimes unsuccessfully) to identify the minimum number and size of ops that I can get away with and then implement most of my models directly in JAX. In our demo application, for example, I could have chosen to make an XLA op generating a full radial velocity model, instead of just solving Kepler's equation, and that might (or might not) give better performance. But, the differentiation rules are much simpler the way it is implemented.

Summary of the relevant files

The files in this repo come in three categories:

  1. In the root directory, there are the standard packaging files like a setup.py and pyproject.toml. Most of this setup is pretty standard, but I'll highlight some of the unique elements in the packaging section below. For example, we'll use a slightly strange combination of PEP-517/518 and CMake to build the extensions. This isn't strictly necessary, but it's the easiest packaging setup that I've been able to put together.

  2. Next, the src/kepler_jax directory is a Python module with the definition of our JAX primitive roughly following the JAX How primitives work tutorial.

  3. Finally, the C++ and CUDA code implementing our XLA op live in the lib directory. The pybind11_kernel_helpers.h and kernel_helpers.h headers are boilerplate necessary for building in the interface. The rest of the files include the code specific for this implementation, but I'll describe this in more detail below.

Defining an XLA custom call on the CPU

The algorithm for our example problem is is implemented in the lib/kepler.h header and I won't go into details about the algorithm here, but the main point is that this could be an implementation built on any external library that you can call from C++ and, if you want to support GPU usage, CUDA. That header file includes a single function compute_eccentric_anomaly with the following signature:

template <typename T>
void compute_eccentric_anomaly(
   const T& mean_anom, const T& ecc, T* sin_ecc_anom, T* cos_ecc_anom
);

This is the function that we want to expose to JAX.

As described in the XLA documentation, the signature for a CPU XLA custom call in C++ is:

void custom_call(void* out, const void** in);

where, as you might expect, the elements of in point to the input values. So, in our case, the inputs are an integer giving the dimension of the problem size, an array with the mean anomalies mean_anomaly, and an array of eccentricities ecc. Therefore, we might parse the input as follows:

#include <cstdint>  // int64_t

template <typename T>
void cpu_kepler(void *out, const void **in) {
  const std::int64_t size = *reinterpret_cast<const std::int64_t *>(in[0]);
  const T *mean_anom = reinterpret_cast<const T *>(in[1]);
  const T *ecc = reinterpret_cast<const T *>(in[2]);
}

Here we have used a template so that we can support both single and double precision version of the op.

The output parameter is somewhat more complicated. If your op only has one output, you would access it using

T *result = reinterpret_cast(out);

but when you have multiple outputs, things get a little hairy. In our example, we have two outputs, the sine sin_ecc_anom and cosine cos_ecc_anom of the eccentric anomaly. Therefore, our out parameter -- even though it looks like a void* -- is actually a void**! Therefore, we will access the output as follows:

template <typename T>
void cpu_kepler(void *out_tuple, const void **in) {
  // ...
  void **out = reinterpret_cast<void **>(out_tuple);
  T *sin_ecc_anom = reinterpret_cast(out[0]);
  T *cos_ecc_anom = reinterpret_cast(out[1]);
}

Then finally, we actually apply the op and the full implementation, which you can find in lib/cpu_ops.cc is:

// lib/cpu_ops.cc
#include <cstdint>

template <typename T>
void cpu_kepler(void *out_tuple, const void **in) {
  const std::int64_t size = *reinterpret_cast<const std::int64_t *>(in[0]);
  const T *mean_anom = reinterpret_cast<const T *>(in[1]);
  const T *ecc = reinterpret_cast<const T *>(in[2]);

  void **out = reinterpret_cast<void **>(out_tuple);
  T *sin_ecc_anom = reinterpret_cast(out[0]);
  T *cos_ecc_anom = reinterpret_cast(out[1]);

  for (std::int64_t n = 0; n < size; ++n) {
    compute_eccentric_anomaly(mean_anom[n], ecc[n], sin_ecc_anom + n, cos_ecc_anom + n);
  }
}

and that's it!

Building & packaging for the CPU

Now that we have an implementation of our XLA custom call target, we need to expose it to JAX. This is done by compiling a CPython module that wraps this function as a PyCapsule type. This can be done using pybind11, Cython, SWIG, or the Python C API directly, but for this example we'll use pybind11 since that's what I'm most familiar with. The LAPACK ops in jaxlib are implemented using Cython if you'd like to see an example of how to do that.

Another choice that I've made is to use CMake to build the extensions. It would be totally possible (and perhaps preferable if you only support CPU usage) to stick to just using setuptools directly, but setuptools doesn't seem to have great support for compiling CUDA extensions so that's why I settled on CMake. In the end, it's not too painful since CMake can be included as a build dependency in pyproject.toml so users won't have to install it separately. Another build option would be to use bazel to compile the code, like the JAX project, but I don't have any experience with it so I decided to stick with what I know. The key point is that we're just compiling a regular old Python module so you can use whatever infrastructure you're familiar with!

With these choices out of the way, the boilerplate code required to define the interface is, using the cpu_kepler function defined in the previous section as follows:

// lib/cpu_ops.cc
#include <pybind11/pybind11.h>

// If you're looking for it, this function is actually implemented in
// lib/pybind11_kernel_helpers.h
template <typename T>
pybind11::capsule EncapsulateFunction(T* fn) {
  return pybind11::capsule((void*)fn, "xla._CUSTOM_CALL_TARGET");
}

pybind11::dict Registrations() {
  pybind11::dict dict;
  dict["cpu_kepler_f32"] = EncapsulateFunction(cpu_kepler<float>);
  dict["cpu_kepler_f64"] = EncapsulateFunction(cpu_kepler<double>);
  return dict;
}

PYBIND11_MODULE(cpu_ops, m) { m.def("registrations", &Registrations); }

In this case, we're exporting a separate function for both single and double precision. Another option would be to pass the data type to the function and perform the dispatch logic directly in C++, but I find it cleaner to do it like this.

With that out of the way, the actual build routine is defined in the following files:

  • In ./pyproject.toml, we specify that pybind11 and cmake are required build dependencies and that we'll use setuptools.build_meta as the build backend.

  • setup.py is a pretty typical setup file with a custom class for building the extensions that executes CMake for the actual compilation step. This does include some extra configuration arguments for CMake to make sure that it uses the correct Python libraries and installs the compiled objects to the right place. It might be possible to use something like scikit-build to replace this step, but I struggled to get it working.

  • Finally, CMakeLists.txt defines the build process for CMake using pybind11's support for CMake builds. This will also, optionally, build the GPU ops as discussed below.

With these files in place, we can now compile our XLA custom call ops using

pip install .

The final thing that I wanted to reiterate in this section is that kepler_jax.cpu_ops is just a regular old CPython extension module, so anything that you already know about packaging C extensions or any other resources that you can find on that topic can be applied. This wasn't obvious when I first started learning about this so I definitely went down some rabbit holes that hopefully you can avoid.

Exposing this op as a JAX primitive

The main components that are required to now call our custom op from JAX are well covered by the How primitives work tutorial, so I won't reproduce all of that here. Instead I'll summarize the key points and then provide the missing part. If you haven't already, you should definitely read that tutorial before getting started on this part.

In summary, we will define a jax.core.Primitive object with an "abstract evaluation" rule (see src/kepler_jax/kepler_jax.py for all the details) following the primitives tutorial. Then, we'll add a "translation rule" and a "JVP rule". We're lucky in this case, and we don't need to add a "transpose rule". JAX can actually work that out automatically, since our primitive is not itself used in the calculation of the output tangents. This won't always be true, and the How primitives work tutorial includes an example of what to do in that case.

Before defining these rules, we need to register the custom call target with JAX. To do that, we import our compiled cpu_ops extension module from above and use the registrations dictionary that we defined:

from jax.lib import xla_client
from kepler_jax import cpu_ops

for _name, _value in cpu_ops.registrations().items():
    xla_client.register_cpu_custom_call_target(_name, _value)

Then, the translation rule is defined roughly as follows (the one you'll find in the source code is a little more complicated since it supports both CPU and GPU translation):

# src/kepler_jax/kepler_jax.py
import numpy as np

def _kepler_cpu_translation(c, mean_anom, ecc):
    # The inputs have "shapes" that provide both the shape and the dtype
    mean_anom_shape = c.get_shape(mean_anom)
    ecc_shape = c.get_shape(ecc)

    # Extract the dtype and shape
    dtype = mean_anom_shape.element_type()
    dims = mean_anom_shape.dimensions()
    assert ecc_shape.element_type() == dtype
    assert ecc_shape.dimensions() == dims

    # The total size of the input is the product across dimensions
    size = np.prod(dims).astype(np.int64)

    # The inputs and outputs all have the same shape so let's predefine this
    # specification
    shape = xla_client.Shape.array_shape(
        np.dtype(dtype), dims, tuple(range(len(dims) - 1, -1, -1))
    )

    # We dispatch a different call depending on the dtype
    if dtype == np.float32:
        op_name = b"cpu_kepler_f32"
    elif dtype == np.float64:
        op_name = b"cpu_kepler_f64"
    else:
        raise NotImplementedError(f"Unsupported dtype {dtype}")

    # On the CPU, we pass the size of the data as a the first input
    # argument
    return xla_client.ops.CustomCallWithLayout(
        c,
        op_name,
        # The inputs:
        operands=(xla_client.ops.ConstantLiteral(c, size), mean_anom, ecc),
        # The input shapes:
        operand_shapes_with_layout=(
              xla_client.Shape.array_shape(np.dtype(np.int64), (), ()),
              shape,
              shape,
        ),
        # The output shapes:
        shape_with_layout=xla_client.Shape.tuple_shape((shape, shape)),
    )

xla.backend_specific_translations["cpu"][_kepler_prim] = _kepler_cpu_translation

There appears to be a lot going on here, but most of it is just typechecking. The main meat of it is the CustomCallWithLayout function which, as far as I can tell, isn't documented anywhere. Here's a summary of its arguments, as best as I can tell:

  • The first argument is the XLA builder that you were passed when your translation rule was called.

  • The second argument is the name (as bytes!) that you gave your PyCapsule in the registrations dictionary in lib/cpu_ops.cc. You can check what names your capsules had by looking at cpu_ops.registrations().keys().

  • Then, the following arguments give the input arguments, and the "shapes" of the input and output arrays. In this context, a "shape" is specified by a data type, a tuple defining the size of each dimension (what I would normally call the shape), and a tuple defining the dimension order. In this case, we're requiring that all of our inputs and outputs are of the same "shape".

It's worth remembering that we're expecting the first argument to our function to be the size of the arrays, and you'll see that that is included as a ConstantLiteral parameter (explicitly cast to int64).

I'm not going to talk about the JVP rule here since it's quite problem specific, but I've tried to comment the code reasonably thoroughly so check out the code in src/kepler_jax/kepler_jax.py if you're interested, and open an issue if anything isn't clear.

Defining an XLA custom call on the GPU

The custom call on the GPU isn't terribly different from the CPU version above, but the syntax is somewhat different and there's a heck of a lot more boilerplate required. Since we need to compile and link CUDA code, there are also a few more packaging steps, but we'll get to that in the next section. The description in this section is a little all over the place, but the key files to look at to get more info are (a) lib/gpu_ops.cc for the dispatch functions called from Python, and (b) lib/kernels.cc.cu for the CUDA code implementing the kernel.

The signature for the GPU custom call is:

// lib/kernels.cc.cu
template <typename T>
void gpu_kepler(
  cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len
);

The first parameter is a CUDA stream, which I won't talk about at all because I don't really know very much about GPU programming and we don't really need to worry about it for now. Then you'll notice that the inputs and outputs are all provided as a single void** buffer. These will be ordered such that our access code from above is replaced by:

// lib/kernels.cc.cu
template <typename T>
void gpu_kepler(
  cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len
) {
  const T *mean_anom = reinterpret_cast<const T *>(buffers[0]);
  const T *ecc = reinterpret_cast<const T *>(buffers[1]);
  T *sin_ecc_anom = reinterpret_cast(buffers[2]);
  T *cos_ecc_anom = reinterpret_cast(buffers[3]);
}

where you might notice that the size parameter is no longer one of the inputs. Instead the array size is passed using the opaque parameter since its value is required on the CPU and within the GPU kernel (see the XLA custom calls documentation for more details). To use this opaque parameter, we will define a type to hold size:

// lib/kernels.h
struct KeplerDescriptor {
  std::int64_t size;
};

and then the following boilerplate to serialize it:

// lib/kernel_helpers.h
#include <string>

// Note that bit_cast is only available in recent C++ standards so you might need
// to provide a shim like the one in lib/kernel_helpers.h
template <typename T>
std::string PackDescriptorAsString(const T& descriptor) {
  return std::string(bit_cast<const char*>(&descriptor), sizeof(T));
}

// lib/pybind11_kernel_helpers.h
#include <pybind11/pybind11.h>

template <typename T>
pybind11::bytes PackDescriptor(const T& descriptor) {
  return pybind11::bytes(PackDescriptorAsString(descriptor));
}

This serialization procedure should then be exposed in the Python module using:

// lib/gpu_ops.cc
#include <pybind11/pybind11.h>

PYBIND11_MODULE(gpu_ops, m) {
  // ...
  m.def("build_kepler_descriptor",
        [](std::int64_t size) {
          return PackDescriptor(KeplerDescriptor{size});
        });
}

Then, to deserialize this descriptor, we can use the following procedure:

// lib/kernel_helpers.h
template <typename T>
const T* UnpackDescriptor(const char* opaque, std::size_t opaque_len) {
  if (opaque_len != sizeof(T)) {
    throw std::runtime_error("Invalid opaque object size");
  }
  return bit_cast<const T*>(opaque);
}

// lib/kernels.cc.cu
template <typename T>
void gpu_kepler(
  cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len
) {
  // ...
  const KeplerDescriptor &d = *UnpackDescriptor(opaque, opaque_len);
  const std::int64_t size = d.size;
}

Once we have these parameters, the full procedure for launching the CUDA kernel is:

// lib/kernels.cc.cu
template <typename T>
void gpu_kepler(
  cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len
) {
  const T *mean_anom = reinterpret_cast<const T *>(buffers[0]);
  const T *ecc = reinterpret_cast<const T *>(buffers[1]);
  T *sin_ecc_anom = reinterpret_cast(buffers[2]);
  T *cos_ecc_anom = reinterpret_cast(buffers[3]);
  const KeplerDescriptor &d = *UnpackDescriptor(opaque, opaque_len);
  const std::int64_t size = d.size;

  // Select block sizes, etc., no promises that these numbers are the right choices
  const int block_dim = 128;
  const int grid_dim = std::min<int>(1024, (size + block_dim - 1) / block_dim);

  // Launch the kernel
  kepler_kernel
      <<0, stream>>>(size, mean_anom, ecc, sin_ecc_anom, cos_ecc_anom);

  cudaError_t error = cudaGetLastError();
  if (error != cudaSuccess) {
    throw std::runtime_error(cudaGetErrorString(error));
  }
}

Finally, the kernel itself is relatively simple:

// lib/kernels.cc.cu
template <typename T>
__global__ void kepler_kernel(
  std::int64_t size, const T *mean_anom, const T *ecc, T *sin_ecc_anom, T *cos_ecc_anom
) {
  for (std::int64_t idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size;
       idx += blockDim.x * gridDim.x) {
    compute_eccentric_anomaly(mean_anom[idx], ecc[idx], sin_ecc_anom + idx, cos_ecc_anom + idx);
  }
}

Building & packaging for the GPU

Since we're already using CMake to build our project, it's not too hard to add support for CUDA. I've chosen to enable GPU builds by the environment variable KEPLER_JAX_CUDA=yes that you'll see in both setup.py and CMakeLists.txt. Other than conditionally adding an Extension in setup.py, everything else on the Python side is the same. In CMakeLists.txt, we also add a conditional:

if (KEPLER_JAX_CUDA)
  enable_language(CUDA)
  # ...
else()
  message(STATUS "Building without CUDA")
endif()

Then, to expose this to JAX, we need to update the translation rule from above as follows:

# src/kepler_jax/kepler_jax.py
import numpy as np
from jax.lib import xla_client
from kepler_jax import gpu_ops

for _name, _value in gpu_ops.registrations().items():
    xla_client.register_custom_call_target(_name, _value, platform="gpu")

def _kepler_gpu_translation(c, mean_anom, ecc):
    # Most of this function is the same as the CPU version above...

    # The name of the op is now prefaced with 'gpu' (our choice, see lib/gpu_ops.cc,
    # not a requirement)
    if dtype == np.float32:
        op_name = b"gpu_kepler_f32"
    elif dtype == np.float64:
        op_name = b"gpu_kepler_f64"
    else:
        raise NotImplementedError(f"Unsupported dtype {dtype}")

    # We need to serialize the array size using a descriptor
    opaque = gpu_ops.build_kepler_descriptor(size)

    # The syntax is *almost* the same as the CPU version, but we need to pass the
    # size using 'opaque' rather than as an input
    return xla_client.ops.CustomCallWithLayout(
        c,
        op_name,
        operands=(mean_anom, ecc),
        operand_shapes_with_layout=(shape, shape),
        shape_with_layout=xla_client.Shape.tuple_shape((shape, shape)),
        opaque=opaque,
    )

xla.backend_specific_translations["gpu"][_kepler_prim] = _kepler_gpu_translation

Otherwise, everything else from our CPU implementation doesn't need to change.

Testing

As usual, you should always test your code and this repo includes some unit tests in the tests directory for inspiration. You can also see an example of how to run these tests using the GitHub Actions CI service and the workflow in .github/workflows/tests.yml. I don't know of any public CI servers that provide GPU support, but I do include a test to confirm that the GPU ops can be compiled. You can see the infrastructure for that test in the .github/action directory.

See this in action

To demo the use of this custom op, I put together a notebook, based on an example from the exoplanet docs. You can see this notebook in the demo.ipynb file in the root of this repository or open it on Google Colab:

Open In Colab

References

Comments
  • A Simple Setup

    A Simple Setup

    Thank you for the excellent tutorial. You have reduced my dev time by x1000 by providing this excellent minimum working example.

    I wanted to offer a simpler setup.py. I am writing a module, and I want the module to talk with both Pytorch and Jax. To this end, I didn't want to use a CMake file since my Pytorch module currently doesn't need one. Instead, I am using the torch setup tools, described here.

    I think your CMake file logic is good still since it shows how to do it (I had no idea how). This version just requires less code.

    opened by gauenk 7
  • How would I include cuDNN?

    How would I include cuDNN?

    First of all, thank you very much for this. You have saved me tons of work and I am very grateful for the great documentation.

    I would like to extend JAX with custom calls that internally make use of cudnn. For this I added an include at the top of "kernels.cc.cu". I tried both of the following:

    #include <cudnn.h>
    #include "/usr/include/cudnn.h"
    

    The compiler finds the header and does not complain when I add the following host code:

      cudnnHandle_t handle_;
      cudnnCreate(&handle_);
    

    However as soon as I try to run the code from JAX, I get the error that cudnnCreate is an undefined symbol. If I remove the includes then the compiler complains.

    Do you have any idea how I could potentially fix this?

    opened by helange23 2
  • How to reinterpret_cast a matrix?

    How to reinterpret_cast a matrix?

    Dear @dfm , your tutorial is excellent. But I am not familar with c++. I have a naive question.

    You said that how to receive input values by:

    #include <cstdint> // int64_t

    template <typename T> void cpu_kepler(void *out, const void **in) { const std::int64_t size = *reinterpret_cast<const std::int64_t *>(in[0]); const T *mean_anom = reinterpret_cast<const T *>(in[1]); const T *ecc = reinterpret_cast<const T *>(in[2]); }

    However, if one of my input values is a matrix, how can I reinterpret_cast it?

    Thanks.

    opened by chaoming0625 2
  • XLA register translation rule fail

    XLA register translation rule fail

    Hey, i try to add custom call and define the xla translation rule follow this doc https://jax.readthedocs.io/en/latest/notebooks/How_JAX_primitives_work.html#xla-compilation-rules

    However, it miss the custom call part. And i try to implement this part follow your example code.

    You use functools.partial(translation, platform="cpu) here, however, i got func_wrapper() got an unexpected keyword argument 'platform' https://github.com/dfm/extending-jax/blob/main/src/kepler_jax/kepler_jax.py#L190-L193

    Could you please give some suggestions?

    P.S. I use Cython to implement the C++ XLA custom call function. And the only remaining part is register the xla translation rule.

    opened by llCurious 1
  • nit on the need for transpose rules

    nit on the need for transpose rules

    I love the tutorial, thanks for putting together this great resource!

    You write:

    We're lucky in this case, and we don't need to add a "transpose rule", since JAX can actually work that out by itself (our JVP is linear in the tangents).

    Every well behaved JVP is linear in the tangents, by definition -- the tangents are the "vector" part of "Jacobian-vector product."

    What's special about this primitive (and which means that we don't need the transpose rule) is that it is non-linear. That means it that it can't appear in the tangent calculation (because again, output tangents are always a linear function of input tangents, per the chain rule), and only things that appear in the tangent calculations needs to be transposed to calculate cotangents (VJPs).

    So the key question is actually whether a primitive is a linear function of one or more of its arguments. If so, then yes, you need a transpose rule to support reverse mode autodiff.

    opened by shoyer 1
  • Minor update to the new MLIR specification for custom ops

    Minor update to the new MLIR specification for custom ops

    Following #7 , this small PR proposes to update this post to the new way that JAX uses to define custom ops by directly writing their specification using MLIR.

    I'm pretty noob in MLIR but it seems very cool, and I guess at some point XLA will deprecate the old way of custom call building. Here is the doc for the custom call op in MLIR: https://www.tensorflow.org/mlir/hlo_ops#mhlocustom_call_mlirmhlocustomcallop

    And all the custom ops in jaxlib are now built this way.

    Let me know if this looks good. I haven't run the checks on GPU

    opened by EiffL 4
  • Interested in help updating these instructions to the new style of XLA translation rules?

    Interested in help updating these instructions to the new style of XLA translation rules?

    Hey @dfm :-) I had to do a bit of hacking this week, read a lot of XLA source code, and played with the new MLIR approach for specifying translation rules for custom ops in JAX. My understanding is that since April, all builtin LAX primitives have been transferred to MLIR equivalents, and that the old style CustomCallWithLayout just remains for backward compatibility. Here is an example of what the custom calls look like in current jax: https://github.com/google/jax/blob/f697b8e0876f8e1144a53ace02ee6d7eaa43fa14/jaxlib/gpu_solver.py#L66

    Before the knowledge of how to make these things work leaves my short term memory, would you be interested in something like a PR to this post? If you prefer these posts to stay static, no worries, I can write down that info elsewhere, linking to your post for extended context ;-)

    opened by EiffL 1
Owner
Dan Foreman-Mackey
Astrophysics. Good. Code.
Dan Foreman-Mackey
Example repository for custom C++/CUDA operators for TorchScript

Custom TorchScript Operators Example This repository contains examples for writing, compiling and using custom TorchScript operators. See here for the

null 106 Dec 14, 2022
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
Source code for models described in the paper "AudioCLIP: Extending CLIP to Image, Text and Audio" (https://arxiv.org/abs/2106.13043)

AudioCLIP Extending CLIP to Image, Text and Audio This repository contains implementation of the models described in the paper arXiv:2106.13043. This

null 458 Jan 2, 2023
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022
Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable the user to perform stochastic variational inference in Bayesian deep neural networks

Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable the user to perform stochastic variational inference in Bayesian deep neural networks. Bayesian-Torch is designed to be flexible and seamless in extending a deterministic deep neural network architecture to corresponding Bayesian form by simply replacing the deterministic layers with Bayesian layers.

Intel Labs 210 Jan 4, 2023
Library extending Jupyter notebooks to integrate with Apache TinkerPop and RDF SPARQL.

Graph Notebook: easily query and visualize graphs The graph notebook provides an easy way to interact with graph databases using Jupyter notebooks. Us

Amazon Web Services 501 Dec 28, 2022
Convert Python 3 code to CUDA code.

Py2CUDA Convert python code to CUDA. Usage To convert a python file say named py_file.py to CUDA, run python generate_cuda.py --file py_file.py --arch

Yuval Rosen 3 Jul 14, 2021
FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.

FedJAX: Federated learning with JAX What is FedJAX? FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX priori

Google 208 Dec 14, 2022
Prevent `CUDA error: out of memory` in just 1 line of code.

?? Koila Koila solves CUDA error: out of memory error painlessly. Fix it with just one line of code, and forget it. ?? Features ?? Prevents CUDA error

RenChu Wang 1.7k Jan 2, 2023
Bytedance Inc. 2.5k Jan 6, 2023
Example-custom-ml-block-keras - Custom Keras ML block example for Edge Impulse

Custom Keras ML block example for Edge Impulse This repository is an example on

Edge Impulse 8 Nov 2, 2022
Picasso: A CUDA-based Library for Deep Learning over 3D Meshes

The Picasso Library is intended for complex real-world applications with large-scale surfaces, while it also performs impressively on the small-scale applications over synthetic shape manifolds. We have upgraded the point cloud modules of SPH3D-GCN from homogeneous to heterogeneous representations, and included the upgraded modules into this latest work as well. We are happy to announce that the work is accepted to IEEE CVPR2021.

null 97 Dec 1, 2022
This Repo is the official CUDA implementation of ICCV 2019 Oral paper for CARAFE: Content-Aware ReAssembly of FEatures

Introduction This Repo is the official CUDA implementation of ICCV 2019 Oral paper for CARAFE: Content-Aware ReAssembly of FEatures. @inproceedings{Wa

Jiaqi Wang 42 Jan 7, 2023
PyTorch implementation of Soft-DTW: a Differentiable Loss Function for Time-Series in CUDA

Soft DTW Loss Function for PyTorch in CUDA This is a Pytorch Implementation of Soft-DTW: a Differentiable Loss Function for Time-Series which is batch

Keon Lee 76 Dec 20, 2022
This demo showcase the use of onnxruntime-rs with a GPU on CUDA 11 to run Bert in a data pipeline with Rust.

Demo BERT ONNX pipeline written in rust This demo showcase the use of onnxruntime-rs with a GPU on CUDA 11 to run Bert in a data pipeline with Rust. R

Xavier Tao 14 Dec 17, 2022
CUDA Python Low-level Bindings

CUDA Python Low-level Bindings

NVIDIA Corporation 529 Jan 3, 2023
Time-stretch audio clips quickly with PyTorch (CUDA supported)! Additional utilities for searching efficient transformations are included.

Time-stretch audio clips quickly with PyTorch (CUDA supported)! Additional utilities for searching efficient transformations are included.

Kento Nishi 22 Jul 7, 2022
A dead simple python wrapper for darknet that works with OpenCV 4.1, CUDA 10.1

What Dead simple python wrapper for Yolo V3 using AlexyAB's darknet fork. Works with CUDA 10.1 and OpenCV 4.1 or later (I use OpenCV master as of Jun

Pliable Pixels 6 Jan 12, 2022