Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

Related tags

Deep Learning jax
Overview
logo

JAX: Autograd and XLA

Continuous integration PyPI version

Quickstart | Transformations | Install guide | Neural net libraries | Change logs | Reference docs | Code search

News: JAX tops largest-scale MLPerf Training 0.7 benchmarks!

What is JAX?

JAX is Autograd and XLA, brought together for high-performance machine learning research.

With its updated version of Autograd, JAX can automatically differentiate native Python and NumPy functions. It can differentiate through loops, branches, recursion, and closures, and it can take derivatives of derivatives of derivatives. It supports reverse-mode differentiation (a.k.a. backpropagation) via grad as well as forward-mode differentiation, and the two can be composed arbitrarily to any order.

What’s new is that JAX uses XLA to compile and run your NumPy programs on GPUs and TPUs. Compilation happens under the hood by default, with library calls getting just-in-time compiled and executed. But JAX also lets you just-in-time compile your own Python functions into XLA-optimized kernels using a one-function API, jit. Compilation and automatic differentiation can be composed arbitrarily, so you can express sophisticated algorithms and get maximal performance without leaving Python. You can even program multiple GPUs or TPU cores at once using pmap, and differentiate through the whole thing.

Dig a little deeper, and you'll see that JAX is really an extensible system for composable function transformations. Both grad and jit are instances of such transformations. Others are vmap for automatic vectorization and pmap for single-program multiple-data (SPMD) parallel programming of multiple accelerators, with more to come.

This is a research project, not an official Google product. Expect bugs and sharp edges. Please help by trying it out, reporting bugs, and letting us know what you think!

import jax.numpy as jnp
from jax import grad, jit, vmap

def predict(params, inputs):
  for W, b in params:
    outputs = jnp.dot(inputs, W) + b
    inputs = jnp.tanh(outputs)
  return outputs

def logprob_fun(params, inputs, targets):
  preds = predict(params, inputs)
  return jnp.sum((preds - targets)**2)

grad_fun = jit(grad(logprob_fun))  # compiled gradient evaluation function
perex_grads = jit(vmap(grad_fun, in_axes=(None, 0, 0)))  # fast per-example grads

Contents

Quickstart: Colab in the Cloud

Jump right in using a notebook in your browser, connected to a Google Cloud GPU. Here are some starter notebooks:

JAX now runs on Cloud TPUs. To try out the preview, see the Cloud TPU Colabs.

For a deeper dive into JAX:

You can also take a look at the mini-libraries in jax.experimental, like stax for building neural networks and optimizers for first-order stochastic optimization, or the examples.

Transformations

At its core, JAX is an extensible system for transforming numerical functions. Here are four of primary interest: grad, jit, vmap, and pmap.

Automatic differentiation with grad

JAX has roughly the same API as Autograd. The most popular function is grad for reverse-mode gradients:

from jax import grad
import jax.numpy as jnp

def tanh(x):  # Define a function
  y = jnp.exp(-2.0 * x)
  return (1.0 - y) / (1.0 + y)

grad_tanh = grad(tanh)  # Obtain its gradient function
print(grad_tanh(1.0))   # Evaluate it at x = 1.0
# prints 0.4199743

You can differentiate to any order with grad.

print(grad(grad(grad(tanh)))(1.0))
# prints 0.62162673

For more advanced autodiff, you can use jax.vjp for reverse-mode vector-Jacobian products and jax.jvp for forward-mode Jacobian-vector products. The two can be composed arbitrarily with one another, and with other JAX transformations. Here's one way to compose those to make a function that efficiently computes full Hessian matrices:

from jax import jit, jacfwd, jacrev

def hessian(fun):
  return jit(jacfwd(jacrev(fun)))

As with Autograd, you're free to use differentiation with Python control structures:

def abs_val(x):
  if x > 0:
    return x
  else:
    return -x

abs_val_grad = grad(abs_val)
print(abs_val_grad(1.0))   # prints 1.0
print(abs_val_grad(-1.0))  # prints -1.0 (abs_val is re-evaluated)

See the reference docs on automatic differentiation and the JAX Autodiff Cookbook for more.

Compilation with jit

You can use XLA to compile your functions end-to-end with jit, used either as an @jit decorator or as a higher-order function.

import jax.numpy as jnp
from jax import jit

def slow_f(x):
  # Element-wise ops see a large benefit from fusion
  return x * x + x * 2.0

x = jnp.ones((5000, 5000))
fast_f = jit(slow_f)
%timeit -n10 -r3 fast_f(x)  # ~ 4.5 ms / loop on Titan X
%timeit -n10 -r3 slow_f(x)  # ~ 14.5 ms / loop (also on GPU via JAX)

You can mix jit and grad and any other JAX transformation however you like.

Using jit puts constraints on the kind of Python control flow the function can use; see the Gotchas Notebook for more.

Auto-vectorization with vmap

vmap is the vectorizing map. It has the familiar semantics of mapping a function along array axes, but instead of keeping the loop on the outside, it pushes the loop down into a function’s primitive operations for better performance.

Using vmap can save you from having to carry around batch dimensions in your code. For example, consider this simple unbatched neural network prediction function:

def predict(params, input_vec):
  assert input_vec.ndim == 1
  activations = inputs
  for W, b in params:
    outputs = jnp.dot(W, activations) + b  # `input_vec` on the right-hand side!
    activations = jnp.tanh(outputs)
  return outputs

We often instead write jnp.dot(inputs, W) to allow for a batch dimension on the left side of inputs, but we’ve written this particular prediction function to apply only to single input vectors. If we wanted to apply this function to a batch of inputs at once, semantically we could just write

from functools import partial
predictions = jnp.stack(list(map(partial(predict, params), input_batch)))

But pushing one example through the network at a time would be slow! It’s better to vectorize the computation, so that at every layer we’re doing matrix-matrix multiplication rather than matrix-vector multiplication.

The vmap function does that transformation for us. That is, if we write

from jax import vmap
predictions = vmap(partial(predict, params))(input_batch)
# or, alternatively
predictions = vmap(predict, in_axes=(None, 0))(params, input_batch)

then the vmap function will push the outer loop inside the function, and our machine will end up executing matrix-matrix multiplications exactly as if we’d done the batching by hand.

It’s easy enough to manually batch a simple neural network without vmap, but in other cases manual vectorization can be impractical or impossible. Take the problem of efficiently computing per-example gradients: that is, for a fixed set of parameters, we want to compute the gradient of our loss function evaluated separately at each example in a batch. With vmap, it’s easy:

per_example_gradients = vmap(partial(grad(loss), params))(inputs, targets)

Of course, vmap can be arbitrarily composed with jit, grad, and any other JAX transformation! We use vmap with both forward- and reverse-mode automatic differentiation for fast Jacobian and Hessian matrix calculations in jax.jacfwd, jax.jacrev, and jax.hessian.

SPMD programming with pmap

For parallel programming of multiple accelerators, like multiple GPUs, use pmap. With pmap you write single-program multiple-data (SPMD) programs, including fast parallel collective communication operations. Applying pmap will mean that the function you write is compiled by XLA (similarly to jit), then replicated and executed in parallel across devices.

Here's an example on an 8-GPU machine:

from jax import random, pmap
import jax.numpy as jnp

# Create 8 random 5000 x 6000 matrices, one per GPU
keys = random.split(random.PRNGKey(0), 8)
mats = pmap(lambda key: random.normal(key, (5000, 6000)))(keys)

# Run a local matmul on each device in parallel (no data transfer)
result = pmap(lambda x: jnp.dot(x, x.T))(mats)  # result.shape is (8, 5000, 5000)

# Compute the mean on each device in parallel and print the result
print(pmap(jnp.mean)(result))
# prints [1.1566595 1.1805978 ... 1.2321935 1.2015157]

In addition to expressing pure maps, you can use fast collective communication operations between devices:

from functools import partial
from jax import lax

@partial(pmap, axis_name='i')
def normalize(x):
  return x / lax.psum(x, 'i')

print(normalize(jnp.arange(4.)))
# prints [0.         0.16666667 0.33333334 0.5       ]

You can even nest pmap functions for more sophisticated communication patterns.

It all composes, so you're free to differentiate through parallel computations:

from jax import grad

@pmap
def f(x):
  y = jnp.sin(x)
  @pmap
  def g(z):
    return jnp.cos(z) * jnp.tan(y.sum()) * jnp.tanh(x).sum()
  return grad(lambda w: jnp.sum(g(w)))(x)

print(f(x))
# [[ 0.        , -0.7170853 ],
#  [-3.1085174 , -0.4824318 ],
#  [10.366636  , 13.135289  ],
#  [ 0.22163185, -0.52112055]]

print(grad(lambda x: jnp.sum(f(x)))(x))
# [[ -3.2369726,  -1.6356447],
#  [  4.7572474,  11.606951 ],
#  [-98.524414 ,  42.76499  ],
#  [ -1.6007166,  -1.2568436]]

When reverse-mode differentiating a pmap function (e.g. with grad), the backward pass of the computation is parallelized just like the forward pass.

See the SPMD Cookbook and the SPMD MNIST classifier from scratch example for more.

Current gotchas

For a more thorough survey of current gotchas, with examples and explanations, we highly recommend reading the Gotchas Notebook. Some standouts:

  1. JAX transformations only work on pure functions, which don't have side-effects and respect referential transparency (i.e. object identity testing with is isn't preserved). If you use a JAX transformation on an impure Python function, you might see an error like Exception: Can't lift Traced... or Exception: Different traces at same level.
  2. In-place mutating updates of arrays, like x[i] += y, aren't supported, but there are functional alternatives. Under a jit, those functional alternatives will reuse buffers in-place automatically.
  3. Random numbers are different, but for good reasons.
  4. If you're looking for convolution operators, they're in the jax.lax package.
  5. JAX enforces single-precision (32-bit, e.g. float32) values by default, and to enable double-precision (64-bit, e.g. float64) one needs to set the jax_enable_x64 variable at startup (or set the environment variable JAX_ENABLE_X64=True).
  6. Some of NumPy's dtype promotion semantics involving a mix of Python scalars and NumPy types aren't preserved, namely np.add(1, np.array([2], np.float32)).dtype is float64 rather than float32.
  7. Some transformations, like jit, constrain how you can use Python control flow. You'll always get loud errors if something goes wrong. You might have to use jit's static_argnums parameter, structured control flow primitives like lax.scan, or just use jit on smaller subfunctions.

Installation

JAX is written in pure Python, but it depends on XLA, which needs to be installed as the jaxlib package. Use the following instructions to install a binary package with pip, or to build JAX from source.

We support installing or building jaxlib on Linux (Ubuntu 16.04 or later) and macOS (10.12 or later) platforms. Windows users can use JAX on CPU and GPU via the Windows Subsystem for Linux. There is some initial native Windows support, but since it is still somewhat immature, there are no binary releases and it must be built from source.

pip installation

To install a CPU-only version, which might be useful for doing local development on a laptop, you can run

pip install --upgrade pip
pip install --upgrade jax jaxlib  # CPU-only version

On Linux, it is often necessary to first update pip to a version that supports manylinux2010 wheels.

If you want to install JAX with both CPU and NVidia GPU support, you must first install CUDA and CuDNN, if they have not already been installed. Unlike some other popular deep learning systems, JAX does not bundle CUDA or CuDNN as part of the pip package. The CUDA 10 JAX wheels require CuDNN 7, whereas the CUDA 11 wheels of JAX require CuDNN 8. Other combinations of CUDA and CuDNN are possible but require building from source.

Next, run

pip install --upgrade pip
pip install --upgrade jax jaxlib==0.1.60+cuda110 -f https://storage.googleapis.com/jax-releases/jax_releases.html

The jaxlib version must correspond to the version of the existing CUDA installation you want to use, with cuda111 for CUDA 11.1, cuda110 for CUDA 11.0, cuda102 for CUDA 10.2, and cuda101 for CUDA 10.1. You can find your CUDA version with the command:

nvcc --version

Note that some GPU functionality expects the CUDA installation to be at /usr/local/cuda-X.X, where X.X should be replaced with the CUDA version number (e.g. cuda-10.2). If CUDA is installed elsewhere on your system, you can either create a symlink:

sudo ln -s /path/to/cuda /usr/local/cuda-X.X

Alternatively, you can set the following environment variable before importing JAX:

XLA_FLAGS=--xla_gpu_cuda_data_dir=/path/to/cuda

Please let us know on the issue tracker if you run into any errors or problems with the prebuilt wheels.

Building JAX from source

See Building JAX from source.

Neural network libraries

Multiple Google research groups develop and share libraries for training neural networks in JAX. If you want a fully featured library for neural network training with examples and how-to guides, try Flax. Another option is Trax, a combinator-based framework focused on ease-of-use and end-to-end single-command examples, especially for sequence models and reinforcement learning. Finally, Objax is a minimalist object-oriented framework with a PyTorch-like interface.

DeepMind has open-sourced an ecosystem of libraries around JAX including Haiku for neural network modules, Optax for gradient processing and optimization, RLax for RL algorithms, and chex for reliable code and testing.

Citing JAX

To cite this repository:

@software{jax2018github,
  author = {James Bradbury and Roy Frostig and Peter Hawkins and Matthew James Johnson and Chris Leary and Dougal Maclaurin and George Necula and Adam Paszke and Jake Vander{P}las and Skye Wanderman-{M}ilne and Qiao Zhang},
  title = {{JAX}: composable transformations of {P}ython+{N}um{P}y programs},
  url = {http://github.com/google/jax},
  version = {0.2.5},
  year = {2018},
}

In the above bibtex entry, names are in alphabetical order, the version number is intended to be that from jax/version.py, and the year corresponds to the project's open-source release.

A nascent version of JAX, supporting only automatic differentiation and compilation to XLA, was described in a paper that appeared at SysML 2018. We're currently working on covering JAX's ideas and capabilities in a more comprehensive and up-to-date paper.

Reference documentation

For details about the JAX API, see the reference documentation.

For getting started as a JAX developer, see the developer documentation.

Comments
  • Provide wheels for macOS ARM

    Provide wheels for macOS ARM

    Hi all,

    I was digging around to see what might need to happen to allow JAX to work on Apple Silicon. Knowing that JAX gets compiled to XLA, my guess here is that XLA would need to be made Apple Silicon-compatible first before JAX could run on it. May I ask, do you all know if there are plans on the XLA team to make that happen, or is it being ignored completely? (Knowing the answer can help me make some decisions on how I should set up my development environment mostly.)

    Cheers, Eric

    enhancement contributions welcome open P2 (eventual) 
    opened by ericmjl 130
  • conda-based installation

    conda-based installation

    Putting this here and tagging myself @ericmjl so that I can remember this exists.

    To get jax into the hands of data scientists and machine learning researchers, conda installation would be very useful. I will take a stab at this on conda-forge, and record my progress here.

    enhancement build contributions welcome P2 (eventual) NVIDIA GPU 
    opened by ericmjl 71
  • Failure to build jaxlib v0.1.62 on Windows (Updated)

    Failure to build jaxlib v0.1.62 on Windows (Updated)

    The original thread #5981 was closed due to a proposed fix. However, even with this in place, using the same build parameters, it still fails to build, but much further into the build process.

         _   _  __  __
        | | / \ \ \/ /
     _  | |/ _ \ \  /
    | |_| / ___ \/  \
     \___/_/   \/_/\_\
    
    
    Starting local Bazel server and connecting to it...
    Bazel binary path: C:\bazel\bazel.EXE
    Python binary path: C:/Users/Adam/anaconda3/python.exe
    Python version: 3.8
    MKL-DNN enabled: yes
    Target CPU features: release
    CUDA enabled: yes
    CUDA toolkit path: C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0
    CUDNN library path: C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0
    CUDA compute capabilities: 7.5
    CUDA version: 11.0
    CUDNN version: 8.0.5
    TPU enabled: no
    ROCm enabled: no
    
    Building XLA and installing it in the jaxlib source tree...
    C:\bazel\bazel.EXE run --verbose_failures=true --config=short_logs --config=mkl_open_source_only --config=cuda --define=xla_python_enable_gpu=true :build_wheel -- --output_path=C:\sdks\jax-jaxlib-v0.1.62\dist
    INFO: Options provided by the client:
      Inherited 'common' options: --isatty=1 --terminal_columns=80
    INFO: Reading rc options for 'run' from c:\sdks\jax-jaxlib-v0.1.62\.bazelrc:
      Inherited 'common' options: --experimental_repo_remote_exec
    INFO: Options provided by the client:
      Inherited 'build' options: --python_path=C:/Users/Adam/anaconda3/python.exe
    INFO: Reading rc options for 'run' from c:\sdks\jax-jaxlib-v0.1.62\.bazelrc:
      Inherited 'build' options: --repo_env PYTHON_BIN_PATH=C:/Users/Adam/anaconda3/python.exe --action_env=PYENV_ROOT --python_path=C:/Users/Adam/anaconda3/python.exe --repo_env TF_NEED_CUDA=1 --action_env TF_CUDA_COMPUTE_CAPABILITIES=7.5 --repo_env TF_NEED_ROCM=0 --action_env TF_ROCM_AMDGPU_TARGETS=gfx803,gfx900,gfx906,gfx1010 --distinct_host_configuration=false -c opt --apple_platform_type=macos --macos_minimum_os=10.9 --announce_rc --define open_source_build=true --define=no_kafka_support=true --define=no_ignite_support=true --define=grpc_no_ares=true --spawn_strategy=standalone --strategy=Genrule=standalone --enable_platform_specific_config --action_env CUDA_TOOLKIT_PATH=C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0 --action_env CUDNN_INSTALL_PATH=C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0 --action_env TF_CUDA_VERSION=11.0 --action_env TF_CUDNN_VERSION=8.0.5
    INFO: Found applicable config definition build:short_logs in file c:\sdks\jax-jaxlib-v0.1.62\.bazelrc: --output_filter=DONT_MATCH_ANYTHING
    INFO: Found applicable config definition build:mkl_open_source_only in file c:\sdks\jax-jaxlib-v0.1.62\.bazelrc: --define=tensorflow_mkldnn_contraction_kernel=1
    INFO: Found applicable config definition build:cuda in file c:\sdks\jax-jaxlib-v0.1.62\.bazelrc: --crosstool_top=@local_config_cuda//crosstool:toolchain --@local_config_cuda//:enable_cuda
    INFO: Found applicable config definition build:windows in file c:\sdks\jax-jaxlib-v0.1.62\.bazelrc: --copt=/D_USE_MATH_DEFINES --host_copt=/D_USE_MATH_DEFINES --copt=-DWIN32_LEAN_AND_MEAN --host_copt=-DWIN32_LEAN_AND_MEAN --copt=-DNOGDI --host_copt=-DNOGDI --copt=/Zc:preprocessor --cxxopt=/std:c++14 --host_cxxopt=/std:c++14 --linkopt=/DEBUG --host_linkopt=/DEBUG --linkopt=/OPT:REF --host_linkopt=/OPT:REF --linkopt=/OPT:ICF --host_linkopt=/OPT:ICF --experimental_strict_action_env=true
    DEBUG: Rule 'io_bazel_rules_docker' indicated that a canonical reproducible form can be obtained by modifying arguments shallow_since = "1556410077 -0400"
    DEBUG: Repository io_bazel_rules_docker instantiated at:
      C:/sdks/jax-jaxlib-v0.1.62/WORKSPACE:34:10: in <toplevel>
      C:/users/adam/_bazel_adam/nzquhzn2/external/org_tensorflow/tensorflow/workspace0.bzl:105:34: in workspace
      C:/users/adam/_bazel_adam/nzquhzn2/external/bazel_toolchains/repositories/repositories.bzl:37:23: in repositories
    Repository rule git_repository defined at:
      C:/users/adam/_bazel_adam/nzquhzn2/external/bazel_tools/tools/build_defs/repo/git.bzl:199:33: in <toplevel>
    INFO: Analyzed target //build:build_wheel (182 packages loaded, 15809 targets configured).
    INFO: Found 1 target...
    ERROR: C:/users/adam/_bazel_adam/nzquhzn2/external/org_tensorflow/tensorflow/core/tpu/BUILD:105:11: C++ compilation of rule '@org_tensorflow//tensorflow/core/tpu:tpu_initializer_helper' failed (Exit 2): python.exe failed: error executing command
      cd C:/users/adam/_bazel_adam/nzquhzn2/execroot/__main__
      SET CUDA_TOOLKIT_PATH=C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0
        SET CUDNN_INSTALL_PATH=C:/Program Files/NVIDIA GPU Computing Toolkit/CUDA/v11.0
        SET INCLUDE=C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Tools\MSVC\14.28.29333\ATLMFC\include;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Tools\MSVC\14.28.29333\include;C:\Program Files (x86)\Windows Kits\NETFXSDK\4.8\include\um;C:\Program Files (x86)\Windows Kits\10\include\10.0.18362.0\ucrt;C:\Program Files (x86)\Windows Kits\10\include\10.0.18362.0\shared;C:\Program Files (x86)\Windows Kits\10\include\10.0.18362.0\um;C:\Program Files (x86)\Windows Kits\10\include\10.0.18362.0\winrt;C:\Program Files (x86)\Windows Kits\10\include\10.0.18362.0\cppwinrt
        SET LIB=C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Tools\MSVC\14.28.29333\ATLMFC\lib\x64;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Tools\MSVC\14.28.29333\lib\x64;C:\Program Files (x86)\Windows Kits\NETFXSDK\4.8\lib\um\x64;C:\Program Files (x86)\Windows Kits\10\lib\10.0.18362.0\ucrt\x64;C:\Program Files (x86)\Windows Kits\10\lib\10.0.18362.0\um\x64
        SET PATH=C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\\Extensions\Microsoft\IntelliCode\CLI;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\VC\Tools\MSVC\14.28.29333\bin\HostX64\x64;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\VC\VCPackages;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\CommonExtensions\Microsoft\TestWindow;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\CommonExtensions\Microsoft\TeamFoundation\Team Explorer;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\MSBuild\Current\bin\Roslyn;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Team Tools\Performance Tools\x64;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Team Tools\Performance Tools;C:\Program Files (x86)\Microsoft Visual Studio\Shared\Common\VSPerfCollectionTools\vs2019\\x64;C:\Program Files (x86)\Microsoft Visual Studio\Shared\Common\VSPerfCollectionTools\vs2019\;C:\Program Files (x86)\Microsoft SDKs\Windows\v10.0A\bin\NETFX 4.8 Tools\x64\;C:\Program Files (x86)\Windows Kits\10\bin\10.0.18362.0\x64;C:\Program Files (x86)\Windows Kits\10\bin\x64;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\\MSBuild\Current\Bin;C:\Windows\Microsoft.NET\Framework64\v4.0.30319;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\Tools\;;C:\WINDOWS\system32;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\CommonExtensions\Microsoft\CMake\CMake\bin;C:\Program Files (x86)\Microsoft Visual Studio\2019\Enterprise\Common7\IDE\CommonExtensions\Microsoft\CMake\Ninja
        SET PWD=/proc/self/cwd
        SET RUNFILES_MANIFEST_ONLY=1
        SET TEMP=C:\Users\Adam\AppData\Local\Temp
        SET TF_CUDA_COMPUTE_CAPABILITIES=7.5
        SET TF_CUDA_VERSION=11.0
        SET TF_CUDNN_VERSION=8.0.5
        SET TF_ROCM_AMDGPU_TARGETS=gfx803,gfx900,gfx906,gfx1010
        SET TMP=C:\Users\Adam\AppData\Local\Temp
      C:/Users/Adam/anaconda3/python.exe -B external/local_config_cuda/crosstool/windows/msvc_wrapper_for_nvcc.py /nologo /DCOMPILER_MSVC /DNOMINMAX /D_WIN32_WINNT=0x0600 /D_CRT_SECURE_NO_DEPRECATE /D_CRT_SECURE_NO_WARNINGS /D_SILENCE_STDEXT_HASH_DEPRECATION_WARNINGS /bigobj /Zm500 /J /Gy /GF /EHsc /wd4351 /wd4291 /wd4250 /wd4996 /Iexternal/org_tensorflow /Ibazel-out/x64_windows-opt/bin/external/org_tensorflow /Iexternal/eigen_archive /Ibazel-out/x64_windows-opt/bin/external/eigen_archive /Iexternal/com_google_absl /Ibazel-out/x64_windows-opt/bin/external/com_google_absl /Iexternal/nsync /Ibazel-out/x64_windows-opt/bin/external/nsync /Iexternal/eigen_archive /Ibazel-out/x64_windows-opt/bin/external/eigen_archive /Iexternal/nsync/public /Ibazel-out/x64_windows-opt/bin/external/nsync/public /DEIGEN_MPL2_ONLY /DEIGEN_MAX_ALIGN_BYTES=64 /D__CLANG_SUPPORT_DYN_ANNOTATION__ /showIncludes /MD /O2 /DNDEBUG /D_USE_MATH_DEFINES -DWIN32_LEAN_AND_MEAN -DNOGDI /Zc:preprocessor /std:c++14 /Fobazel-out/x64_windows-opt/bin/external/org_tensorflow/tensorflow/core/tpu/_objs/tpu_initializer_helper/tpu_initializer_helper.obj /c external/org_tensorflow/tensorflow/core/tpu/tpu_initializer_helper.cc
    Execution platform: @local_execution_config_platform//:platform
    Target //build:build_wheel failed to build
    INFO: Elapsed time: 1450.560s, Critical Path: 375.03s
    INFO: 5002 processes: 2127 internal, 2875 local.
    FAILED: Build did NOT complete successfully
    FAILED: Build did NOT complete successfully
    Traceback (most recent call last):
      File ".\build\build.py", line 521, in <module>
        main()
      File ".\build\build.py", line 516, in main
        shell(command)
      File ".\build\build.py", line 51, in shell
        output = subprocess.check_output(cmd)
      File "C:\Users\Adam\anaconda3\lib\subprocess.py", line 411, in check_output
        return run(*popenargs, stdout=PIPE, timeout=timeout, check=True,
      File "C:\Users\Adam\anaconda3\lib\subprocess.py", line 512, in run
        raise CalledProcessError(retcode, process.args,
    subprocess.CalledProcessError: Command '['C:\\bazel\\bazel.EXE', 'run', '--verbose_failures=true', '--config=short_logs', '--config=mkl_open_source_only', '--config=cuda', '--define=xla_python_enable_gpu=true', ':build_wheel', '--', '--output_path=C:\\sdks\\jax-jaxlib-v0.1.62\\dist']' returned non-zero exit status 1.
    

    Originally posted by @oracle3001 in https://github.com/google/jax/issues/5981#issuecomment-793318112

    opened by adam-hartshorne 41
  • Gmres qr

    Gmres qr

    A synthesis of @gehring's and @shoyer's GMRES implementation #4025 with the one in TensorNetwork.

    Per those authors, this implementation supports both preconditioning and pytrees, but I've modified it to also use a more efficient QR implementation of the main Arnoldi loop that builds the QR decomposition of the Krylov matrix instead of the matrix itself. This greatly reduces the overhead of the final linear solve, since the relevant system is now triangular. It also allows the iteration to terminate early if convergence is reached midway through the Arnoldi iterations.

    The behaviour of the maxiter and restart arguments has also been modified (I think) to mirror that in SciPy. Finally, a few bugs have been fixed such that the implementation passes some simple tests, but I haven't tested very extensively.

    @gehring @shoyer, it would be great if you had a look!

    Relevant issues: #1531

    cla: yes 
    opened by alewis 36
  • [jax2tf] NotImplementedError: Call to scatter add cannot be converted with enable_xla=False

    [jax2tf] NotImplementedError: Call to scatter add cannot be converted with enable_xla=False

    Conversions for XlaScatter are currently unsupported when using enabled_xla=False. I'm wondering if support could be added?

    Here's the full error that I'm seeing:

    NotImplementedError                       Traceback (most recent call last)
    /usr/local/lib/python3.7/dist-packages/jax/experimental/jax2tf/impl_no_xla.py in op(*arg, **kwargs)
         52 
         53   def op(*arg, **kwargs):
    ---> 54     raise _xla_disabled_error(name)
         55 
         56   return op
    
    NotImplementedError: Call to scatter add cannot be converted with enable_xla=False.
    

    Here is some code that reproduces this error:

    !pip install --upgrade flax
    !pip install git+https://github.com/josephrocca/transformers.git@patch-2
    
    import jax
    from jax.experimental import jax2tf
    from jax import numpy as jnp
    
    import numpy as np
    import tensorflow as tf
    
    from transformers import FlaxCLIPModel
    
    clip = FlaxCLIPModel.from_pretrained("openai/clip-vit-base-patch32")
    
    def score(pixel_values, input_ids, attention_mask):
        pixel_values = jax.image.resize(pixel_values, (3, 224, 224), "nearest")
        inputs = {"pixel_values":jnp.array([pixel_values]), "input_ids":input_ids, "attention_mask":attention_mask}
        outputs = clip(**inputs)
        return outputs.logits_per_image[0][0][0]
    
    score_tf = jax2tf.convert(jax.grad(score), enable_xla=False)
    
    my_model = tf.Module()
    my_model.f = tf.function(score_tf, autograph=False, jit_compile=True, input_signature=[
      tf.TensorSpec([3, 40, 40], tf.float32),
      tf.TensorSpec([1, 30], tf.int32),
      tf.TensorSpec([1, 30], tf.int32),
    ])
    
    model_name = 'pixel_text_score_grad'
    tf.saved_model.save(my_model, model_name, options=tf.saved_model.SaveOptions(experimental_custom_gradients=True))
    

    Here's a public colab with that code: https://colab.research.google.com/drive/1HjpRXsa8Ue9KWiKbVVWUX6DlXoIYx2r8?usp=sharing You can click "Runtime > Run all" to see the error.

    Thanks!

    enhancement 
    opened by josephrocca 35
  • cannot find libdevice

    cannot find libdevice

    Hi

    Jax cannot find libdevice. I'm running Python 3.7 with cuda 10.0 on my personal laptop qwith a GeForce RTX 2080. I installed jax using pip.

    I made a little test script shown below

    import os
    os.environ["XLA_FLAGS"]="--xla_gpu_cuda_data_dir=/home/murphyk/miniconda3/lib"
    os.environ["CUDA_HOME"]="/usr"
    
    
    import jax
    import jax.numpy as np
    print("jax version {}".format(jax.__version__))
    from jax.lib import xla_bridge
    print("jax backend {}".format(xla_bridge.get_backend().platform))
    
    
    from jax import random
    key = random.PRNGKey(0)
    x = random.normal(key, (5,5))
    print(x)
    

    The output is shown below.

    jax version 0.1.39
    jax backend gpu
    2019-07-07 16:44:03.905071: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/llvm_gpu_backend/nvptx_backend_lib.cc:105] Unknown compute capability (7, 5) .Defaulting to libdevice for compute_20
    Traceback (most recent call last):
    
      File "<ipython-input-15-e39e42274024>", line 1, in <module>
        runfile('/home/murphyk/github/pyprobml/scripts/jax_debug.py', wdir='/home/murphyk/github/pyprobml/scripts')
    
      File "/home/murphyk/miniconda3/lib/python3.7/site-packages/spyder_kernels/customize/spydercustomize.py", line 827, in runfile
        execfile(filename, namespace)
    
      File "/home/murphyk/miniconda3/lib/python3.7/site-packages/spyder_kernels/customize/spydercustomize.py", line 110, in execfile
        exec(compile(f.read(), filename, 'exec'), namespace)
    
      File "/home/murphyk/github/pyprobml/scripts/jax_debug.py", line 18, in <module>
        x = random.normal(key, (5,5))
    
      File "/home/murphyk/miniconda3/lib/python3.7/site-packages/jax/random.py", line 389, in normal
        return _normal(key, shape, dtype)
    
      File "/home/murphyk/miniconda3/lib/python3.7/site-packages/jax/api.py", line 123, in f_jitted
        out = xla.xla_call(flat_fun, *args_flat, device_values=device_values)
    
      File "/home/murphyk/miniconda3/lib/python3.7/site-packages/jax/core.py", line 663, in call_bind
        ans = primitive.impl(f, *args, **params)
    
      File "/home/murphyk/miniconda3/lib/python3.7/site-packages/jax/interpreters/xla.py", line 606, in xla_call_impl
        compiled_fun = xla_callable(fun, device_values, *map(abstractify, args))
    
      File "/home/murphyk/miniconda3/lib/python3.7/site-packages/jax/linear_util.py", line 208, in memoized_fun
        ans = call(f, *args)
    
      File "/home/murphyk/miniconda3/lib/python3.7/site-packages/jax/interpreters/xla.py", line 621, in xla_callable
        compiled, result_shape = compile_jaxpr(jaxpr, consts, *abstract_args)
    
      File "/home/murphyk/miniconda3/lib/python3.7/site-packages/jax/interpreters/xla.py", line 207, in compile_jaxpr
        backend=xb.get_backend()), result_shape
    
      File "/home/murphyk/miniconda3/lib/python3.7/site-packages/jaxlib/xla_client.py", line 535, in Compile
        return backend.compile(self.computation, compile_options)
    
      File "/home/murphyk/miniconda3/lib/python3.7/site-packages/jaxlib/xla_client.py", line 118, in compile
        compile_options.device_assignment)
    
    RuntimeError: Not found: ./libdevice.compute_20.10.bc not found
    
    opened by murphyk 35
  • an experiment in handling instances with __jax_array__

    an experiment in handling instances with __jax_array__

    Before raising an error on an unrecognized type, first check if the object defines a __jax_array__ method. If it does, call it!

    This provides a way for custom types to be auto-converted to JAX-compatible types.

    We didn't add a check for __array__ because that might entail a significant change in behavior. For example, we'd be auto-converting tf.Tensor values. Maybe it's better to remain loud in those cases.

    Implementing this method is not sufficient for a type to be duck-typed enough for use with jax.numpy. But it may be necessary. That is, someone trying to add a duck-typed array to be used with JAX identified a need for __jax_array__ or similar.

    This feature is experimental, so it may disappear, change arbitrarily, or never be documented.

    cla: yes pull ready 
    opened by mattjj 33
  • FFT on CPU noticeably slower than SciPy's FFT

    FFT on CPU noticeably slower than SciPy's FFT

    The FFT implementation in JAX seems to be noticeably slower than the one in SciPy even though both use some flavor of the PocketFFT FFT implementation.

    from jax.config import config
    
    config.update("jax_enable_x64", True)
    
    import timeit
    import numpy as np
    from scipy import fft as sp_fft
    from jax import numpy as jnp
    from jax import jit, random
    
    key = random.PRNGKey(42)
    jr = random.normal(key, (2**26, ))
    r = np.array(jr)
    
    N_IT = 7
    
    timing = timeit.timeit(lambda: np.fft.fft(r), number=N_IT) / N_IT
    print(f"NumPy: {timing} s")
    timing = timeit.timeit(lambda: sp_fft.fft(r), number=N_IT) / N_IT
    print(f"SciPy: {timing} s")
    
    jax_fft = jnp.fft.fft
    timing = timeit.timeit(
        lambda: jax_fft(jr).block_until_ready(), number=N_IT
    ) / N_IT
    print(f"JAX (unjitted): {timing} s")
    
    jax_fft_jit = jit(jax_fft)
    jax_fft_jit(jr)  # Warm-up
    timing = timeit.timeit(
        lambda: jax_fft_jit(jr).block_until_ready(), number=N_IT
    ) / N_IT
    print(f"JAX (jitted): {timing} s")
    

    On an AMD Ryzen 7 4800H (with Radeon Graphics) with JAX 0.2.18 and jaxlib 0.1.69 installed from PyPI, I get the following timings:

    WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
    NumPy: 3.5420487681424544 s
    SciPy: 1.9076589474290293 s
    JAX (unjitted): 4.5806920641433475 s
    JAX (jitted): 4.63735637514245 s
    

    The timings improve if I compile JAX(lib) myself though it still compares unfavorably to SciPy (compiled JAX (unjitted): 3.506302507857202 s).

    My hypothesis is that JAX is distributing binaries that are suboptimal for recent AMD CPUs and much more importantly JAX is probably using some outdated version of PocketFFT.

    bug 
    opened by Edenhofer 32
  • add L-BFGS optimizer

    add L-BFGS optimizer

    As discussed in #1400

    • [x] Type check
    • [x] Rolling Buffer for histories (added TODO for now)
    • [x] someone double check the math (@shoyer @Joshuaalbert ?)
    • [ ] Compare histories to scipy
    cla: yes pull ready 
    opened by Jakob-Unfried 32
  • BFGS algorithm

    BFGS algorithm

    Addressing https://github.com/google/jax/issues/1400

    This PR adds jax.scipy.optimize and supplies the jittable BFGS algorithm.

    The inexact line search attempts to satisfy the first Wolfe condition and the strong second Wolfe condition.

    It adds the feature of using JAX to get the exact hessian for initialisation.

    Tests are in the commited files and need to be placed somewhere properly.

    Possible extension is the use custom_root to define gradient @shoyer.

    cla: yes 
    opened by Joshuaalbert 32
  • cuda failed to allocate errors

    cuda failed to allocate errors

    When running a a training script using the new memory allocation backend (https://github.com/google/jax/issues/417), I see a bunch of non-fatal errors like this:

    [1] 2019-05-29 23:55:55.555823: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:828] failed to allocate 528.00M (553648128 bytes) from 
    device: CUDA_ERROR_OUT_OF_MEMORY: out of memory
    [1] 2019-05-29 23:55:55.581962: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc:525] Resource exhausted: Failed to 
    allocate request for 528.00MiB (553648128B) on device ordinal 0
    [7] 2019-05-29 23:55:55.594693: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:828] failed to allocate 528.00M (553648128 bytes) from 
    device: CUDA_ERROR_OUT_OF_MEMORY: out of memory
    [7] 2019-05-29 23:55:55.606314: E external/org_tensorflow/tensorflow/compiler/xla/service/gpu/cudnn_conv_algorithm_picker.cc:525] Resource exhausted: Failed to 
    allocate request for 528.00MiB (553648128B) on device ordinal 0
    [1] 2019-05-29 23:55:55.633261: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:828] failed to allocate 1.14G (1224736768 bytes) from d
    evice: CUDA_ERROR_OUT_OF_MEMORY: out of memory
    [1] 2019-05-29 23:55:55.635169: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:828] failed to allocate 1.05G (1132822528 bytes) from d
    evice: CUDA_ERROR_OUT_OF_MEMORY: out of memory
    [1] 2019-05-29 23:55:55.646031: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:828] failed to allocate 561.11M (588365824 bytes) from 
    device: CUDA_ERROR_OUT_OF_MEMORY: out of memory
    [1] 2019-05-29 23:55:55.647926: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:828] failed to allocate 592.04M (620793856 bytes) from 
    device: CUDA_ERROR_OUT_OF_MEMORY: out of memory
    [7] 2019-05-29 23:55:55.655470: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:828] failed to allocate 1.14G (1224736768 bytes) from d
    evice: CUDA_ERROR_OUT_OF_MEMORY: out of memory
    

    Is this a known issue? The errors go away when using XLA_PYTHON_CLIENT_ALLOCATOR=platform.

    bug 
    opened by christopherhesse 32
  • removing repeated average function implementation and docs rendering

    removing repeated average function implementation and docs rendering

    Closes https://github.com/google/jax/issues/13853

    The function jax.numpy.average is repeated many times on the documentation page jax.numpy.average

    This PR removes duplicated code that caused the issue Thanks

    opened by hirwa-nshuti 0
  • The `jax.numpy.average` function is rendering repeatedly in Docs

    The `jax.numpy.average` function is rendering repeatedly in Docs

    Description

    This is being caused by repetition of the same function at the implementation side reduction functions I am going to open a pull request that can fix the mentioned issue

    image

    What jax/jaxlib version are you using?

    No response

    Which accelerator(s) are you using?

    CPU

    Additional system info

    Linux

    NVIDIA GPU info

    No response

    bug 
    opened by hirwa-nshuti 0
  • [sparse] bcoo_extract behavior with duplicate indices

    [sparse] bcoo_extract behavior with duplicate indices

    bcoo_extract is a function that given sparse indices, extracts corresponding values from a dense array. When the indices in a sparse array are unique, the semantics are clear, and it is effectively the inverse of bcoo_todense. For example:

    import jax.numpy as jnp
    from jax.experimental import sparse
    
    # bcoo_extract works as expected with unique indices
    data = jnp.array([1, 3, 5])
    indices = jnp.array([[0], [2], [4]])  # note: unique values
    
    spmat = sparse.BCOO((data, indices), shape=(5,))
    mat = spmat.todense()
    
    data2 = sparse.bcoo_extract(indices, mat)
    spmat2 = sparse.BCOO((data2, indices), shape=(5,))
    
    print(mat)
    # [1 0 3 0 5]
    print(spmat.todense())
    # [1 0 3 0 5]
    print(spmat2.todense()) # <--- matches the above
    # [1 0 3 0 5]
    

    However, when duplicate indices are present, bcoo_extract will effectively extract duplicate copies of any particular value, which may lead to surprising results:

    # bcoo_extract is surprising with duplicate indices
    data = jnp.array([1, 3, 5])
    indices = jnp.array([[2], [2], [4]])  # note: duplicate values
    
    spmat = sparse.BCOO((data, indices), shape=(5,))
    mat = spmat.todense()
    
    data2 = sparse.bcoo_extract(indices, mat)
    spmat2 = sparse.BCOO((data2, indices), shape=(5,))
    
    print(mat)
    # [0 0 4 0 5]
    print(spmat.todense())
    # [0 0 4 0 5]
    print(spmat2.todense()) # <--- different from the above
    # [0 0 8 0 5]
    

    So bcoo_extract is not precisely the inverse of bcoo_todense in these cases.

    We probably should fix this, but the correct semantics are not exactly clear (should bcoo_extract with duplicate indices only extract the first and leave the other entries blank? Should it divide the extracted entry evenly among the duplicate buckets? Should it just fail?)

    A complication here is that the current implemenetation actually is the correct transpose operation for bcoo_todense, and likewise bcoo_todense is the transpose of the current implementation of bcoo_extract. So if we were to change this behavior, we would need to implement the transpose rule for these primitives more carefully.

    With this in mind, it may be that we should add an optional keyword argument to bcoo_extract to control its behavior in the presence of duplicates, and add a complementary keyword to bcoo_todense that would allow the transpose rules of the modified primitives to map onto each other.

    P3 (no schedule) 
    opened by jakevdp 0
  • [jax2tf] Improves jax2tf (enable_xla=False) model testing logic.

    [jax2tf] Improves jax2tf (enable_xla=False) model testing logic.

    [jax2tf] Improves jax2tf (enable_xla=False) model testing logic.

    • Previously we were creating the variables for all models, even if we did not test them. This changes ensures we only create them if we actually test the model
    • It also reports when we aren't testing any models.
    • Ensures we can generate markdown both from internally and externally.
    • Ran all tests again and updated the g3doc with the results, which are slightly different now.
    opened by copybara-service[bot] 0
  • `ravel_pytree` now produces jit-compatible unravel functions

    `ravel_pytree` now produces jit-compatible unravel functions

    Previously,

    _, unravel1 = ravel_pytree(pytree)
    _, unravel2 = ravel_pytree(pytree)
    
    @partial(jax.jit, static_argnums=0)
    def run(unravel, ...):
        ...
    
    run(unravel1, ...)
    run(unravel2, ...)
    

    would unecessarily induce recompilation.

    opened by patrick-kidger 0
Releases(jaxlib-v0.4.1)
  • jaxlib-v0.4.1(Dec 13, 2022)

    • Changes
      • Support for Python 3.7 has been dropped, in accordance with JAX's {ref}version-support-policy.
      • The behavior of XLA_PYTHON_CLIENT_MEM_FRACTION=.XX has been changed to allocate XX% of the total GPU memory instead of the previous behavior of using currently available GPU memory to calculate preallocation. Please refer to GPU memory allocation for more details.
      • The deprecated method .block_host_until_ready() has been removed. Use .block_until_ready() instead.
    Source code(tar.gz)
    Source code(zip)
  • jax-v0.4.1(Dec 13, 2022)

    • Changes
      • Support for Python 3.7 has been dropped, in accordance with JAX's {ref}version-support-policy.
      • We introduce jax.Array which is a unified array type that subsumes DeviceArray, ShardedDeviceArray, and GlobalDeviceArray types in JAX. The jax.Array type helps make parallelism a core feature of JAX, simplifies and unifies JAX internals, and allows us to unify jit and pjit. jax.Array has been enabled by default in JAX 0.4 and makes some breaking change to the pjit API. The jax.Array migration guide can help you migrate your codebase to jax.Array. You can also look at the Distributed arrays and automatic parallelization tutorial to understand the new concepts.
      • PartitionSpec and Mesh are now out of experimental. The new API endpoints are jax.sharding.PartitionSpec and jax.sharding.Mesh. jax.experimental.maps.Mesh and jax.experimental.PartitionSpec are deprecated and will be removed in 3 months.
      • with_sharding_constraints new public endpoint is jax.lax.with_sharding_constraint.
      • If using ABSL flags together with jax.config, the ABSL flag values are no longer read or written after the JAX configuration options are initially populated from the ABSL flags. This change improves performance of reading jax.config options, which are used pervasively in JAX.
      • The jax2tf.call_tf function now uses for TF lowering the first TF device of the same platform as used by the embedding JAX computation. Before, it was using the 0th device for the JAX-default backend.
      • A number of jax.numpy functions now have their arguments marked as positional-only, matching NumPy.
      • jnp.msort is now deprecated, following the deprecation of np.msort in numpy 1.24. It will be removed in a future release, in accordance with the {ref}api-compatibility policy. It can be replaced with jnp.sort(a, axis=0).
    Source code(tar.gz)
    Source code(zip)
  • jaxlib-v0.3.25(Nov 15, 2022)

  • jax-v0.3.25(Nov 15, 2022)

  • jaxlib-v0.3.24(Nov 4, 2022)

  • jax-v0.3.24(Nov 4, 2022)

  • jax-v0.3.23(Oct 12, 2022)

  • jaxlib-v0.3.22(Oct 11, 2022)

  • jax-v0.3.22(Oct 11, 2022)

    • Changes
      • Add JAX_PLATFORMS=tpu,cpu as default setting in TPU initialization, so JAX will raise an error if TPU cannot be initialized instead of falling back to CPU. Set JAX_PLATFORMS='' to override this behavior and automatically choose an available backend (the original default), or set JAX_PLATFORMS=cpu to always use CPU regardless of if the TPU is available.
    • Deprecations
      • Several test utilities deprecated in JAX v0.3.8 are now removed from {mod}jax.test_util.
    Source code(tar.gz)
    Source code(zip)
  • jax-v0.3.21(Oct 3, 2022)

    • Changes
      • The persistent compilation cache will now warn instead of raising an exception on error ({jax-issue}#12582), so program execution can continue if something goes wrong with the cache. Set JAX_RAISE_PERSISTENT_CACHE_ERRORS=true to revert this behavior.
    Source code(tar.gz)
    Source code(zip)
  • jaxlib-v0.3.20(Sep 28, 2022)

    Notable changes:

    • Fixes support for limiting the visible CUDA devices viajax_cuda_visible_devices in distributed jobs. This functionality is needed for the JAX/SLURM integration on GPU (#12533).
    Source code(tar.gz)
    Source code(zip)
  • jax-v0.3.20(Sep 28, 2022)

    Notable changes:

    • Adds missing .pyi files that were missing from the previous release (#12536).
    • Fixes an incompatibility between jax 0.3.19 and the libtpu version it pinned (#12550). Requires jaxlib 0.3.20.
    • Fix incorrect pip url in setup.py comment (#12528).
    Source code(tar.gz)
    Source code(zip)
  • jax-v0.3.19(Sep 27, 2022)

  • jax-v0.3.18(Sep 26, 2022)

    • GitHub commits.
    • Changes
      • Ahead-of-time lowering and compilation functionality (tracked in {jax-issue}#7733) is stable and public. See the overview and the API docs for {mod}jax.stages.
      • Introduced {class}jax.Array, intended to be used for both isinstance checks and type annotations for array types in JAX. Notice that this included some subtle changes to how isinstance works for {class}jax.numpy.ndarray for jax-internal objects, as {class}jax.numpy.ndarray is now a simple alias of {class}jax.Array.
    • Breaking changes
      • jax._src is no longer imported into the from the public jax namespace. This may break users that were using JAX internals.
      • jax.soft_pmap has been deleted. Please use pjit or xmap instead. jax.soft_pmap is undocumented. If it were documented, a deprecation period would have been provided.
    Source code(tar.gz)
    Source code(zip)
  • jax-v0.3.17(Aug 31, 2022)

    • GitHub commits.
    • Bugs
      • Fix corner case issue in gradient of lax.pow with an exponent of zero (#12041)
    • Breaking changes
      • jax.checkpoint, also known as jax.remat, no longer supports the concrete option, following the previous version's deprecation; see JEP 11830.
    • Changes
      • Added jax.pure_callback that enables calling back to pure Python functions from compiled functions (e.g. functions decorated with jax.jit or jax.pmap).
    • Deprecations:
      • The deprecated DeviceArray.tile() method has been removed. Use jax.numpy.tile (#11944).
      • DeviceArray.to_py() has been deprecated. Use np.asarray(x) instead.
    Source code(tar.gz)
    Source code(zip)
  • jax-v0.3.16(Aug 12, 2022)

    • GitHub commits.
    • Breaking changes
      • Support for NumPy 1.19 has been dropped, per the deprecation policy. Please upgrade to NumPy 1.20 or newer.
    • Changes
      • Added jax.debug that includes utilities for runtime value debugging such at jax.debug.print and jax.debug.breakpoint.
      • Added new documentation for runtime value debugging
    • Deprecations
      • jax.mask jax.shapecheck APIs have been removed. See #11557.
      • jax.experimental.loops has been removed. See #10278 for an alternative API.
      • jax.tree_util.tree_multimap has been removed. It has been deprecated since JAX release 0.3.5, and jax.tree_util.tree_map is a direct replacement.
      • Removed jax.experimental.stax; it has long been a deprecated alias of jax.example_libraries.stax.
      • Removed jax.experimental.optimizers; it has long been a deprecated alias of jax.example_libraries.optimizers.
      • jax.checkpoint, also known as jax.remat, has a new implementation switched on by default, meaning the old implementation is deprecated; see JEP 11830.
    Source code(tar.gz)
    Source code(zip)
  • jaxlib-v0.3.15(Jul 22, 2022)

  • jax-v0.3.15(Jul 22, 2022)

  • jaxlib-v0.3.14(Jun 21, 2022)

  • jax-v0.3.14(Jun 21, 2022)

  • jax-v0.3.13(May 16, 2022)

  • jax-v0.3.12(May 16, 2022)

  • jax-v0.3.11(May 15, 2022)

    • Changes
      • {func}jax.lax.eigh now accepts an optional sort_eigenvalues argument that allows users to opt out of eigenvalue sorting on TPU.
    • Deprecations
      • Non-array arguments to functions in {mod}jax.lax.linalg are now marked keyword-only. As a backward-compatibility step passing keyword-only arguments positionally yields a warning, but in a future JAX release passing keyword-only arguments positionally will fail. However, most users should prefer to use {mod}jax.numpy.linalg instead.
      • {func}jax.scipy.linalg.polar_unitary, which was a JAX extension to the scipy API, is deprecated. Use {func}jax.scipy.linalg.polar instead.
    Source code(tar.gz)
    Source code(zip)
  • jaxlib-v0.3.10(May 4, 2022)

  • jax-v0.3.10(May 4, 2022)

  • jax-v0.3.9(May 3, 2022)

  • jax-v0.3.8(Apr 30, 2022)

    • GitHub commits.
    • Changes
      • {func}jax.numpy.linalg.svd on TPUs uses a qdwh-svd solver.
      • {func}jax.numpy.linalg.cond on TPUs now accepts complex input.
      • {func}jax.numpy.linalg.pinv on TPUs now accepts complex input.
      • {func}jax.numpy.linalg.matrix_rank on TPUs now accepts complex input.
      • {func}jax.scipy.cluster.vq.vq has been added.
      • jax.experimental.maps.mesh has been deleted. Please use jax.experimental.maps.Mesh. Please see https://jax.readthedocs.io/en/latest/_autosummary/jax.experimental.maps.Mesh.html#jax.experimental.maps.Mesh for more information.
      • {func}jax.scipy.linalg.qr now returns a length-1 tuple rather than the raw array when mode='r', in order to match the behavior of scipy.linalg.qr ({jax-issue}#10452)
      • {func}jax.numpy.take_along_axis now takes an optional mode parameter that specifies the behavior of out-of-bounds indexing. By default, invalid values (e.g., NaN) will be returned for out-of-bounds indices. In previous versions of JAX, invalid indices were clamped into range. The previous behavior can be restored by passing mode="clip".
      • {func}jax.numpy.take now defaults to mode="fill", which returns invalid values (e.g., NaN) for out-of-bounds indices.
      • Scatter operations, such as x.at[...].set(...), now have "drop" semantics. This has no effect on the scatter operation itself, but it means that when differentiated the gradient of a scatter will yield zero cotangents for out-of-bounds indices. Previously out-of-bounds indices were clamped into range for the gradient, which was not mathematically correct.
      • {func}jax.numpy.take_along_axis now raises a TypeError if its indices are not of an integer type, matching the behavior of {func}numpy.take_along_axis. Previously non-integer indices were silently cast to integers.
      • {func}jax.numpy.ravel_multi_index now raises a TypeError if its dims argument is not of an integer type, matching the behavior of {func}numpy.ravel_multi_index. Previously non-integer dims was silently cast to integers.
      • {func}jax.numpy.split now raises a TypeError if its axis argument is not of an integer type, matching the behavior of {func}numpy.split. Previously non-integer axis was silently cast to integers.
      • {func}jax.numpy.indices now raises a TypeError if its dimensions are not of an integer type, matching the behavior of {func}numpy.indices. Previously non-integer dimensions were silently cast to integers.
      • {func}jax.numpy.diag now raises a TypeError if its k argument is not of an integer type, matching the behavior of {func}numpy.diag. Previously non-integer k was silently cast to integers.
      • Added {func}jax.random.orthogonal.
    • Deprecations
      • Many functions and objects available in {mod}jax.test_util are now deprecated and will raise a warning on import. This includes cases_from_list, check_close, check_eq, device_under_test, format_shape_dtype_string, rand_uniform, skip_on_devices, with_config, xla_bridge, and _default_tolerance ({jax-issue}#10389). These, along with previously-deprecated JaxTestCase, JaxTestLoader, and BufferDonationTestCase, will be removed in a future JAX release. Most of these utilites can be replaced by calls to standard python & numpy testing utilities found in e.g. {mod}unittest, {mod}absl.testing, {mod}numpy.testing, etc. JAX-specific functionality such as device checking can be replaced through the use of public APIs such as {func}jax.devices. Many of the deprecated utilities will still exist in {mod}jax._src.test_util, but these are not public APIs and as such may be changed or removed without notice in future releases.
    Source code(tar.gz)
    Source code(zip)
  • jaxlib-v0.3.7(Apr 29, 2022)

  • jax-v0.3.7(Apr 29, 2022)

    • Fixed a performance problem if the indices passed to jax.numpy.take_along_axis were broadcasted (#10281).
    • jax.scipy.special.expit and jax.scipy.special.logit now require their arguments to be scalars or JAX arrays. They also now promote integer arguments to floating point.
    • The DeviceArray.tile() method is deprecated, because numpy arrays do not have a tile() method. As a replacement for this, use jax.numpy.tile (#10266).
    Source code(tar.gz)
    Source code(zip)
  • jax-v0.3.6(Apr 13, 2022)

Owner
Google
Google ❤️ Open Source
Google
Composable transformations of Python+NumPy programsComposable transformations of Python+NumPy programs

Chex Chex is a library of utilities for helping to write reliable JAX code. This includes utils to help: Instrument your code (e.g. assertions) Debug

DeepMind 506 Jan 8, 2023
Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.

Equinox Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX Equinox brings more power to your model building in JAX. Repr

Patrick Kidger 909 Dec 30, 2022
Grow Function: Generate 3D Stacked Bifurcating Double Deep Cellular Automata based organisms which differentiate using a Genetic Algorithm...

Grow Function: A 3D Stacked Bifurcating Double Deep Cellular Automata which differentiates using a Genetic Algorithm... TLDR;High Def Trees that you can mint as NFTs on Solana

Nathaniel Gibson 4 Oct 8, 2022
Jittor is a high-performance deep learning framework based on JIT compiling and meta-operators.

Jittor: a Just-in-time(JIT) deep learning framework Quickstart | Install | Tutorial | Chinese Jittor is a high-performance deep learning framework bas

null 2.7k Jan 3, 2023
This repo is a C++ version of yolov5_deepsort_tensorrt. Packing all C++ programs into .so files, using Python script to call C++ programs further.

yolov5_deepsort_tensorrt_cpp Introduction This repo is a C++ version of yolov5_deepsort_tensorrt. And packing all C++ programs into .so files, using P

null 41 Dec 27, 2022
functorch is a prototype of JAX-like composable function transforms for PyTorch.

functorch is a prototype of JAX-like composable function transforms for PyTorch.

Facebook Research 1.2k Jan 9, 2023
Refactoring dalle-pytorch and taming-transformers for TPU VM

Text-to-Image Translation (DALL-E) for TPU in Pytorch Refactoring Taming Transformers and DALLE-pytorch for TPU VM with Pytorch Lightning Requirements

Kim, Taehoon 61 Nov 7, 2022
A minimal TPU compatible Jax implementation of NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis

NeRF Minimal Jax implementation of NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis. Result of Tiny-NeRF RGB Depth

Soumik Rakshit 11 Jul 24, 2022
The full training script for Enformer (Tensorflow Sonnet) on TPU clusters

Enformer TPU training script (wip) The full training script for Enformer (Tensorflow Sonnet) on TPU clusters, in an effort to migrate the model to pyt

Phil Wang 10 Oct 19, 2022
MLP-Numpy - A simple modular implementation of Multi Layer Perceptron in pure Numpy.

MLP-Numpy A simple modular implementation of Multi Layer Perceptron in pure Numpy. I used the Iris dataset from scikit-learn library for the experimen

Soroush Omranpour 1 Jan 1, 2022
Multiple types of NN model optimization environments. It is possible to directly access the host PC GUI and the camera to verify the operation. Intel iHD GPU (iGPU) support. NVIDIA GPU (dGPU) support.

mtomo Multiple types of NN model optimization environments. It is possible to directly access the host PC GUI and the camera to verify the operation.

Katsuya Hyodo 24 Mar 2, 2022
High performance Cross-platform Inference-engine, you could run Anakin on x86-cpu,arm, nv-gpu, amd-gpu,bitmain and cambricon devices.

Anakin2.0 Welcome to the Anakin GitHub. Anakin is a cross-platform, high-performance inference engine, which is originally developed by Baidu engineer

null 514 Dec 28, 2022
GrabGpu_py: a scripts for grab gpu when gpu is free

GrabGpu_py a scripts for grab gpu when gpu is free. WaitCondition: gpu_memory >

tianyuluan 3 Jun 18, 2022
⚡️Optimizing einsum functions in NumPy, Tensorflow, Dask, and more with contraction order optimization.

Optimized Einsum Optimized Einsum: A tensor contraction order optimizer Optimized einsum can significantly reduce the overall execution time of einsum

Daniel Smith 653 Dec 30, 2022
We present a framework for training multi-modal deep learning models on unlabelled video data by forcing the network to learn invariances to transformations applied to both the audio and video streams.

Multi-Modal Self-Supervision using GDT and StiCa This is an official pytorch implementation of papers: Multi-modal Self-Supervision from Generalized D

Facebook Research 42 Dec 9, 2022
Image transformations designed for Scene Text Recognition (STR) data augmentation. Published at ICCV 2021 Workshop on Interactive Labeling and Data Augmentation for Vision.

Data Augmentation for Scene Text Recognition (ICCV 2021 Workshop) (Pronounced as "strog") Paper Arxiv Why it matters? Scene Text Recognition (STR) req

Rowel Atienza 152 Dec 28, 2022
Using some basic methods to show linkages and transformations of robotic arms

roboticArmVisualizer Python GUI application to create custom linkages and adjust joint angles. In the future, I plan to add 2d inverse kinematics solv

Sandesh Banskota 1 Nov 19, 2021