Composable transformations of Python+NumPy programsComposable transformations of Python+NumPy programs

Related tags

Deep Learning chex
Overview

Chex

CI status

Chex is a library of utilities for helping to write reliable JAX code.

This includes utils to help:

  • Instrument your code (e.g. assertions)
  • Debug (e.g. transforming pmaps in vmaps within a context manager).
  • Test JAX code across many variants (e.g. jitted vs non-jitted).

Installation

Chex can be installed with pip directly from github, with the following command:

pip install git+git://github.com/deepmind/chex.git

or from PyPI:

pip install chex

Modules Overview

Dataclass (dataclass.py)

Dataclasses are a popular construct introduced by Python 3.7 to allow to easily specify typed data structures with minimal boilerplate code. They are not, however, compatible with JAX and dm-tree out of the box.

In Chex we provide a JAX-friendly dataclass implementation reusing python dataclasses.

Chex implementation of dataclass registers dataclasses as internal PyTree nodes to ensure compatibility with JAX data structures.

In addition, we provide a class wrapper that exposes dataclasses as collections.Mapping descendants which allows to process them (e.g. (un-)flatten) in dm-tree methods as usual Python dictionaries. See @mappable_dataclass docstring for more details.

Example:

@chex.dataclass
class Parameters:
  x: chex.ArrayDevice
  y: chex.ArrayDevice

parameters = Parameters(
    x=jnp.ones((2, 2)),
    y=jnp.ones((1, 2)),
)

# Dataclasses can be treated as JAX pytrees
jax.tree_map(lambda x: 2.0 * x, parameters)

# and as mappings by dm-tree
tree.flatten(parameters)

NOTE: Unlike standard Python 3.7 dataclasses, Chex dataclasses cannot be constructed using positional arguments. They support construction arguments provided in the same format as the Python dict constructor. Dataclasses can be converted to tuples with the from_tuple and to_tuple methods if necessary.

parameters = Parameters(
    jnp.ones((2, 2)),
    jnp.ones((1, 2)),
)
# ValueError: Mappable dataclass constructor doesn't support positional args.

Assertions (asserts.py)

One limitation of PyType annotations for JAX is that they do not support the specification of DeviceArray ranks, shapes or dtypes. Chex includes a number of functions that allow flexible and concise specification of these properties.

E.g. suppose you want to ensure that all tensors t1, t2, t3 have the same shape, and that tensors t4, t5 have rank 2 and (3 or 4), respectively.

chex.assert_equal_shape([t1, t2, t3])
chex.assert_rank([t4, t5], [2, {3, 4}])

More examples:

from chex import assert_shape, assert_rank, ...

assert_shape(x, (2, 3))                # x has shape (2, 3)
assert_shape([x, y], [(), (2,3)])      # x is scalar and y has shape (2, 3)

assert_rank(x, 0)                      # x is scalar
assert_rank([x, y], [0, 2])            # x is scalar and y is a rank-2 array
assert_rank([x, y], {0, 2})            # x and y are scalar OR rank-2 arrays

assert_type(x, int)                    # x has type `int` (x can be an array)
assert_type([x, y], [int, float])      # x has type `int` and y has type `float`

assert_equal_shape([x, y, z])          # x, y, and z have equal shapes

assert_trees_all_close(tree_x, tree_y) # values and structure of trees match
assert_tree_all_finite(tree_x)         # all tree_x leaves are finite

assert_devices_available(2, 'gpu')     # 2 GPUs available
assert_tpu_available()                 # at least 1 TPU available

assert_numerical_grads(f, (x, y), j)   # f^{(j)}(x, y) matches numerical grads

All chex assertions support the following optional kwargs for manipulating the emitted exception messages:

  • custom_message: A string to include into the emitted exception messages.
  • include_default_message: Whether to include the default Chex message into the emitted exception messages.
  • exception_type: An exception type to use. AssertionError by default.

For example, the following code:

dataset = load_dataset()
params = init_params()
for i in range(num_steps):
  params = update_params(params, dataset.sample())
  chex.assert_tree_all_finite(params,
                              custom_message=f'Failed at iteration {i}.',
                              exception_type=ValueError)

will raise a ValueError that includes a step number when params get polluted with NaNs or Nones.

JAX re-traces JIT'ted function every time the structure of passed arguments changes. Often this behavior is inadvertent and leads to a significant performance drop which is hard to debug. @chex.assert_max_traces decorator asserts that the function is not re-traced more than n times during program execution.

Global trace counter can be cleared by calling chex.clear_trace_counter(). This function be used to isolate unittests relying on @chex.assert_max_traces.

Examples:

  @jax.jit
  @chex.assert_max_traces(n=1)
  def fn_sum_jitted(x, y):
    return x + y

  z = fn_sum_jitted(jnp.zeros(3), jnp.zeros(3))
  t = fn_sum_jitted(jnp.zeros(6, 7), jnp.zeros(6, 7))  # AssertionError!

Can be used with jax.pmap() as well:

  def fn_sub(x, y):
    return x - y

  fn_sub_pmapped = jax.pmap(chex.assert_max_traces(fn_sub, n=10))

More about tracing

See documentation of asserts.py for details on all supported assertions.

Test variants (variants.py)

JAX relies extensively on code transformation and compilation, meaning that it can be hard to ensure that code is properly tested. For instance, just testing a python function using JAX code will not cover the actual code path that is executed when jitted, and that path will also differ whether the code is jitted for CPU, GPU, or TPU. This has been a source of obscure and hard to catch bugs where XLA changes would lead to undesirable behaviours that however only manifest in one specific code transformation.

Variants make it easy to ensure that unit tests cover different ‘variations’ of a function, by providing a simple decorator that can be used to repeat any test under all (or a subset) of the relevant code transformations.

E.g. suppose you want to test the output of a function fn with or without jit. You can use chex.variants to run the test with both the jitted and non-jitted version of the function by simply decorating a test method with @chex.variants, and then using self.variant(fn) in place of fn in the body of the test.

def fn(x, y):
  return x + y
...

class ExampleTest(chex.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  def test(self):
    var_fn = self.variant(fn)
    self.assertEqual(fn(1, 2), 3)
    self.assertEqual(var_fn(1, 2), fn(1, 2))

If you define the function in the test method, you may also use self.variant as a decorator in the function definition. For example:

class ExampleTest(chex.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  def test(self):
    @self.variant
    def var_fn(x, y):
       return x + y

    self.assertEqual(var_fn(1, 2), 3)

Example of parameterized test:

from absl.testing import parameterized

# Could also be:
#  `class ExampleParameterizedTest(chex.TestCase, parameterized.TestCase):`
#  `class ExampleParameterizedTest(chex.TestCase):`
class ExampleParameterizedTest(parameterized.TestCase):

  @chex.variants(with_jit=True, without_jit=True)
  @parameterized.named_parameters(
      ('case_positive', 1, 2, 3),
      ('case_negative', -1, -2, -3),
  )
  def test(self, arg_1, arg_2, expected):
    @self.variant
    def var_fn(x, y):
       return x + y

    self.assertEqual(var_fn(arg_1, arg_2), expected)

Chex currently supports the following variants:

  • with_jit -- applies jax.jit() transformation to the function.
  • without_jit -- uses the function as is, i.e. identity transformation.
  • with_device -- places all arguments (except specified in ignore_argnums argument) into device memory before applying the function.
  • without_device -- places all arguments in RAM before applying the function.
  • with_pmap -- applies jax.pmap() transformation to the function (see notes below).

See documentation in variants.py for more details on the supported variants. More examples can be found in variants_test.py.

Variants notes

  • Test classes that use @chex.variants must inherit from chex.TestCase (or any other base class that unrolls tests generators within TestCase, e.g. absl.testing.parameterized.TestCase).

  • [jax.vmap] All variants can be applied to a vmapped function; please see an example in variants_test.py (test_vmapped_fn_named_params and test_pmap_vmapped_fn).

  • [@chex.all_variants] You can get all supported variants by using the decorator @chex.all_variants.

  • [with_pmap variant] jax.pmap(fn) (doc) performs parallel map of fn onto multiple devices. Since most tests run in a single-device environment (i.e. having access to a single CPU or GPU), in which case jax.pmap is a functional equivalent to jax.jit, with_pmap variant is skipped by default (although it works fine with a single device). Below we describe a way to properly test fn if it is supposed to be used in multi-device environments (TPUs or multiple CPUs/GPUs). To disable skipping with_pmap variants in case of a single device, add --chex_skip_pmap_variant_if_single_device=false to your test command.

Fakes (fake.py)

Debugging in JAX is made more difficult by code transformations such as jit and pmap, which introduce optimizations that make code hard to inspect and trace. It can also be difficult to disable those transformations during debugging as they can be called at several places in the underlying code. Chex provides tools to globally replace jax.jit with a no-op transformation and jax.pmap with a (non-parallel) jax.vmap, in order to more easily debug code in a single-device context.

For example, you can use Chex to fake pmap and have it replaced with a vmap. This can be achieved by wrapping your code with a context manager:

with chex.fake_pmap():
  @jax.pmap
  def fn(inputs):
    ...

  # Function will be vmapped over inputs
  fn(inputs)

The same functionality can also be invoked with start and stop:

fake_pmap = chex.fake_pmap()
fake_pmap.start()
... your jax code ...
fake_pmap.stop()

In addition, you can fake a real multi-device test environment with a multi-threaded CPU. See section Faking multi-device test environments for more details.

See documentation in fake.py and examples in fake_test.py for more details.

Faking multi-device test environments

In situations where you do not have easy access to multiple devices, you can still test parallel computation using single-device multi-threading.

In particular, one can force XLA to use a single CPU's threads as separate devices, i.e. to fake a real multi-device environment with a multi-threaded one. These two options are theoretically equivalent from XLA perspective because they expose the same interface and use identical abstractions.

Chex has a flag chex_n_cpu_devices that specifies a number of CPU threads to use as XLA devices.

To set up a multi-threaded XLA environment for absl tests, define setUpModule function in your test module:

def setUpModule():
  chex.set_n_cpu_devices()

Now you can launch your test with python test.py --chex_n_cpu_devices=N to run it in multi-device regime. Note that all tests within a module will have an access to N devices.

More examples can be found in variants_test.py, fake_test.py and fake_set_n_cpu_devices_test.py.

Citing Chex

To cite this repository:

@software{chex2020github,
  author = {David Budden and Matteo Hessel and Iurii Kemaev and Stephen Spencer
            and Fabio Viola},
  title = {Chex: Testing made fun, in JAX!},
  url = {http://github.com/deepmind/chex},
  version = {0.0.1},
  year = {2020},
}

In this bibtex entry, the version number is intended to be from chex/__init__.py, and the year corresponds to the project's open-source release.

Comments
  • [chex] Allow an ellipsis in the expected shape passed to `assert_shape`.

    [chex] Allow an ellipsis in the expected shape passed to `assert_shape`.

    [chex] Allow an ellipsis in the expected shape passed to assert_shape.

    This allows things like:

    chex.assert_shape(a, [..., seq_len, features])
    

    This is particularly useful for situations like variable numbers of batch dimensions.

    cla: yes 
    opened by copybara-service[bot] 8
  • CpuDevice no longer in jax

    CpuDevice no longer in jax

    Hello,

    Seems like the newest version of jax (0.3.7) removed some classes that are used here in chex. Should chex upper bound the jax version? I see this conflicting code is not currently on the main branch -- alternatively, maybe a new release can be made?

    https://github.com/google/jax/pull/10326

    opened by adamgayoso 4
  • `AssertsChexifyTest.test_uninspected_checks` test failure

    `AssertsChexifyTest.test_uninspected_checks` test failure

    I'm seeing the following test failure when running the test suite:

    ============================= test session starts ==============================
    platform linux -- Python 3.10.7, pytest-7.1.3, pluggy-1.0.0
    rootdir: /build/source
    collected 548 items                                                            
    
    chex/chex_test.py .                                                      [  0%]
    chex/_src/asserts_chexify_test.py ......F.....                           [  2%]
    chex/_src/asserts_internal_test.py .s.s.........                         [  4%]
    chex/_src/asserts_test.py ..s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s. [ 13%]
    s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s..................... [ 26%]
    ........................................................................ [ 39%]
    ........................................................................ [ 52%]
    .................................                                        [ 58%]
    chex/_src/dataclass_test.py ...........................................  [ 66%]
    chex/_src/dimensions_test.py .................                           [ 69%]
    chex/_src/fake_set_n_cpu_devices_test.py s                               [ 69%]
    chex/_src/fake_test.py ................................                  [ 75%]
    chex/_src/restrict_backends_test.py ssssssssss                           [ 77%]
    chex/_src/variants_test.py .....................s....s............s....s [ 85%]
    ..........................................................ssssssssssssss [ 98%]
    sssssss                                                                  [100%]
    
    =================================== FAILURES ===================================
    __________________ AssertsChexifyTest.test_uninspected_checks __________________
    
    self = <chex._src.asserts_chexify_test.AssertsChexifyTest testMethod=test_uninspected_checks>
    
        def test_uninspected_checks(self):
        
          @jax.jit
          def _pos_sum(x):
            chex_value_assert_positive(x, custom_message='err_label')
            return x.sum()
        
          invalid_x = -jnp.ones(3)
          chexify_async(_pos_sum)(invalid_x)  # async error
        
    >     with self.assertRaisesRegex(AssertionError, 'err_label'):
    E     AssertionError: AssertionError not raised
    
    chex/_src/asserts_chexify_test.py:179: AssertionError
    ------------------------------ Captured log call -------------------------------
    WARNING  absl:asserts_chexify.py:57 [Chex] Some of chexify assetion statuses were not inspected due to async exec (https://jax.readthedocs.io/en/latest/async_dispatch.html). Consider calling `chex.block_until_chexify_assertions_complete()` at the end of computations that rely on jitted chex assetions.
    =============================== warnings summary ===============================
    chex/_src/asserts_chexify_test.py: 12 warnings
      /build/source/chex/_src/asserts_chexify_test.py:58: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
        return jnp.all(jnp.array([(x > 0).all() for x in jax.tree_leaves(tree)]))
    
    chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_static_assertion__with_jit
    chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_static_assertion__without_jit
      /build/source/chex/_src/asserts_chexify_test.py:86: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
        return sum(x.sum() for x in jax.tree_leaves(tree))
    
    chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_uninspected_checks
      /nix/store/4y9j6xdkgqwkdx5ki508l175smcjgs9l-python3.10-pytest-7.1.3/lib/python3.10/site-packages/_pytest/unraisableexception.py:78: PytestUnraisableExceptionWarning: Exception ignored in atexit callback: <function _check_if_hanging_assertions at 0x7ffddfe66d40>
      
      Traceback (most recent call last):
        File "/build/source/chex/_src/asserts_chexify.py", line 32, in _check_error
          checkify.check_error(err)
        File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/_src/checkify.py", line 476, in check_error
          return assert_p.bind(err, code, payload, msgs=error.msgs)
        File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/core.py", line 328, in bind
          return self.bind_with_trace(find_top_trace(args), args, params)
        File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/core.py", line 331, in bind_with_trace
          out = trace.process_primitive(self, map(trace.full_raise, args), params)
        File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/core.py", line 698, in process_primitive
          return primitive.impl(*tracers, **params)
        File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/_src/checkify.py", line 483, in assert_impl
          raise_error(Error(err, code, msgs, payload))
        File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/_src/checkify.py", line 123, in raise_error
          raise ValueError(err)
      ValueError: [Chex] chexify assertion failed [err_label] [failed at /build/source/chex/_src/asserts_chexify_test.py:173] (check failed at /build/source/chex/_src/asserts_internal.py:229 (_chex_assert_fn))
      
      During handling of the above exception, another exception occurred:
      
      Traceback (most recent call last):
        File "/build/source/chex/_src/asserts_chexify.py", line 62, in _check_if_hanging_assertions
          block_until_chexify_assertions_complete()
        File "/build/source/chex/_src/asserts_chexify.py", line 51, in block_until_chexify_assertions_complete
          wait_fn()
        File "/build/source/chex/_src/asserts_chexify.py", line 180, in _wait_checks
          _check_error(async_check_futures.popleft().result(async_timeout))
        File "/build/source/chex/_src/asserts_chexify.py", line 40, in _check_error
          raise AssertionError(msg)  # pylint:disable=raise-missing-from
      AssertionError: [Chex] chexify assertion failed [err_label] [failed at /build/source/chex/_src/asserts_chexify_test.py:173] 
      
        warnings.warn(pytest.PytestUnraisableExceptionWarning(msg))
    
    chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted
    chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted
    chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted
    chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted_vmapped
    chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted_vmapped
      /build/source/chex/_src/asserts_chexify_test.py:52: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
        if not all((x > 0).all() for x in jax.tree_leaves(tree)):
    
    -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
    =========================== short test summary info ============================
    FAILED chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_uninspected_checks
    ====== 1 failed, 461 passed, 86 skipped, 20 warnings in 84.47s (0:01:24) =======
    error: builder for '/nix/store/f9icjsb9pbz4p8qpsyhp9gq1fvjvwwhz-python3.10-chex-0.1.5.drv' failed with exit code 1;
           last 10 log lines:
           > chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted
           > chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted_vmapped
           > chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted_vmapped
           >   /build/source/chex/_src/asserts_chexify_test.py:52: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
           >     if not all((x > 0).all() for x in jax.tree_leaves(tree)):
           >
           > -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
           > =========================== short test summary info ============================
           > FAILED chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_uninspected_checks
           > ====== 1 failed, 461 passed, 86 skipped, 20 warnings in 84.47s (0:01:24) =======
    

    I'm using

    • jax v0.3.23
    • jaxlib v0.3.22
    • absl-py v1.2.0
    • dm-tree from commit https://github.com/deepmind/tree/commit/b452e5c2743e7489b4ba7f16ecd51c516d7cd8e3
    • numpy 1.23.3
    • toolz 0.12.0
    opened by samuela 3
  • AttributeError: module 'jax' has no attribute '_src'

    AttributeError: module 'jax' has no attribute '_src'

    trying to import optax and getting an error AttributeError: module 'jax' has no attribute '_src' for jax versions > 0.3.17

    optax version == 0.1.3 chex version == 0.1.3

    In [1]: import optax
    /home/penn/anaconda3/envs/jax/lib/python3.10/site-packages/chex/_src/pytypes.py:37: FutureWarning: jax.tree_structure is deprecated, and will be removed in a future release. Use jax.tree_util.tree_structure instead.
      PyTreeDef = type(jax.tree_structure(None))
    ---------------------------------------------------------------------------
    AttributeError                            Traceback (most recent call last)
    Input In [1], in <cell line: 1>()
    ----> 1 import optax
    
    File ~/anaconda3/envs/jax/lib/python3.10/site-packages/optax/__init__.py:17, in <module>
          1 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
          2 #
          3 # Licensed under the Apache License, Version 2.0 (the "License");
       (...)
         13 # limitations under the License.
         14 # ==============================================================================
         15 """Optax: composable gradient processing and optimization, in JAX."""
    ---> 17 from optax import experimental
         18 from optax._src.alias import adabelief
         19 from optax._src.alias import adafactor
    
    File ~/anaconda3/envs/jax/lib/python3.10/site-packages/optax/experimental/__init__.py:20, in <module>
          1 # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
          2 #
          3 # Licensed under the Apache License, Version 2.0 (the "License");
       (...)
         13 # limitations under the License.
         14 # ==============================================================================
         15 """Experimental features in Optax.
         16 
         17 Features may be removed or modified at any time.
         18 """
    ---> 20 from optax._src.experimental.complex_valued import split_real_and_imaginary
         21 from optax._src.experimental.complex_valued import SplitRealAndImaginaryState
    
    File ~/anaconda3/envs/jax/lib/python3.10/site-packages/optax/_src/experimental/complex_valued.py:32, in <module>
         15 """Complex-valued optimization.
         16 
         17 When using `split_real_and_imaginary` to wrap an optimizer, we split the complex
       (...)
         27 See details at https://github.com/deepmind/optax/issues/196
         28 """
         30 from typing import NamedTuple, Union
    ---> 32 import chex
         33 import jax
         34 import jax.numpy as jnp
    
    File ~/anaconda3/envs/jax/lib/python3.10/site-packages/chex/__init__.py:17, in <module>
          1 # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
          2 #
          3 # Licensed under the Apache License, Version 2.0 (the "License");
       (...)
         13 # limitations under the License.
         14 # ==============================================================================
         15 """Chex: Testing made fun, in JAX!"""
    ---> 17 from chex._src.asserts import assert_axis_dimension
         18 from chex._src.asserts import assert_axis_dimension_comparator
         19 from chex._src.asserts import assert_axis_dimension_gt
    
    File ~/anaconda3/envs/jax/lib/python3.10/site-packages/chex/_src/asserts.py:26, in <module>
         23 import unittest
         24 from unittest import mock
    ---> 26 from chex._src import asserts_internal as _ai
         27 from chex._src import pytypes
         28 import jax
    
    File ~/anaconda3/envs/jax/lib/python3.10/site-packages/chex/_src/asserts_internal.py:32, in <module>
         29 from typing import Any, Sequence, Union, Callable, Optional, Set, Tuple, Type
         31 from absl import logging
    ---> 32 from chex._src import pytypes
         33 import jax
         34 import jax.numpy as jnp
    
    File ~/anaconda3/envs/jax/lib/python3.10/site-packages/chex/_src/pytypes.py:44, in <module>
         40 Device = jax.lib.xla_extension.Device
         42 ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]
    ---> 44 ArrayDType = jax._src.numpy.lax_numpy._ScalarMeta
    
    AttributeError: module 'jax' has no attribute '_src'
    
    opened by jenkspt 3
  • Use dataclass_transform to help type checkers with @chex.dataclass

    Use dataclass_transform to help type checkers with @chex.dataclass

    closes #155

    I basically copied this example: https://peps.python.org/pep-0681/#id1

    Tested with pyright/pylance.

    I had to specify a return type for chex.dataclass because otherwise pyright/pylance is ignoring it completely if arguments are passed to it (like chex.dataclass(eq=False)), but if it's used bare (just chex.dataclass without any parentheses) then it also works without the return type annotation.

    There is a (expected) test failure from pytype:

    FAILED: /home/tmk/dev/python/chex/.pytype/pyi/chex/_src/dataclass.pyi 
    /tmp/chex-env/bin/python3 -m pytype.single --imports_info /home/tmk/dev/python/chex/.pytype/imports/chex._src.dataclass.imports --module-name chex._src.dataclass --platform linux -V 3.9 -o /home/tmk/dev/python/chex/.pytype/pyi/chex/_src/dataclass.pyi --analyze-annotated --nofail --quick /home/tmk/dev/python/chex/chex/_src/dataclass.py
    File "/home/tmk/dev/python/chex/chex/_src/dataclass.py", line 90, in <module>: typing_extensions.dataclass_transform not supported yet [not-supported-yet]
    

    I'm not sure how to deal with that.

    There is also not really a way to write tests for this...

    cc @hbq1

    opened by thomkeh 3
  • Prevent chex.fake_pmap|jit function signature inspection from following through wrappers, otherwise if a wrapper changes the signature in some way, the fakes choke on those.

    Prevent chex.fake_pmap|jit function signature inspection from following through wrappers, otherwise if a wrapper changes the signature in some way, the fakes choke on those.

    Prevent chex.fake_pmap|jit function signature inspection from following through wrappers, otherwise if a wrapper changes the signature in some way, the fakes choke on those.

    cla: no 
    opened by copybara-service[bot] 3
  • Using variants with pytest

    Using variants with pytest

    Hi,

    First of all thank you for this very useful library !

    I have a project in Jax in which I already implemented my tests using pytest. However the possibility that chex.variants offers are too nice to ignore. Simultaneously I would like not to rewrite all my test.

    Is there a way to reconcile pytest and chex ?

    Thank you again for all the work! Best,

    opened by pablo2909 2
  • [chex] Allow `set`s of alternatives in expected shape for `assert_shape`.

    [chex] Allow `set`s of alternatives in expected shape for `assert_shape`.

    [chex] Allow sets of alternatives in expected shape for assert_shape.

    This extends the behavior allowed by assert_rank to assert_shape, enabling things like:

    chex.assert_shape(mask, (batch_size, {num_heads, 1}, q_len, kv_len))
    

    In this example, axis 1 can either be num_heads or 1, which is helpful, for example, in situations where you want to allow only particular dimensions to be broadcastable.

    cla: yes 
    opened by copybara-service[bot] 2
  • [REQ] Conda recipe

    [REQ] Conda recipe

    Hi, I'm the lead developer of NetKet, an established machine learning / quantum physics package.

    We have recently finished rewriting our core to be based on Jax (and flax), and recently released a beta version. Since many physicists seem to use anaconda, we would also like to update our conda recipe. However, since we depend on optax (and therefore on Chex), we would need Chex to have a Conda recipe.

    Is that something you'd consider? I am willing to volunteer some work to help you.

    I tried creating a recipe starting from your pypi source distribution, but that is problematic because you don't bundle your requirements.txt file, which is required to run setup.py. I could create a recipe from the tag tarballs on GitHub, but that sometimes prevent the conda packages from auto-updating the recipe for later releases.

    opened by PhilipVinc 2
  • Chex dataclass defaulting mappable_dataclass=True

    Chex dataclass defaulting mappable_dataclass=True

    To start with, thanks for open sourcing your work on Chex, it's a great tooling library for building robust Jax applications!

    As I was upgrading to the latest release 0.0.3, I noticed quite a few of my tests breaking. It happens that the default option mappable_dataclass=True in chex.dataclass is breaking the usual interface of dataclasses (which is clearly expected reading the code documentation!)

    I guess probably from the perspective of Deepmind usage, it makes sense to default this option. But from an external user point of view, it is rather surprising to have a dataclass decorator not behaving like a dataclass. I think it would be great to make it clear in the library readme that this option needs to be turned off to get the full dataclass behaviour (or turned it off by default).

    opened by balancap 2
  • Add ability to check shapes with wildcards

    Add ability to check shapes with wildcards

    I often find myself writing the following sort of thing:

    chex.assert_rank(x, 2)
    x.shape[1] == num_actions, "some custom message ..."
    

    It would be nice to be able to simply check the shape with a wildcard, i.e.

    chex.assert_shape(x, (None, num_actions))
    

    What do you think?

    cla: yes 
    opened by KristianHolsheimer 2
  • [chex] Add `assert_trees_all_equal_shapes_and_dtypes`

    [chex] Add `assert_trees_all_equal_shapes_and_dtypes`

    [chex] Add assert_trees_all_equal_shapes_and_dtypes

    This is purely a convenience function, asserting both assert_trees_all_equal_shapes and assert_trees_all_equal_dtypes.

    opened by copybara-service[bot] 0
  • Improve support for custom `__init__` methods in dataclasses.

    Improve support for custom `__init__` methods in dataclasses.

    Improve support for custom __init__ methods in dataclasses.

    Chex dataclasses assume the dataclass has a default constructor, which is necessary for flatten/unflatten. This change allows custom initializers by keeping an internal reference to a default initializer for use with flatten/unflatten.

    opened by copybara-service[bot] 0
  • post_init error in inherited dataclass

    post_init error in inherited dataclass

    When inheriting one dataclass from another, Chex's dataclass does not allow a super() call to be made. This is something you can do in Python's base dataclass module.

    A minimum working example is

    from chex import dataclass as dataclass
    
    @dataclass
    class ChexBase:
        a : int 
    
        def __post_init__(self):
            self.b = self.a + 1
    
    @dataclass
    class ChexSub(ChexBase):
        a: int 
    
        def __post_init__(self):
            super().__post_init__()
            self.c = self.a + 2
    
    temp = ChexSub(a = 1)
    temp.b
    

    Importing dataclass from dataclasses runs without error and returns 2, as expected.

    Environment

    • Chex version 0.1.5
    • Ubuntu 20.04
    • Python 3.9
    opened by thomaspinder 1
Releases(v0.1.5)
  • v0.1.5(Sep 13, 2022)

    What's Changed

    • Add support for value assertions in jitted functions. by @copybara-service in https://github.com/deepmind/chex/pull/178
    • [JAX] Avoid private implementation detail _ScalarMeta. by @copybara-service in https://github.com/deepmind/chex/pull/180
    • [JAX] Avoid implicit references to jax._src. by @copybara-service in https://github.com/deepmind/chex/pull/181
    • Release v0.1.15 by @copybara-service in https://github.com/deepmind/chex/pull/184

    Full Changelog: https://github.com/deepmind/chex/compare/v0.1.4...v0.1.5

    Source code(tar.gz)
    Source code(zip)
  • v0.1.4(Aug 4, 2022)

    What's Changed

    • Add an InitVar field in the dataclass tests. by @copybara-service in https://github.com/deepmind/chex/pull/161
    • Download latest .pylintrc version in tests. by @copybara-service in https://github.com/deepmind/chex/pull/167
    • Fix assert_axis_dimension_comparator usages. by @copybara-service in https://github.com/deepmind/chex/pull/168
    • Update "jax.tree_util" functions by @copybara-service in https://github.com/deepmind/chex/pull/171
    • Use jax.tree_util.tree_map in place of deprecated tree_multimap. by @copybara-service in https://github.com/deepmind/chex/pull/175
    • Silence some pytype errors. by @copybara-service in https://github.com/deepmind/chex/pull/174
    • Add chex.Dimensions utility for readable shape asserts. by @copybara-service in https://github.com/deepmind/chex/pull/169

    Full Changelog: https://github.com/deepmind/chex/compare/v0.1.3...v0.1.4

    Source code(tar.gz)
    Source code(zip)
  • v0.1.3(Apr 19, 2022)

    What's Changed

    • Slight helping clarification to clear_trace_counter. by @lucasb-eyer in https://github.com/deepmind/chex/pull/148
    • Add new JAX-specific pytypes to chex pytypes. by @copybara-service in https://github.com/deepmind/chex/pull/153
    • Remove chex.{C,G,T}puDevice in favour of chex.Device. by @copybara-service in https://github.com/deepmind/chex/pull/154

    New Contributors

    • @lucasb-eyer made their first contribution in https://github.com/deepmind/chex/pull/148

    Full Changelog: https://github.com/deepmind/chex/compare/v0.1.2...v0.1.3

    Source code(tar.gz)
    Source code(zip)
  • v0.1.2(Mar 31, 2022)

    What's Changed

    • Support JAX parallel operations in chex.fake_pmap contexts by @copybara-service in https://github.com/deepmind/chex/pull/142
    • Remove references to jax.numpy.lax_numpy. by @copybara-service in https://github.com/deepmind/chex/pull/150

    Full Changelog: https://github.com/deepmind/chex/compare/v0.1.1...v0.1.2

    Source code(tar.gz)
    Source code(zip)
  • v0.1.1(Feb 25, 2022)

    What's Changed

    • Move dataclass registration to init so that it's invoked after deserialization. by @copybara-service in https://github.com/deepmind/chex/pull/111
    • Add pytype for jax array's dtypes. by @copybara-service in https://github.com/deepmind/chex/pull/112
    • Fix dataclass registration on deserialization. by @copybara-service in https://github.com/deepmind/chex/pull/114
    • Fix restrict_backends after jax.xla.backend_compile was moved by @copybara-service in https://github.com/deepmind/chex/pull/116
    • Refactor asserts.py and warn users not to rely on asserts_internal's functionality. by @copybara-service in https://github.com/deepmind/chex/pull/117
    • Set up ReadTheDoc pages and add a few examples. by @copybara-service in https://github.com/deepmind/chex/pull/118
    • Include Sphinx builds into CI tests. by @copybara-service in https://github.com/deepmind/chex/pull/119
    • Adds internal functionality by @copybara-service in https://github.com/deepmind/chex/pull/122
    • Update Chex citation. by @copybara-service in https://github.com/deepmind/chex/pull/125
    • Refactor assertions in preparation for including them into the RTD docs. by @copybara-service in https://github.com/deepmind/chex/pull/126
    • Add asserts, variants, and pytypes modules to the RTD docs. by @copybara-service in https://github.com/deepmind/chex/pull/127
    • Fix references to collections.abc.Mappable -> collections.abc.Mapping in docs and comments. collections.abc.Mappable does not exist. by @copybara-service in https://github.com/deepmind/chex/pull/129
    • Document the rational behing the mappability of chex.dataclasses. by @copybara-service in https://github.com/deepmind/chex/pull/130
    • Add 3 new tree assertions: by @copybara-service in https://github.com/deepmind/chex/pull/131
    • Add assert_tree_is_sharded for asserting that a tree is sharded across the specified devices. by @copybara-service in https://github.com/deepmind/chex/pull/132
    • Add PyTreeDef to pytypes. by @copybara-service in https://github.com/deepmind/chex/pull/134
    • Disallow ShardedDeviceArrays in assert_tree_is_on_host and assert_tree_is_on_device. by @copybara-service in https://github.com/deepmind/chex/pull/133
    • Bump ipython from 7.16.1 to 7.16.3 in /requirements by @dependabot in https://github.com/deepmind/chex/pull/135
    • Remove the old venv directory before testing the package. by @copybara-service in https://github.com/deepmind/chex/pull/138
    • Refactor asserts.py in preparation for experimental device assertions. by @copybara-service in https://github.com/deepmind/chex/pull/137
    • Fix minor typo in docs. by @copybara-service in https://github.com/deepmind/chex/pull/139
    • Improve exception message for assert_tree_shape_prefix. by @copybara-service in https://github.com/deepmind/chex/pull/143
    • Release v0.1.1 by @copybara-service in https://github.com/deepmind/chex/pull/146

    Full Changelog: https://github.com/deepmind/chex/compare/v0.1.0...v0.1.1

    Source code(tar.gz)
    Source code(zip)
  • v0.1.0(Nov 18, 2021)

  • v0.0.9(Nov 16, 2021)

    It is the latest version compatible with Python 3.6. See https://github.com/deepmind/optax/issues/222 for more details.

    Changes since 0.0.8:

    • Use rtol=1e-6 in asserts.assert_tree_close;
    • Added asserts.assert_trees_all_equal;
    • Removed restricted_inheritance option from Chex dataclasses;
    • Added dims= option to assert_equal_shape, to check a subset of dims;
    • Added test.sh for launching CI tests on a local machine;
    • Added support for default exception messages and types to assertions;
    • Added support for jnp.bfloat16 to asserts.assert_trees_all_close();
    • Added support for static_argnames to variants.with_jit;
    • Added a restrict_backends module for constraining the set of backends that a region of code can use;
    • Added asserts.assert_trees_all_equal_dtypes assertion;
    • Exposed asserts.assert_tree_shape_suffix to the public API;
    • Added asserts.assert_tree_shape_suffix to check whether arrays share the same suffix.
    Source code(tar.gz)
    Source code(zip)
  • v0.0.8(Jul 2, 2021)

    Changes:

    • Add support for static_broadcasted_argnums to fake_pmap;
    • Allows sets of alternatives and ellipsis in assert_shape;
    • Format @variant test names to use only underscores and lowercase letters;
    • Fix incorrect type annotation in asserts.py;
    • Fix dataclass (un-)flatten functions;
    • Add more tests for dataclasses;
    • Raise ValueError when no variants are selected;
    • Exclude chex' internal frames from AssertionError tracebacks;
    • Add '[Chex] ' prefix to AssertionError messages;
    • Include path to leaves that failed the equality check in assert_tree_all_close;
    • Clean up asserts.py;
    • Asserts which only make sense on >1 tree now demand this (can result in breakages in the existing code).
    Source code(tar.gz)
    Source code(zip)
  • v0.0.7(May 4, 2021)

    Changelog

    Full Changelog

    Closed issues:

    • [REQ] Conda recipe #37

    Merged pull requests:

    * This Changelog was automatically generated by github_changelog_generator

    Source code(tar.gz)
    Source code(zip)
  • v0.0.6(Mar 25, 2021)

  • v0.0.5(Mar 22, 2021)

    Changelog

    Note: this is a first GitHub release of Chex. It includes all changes since the repo was created.

    Full Changelog

    Closed issues:

    • Chex dataclass throws an exception in Python 3.9 #10
    • 'jax.interpreters.xla' has no attribute '_DeviceArray' for jax <= 0.2.5 #9
    • Chex dataclass defaulting mappable_dataclass=True #8
    • DeprecationWarning for importing toolz #4
    • Fake contexts by calling .start() not working #3

    Merged pull requests:

    * This Changelog was automatically generated by github_changelog_generator

    Source code(tar.gz)
    Source code(zip)
Owner
DeepMind
DeepMind
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

JAX: Autograd and XLA Quickstart | Transformations | Install guide | Neural net libraries | Change logs | Reference docs | Code search News: JAX tops

Google 11.4k Feb 13, 2021
This repo is a C++ version of yolov5_deepsort_tensorrt. Packing all C++ programs into .so files, using Python script to call C++ programs further.

yolov5_deepsort_tensorrt_cpp Introduction This repo is a C++ version of yolov5_deepsort_tensorrt. And packing all C++ programs into .so files, using P

null 41 Dec 27, 2022
functorch is a prototype of JAX-like composable function transforms for PyTorch.

functorch is a prototype of JAX-like composable function transforms for PyTorch.

Facebook Research 1.2k Jan 9, 2023
MLP-Numpy - A simple modular implementation of Multi Layer Perceptron in pure Numpy.

MLP-Numpy A simple modular implementation of Multi Layer Perceptron in pure Numpy. I used the Iris dataset from scikit-learn library for the experimen

Soroush Omranpour 1 Jan 1, 2022
Code for "Learning Structural Edits via Incremental Tree Transformations" (ICLR'21)

Learning Structural Edits via Incremental Tree Transformations Code for "Learning Structural Edits via Incremental Tree Transformations" (ICLR'21) 1.

NeuLab 40 Dec 23, 2022
Canonical Appearance Transformations

CAT-Net: Learning Canonical Appearance Transformations Code to accompany our paper "How to Train a CAT: Learning Canonical Appearance Transformations

STARS Laboratory 54 Dec 24, 2022
Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.

Equinox Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX Equinox brings more power to your model building in JAX. Repr

Patrick Kidger 909 Dec 30, 2022
We present a framework for training multi-modal deep learning models on unlabelled video data by forcing the network to learn invariances to transformations applied to both the audio and video streams.

Multi-Modal Self-Supervision using GDT and StiCa This is an official pytorch implementation of papers: Multi-modal Self-Supervision from Generalized D

Facebook Research 42 Dec 9, 2022
Image transformations designed for Scene Text Recognition (STR) data augmentation. Published at ICCV 2021 Workshop on Interactive Labeling and Data Augmentation for Vision.

Data Augmentation for Scene Text Recognition (ICCV 2021 Workshop) (Pronounced as "strog") Paper Arxiv Why it matters? Scene Text Recognition (STR) req

Rowel Atienza 152 Dec 28, 2022
Time-stretch audio clips quickly with PyTorch (CUDA supported)! Additional utilities for searching efficient transformations are included.

Time-stretch audio clips quickly with PyTorch (CUDA supported)! Additional utilities for searching efficient transformations are included.

Kento Nishi 22 Jul 7, 2022
Using some basic methods to show linkages and transformations of robotic arms

roboticArmVisualizer Python GUI application to create custom linkages and adjust joint angles. In the future, I plan to add 2d inverse kinematics solv

Sandesh Banskota 1 Nov 19, 2021
An official PyTorch implementation of the TKDE paper "Self-Supervised Graph Representation Learning via Topology Transformations".

Self-Supervised Graph Representation Learning via Topology Transformations This repository is the official PyTorch implementation of the following pap

Hsiang Gao 2 Oct 31, 2022
Streaming over lightweight data transformations

Description Data augmentation libarary for Deep Learning, which supports images, segmentation masks, labels and keypoints. Furthermore, SOLT is fast a

Research Unit of Medical Imaging, Physics and Technology 256 Jan 8, 2023
A static analysis library for computing graph representations of Python programs suitable for use with graph neural networks.

python_graphs This package is for computing graph representations of Python programs for machine learning applications. It includes the following modu

Google Research 258 Dec 29, 2022
Prototypical python implementation of the trust-region algorithm presented in Sequential Linearization Method for Bound-Constrained Mathematical Programs with Complementarity Constraints by Larson, Leyffer, Kirches, and Manns.

Prototypical python implementation of the trust-region algorithm presented in Sequential Linearization Method for Bound-Constrained Mathematical Programs with Complementarity Constraints by Larson, Leyffer, Kirches, and Manns.

null 3 Dec 2, 2022
Some simple programs built in Python: webcam with cv2 that detects eyes and face, with grayscale filter

Programas en Python Algunos programas simples creados en Python: ?? Webcam con c

Madirex 1 Feb 15, 2022
Code for ICML 2021 paper: How could Neural Networks understand Programs?

OSCAR This repository contains the source code of our ICML 2021 paper How could Neural Networks understand Programs?. Environment Run following comman

Dinglan Peng 115 Dec 17, 2022
[ICML 2021] Break-It-Fix-It: Learning to Repair Programs from Unlabeled Data

Break-It-Fix-It: Learning to Repair Programs from Unlabeled Data This repo provides the source code & data of our paper: Break-It-Fix-It: Unsupervised

Michihiro Yasunaga 86 Nov 30, 2022