Composable transformations of Python+NumPy programsComposable transformations of Python+NumPy programs

Related tags

Deep Learning chex
Overview

Chex

CI status

Chex is a library of utilities for helping to write reliable JAX code.

This includes utils to help:

  • Instrument your code (e.g. assertions)
  • Debug (e.g. transforming pmaps in vmaps within a context manager).
  • Test JAX code across many variants (e.g. jitted vs non-jitted).

Installation

Chex can be installed with pip directly from github, with the following command:

pip install git+git://github.com/deepmind/chex.git

or from PyPI:

pip install chex

Modules Overview

Dataclass (dataclass.py)

Dataclasses are a popular construct introduced by Python 3.7 to allow to easily specify typed data structures with minimal boilerplate code. They are not, however, compatible with JAX and dm-tree out of the box.

In Chex we provide a JAX-friendly dataclass implementation reusing python dataclasses.

Chex implementation of dataclass registers dataclasses as internal PyTree nodes to ensure compatibility with JAX data structures.

In addition, we provide a class wrapper that exposes dataclasses as collections.Mapping descendants which allows to process them (e.g. (un-)flatten) in dm-tree methods as usual Python dictionaries. See @mappable_dataclass docstring for more details.

Example:

@chex.dataclass
class Parameters:
  x: chex.ArrayDevice
  y: chex.ArrayDevice

parameters = Parameters(
    x=jnp.ones((2, 2)),
    y=jnp.ones((1, 2)),
)

# Dataclasses can be treated as JAX pytrees
jax.tree_map(lambda x: 2.0 * x, parameters)

# and as mappings by dm-tree
tree.flatten(parameters)

NOTE: Unlike standard Python 3.7 dataclasses, Chex dataclasses cannot be constructed using positional arguments. They support construction arguments provided in the same format as the Python dict constructor. Dataclasses can be converted to tuples with the from_tuple and to_tuple methods if necessary.

parameters = Parameters(
    jnp.ones((2, 2)),
    jnp.ones((1, 2)),
)
# ValueError: Mappable dataclass constructor doesn't support positional args.

Assertions (asserts.py)

One limitation of PyType annotations for JAX is that they do not support the specification of DeviceArray ranks, shapes or dtypes. Chex includes a number of functions that allow flexible and concise specification of these properties.

E.g. suppose you want to ensure that all tensors t1, t2, t3 have the same shape, and that tensors t4, t5 have rank 2 and (3 or 4), respectively.

chex.assert_equal_shape([t1, t2, t3])
chex.assert_rank([t4, t5], [2, {3, 4}])

More examples:

from chex import assert_shape, assert_rank, ...

assert_shape(x, (2, 3))                # x has shape (2, 3)
assert_shape([x, y], [(), (2,3)])      # x is scalar and y has shape (2, 3)

assert_rank(x, 0)                      # x is scalar
assert_rank([x, y], [0, 2])            # x is scalar and y is a rank-2 array
assert_rank([x, y], {0, 2})            # x and y are scalar OR rank-2 arrays

assert_type(x, int)                    # x has type `int` (x can be an array)
assert_type([x, y], [int, float])      # x has type `int` and y has type `float`

assert_equal_shape([x, y, z])          # x, y, and z have equal shapes

assert_trees_all_close(tree_x, tree_y) # values and structure of trees match
assert_tree_all_finite(tree_x)         # all tree_x leaves are finite

assert_devices_available(2, 'gpu')     # 2 GPUs available
assert_tpu_available()                 # at least 1 TPU available

assert_numerical_grads(f, (x, y), j)   # f^{(j)}(x, y) matches numerical grads

All chex assertions support the following optional kwargs for manipulating the emitted exception messages:

  • custom_message: A string to include into the emitted exception messages.
  • include_default_message: Whether to include the default Chex message into the emitted exception messages.
  • exception_type: An exception type to use. AssertionError by default.

For example, the following code:

dataset = load_dataset()
params = init_params()
for i in range(num_steps):
  params = update_params(params, dataset.sample())
  chex.assert_tree_all_finite(params,
                              custom_message=f'Failed at iteration {i}.',
                              exception_type=ValueError)

will raise a ValueError that includes a step number when params get polluted with NaNs or Nones.

JAX re-traces JIT'ted function every time the structure of passed arguments changes. Often this behavior is inadvertent and leads to a significant performance drop which is hard to debug. @chex.assert_max_traces decorator asserts that the function is not re-traced more than n times during program execution.

Global trace counter can be cleared by calling chex.clear_trace_counter(). This function be used to isolate unittests relying on @chex.assert_max_traces.

Examples:

  @jax.jit
  @chex.assert_max_traces(n=1)
  def fn_sum_jitted(x, y):
    return x + y

  z = fn_sum_jitted(jnp.zeros(3), jnp.zeros(3))
  t = fn_sum_jitted(jnp.zeros(6, 7), jnp.zeros(6, 7))  # AssertionError!

Can be used with jax.pmap() as well:

  def fn_sub(x, y):
    return x - y

  fn_sub_pmapped = jax.pmap(chex.assert_max_traces(fn_sub, n=10))

More about tracing

See documentation of asserts.py for details on all supported assertions.

Test variants (variants.py)

JAX relies extensively on code transformation and compilation, meaning that it can be hard to ensure that code is properly tested. For instance, just testing a python function using JAX code will not cover the actual code path that is executed when jitted, and that path will also differ whether the code is jitted for CPU, GPU, or TPU. This has been a source of obscure and hard to catch bugs where XLA changes would lead to undesirable behaviours that however only manifest in one specific code transformation.

Variants make it easy to ensure that unit tests cover different ‘variations’ of a function, by providing a simple decorator that can be used to repeat any test under all (or a subset) of the relevant code transformations.

E.g. suppose you want to test the output of a function fn with or without jit. You can use chex.variants to run the test with both the jitted and non-jitted version of the function by simply decorating a test method with @chex.variants, and then using self.variant(fn) in place of fn in the body of the test.

def fn(x, y):
  return x + y
...

class ExampleTest(chex.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  def test(self):
    var_fn = self.variant(fn)
    self.assertEqual(fn(1, 2), 3)
    self.assertEqual(var_fn(1, 2), fn(1, 2))

If you define the function in the test method, you may also use self.variant as a decorator in the function definition. For example:

class ExampleTest(chex.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  def test(self):
    @self.variant
    def var_fn(x, y):
       return x + y

    self.assertEqual(var_fn(1, 2), 3)

Example of parameterized test:

from absl.testing import parameterized

# Could also be:
#  `class ExampleParameterizedTest(chex.TestCase, parameterized.TestCase):`
#  `class ExampleParameterizedTest(chex.TestCase):`
class ExampleParameterizedTest(parameterized.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  @parameterized.named_parameters(
      ('case_positive', 1, 2, 3),
      ('case_negative', -1, -2, -3),
  )
  def test(self, arg_1, arg_2, expected):
    @self.variant
    def var_fn(x, y):
       return x + y

    self.assertEqual(var_fn(arg_1, arg_2), expected)

Chex currently supports the following variants:

  • with_jit -- applies jax.jit() transformation to the function.
  • without_jit -- uses the function as is, i.e. identity transformation.
  • with_device -- places all arguments (except specified in ignore_argnums argument) into device memory before applying the function.
  • without_device -- places all arguments in RAM before applying the function.
  • with_pmap -- applies jax.pmap() transformation to the function (see notes below).

See documentation in variants.py for more details on the supported variants. More examples can be found in variants_test.py.

Variants notes

  • Test classes that use @chex.variants must inherit from chex.TestCase (or any other base class that unrolls tests generators within TestCase, e.g. absl.testing.parameterized.TestCase).

  • [jax.vmap] All variants can be applied to a vmapped function; please see an example in variants_test.py (test_vmapped_fn_named_params and test_pmap_vmapped_fn).

  • [@chex.all_variants] You can get all supported variants by using the decorator @chex.all_variants.

  • [with_pmap variant] jax.pmap(fn) (doc) performs parallel map of fn onto multiple devices. Since most tests run in a single-device environment (i.e. having access to a single CPU or GPU), in which case jax.pmap is a functional equivalent to jax.jit, with_pmap variant is skipped by default (although it works fine with a single device). Below we describe a way to properly test fn if it is supposed to be used in multi-device environments (TPUs or multiple CPUs/GPUs). To disable skipping with_pmap variants in case of a single device, add --chex_skip_pmap_variant_if_single_device=false to your test command.

Fakes (fake.py)

Debugging in JAX is made more difficult by code transformations such as jit and pmap, which introduce optimizations that make code hard to inspect and trace. It can also be difficult to disable those transformations during debugging as they can be called at several places in the underlying code. Chex provides tools to globally replace jax.jit with a no-op transformation and jax.pmap with a (non-parallel) jax.vmap, in order to more easily debug code in a single-device context.

For example, you can use Chex to fake pmap and have it replaced with a vmap. This can be achieved by wrapping your code with a context manager:

with chex.fake_pmap():
  @jax.pmap
  def fn(inputs):
    ...

  # Function will be vmapped over inputs
  fn(inputs)

The same functionality can also be invoked with start and stop:

fake_pmap = chex.fake_pmap()
fake_pmap.start()
... your jax code ...
fake_pmap.stop()

In addition, you can fake a real multi-device test environment with a multi-threaded CPU. See section Faking multi-device test environments for more details.

See documentation in fake.py and examples in fake_test.py for more details.

Faking multi-device test environments

In situations where you do not have easy access to multiple devices, you can still test parallel computation using single-device multi-threading.

In particular, one can force XLA to use a single CPU's threads as separate devices, i.e. to fake a real multi-device environment with a multi-threaded one. These two options are theoretically equivalent from XLA perspective because they expose the same interface and use identical abstractions.

Chex has a flag chex_n_cpu_devices that specifies a number of CPU threads to use as XLA devices.

To set up a multi-threaded XLA environment for absl tests, define setUpModule function in your test module:

def setUpModule():
  chex.set_n_cpu_devices()

Now you can launch your test with python test.py --chex_n_cpu_devices=N to run it in multi-device regime. Note that all tests within a module will have an access to N devices.

More examples can be found in variants_test.py, fake_test.py and fake_set_n_cpu_devices_test.py.

Citing Chex

To cite this repository:

@software{chex2020github,
  author = {David Budden and Matteo Hessel and Iurii Kemaev and Stephen Spencer
            and Fabio Viola},
  title = {Chex: Testing made fun, in JAX!},
  url = {http://github.com/deepmind/chex},
  version = {0.0.1},
  year = {2020},
}

In this bibtex entry, the version number is intended to be from chex/__init__.py, and the year corresponds to the project's open-source release.

Issues
  • [chex] Allow an ellipsis in the expected shape passed to `assert_shape`.

    [chex] Allow an ellipsis in the expected shape passed to `assert_shape`.

    [chex] Allow an ellipsis in the expected shape passed to assert_shape.

    This allows things like:

    chex.assert_shape(a, [..., seq_len, features])
    

    This is particularly useful for situations like variable numbers of batch dimensions.

    cla: yes 
    opened by copybara-service[bot] 8
  • Improve chex.fake_jit by also disabling internal jitting in functions such as jax.lax.scan.

    Improve chex.fake_jit by also disabling internal jitting in functions such as jax.lax.scan.

    Improve chex.fake_jit by also disabling internal jitting in functions such as jax.lax.scan.

    cla: no 
    opened by copybara-service[bot] 6
  • Raise informative error for negative axes when asserting axis dimension.

    Raise informative error for negative axes when asserting axis dimension.

    Raise informative error for negative axes when asserting axis dimension.

    cla: no 
    opened by copybara-service[bot] 4
  • Add option to fake pmap with vmap while still jitting the result.

    Add option to fake pmap with vmap while still jitting the result.

    Add option to fake pmap with vmap while still jitting the result.

    cla: no 
    opened by copybara-service[bot] 3
  • Document that `fake_pmap` may change the output of pmapped functions.

    Document that `fake_pmap` may change the output of pmapped functions.

    Document that fake_pmap may change the output of pmapped functions.

    cla: no 
    opened by copybara-service[bot] 3
  • Prevent chex.fake_pmap|jit function signature inspection from following through wrappers, otherwise if a wrapper changes the signature in some way, the fakes choke on those.

    Prevent chex.fake_pmap|jit function signature inspection from following through wrappers, otherwise if a wrapper changes the signature in some way, the fakes choke on those.

    Prevent chex.fake_pmap|jit function signature inspection from following through wrappers, otherwise if a wrapper changes the signature in some way, the fakes choke on those.

    cla: no 
    opened by copybara-service[bot] 3
  • [chex] Allow `set`s of alternatives in expected shape for `assert_shape`.

    [chex] Allow `set`s of alternatives in expected shape for `assert_shape`.

    [chex] Allow sets of alternatives in expected shape for assert_shape.

    This extends the behavior allowed by assert_rank to assert_shape, enabling things like:

    chex.assert_shape(mask, (batch_size, {num_heads, 1}, q_len, kv_len))
    

    In this example, axis 1 can either be num_heads or 1, which is helpful, for example, in situations where you want to allow only particular dimensions to be broadcastable.

    cla: yes 
    opened by copybara-service[bot] 2
  • Fix typo

    Fix typo

    Fix typo

    cla: no 
    opened by copybara-service[bot] 2
  • [REQ] Conda recipe

    [REQ] Conda recipe

    Hi, I'm the lead developer of NetKet, an established machine learning / quantum physics package.

    We have recently finished rewriting our core to be based on Jax (and flax), and recently released a beta version. Since many physicists seem to use anaconda, we would also like to update our conda recipe. However, since we depend on optax (and therefore on Chex), we would need Chex to have a Conda recipe.

    Is that something you'd consider? I am willing to volunteer some work to help you.

    I tried creating a recipe starting from your pypi source distribution, but that is problematic because you don't bundle your requirements.txt file, which is required to run setup.py. I could create a recipe from the tag tarballs on GitHub, but that sometimes prevent the conda packages from auto-updating the recipe for later releases.

    opened by PhilipVinc 2
  • Create py.typed

    Create py.typed

    cla: yes 
    opened by graingert 2
  • WIP

    WIP

    WIP

    cla: yes 
    opened by copybara-service[bot] 0
  • chex.variants(with_pmap=True) ignores `static_argnames`

    chex.variants(with_pmap=True) ignores `static_argnames`

    The _with_pmap function accepts static_argnums as a parameter, but not static_argnames. This is inconsistent with other variants, such as with_jit and with_device. Crucially, this prevents to test methods that require to pass arguments by name (e.g., Distrax's Distribution.sample())

    More generally, it would be best if all variants accepted the same parameters where possible (i.e., where not specific to a single variant) and I would suggest to check all keys in **unused_kwargs against a list of allowed parameters (i.e., the union of the parameters of all variant functions) to prevent silent errors due to e.g., misspells.

    opened by fvisin 9
  • Move dataclass registration to __init__ so that it's invoked after deserialization.

    Move dataclass registration to __init__ so that it's invoked after deserialization.

    Move dataclass registration to init so that it's invoked after deserialization.

    cla: no 
    opened by copybara-service[bot] 0
  • Consider supporting static attributes in chex.dataclass

    Consider supporting static attributes in chex.dataclass

    from jax import jit
    from jax.lax import scan
    from tjax import IntegralNumeric, RealNumeric
    from tjax.dataclasses import dataclass, field
    import chex
    
    def f(carry, _):
      return carry + 1.0, None
    
    @jit
    def do_scan(c):
      final, _ = scan(f, c.x, None, c.y)
      return final
    
    @dataclass
    class C:
      x: RealNumeric
      y: IntegralNumeric = field(static=True)
    
    print(do_scan(C(1.0, 10)))  # works
    
    @chex.dataclass
    class D:
      x: RealNumeric
      y: IntegralNumeric
    
    print(do_scan(D(x=1.0, y=10)))  # fails
    
    opened by NeilGirdhar 1
  • Error with Pydantic

    Error with Pydantic

    Hello! I'm interested in using pydantic's recursive constructor / asdict functionality, but jax.jit-ed functions give the following error:

    Argument '_Pydantic_OptimConfig_93971134241088(.. SOMETHING HERE...)' of type <class 'pydantic.dataclasses._Pydantic_OptimConfig_93971134241088'> is not a valid JAX type.
    
    opened by kaiwenw 0
  • Specify non-pytree node dataclass fields

    Specify non-pytree node dataclass fields

    Hi,

    Thanks for making this awesome library!

    Is it possible to specify fields in the chex.dataclass definitions to not include certain fields? This is a feature supported in flax https://flax.readthedocs.io/en/latest/_modules/flax/struct.html#dataclass which I found to be quite useful when defining data classes with fields (such as JAX functions) that shouldn't be mapped over with dm-tree or jax.tree_map. I am not sure if this is supported out of the box by chex at the moment but is something that I hope would be part of chex.

    opened by ethanluoyc 0
  • without_jit=True for already jitted functions

    without_jit=True for already jitted functions

    In most JAX-based implementations, jit is almost always included. Basically, if there is no reason not to use it, people will try to take advantage of its speedup.

    I noticed that @chex.variants(with_jit=True, without_jit=True) is a great way to assert the same behavior for both execution paths, as long as the variant is derived from a non-jitted function.

    In the following example, I would expect to see "Tracing fn" four times total: Three times for the non-jitted variants and once for the initial jit compiliation. In reality, test_variant_pre_jitted() is executed twice with the jitted fn, resulting in two tracer outputs.

    @chex.variants(with_jit=True, without_jit=Truue)
    def test_variant_pre_jitted(self):
      @jit
      def fn(x, y):
        print("Tracing fn")
        return x + y
    
      var_fn = self.variant(fn)
      self.assertEqual(var_fn(1, 2), 3)
      self.assertEqual(var_fn(3, 4), 7)
      self.assertEqual(var_fn(5, 6), 11)
    

    Of course, omitting @jit will lead to the expected behavior. However, when more complex implementations already make use of jit, variants do not make sense anymore, sadly.

    My case is the latter and I only see the option of implementing a model-wide use_jit flag so that I can derive variants from non-jitted code. However, this makes the whole idea of variants rather obsolete altogether.

    I'm aware this could well be a limitation of JAX and jit itself rather than chex. In that case, I think an error when jitted code is passed to variant() would make this more transparent.

    opened by fabiannagel 0
Releases(v0.1.0)
Owner
DeepMind
DeepMind
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

JAX: Autograd and XLA Quickstart | Transformations | Install guide | Neural net libraries | Change logs | Reference docs | Code search News: JAX tops

Google 11.4k Feb 13, 2021
This repo is a C++ version of yolov5_deepsort_tensorrt. Packing all C++ programs into .so files, using Python script to call C++ programs further.

yolov5_deepsort_tensorrt_cpp Introduction This repo is a C++ version of yolov5_deepsort_tensorrt. And packing all C++ programs into .so files, using P

null 8 Nov 18, 2021
functorch is a prototype of JAX-like composable function transforms for PyTorch.

functorch is a prototype of JAX-like composable function transforms for PyTorch.

Facebook Research 414 Nov 25, 2021
A static analysis library for computing graph representations of Python programs suitable for use with graph neural networks.

python_graphs This package is for computing graph representations of Python programs for machine learning applications. It includes the following modu

Google Research 45 Nov 28, 2021
Prototypical python implementation of the trust-region algorithm presented in Sequential Linearization Method for Bound-Constrained Mathematical Programs with Complementarity Constraints by Larson, Leyffer, Kirches, and Manns.

Prototypical python implementation of the trust-region algorithm presented in Sequential Linearization Method for Bound-Constrained Mathematical Programs with Complementarity Constraints by Larson, Leyffer, Kirches, and Manns.

null 1 Dec 3, 2021
Code for ICML 2021 paper: How could Neural Networks understand Programs?

OSCAR This repository contains the source code of our ICML 2021 paper How could Neural Networks understand Programs?. Environment Run following comman

Dinglan Peng 89 Nov 24, 2021
[ICML 2021] Break-It-Fix-It: Learning to Repair Programs from Unlabeled Data

Break-It-Fix-It: Learning to Repair Programs from Unlabeled Data This repo provides the source code & data of our paper: Break-It-Fix-It: Unsupervised

Michihiro Yasunaga 44 Nov 13, 2021
This repository contains the code for the paper "Hierarchical Motion Understanding via Motion Programs"

Hierarchical Motion Understanding via Motion Programs (CVPR 2021) This repository contains the official implementation of: Hierarchical Motion Underst

Sumith Kulal 24 Nov 24, 2021
TensorFlowOnSpark brings TensorFlow programs to Apache Spark clusters.

TensorFlowOnSpark TensorFlowOnSpark brings scalable deep learning to Apache Hadoop and Apache Spark clusters. By combining salient features from the T

Yahoo 3.7k Nov 24, 2021
PerfFuzz: Automatically Generate Pathological Inputs for C/C++ programs

PerfFuzz Performance problems in software can arise unexpectedly when programs are provided with inputs that exhibit pathological behavior. But how ca

Caroline Lemieux 121 Nov 22, 2021
A testcase generation tool for Persistent Memory Programs.

PMFuzz PMFuzz is a testcase generation tool to generate high-value tests cases for PM testing tools (XFDetector, PMDebugger, PMTest and Pmemcheck) If

Systems Research at ShiftLab 14 Oct 27, 2021
Code for "Learning Structural Edits via Incremental Tree Transformations" (ICLR'21)

Learning Structural Edits via Incremental Tree Transformations Code for "Learning Structural Edits via Incremental Tree Transformations" (ICLR'21) 1.

NeuLab 33 Oct 29, 2021
Canonical Appearance Transformations

CAT-Net: Learning Canonical Appearance Transformations Code to accompany our paper "How to Train a CAT: Learning Canonical Appearance Transformations

STARS Laboratory 50 Apr 7, 2021
Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.

Equinox Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX Equinox brings more power to your model building in JAX. Repr

Patrick Kidger 138 Nov 26, 2021
We present a framework for training multi-modal deep learning models on unlabelled video data by forcing the network to learn invariances to transformations applied to both the audio and video streams.

Multi-Modal Self-Supervision using GDT and StiCa This is an official pytorch implementation of papers: Multi-modal Self-Supervision from Generalized D

Facebook Research 28 Nov 18, 2021
Image transformations designed for Scene Text Recognition (STR) data augmentation. Published at ICCV 2021 Workshop on Interactive Labeling and Data Augmentation for Vision.

Data Augmentation for Scene Text Recognition (ICCV 2021 Workshop) (Pronounced as "strog") Paper Arxiv Why it matters? Scene Text Recognition (STR) req

Rowel Atienza 72 Nov 29, 2021
Time-stretch audio clips quickly with PyTorch (CUDA supported)! Additional utilities for searching efficient transformations are included.

Time-stretch audio clips quickly with PyTorch (CUDA supported)! Additional utilities for searching efficient transformations are included.

Kento Nishi 12 Nov 8, 2021
Using some basic methods to show linkages and transformations of robotic arms

roboticArmVisualizer Python GUI application to create custom linkages and adjust joint angles. In the future, I plan to add 2d inverse kinematics solv

Sandesh Banskota 1 Nov 19, 2021