JAX-based neural network library

Overview

Haiku: Sonnet for JAX

Overview | Why Haiku? | Quickstart | Installation | Examples | User manual | Documentation | Citing Haiku

pytest

What is Haiku?

Haiku is a tool
For building neural networks
Think: "Sonnet for JAX"

Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet, a neural network library for TensorFlow.

Documentation on Haiku can be found at https://dm-haiku.readthedocs.io/.

Disambiguation: if you are looking for Haiku the operating system then please see https://haiku-os.org/.

Overview

JAX is a numerical computing library that combines NumPy, automatic differentiation, and first-class GPU/TPU support.

Haiku is a simple neural network library for JAX that enables users to use familiar object-oriented programming models while allowing full access to JAX's pure function transformations.

Haiku provides two core tools: a module abstraction, hk.Module, and a simple function transformation, hk.transform.

hk.Modules are Python objects that hold references to their own parameters, other modules, and methods that apply functions on user inputs.

hk.transform turns functions that use these object-oriented, functionally "impure" modules into pure functions that can be used with jax.jit, jax.grad, jax.pmap, etc.

Why Haiku?

There are a number of neural network libraries for JAX. Why should you choose Haiku?

Haiku has been tested by researchers at DeepMind at scale.

  • DeepMind has reproduced a number of experiments in Haiku and JAX with relative ease. These include large-scale results in image and language processing, generative models, and reinforcement learning.

Haiku is a library, not a framework.

  • Haiku is designed to make specific things simpler: managing model parameters and other model state.
  • Haiku can be expected to compose with other libraries and work well with the rest of JAX.
  • Haiku otherwise is designed to get out of your way - it does not define custom optimizers, checkpointing formats, or replication APIs.

Haiku does not reinvent the wheel.

  • Haiku builds on the programming model and APIs of Sonnet, a neural network library with near universal adoption at DeepMind. It preserves Sonnet's Module-based programming model for state management while retaining access to JAX's function transformations.
  • Haiku APIs and abstractions are as close as reasonable to Sonnet. Many users have found Sonnet to be a productive programming model in TensorFlow; Haiku enables the same experience in JAX.

Transitioning to Haiku is easy.

  • By design, transitioning from TensorFlow and Sonnet to JAX and Haiku is easy.
  • Outside of new features (e.g. hk.transform), Haiku aims to match the API of Sonnet 2. Modules, methods, argument names, defaults, and initialization schemes should match.

Haiku makes other aspects of JAX simpler.

  • Haiku offers a trivial model for working with random numbers. Within a transformed function, hk.next_rng_key() returns a unique rng key.
  • These unique keys are deterministically derived from an initial random key passed into the top-level transformed function, and are thus safe to use with JAX program transformations.

Quickstart

Let's take a look at an example neural network and loss function.

import haiku as hk
import jax.numpy as jnp

def softmax_cross_entropy(logits, labels):
  one_hot = jax.nn.one_hot(labels, logits.shape[-1])
  return -jnp.sum(jax.nn.log_softmax(logits) * one_hot, axis=-1)

def loss_fn(images, labels):
  mlp = hk.Sequential([
      hk.Linear(300), jax.nn.relu,
      hk.Linear(100), jax.nn.relu,
      hk.Linear(10),
  ])
  logits = mlp(images)
  return jnp.mean(softmax_cross_entropy(logits, labels))

# There are two transforms in Haiku, hk.transform and hk.transform_with_state.
# If our network updated state during the forward pass (e.g. like the moving
# averages in hk.BatchNorm) we would need hk.transform_with_state, but for our
# simple MLP we can just use hk.transform.
loss_fn_t = hk.transform(loss_fn)

# MLP is deterministic once we have our parameters, as such we will not need to
# pass an RNG key to apply. without_apply_rng is a convenience wrapper that will
# make the rng argument to `loss_fn_t.apply` default to `None`.
loss_fn_t = hk.without_apply_rng(loss_fn_t)

hk.transform allows us to turn this function into a pair of pure functions: init and apply. All JAX transformations (e.g. jax.grad) require you to pass in a pure function for correct behaviour. Haiku makes it easy to write them.

The init function returned by hk.transform allows you to collect the initial value of any parameters in the network. Haiku does this by running your function, keeping track of any parameters requested through hk.get_parameter and returning them to you:

# Initial parameter values are typically random. In JAX you need a key in order
# to generate random numbers and so Haiku requires you to pass one in.
rng = jax.random.PRNGKey(42)

# `init` runs your function, as such we need an example input. Typically you can
# pass "dummy" inputs (e.g. ones of the same shape and dtype) since initialization
# is not usually data dependent.
images, labels = next(input_dataset)

# The result of `init` is a nested data structure of all the parameters in your
# network. You can pass this into `apply`.
params = loss_fn_t.init(rng, images, labels)

The params object is designed for you to inspect and manipulate. It is a mapping of module name to module parameters, where a module parameter is a mapping of parameter name to parameter value. For example:

{'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
            'w': ndarray(..., shape=(28, 300), dtype=float32)},
 'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
              'w': ndarray(..., shape=(1000, 100), dtype=float32)},
 'linear_2': {'b': ndarray(..., shape=(10,), dtype=float32),
              'w': ndarray(..., shape=(100, 10), dtype=float32)}}

The apply function allows you to inject parameter values into your function. Whenever hk.get_parameter is called the value returned will come from the params you provide as input to apply:

loss = loss_fn_t.apply(params, images, labels)

Since apply is a pure function we can pass it to jax.grad (or any of JAX's other transforms):

grads = jax.grad(loss_fn_t.apply)(params, images, labels)

Finally, we put this all together into a simple training loop:

def sgd(param, update):
  return param - 0.01 * update

for images, labels in input_dataset:
  grads = jax.grad(loss_fn_t.apply)(params, images, labels)
  params = jax.tree_multimap(sgd, params, grads)

Here we used jax.tree_multimap to apply the sgd function across all matching entries in params and grads. The result has the same structure as the previous params and can again be used with apply.

For more, see our examples directory. The MNIST example is a good place to start.

Installation

Haiku is written in pure Python, but depends on C++ code via JAX.

Because JAX installation is different depending on your CUDA version, Haiku does not list JAX as a dependency in requirements.txt.

First, follow these instructions to install JAX with the relevant accelerator support.

Then, install Haiku using pip:

$ pip install git+https://github.com/deepmind/dm-haiku

Our examples rely on additional libraries (e.g. bsuite). You can install the full set of additional requirements using pip:

$ pip install -r examples/requirements.txt

User manual

Writing your own modules

In Haiku, all modules are a subclass of hk.Module. You can implement any method you like (nothing is special-cased), but typically modules implement __init__ and __call__.

Let's work through implementing a linear layer:

class MyLinear(hk.Module):

  def __init__(self, output_size, name=None):
    super().__init__(name=name)
    self.output_size = output_size

  def __call__(self, x):
    j, k = x.shape[-1], self.output_size
    w_init = hk.initializers.TruncatedNormal(1. / np.sqrt(j))
    w = hk.get_parameter("w", shape=[j, k], dtype=x.dtype, init=w_init)
    b = hk.get_parameter("b", shape=[k], dtype=x.dtype, init=jnp.zeros)
    return jnp.dot(x, w) + b

All modules have a name. When no name argument is passed to the module, its name is inferred from the name of the Python class (for example MyLinear becomes my_linear). Modules can have named parameters that are accessed using hk.get_parameter(param_name, ...). We use this API (rather than just using object properties) so that we can convert your code into a pure function using hk.transform.

When using modules you need to define functions and transform them into a pair of pure functions using hk.transform. See our quickstart for more details about the functions returned from transform:

def forward_fn(x):
  model = MyLinear(10)
  return model(x)

# Turn `forward_fn` into an object with `init` and `apply` methods. By default,
# the `apply` will require an rng (which can be None), to be used with
# `hk.next_rng_key`.
forward = hk.transform(forward_fn)

x = jnp.ones([1, 1])

# When we run `forward.init`, Haiku will run `forward_fn(x)` and collect initial
# parameter values. Haiku requires you pass a RNG key to `init`, since parameters
# are typically initialized randomly:
key = hk.PRNGSequence(42)
params = forward.init(next(key), x)

# When we run `forward.apply`, Haiku will run `forward_fn(x)` and inject parameter
# values from the `params` that are passed as the first argument.  Note that
# models transformed using `hk.transform(f)` must be called with an additional
# `rng` argument: `forward.apply(params, rng, x)`. Use
# `hk.without_apply_rng(hk.transform(f))` is this is undesirable.
y = forward.apply(params, None, x)

Working with stochastic models

Some models may require random sampling as part of the computation. For example, in variational autoencoders with the reparametrization trick, a random sample from the standard normal distribution is needed. For dropout we need a random mask to drop units from the input. The main hurdle in making this work with JAX is in management of PRNG keys.

In Haiku we provide a simple API for maintaining a PRNG key sequence associated with modules: hk.next_rng_key() (or next_rng_keys() for multiple keys):

class MyDropout(hk.Module):

  def __init__(self, rate=0.5, name=None):
    super().__init__(name=name)
    self.rate = rate

  def __call__(self, x):
    key = hk.next_rng_key()
    p = jax.random.bernoulli(key, 1.0 - self.rate, shape=x.shape)
    return x * p / (1.0 - self.rate)

forward = hk.transform(lambda x: MyDropout()(x))

key1, key2 = jax.random.split(jax.random.PRNGKey(42), 2)
params = forward.init(key1, x)
prediction = forward.apply(params, key2, x)

For a more complete look at working with stochastic models, please see our VAE example.

Note: hk.next_rng_key() is not functionally pure which means you should avoid using it alongside JAX transformations which are inside hk.transform. For more information and possible workarounds, please consult the docs on Haiku transforms and available wrappers for JAX transforms inside Haiku networks.

Working with non-trainable state

Some models may want to maintain some internal, mutable state. For example, in batch normalization a moving average of values encountered during training is maintained.

In Haiku we provide a simple API for maintaining mutable state that is associated with modules: hk.set_state and hk.get_state. When using these functions you need to transform your function using hk.transform_with_state since the signature of the returned pair of functions is different:

def forward(x, is_training):
  net = hk.nets.ResNet50(1000)
  return net(x, is_training)

forward = hk.transform_with_state(forward)

# The `init` function now returns parameters **and** state. State contains
# anything that was created using `hk.set_state`. The structure is the same as
# params (e.g. it is a per-module mapping of named values).
params, state = forward.init(rng, x, is_training=True)

# The apply function now takes both params **and** state. Additionally it will
# return updated values for state. In the resnet example this will be the
# updated values for moving averages used in the batch norm layers.
logits, state = forward.apply(params, state, rng, x, is_training=True)

If you forget to use hk.transform_with_state don't worry, we will print a clear error pointing you to hk.transform_with_state rather than silently dropping your state.

Distributed training with jax.pmap

The pure functions returned from hk.transform (or hk.transform_with_state) are fully compatible with jax.pmap. For more details on SPMD programming with jax.pmap, look here.

One common use of jax.pmap with Haiku is for data-parallel training on many accelerators, potentially across multiple hosts. With Haiku, that might look like this:

def loss_fn(inputs, labels):
  logits = hk.nets.MLP([8, 4, 2])(x)
  return jnp.mean(softmax_cross_entropy(logits, labels))

loss_fn_t = hk.transform(loss_fn)
loss_fn_t = hk.without_apply_rng(loss_fn_t)

# Initialize the model on a single device.
rng = jax.random.PRNGKey(428)
sample_image, sample_label = next(input_dataset)
params = loss_fn_t.init(rng, sample_image, sample_label)

# Replicate params onto all devices.
num_devices = jax.local_device_count()
params = jax.tree_util.tree_map(lambda x: np.stack([x] * num_devices), params)

def make_superbatch():
  """Constructs a superbatch, i.e. one batch of data per device."""
  # Get N batches, then split into list-of-images and list-of-labels.
  superbatch = [next(input_dataset) for _ in range(num_devices)]
  superbatch_images, superbatch_labels = zip(*superbatch)
  # Stack the superbatches to be one array with a leading dimension, rather than
  # a python list. This is what `jax.pmap` expects as input.
  superbatch_images = np.stack(superbatch_images)
  superbatch_labels = np.stack(superbatch_labels)
  return superbatch_images, superbatch_labels

def update(params, inputs, labels, axis_name='i'):
  """Updates params based on performance on inputs and labels."""
  grads = jax.grad(loss_fn_t.apply)(params, inputs, labels)
  # Take the mean of the gradients across all data-parallel replicas.
  grads = jax.lax.pmean(grads, axis_name)
  # Update parameters using SGD or Adam or ...
  new_params = my_update_rule(params, grads)
  return new_params

# Run several training updates.
for _ in range(10):
  superbatch_images, superbatch_labels = make_superbatch()
  params = jax.pmap(update, axis_name='i')(params, superbatch_images,
                                           superbatch_labels)

For a more complete look at distributed Haiku training, take a look at our ResNet-50 on ImageNet example.

Citing Haiku

To cite this repository:

@software{haiku2020github,
  author = {Tom Hennigan and Trevor Cai and Tamara Norman and Igor Babuschkin},
  title = {{H}aiku: {S}onnet for {JAX}},
  url = {http://github.com/deepmind/dm-haiku},
  version = {0.0.3},
  year = {2020},
}

In this bibtex entry, the version number is intended to be from haiku/__init__.py, and the year corresponds to the project's open-source release.

Comments
  • Is there a good way to save/load & compress/decompress model weights?

    Is there a good way to save/load & compress/decompress model weights?

    Hey- This is Chris. I'm using this open-source for my project.

    https://github.com/chris-chris/haiku-scalable-example

    Since I'm new to JAX and haiku, I have some questions.

    Is there a good way to save/load & compress/decompress & serialize model weights?

    • save/load model (network only or weight only)
    • compress/decompress weights
    • serialize

    I think serialization is an important issue on scalability. Can you give me some keywords or hints about this issue?

    Thanks!

    opened by chris-chris 10
  • Is there a way to share parameters between methods?

    Is there a way to share parameters between methods?

    You can write and transform multiple methods on the same module, but it doesn't seem possible to share parameters between them without manually merging the two parameter FlatMappings. It's particularly cumbersome if the shared parameters are used several submodules deep. Is there any more convenient way to accomplish something like this?

    opened by davisyoshida 7
  • Adds Identity initializer

    Adds Identity initializer

    This PR adds the initializers that sonnets has but are missing in Haiku. I haven't written tests for them yet since I don't know if there is interest in this PR, as soon as the Haiku team gives me the green light I will add the tests

    cla: yes 
    opened by joaogui1 7
  • He initialization

    He initialization

    The default initialization for linear and convolutional modules seems to be Glorot initialization, but for the commonly used ReLU activation function He initialization is superior, while only requiring a quick change to the stddev definition, should we implement better defaults? I know that there are many initialization schemes, I only suggest it as it would't be computationally expensive and would also be only a minor code change.

    enhancement 
    opened by joaogui1 7
  • Feeding in dictionary of data?

    Feeding in dictionary of data?

    Hey all!

    One thing I really enjoyed about Tensorflow was the feeddict option where I could then access the data by the keys to easily access and process the data in chunks. E.g

    {
        "data_to_be_embedded": ... #Some (batch_size, N , M) matrix
        "timeseries data": ... # (batch_size, timeseries_window)
    }
    

    I suppose one option would be to define multiple models and wrap them all into a single function which applies them key-value by key-value. Is there a more "idiomatic" way of doing this in Haiku?

    opened by IanQS 6
  • Correct way to transform and init a `hk.Module` with non-default parameter?

    Correct way to transform and init a `hk.Module` with non-default parameter?

    Hey all!

    I'm trying to run a linear regression example and I've got the following

    import jax.numpy as jnp
    from sklearn.datasets import load_boston
    import haiku as hk
    import optax
    import jax
    
    
    X, y = load_boston(return_X_y=True)
    train_X = jnp.asarray(X.tolist())
    train_y = jnp.asarray(y.tolist())
        
    class Model(hk.Module):
        def __init__(self, input_dims):
            super().__init__()
            self.input_dims = input_dims
        
        def __call__(self, X: jnp.ndarray) -> jnp.ndarray:
            l1 = hk.Linear(self.input_dims)
            return l1(X)
        
    model = hk.transform(lambda x: Model()(x))  # <-- where I would specify the model shape if at all? 
    

    So I'm running into an issue where I'm not able to specify the model shape. If I do not specify it as in the above, I get the error of

    __init__() missing 1 required positional argument: 'input_dims'

    but if I do specify the shape via

    model = hk.transform(lambda x: Model(train_X.shape[1])(x))
    

    I get Argument '<function without_state.<locals>.init_fn at 0x7f1e5c616430>' of type <class 'function'> is not a valid JAX type.


    What is the recommended way of addressing this? I'm reading through hk.transform but I'm not sure. Looking at the code examples, there are __init__ functions without default args so I know it's possible.

    opened by IanQS 6
  • Efficient Ways for Saving and Loading weights

    Efficient Ways for Saving and Loading weights

    I'm sorry if it's not the right place as I could not find the discussions or forum page.

    I was wondering what are some of the most efficient ways to save and load models (also verify it's properly loaded into GPU)?

    1, In the docs, its given as save_the_model using Tensorflow 2. I also understand that the weights of haiku network are stored in a dictionary, as an example

    {'linear': {'b': ndarray(..., shape=(300,), dtype=float32),
                'w': ndarray(..., shape=(28, 300), dtype=float32)},
     'linear_1': {'b': ndarray(..., shape=(100,), dtype=float32),
                  'w': ndarray(..., shape=(1000, 100), dtype=float32)}}
    

    Does haiku have some inbuilt function to save and load models? It becomes crucial in transfer learning tasks. Thanks in advance,

    opened by VIGNESHinZONE 6
  • "NCHW" data_format in Conv not working with latest CUDA

    I'm not able to use the NCHW data format in conv layers:

    import os
    import numpy as np
    import jax
    import haiku as hk
    
    # os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false"
    os.environ["XLA_FLAGS"] = "--xla_gpu_cuda_data_dir=/opt/cuda"
    
    def net(x):
        model = hk.Sequential([hk.Conv2D(2, 5, padding="VALID", data_format="NCHW")])
        return model(x)
    
    key = jax.random.PRNGKey(42)
    net_transformed = hk.without_apply_rng(hk.transform(net))
    params = net_transformed.init(key, np.zeros((1, 1, 28, 28)))
    

    The snippet above works fine on the CPU but on the GPU gives tensorflow-style spew of errors below. The problem goes away if I change data_format to NHWC. I'm running pretty recent versions of nvidia driver and cuda and the same snippet seems to run on older versions (according to a few people I sent it to) so pretty sure it's related to those. My versions are:

    cuda 11.1.0-2
    nvidia driver: 455.38
    jax 0.2.5
    jaxlib 0.1.57+cuda111 
    dm-haiku 0.0.2
    

    Error:

    2020-11-17 12:18:03.717098: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:349] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
    2020-11-17 12:18:03.718623: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:349] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
    2020-11-17 12:18:03.719796: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:349] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
    2020-11-17 12:18:03.719969: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/gpu_conv_algorithm_picker.cc:772] Failed to determine best cudnn convolution algorithm: Internal: All algorithms tried for convolution %custom-call = (f32[1,20,24,24]{3,2,1,0}, u8[0]{0}) custom-call(f32[1,1,28,28]{3,2,1,0} %parameter.1, f32[5,5,1,20]{1,0,2,3} %copy.1), window={size=5x5}, dim_labels=bf01_01io->bf01, custom_call_target="__cudnn$convForward", metadata={op_type="conv_general_dilated" op_name="conv_general_dilated[ batch_group_count=1\n                      dimension_numbers=ConvDimensionNumbers(lhs_spec=(0, 1, 2, 3), rhs_spec=(3, 2, 0, 1), out_spec=(0, 1, 2, 3))\n                      feature_group_count=1\n                      lhs_dilation=(1, 1)\n                      lhs_shape=(1, 1, 28, 28)\n                      padding=((0, 0), (0, 0))\n                      precision=None\n                      rhs_dilation=(1, 1)\n                      rhs_shape=(5, 5, 1, 20)\n                      window_strides=(1, 1) ]"}, backend_config="{\"algorithm\":\"0\",\"tensor_ops_enabled\":false,\"conv_result_scale\":1,\"activation_mode\":\"0\",\"side_input_scale\":0}" failed. Falling back to default algorithm. 
    
    Convolution performance may be suboptimal.
    2020-11-17 12:18:03.800681: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_dnn.cc:349] Could not create cudnn handle: CUDNN_STATUS_INTERNAL_ERROR
    2020-11-17 12:18:03.800721: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_client.cc:1809] Execution of replica 0 failed: Unimplemented: DNN library is not found.
    Traceback (most recent call last):
      File "scratch.py", line 17, in <module>
        params = net_transformed.init(key, np.zeros((1, 1, 28, 28)))
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/haiku/_src/transform.py", line 111, in init_fn
        params, state = f.init(*args, **kwargs)
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/haiku/_src/transform.py", line 277, in init_fn
        f(*args, **kwargs)
      File "scratch.py", line 12, in net
        return model(x)
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/haiku/_src/module.py", line 406, in wrapped
        out = f(*args, **kwargs)
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/haiku/_src/module.py", line 263, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/haiku/_src/basic.py", line 124, in __call__
        out = layer(out, *args, **kwargs)
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/haiku/_src/module.py", line 406, in wrapped
        out = f(*args, **kwargs)
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/haiku/_src/module.py", line 263, in run_interceptors
        return bound_method(*args, **kwargs)
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/haiku/_src/conv.py", line 195, in __call__
        out = lax.conv_general_dilated(inputs,
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/jax/_src/lax/lax.py", line 571, in conv_general_dilated
        return conv_general_dilated_p.bind(
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/jax/core.py", line 266, in bind
        out = top_trace.process_primitive(self, tracers, params)
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/jax/core.py", line 576, in process_primitive
        return primitive.impl(*tracers, **params)
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/jax/interpreters/xla.py", line 234, in apply_primitive
        return compiled_fun(*args)
      File "/home/milad/miniconda3/envs/my_project/lib/python3.8/site-packages/jax/interpreters/xla.py", line 349, in _execute_compiled_primitive
        out_bufs = compiled.execute(input_bufs)
    RuntimeError: Unimplemented: DNN library is not found.
    
    opened by mil-ad 6
  • Iterating through hk modules

    Iterating through hk modules

    Let's say I want to iterate through all modules inside an hk model and replace all hn.Linears with my own custom Module or monkey-patch some of their properties. Does haiku currently support something along these lines?

    opened by mil-ad 6
  • Dealing with conditionally constant state

    Dealing with conditionally constant state

    How could I add a constant state to my haiku module? Specifically I would want something like this:

    class MyModule(hk.Module):
      def __init__(output_size, const, name):
        if const = True:
          self.b = hk.conts(jnp.ones(output_size)) //won't be updated when adding gradient
        else:
          self.b = jnp.zeros(output_size) //will get updated when adding gradient
    
    opened by joaogui1 6
  • FutureWarning: jax.tree_util.tree_multimap() is deprecated

    FutureWarning: jax.tree_util.tree_multimap() is deprecated

    Looks like dm-haiku is still using tree_multimap() which is now deprecated (resulting in annoying "future warning" messages with the latest jax)

    /usr/local/lib/python3.7/dist-packages/jax/_src/tree_util.py:189: FutureWarning: jax.tree_util.tree_multimap() is deprecated. Please use jax.tree_util.tree_map() instead as a drop-in replacement.
      'instead as a drop-in replacement.', FutureWarning)
    
    opened by sokrypton 5
  • Argument `init` in `get_parameter` is not optional

    Argument `init` in `get_parameter` is not optional

    Hi,

    According to the docs, the init argument of get_parameter is optional, while in reality it raises the error: ValueError: Initializer must be specified. (See line 496 in base.py.)

    Example An example where the init=None might occur is if the initialisation is done outside Haiku. For example:

    import haiku as hk
    import jax.numpy as jnp
    
    @hk.without_apply_rng
    @hk.transform
    def foo(x):
        w = hk.get_parameter("w", [1], init=None)
        return x + w
    
    # Initialise params outside haiku, without `foo.init`.
    params = {'~': {'w': jnp.array([2.], dtype=jnp.float32)}}
    
    x = jnp.array([1])
    foo.apply(params, x)
    

    Kind regards,

    Hylke

    opened by hylkedonker 0
  • How to reinitialize the hidden states of RNNs?

    How to reinitialize the hidden states of RNNs?

    I want use initial_state in this way but get an error: AttributeError: 'Transformed' object has no attribute 'init_hidden_state' What is the best way to to this?

    import haiku as hk
    
    class RNN(hk.Module):
      def __init__(self, hidden_size=4, name=None):
        super().__init__(name=name)
        self.rnn = hk.LSTM(hidden_size)
    
      def __call__(self, h, x):
        out, h = self.rnn(x, h)
        return h, out
    
      def init_hidden_state(self, batch_size=1):
        return self.rnn.initial_state(batch_size)
    
    model = hk.without_apply_rng(hk.transform(lambda h, x: RNN(4)(h, x)))
    h = model.init_hidden_state(1)
    
    opened by qlan3 0
  • Correct way to integrate tf2jax output with a hk.Module

    Correct way to integrate tf2jax output with a hk.Module

    I'm looking at the tf2jax project, and the ability to take TensorFlow pretrained modules and convert them to haiku would be a really useful functionality, since there aren't a lot of available Haiku checkpoints. A typical application is something like

    import tf2jax
    import tensorflow as tf
    import jax.numpy as jnp
    jax_func, jax_params = tf2jax.convert(tf.function(tf.keras.applications.resnet50.ResNet50()), jnp.zeros((1, 224, 224, 3)))
    

    So now I have a function and parameters to do what I want, but I need to insert them into a Haiku module. How should I do this? I'm hoping for some way to eventually be able to

    class MyModule(hk.Module):
        def __call__(self, x):
            x = ResNet50Jax()(x)
            x = # some other module specific stuff
            return x
    

    that I can then proceed with hk.transform as usual. I wasn't able to find an obvious way to do this. Any thoughts?

    More broadly, is it a bad idea to rely on tf2jax for checkpoints, versus perhaps making the model directly in Haiku and manually copying over weights from PyTorch/tensorflow?

    opened by rdilip 1
  • Bump certifi from 2021.10.8 to 2022.12.7 in /docs

    Bump certifi from 2021.10.8 to 2022.12.7 in /docs

    Bumps certifi from 2021.10.8 to 2022.12.7.

    Commits

    Dependabot compatibility score

    Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.


    Dependabot commands and options

    You can trigger Dependabot actions by commenting on this PR:

    • @dependabot rebase will rebase this PR
    • @dependabot recreate will recreate this PR, overwriting any edits that have been made to it
    • @dependabot merge will merge this PR after your CI passes on it
    • @dependabot squash and merge will squash and merge this PR after your CI passes on it
    • @dependabot cancel merge will cancel a previously requested merge and block automerging
    • @dependabot reopen will reopen this PR if it is closed
    • @dependabot close will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
    • @dependabot ignore this major version will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this minor version will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this dependency will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
    • @dependabot use these labels will set the current labels as the default for future PRs for this repo and language
    • @dependabot use these reviewers will set the current reviewers as the default for future PRs for this repo and language
    • @dependabot use these assignees will set the current assignees as the default for future PRs for this repo and language
    • @dependabot use this milestone will set the current milestone as the default for future PRs for this repo and language

    You can disable automated security fix PRs for this repo from the Security Alerts page.

    dependencies 
    opened by dependabot[bot] 0
  • Mypy error from `next_rng_key` type inconsistency with jax `PRNGKeyArray`

    Mypy error from `next_rng_key` type inconsistency with jax `PRNGKeyArray`

    Hi,

    It seems that my mypy (version 0.942) is complaining that Haiku's random key generated by hk.next_rng_key() is not compatible with Jax's PRNGKeyArray type. The latter are the types of the key argument in various jax.random samplers.

    Example

    import jax
    import haiku as hk
    
    def sample_phi(alpha: float):
        phi = jax.random.gamma(hk.next_rng_key(), a=alpha)
        return phi
    

    Error

    example.py:5: error: Argument 1 to "gamma" has incompatible type "ndarray"; expected "Union[Array, PRNGKeyArray]"
    

    Apart from explicitly silencing these errors in mypy, are there any other suggestions to fix these errors?

    Thanks in advance,

    Hylke

    Environment

    dm-haiku==0.0.9
    jax==0.3.25
    jaxlib==0.3.25
    mypy==0.942
    
    opened by hylkedonker 0
Releases(v0.0.9)
  • v0.0.9(Nov 16, 2022)

    What's Changed

    • Support vmap where in_axes is a list rather than a tuple in https://github.com/deepmind/dm-haiku/commit/307cf7dbda64d637ca423cacc9978f0ca19dc8a6
    • Pass pmap axis specs optionally to make_model_info in https://github.com/deepmind/dm-haiku/commit/d0ba451c96a6ac4f44fb9457e252b1d675a5416a
    • Remove use of jax_experimental_name_stack flag in https://github.com/deepmind/dm-haiku/commit/dbc0b1f2ffee9b348a3cb67460f28f9cc4667f08
    • Add param_axis argument to RMSNorm to allow setting scale param shape in https://github.com/deepmind/dm-haiku/commit/a4998a02bc4e8303f9897e5c32ded90cc38fa84f
    • Add documentation and error messages for w_init and w_init_scale to avoid confusion in https://github.com/deepmind/dm-haiku/pull/541
    • Fix hk.while_loop carrying state when reserving variable sizes of rng keys. by @copybara-service in https://github.com/deepmind/dm-haiku/pull/551
    • Add ensemble example to hk.lift documentation. by @copybara-service in https://github.com/deepmind/dm-haiku/pull/556

    Full Changelog: https://github.com/deepmind/dm-haiku/compare/v0.0.8...v0.0.9

    Source code(tar.gz)
    Source code(zip)
  • v0.0.8(Sep 21, 2022)

    • Added experimental.force_name.
    • Added ability to simulate a method name in experimental.name_scope.
    • Added a config option for PRNG key block size.
    • Added unroll parameter to dynamic_unroll.
    • Remove use of deprecated jax.tree_* functions.
    • Many improvements to our examples.
    • Improve error messages in vmap.
    • Support jax_experimental_name_stack in jaxpr_info.
    • transform_and_run now supports a map on PRNG keys.
    • remat now uses the new JAX remat implementation.
    • Scale parameter is now optional in RMSNorm.
    Source code(tar.gz)
    Source code(zip)
  • v0.0.7(Jul 4, 2022)

    Source code(tar.gz)
    Source code(zip)
  • v0.0.6(Feb 14, 2022)

    Source code(tar.gz)
    Source code(zip)
  • v0.0.5(Nov 1, 2021)

    • Added support for mixed precision training (dba1fd9) via jmp
    • Added hk.with_empty_state(..).
    • Added hk.multi_transform(..) (#137), supporting transforming multiple functions that share parameters.
    • Added hk.data_structures.is_subset(..) to test whether parameters are a subset of another.
    • Minimum Python version is now 3.7.
    • Multiple changes in preparation for a future version of Haiku changing to plain dicts.
    • hk.next_rng_keys(..) now returns a stacked array rather than a collection.
    • hk.MultiHeadAttention now supports distinct sequence lengths in query and key/value.
    • hk.LayerNorm now optionally supports faster (but less stable) variance computation.
    • hk.nets.MLP now has an output_shape property.
    • hk.nets.ResNet now supports changing strides.
    • UnexpectedTracerError inside a Haiku transform now has a more useful error message.
    • hk.{lift,custom_creator,custom_getter} are no longer experimental.
    • Haiku now supports JAX's pluggable RNGs.
    • We have made multiple improvements to our docs an error messages.

    Any many other small fixes and improvements.

    Source code(tar.gz)
    Source code(zip)
  • v0.0.4(Apr 12, 2021)

    Changelog:

    • (Important Fix) Fixed strides in basic block (300e6a40be3).
    • Added map, partition_n and traverse to data_structures.
    • Added "build your own Haiku" to the docs.
    • Added summarise utility to Haiku.
    • Added visualisation section to docs.
    • Added precision arg to Linear, Conv and ConvTranspose.
    • Added RMSNorm.
    • Added module_name and name to GetterContext.
    • Added hk.eval_shape.
    • Improved performance of non cross-replica BN variance.
    • Haiku branch functions are only traced once (mirroring JAX).
    • Attention logits are rescaled before the softmax now.
    • ModuleMetaclass now inherits from Protocol.
    • Removed "dot access" to FlatMapping.
    • Removed query_size from MultiHeadAttention constructor.

    Any many other small fixes and improvements.

    Source code(tar.gz)
    Source code(zip)
  • v0.0.3(Nov 24, 2020)

    Changelog:

    • Added hk.experimental.intercept_methods.
    • Added hk.running_init.
    • Added hk.experimental.name_scope.
    • Added optional support for state in custom_creator and custom_getter.
    • Added index groups to BatchNorm.
    • Added interactive notebooks to documentation, including basics guide.
    • Added support for batch major unrolls in static_unroll and dynamic_unroll.
    • Added hk.experimental.abstract_to_dot.
    • Added step markers in imagenet example.
    • Added hk.MultiHeadAttention.
    • Added option to remove double bias from VanillaRNN.
    • Added support for feature_group_count in ConvND.
    • Added logits config to resnet models.
    • Added various control flow primitives (fori_loop, switch, while_loop).
    • Added cross_replica_axis to VectorQuantizerEMA.
    • Added original_shape to ParamContext.
    • Added hk.SeparableDepthwiseConv2D.
    • Added support for unroll kwarg to hk.scan.
    • Added output_shape argument to ConvTranspose modules.
    • Replaced frozendict with FlatMapping, significantly reduces overheads calling jitted computations.
    • Misc changes to ensure parameter dtype follows input dtype.
    • Multiple changes to support JAX omnistaging.
    • ExponentialMovingAverage.initialize now takes shape/dtype not value.
    • Replaced optix with optax in examples.
    • hk.Embed embeddings now created lazily.
    • Re-indexed documentation for easier navigation.
    Source code(tar.gz)
    Source code(zip)
  • v0.0.2(Jul 29, 2020)

    Changelog:

    • Changed the default value of apply_rng to True in hk.transform to simplify the apply_fn signature.
    • Made ConvND, ConvNDTranspose, ResetCore and pooling modules optionally batched.
    • Added hk.GroupNorm.
    • Added hk.scan.
    • Changed hk.BatchNorm to always create state for moving averages.
    • Changed use_projection in hk.nets.ResNet to take a sequence of bools.
    • Exposed hk.net.ResNet.{BlockGroup, BlockV1, BlockV2}.
    • Added original_dtype to ParamContext to expose the original parameter dtype to custom_getters.
    • Added GAN example notebook.
    Source code(tar.gz)
    Source code(zip)
  • v0.0.1(Jun 4, 2020)

    Haiku is a simple neural network library for JAX developed by some of the authors of Sonnet, a neural network library for TensorFlow.

    Changelog:

    Features:

    • Exposed hk.nets.ResNet and addeed hk.nets.ResNet{18,34,101,152,200}
    • Added IdentityCore.
    • Added custom_getter API for advanced parameter manipulation.
    • Added ConvND and lifted N<=3 restriction.
    • Added tree_size and tree_bytes to easily compute parameter counts.
    • hk.remat now only threads changed values (faster compilation).
    • Added support for @dataclass to define modules.
    • Added support for splitting >1 key at a time k1, k2 = hk.next_rng_keys(2).
    • Experimental: Added profiler_name_scopes API to add Haiku names to XProf.
    • Experimental: Added optimize_rng_use to improve compilation time for models with lots of RNG keys.

    Examples:

    • Added language model example.
    • Added VQVAE example.

    Bug fixes:

    • LayerNorm now correctly handles bf16 inputs.
    • TruncatedNormal initializer now respects dtype.

    Usability:

    • Improved error messages for get_parameter, to_module and others.
    • Reimplemented core modules with "public" API (easier to read and fork).
    • Added tests that ensure all public symbols are included in documentation.
    • Added type annotations to more internal code.
    Source code(tar.gz)
    Source code(zip)
  • v0.0.1-beta(Mar 26, 2020)

    Changes

    Examples

    • Added VAE example.
    • Added pruning example (https://arxiv.org/abs/1710.01878).
    • MNIST example uses 300-100-10 MLP.
    • Updated imagenet dataset to return correctly scaled examples.

    Breaking changes

    • State arg to hk.transform dropped in favor of transform_with_state.
    • Decay argument is now required in BatchNorm.

    Features

    • Added hk.maybe_next_rng_key().
    • BatchNorm and LayerNorm speed improvements.
    • Added support for partition/filter/merge params.
    • Haiku now allows running with jax_numpy_rank_promotion.

    Experimental features

    • hk.experimental.to_dot - experimental visualisation support.
    • hk.experimental.lift - experimental purification support.

    Usability

    • Improved error message when RNG arg is not and RNG.
    • Improved documentation.
    • Improved test coverage.
    Source code(tar.gz)
    Source code(zip)
  • v0.0.1-alpha(Feb 20, 2020)

Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
Flax is a neural network ecosystem for JAX that is designed for flexibility.

Flax: A neural network library and ecosystem for JAX designed for flexibility Overview | Quick install | What does Flax look like? | Documentation See

Google 3.9k Jan 2, 2023
Evolving neural network parameters in JAX.

Evolving Neural Networks in JAX This repository holds code displaying techniques for applying evolutionary network training strategies in JAX. Each sc

Trevor Thackston 6 Feb 12, 2022
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.

Rockpool Rockpool is a Python package for developing signal processing applications with spiking neural networks. Rockpool allows you to build network

SynSense 21 Dec 14, 2022
This is a model made out of Neural Network specifically a Convolutional Neural Network model

This is a model made out of Neural Network specifically a Convolutional Neural Network model. This was done with a pre-built dataset from the tensorflow and keras packages. There are other alternative libraries that can be used for this purpose, one of which is the PyTorch library.

null 9 Oct 18, 2022
Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX.

Equinox Callable PyTrees and filtered JIT/grad transformations => neural networks in JAX Equinox brings more power to your model building in JAX. Repr

Patrick Kidger 909 Dec 30, 2022
This is a JAX implementation of Neural Radiance Fields for learning purposes.

learn-nerf This is a JAX implementation of Neural Radiance Fields for learning purposes. I've been curious about NeRF and its follow-up work for a whi

Alex Nichol 62 Dec 20, 2022
A minimal TPU compatible Jax implementation of NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis

NeRF Minimal Jax implementation of NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis. Result of Tiny-NeRF RGB Depth

Soumik Rakshit 11 Jul 24, 2022
A lossless neural compression framework built on top of JAX.

Kompressor Branch CI Coverage main (active) main development A neural compression framework built on top of JAX. Install setup.py assumes a compatible

Rosalind Franklin Institute 2 Mar 14, 2022
FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.

FedJAX: Federated learning with JAX What is FedJAX? FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX priori

Google 208 Dec 14, 2022
Newt - a Gaussian process library in JAX.

Newt __ \/_ (' \`\ _\, \ \\/ /`\/\ \\ \ \\

AaltoML 0 Nov 2, 2021
Scenic: A Jax Library for Computer Vision and Beyond

Scenic Scenic is a codebase with a focus on research around attention-based models for computer vision. Scenic has been successfully used to develop c

Google Research 1.6k Dec 27, 2022
JAXDL: JAX (Flax) Deep Learning Library

JAXDL: JAX (Flax) Deep Learning Library Simple and clean JAX/Flax deep learning algorithm implementations: Soft-Actor-Critic (arXiv:1812.05905) Transf

Patrick Hart 4 Nov 27, 2022
This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CNPs), Neural Processes (NPs), Attentive Neural Processes (ANPs).

The Neural Process Family This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CN

DeepMind 892 Dec 28, 2022
Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable the user to perform stochastic variational inference in Bayesian deep neural networks

Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable the user to perform stochastic variational inference in Bayesian deep neural networks. Bayesian-Torch is designed to be flexible and seamless in extending a deterministic deep neural network architecture to corresponding Bayesian form by simply replacing the deterministic layers with Bayesian layers.

Intel Labs 210 Jan 4, 2023
Neural-net-from-scratch - A simple Neural Network from scratch in Python using the Pymathrix library

A Simple Neural Network from scratch A Simple Neural Network from scratch in Pyt

Youssef Chafiqui 2 Jan 7, 2022
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"

Optimal Model Design for Reinforcement Learning This repository contains JAX code for the paper Control-Oriented Model-Based Reinforcement Learning wi

Evgenii Nikishin 43 Sep 28, 2022
A lightweight Python-based 3D network multi-agent simulator. Uses a cell-based congestion model. Calculates risk, loudness and battery capacities of the agents. Suitable for 3D network optimization tasks.

AMAZ3DSim AMAZ3DSim is a lightweight python-based 3D network multi-agent simulator. It uses a cell-based congestion model. It calculates risk, battery

Daniel Hirsch 13 Nov 4, 2022