jax2torch
Use Jax functions in Pytorch with DLPack, as outlined in a gist by @mattjj. The repository was made for the purposes of making this differentiable alignment work interoperable with Pytorch projects.
Install
$ pip install jax2torch
Usage
import jax
import torch
from jax2torch import jax2torch
# Jax function
@jax.jit
def jax_pow(x, y = 2):
return x ** y
# convert to Torch function
torch_pow = jax2torch(jax_pow)
# run it on Torch data!
x = torch.tensor([1., 2., 3.])
y = torch_pow(x, y = 3)
print(y) # tensor([1., 8., 27.])
# And differentiate!
x = torch.tensor([2., 3.], requires_grad = True)
y = torch.sum(torch_pow(x, y = 3))
y.backward()
print(x.grad) # tensor([12., 27.])