Flax is a neural network ecosystem for JAX that is designed for flexibility.

Related tags

Deep Learning jax
Overview

Flax: A neural network library and ecosystem for JAX designed for flexibility

Build coverage

Overview | Quick install | What does Flax look like? | Documentation

See our full documentation to learn everything you need to know about Flax.

Flax was originally started by engineers and researchers within the Brain Team in Google Research (in close collaboration with the JAX team), and is now developed jointly with the open source community.

Flax is being used by a growing community of hundreds of folks in various Alphabet research departments for their daily work, as well as a growing community of open source projects.

The Flax team's mission is to serve the growing JAX neural network research ecosystem -- both within Alphabet and with the broader community, and to explore the use-cases where JAX shines. We use GitHub for almost all of our coordination and planning, as well as where we discuss upcoming design changes. We welcome feedback on any of our discussion, issue and pull request thread. We are in the process of moving some remaining internal design docs and conversation threads to GitHub discussions, issues and pull requests. We hope to increasingly engage with the needs and clarifications of the broader ecosystem. Please let us know how we can help!

NOTE: The new Flax "Linen" module API is now stable and we recommend it for all new projects. The old flax.nn API will be deprecated.

Please report any feature requests, issues, questions or concerns in our discussion forum, or just let us know what you're working on!

We expect to improve Flax, but we don't anticipate significant breaking changes to the core API. We use Changelog entries and deprecation warnings when possible.

In case you want to reach us directly, we're at [email protected].

Overview

Flax is a high-performance neural network library and ecosystem for JAX that is designed for flexibility: Try new forms of training by forking an example and by modifying the training loop, not by adding features to a framework.

Flax is being developed in close collaboration with the JAX team and comes with everything you need to start your research, including:

  • Neural network API (flax.linen): Dense, Conv, {Batch|Layer|Group} Norm, Attention, Pooling, {LSTM|GRU} Cell, Dropout

  • Optimizers (flax.optim): SGD, Momentum, Adam, LARS, Adagrad, LAMB, RMSprop

  • Utilities and patterns: replicated training, serialization and checkpointing, metrics, prefetching on device

  • Educational examples that work out of the box: MNIST, LSTM seq2seq, Graph Neural Networks, Sequence Tagging

  • Fast, tuned large-scale end-to-end examples: CIFAR10, ResNet on ImageNet, Transformer LM1b

Quick install

You will need Python 3.6 or later and a working JAX installation (with or without GPU support, see instructions there). For a CPU-only version:

> pip install --upgrade pip # To support manylinux2010 wheels.
> pip install --upgrade jax jaxlib # CPU-only

Then install Flax from PyPi:

> pip install flax

To upgrade to the latest version of Flax, you can use:

> pip install --upgrade git+https://github.com/google/flax.git

What does Flax look like?

We provide three examples using the Flax API: a simple multi-layer perceptron, a CNN and an auto-encoder.

To learn more about the Module abstraction, please check our docs, our broad intro to the Module abstraction or visit our HOWTO guides page for additional concrete demonstrations of best practices.

class MLP(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, x):
    for feat in self.features[:-1]:
      x = nn.relu(Dense(feat)(x))
    x = Dense(self.features[-1])(x)
    return x
class CNN(nn.Module):
  @nn.compact
  def __call__(self, x):
    x = nn.Conv(features=32, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = nn.Conv(features=64, kernel_size=(3, 3))(x)
    x = nn.relu(x)
    x = nn.avg_pool(x, window_shape=(2, 2), strides=(2, 2))
    x = x.reshape((x.shape[0], -1))  # flatten
    x = nn.Dense(features=256)(x)
    x = nn.relu(x)
    x = nn.Dense(features=10)(x)
    x = nn.log_softmax(x)
    return x
class AutoEncoder(Module):
  encoder_widths: Sequence[int]
  decoder_widths: Sequence[int]
  input_shape: Tuple[int] = None

  def setup(self):
    self.encoder = MLP(self.encoder_widths)
    self.decoder = MLP(self.decoder_widths + (jnp.prod(self.input_shape, ))

  def __call__(self, x):
    return self.decode(self.encode(x))

  def encode(self, x):
    assert x.shape[1:] == self.input_shape
    return self.encoder(jnp.reshape(x, (x.shape[0], -1)))

  def decode(self, z):
    z = self.decoder(z)
    x = nn.sigmoid(z)
    x = jnp.reshape(x, (x.shape[0],) + self.input_shape)
    return x

Citing Flax

To cite this repository:

@software{flax2020github,
  author = {Jonathan Heek and Anselm Levskaya and Avital Oliver and Marvin Ritter and Bertrand Rondepierre and Andreas Steiner and Marc van {Z}ee},
  title = {{F}lax: A neural network library and ecosystem for {JAX}},
  url = {http://github.com/google/flax},
  version = {0.3.0},
  year = {2020},
}

In the above bibtex entry, names are in alphabetical order, the version number is intended to be that from flax/version.py, and the year corresponds to the project's open-source release.

Note

Flax is an open source project maintained by a dedicated team in Google Research, but is not an official Google product.

Comments
  • Module tabulation

    Module tabulation

    What does this PR do?

    Adds mechanism to tabulate a Module, similar to Haiku's tabulate or Kera's summary. This builds on the ideas/comments from #288 and aims to use rich to ultimately generate a nice table representation of the module. As previously discussed, the idea is to implement 2 steps:

    1. Generate an intermediate representation
    2. Render such representation

    The PR has the following goals:

    • [x] Create function to extract an intermediate representation for the Module (e.g. module_info).
    • [x] Create a function that takes the intermediate representation and render is as a string (e.g. render_table).
    • [x] Add a method to Module that directly render the module to a string (e.g. Module.tabulate).
    • [x] Add docs
    • [x] Add tests

    NOTE: current names used are open for debate.

    Priority: P2 - eventual 
    opened by cgarciae 32
  • FLIP: Make module instances semantically meaningful by not overriding `module.__new__`

    FLIP: Make module instances semantically meaningful by not overriding `module.__new__`

    Introduction

    Currently, while Flax modules are defined by subclassing flax.nn.Module, those modules don't behave the same way that normal Python objects behave.

    One of the large differences is that Flax Modules override __new__, meaning that module instances aren't a semantically meaningful thing in Flax at the moment. Right now, in Flax, what looks like module construction (nn.Dense(x, features=10)) actually does two things:

    1. Construct an object of type nn.Dense (using the non-documented API module.new_instance())
    2. Call the apply method on that instance and return it.

    Some upsides of the current approach are:

    1. Modules are defined as a single function, as opposed to, e.g. the style of other libraries, such as Haiku, where you need to scroll up and down between __init__ and __call__ to fully understand what a module does.
    2. Calls to submodules are very concise, e.g. nn.Dense(x, features=10).

    Some downsides of the current approach are:

    1. In order to reuse a module, you must use the module.shared() abstraction which has a confusing mental model -- what does module.shared() return? A module class? A module instance? Moreover, which arguments must be passed into module.shared() in order for the shared module to be usable? (Behind the scenes shared is implemented on top of partial)
    2. You can't instantiate a module directly, outside of another module. This leads to surprising things like new nn.Model(nn.Dense.partial(features=10), params) -- why do we need to use partial to instantiate a Model? What type does the first argument to nn.Model have? Is it a module class? Module instance?
    3. In a few spots in flax/nn/base.py there is code that does "kwarg mangling". Some of these code had bugs before. It would be nice to reduce the need for kwarg mangling.
    4. In order to support multiple methods on a module, the module_method decorator turns methods that aren't apply into new Modules. This is surprising, for example how would I do the equivalent of module.call(params, *args) but to call a method foo that's not apply? That would be module.foo.call(params, *args). That's a pretty surprising mental model.
    5. Wanting to define shared parameters or submodules that work across multiple methods requires either using non-traditional patterns and/or with more complexity in Flax core (see discussion on https://github.com/google/flax/issues/161)
    6. apply was a special-cased method on modules.

    Proposal

    1. No longer override __new__ in Modules
    2. Eliminate .partial()
    3. Potentially eliminate .shared() (though we may choose to keep it as a safeguard -- see below)
    4. Split up current module's apply methods into the controlled use of Python 3.7 dataclasses (for storing module hyperparameters) and a "vanilla Python" __call__ method (or actually, any name you want) that only takes in the module input(s)
    5. This may even allow for module instance to directly refer to a read-only version of their parameters, e.g.:
    class Foo(Module):
      def __init__(x):
        dense = nn.Dense(features=16)
        x = dense(x)
        # `dense.params` is defined here; maybe also `dense.params.kernel` and `dense.params.bias`
    

    For example, a simple Dense layer may look like this:

    @dataclass
    class Dense(Module):
      features: int
      kernel_init: Callable = initializers.lecun_normal()
      bias_init: Callable = initializers.zeros
    
      def __call__(self, x):
        """Applies a linear transformation to the inputs along the last dimension."""
        kernel = self.param('kernel', (x.shape[-1], self.features), self.kernel_init)
        bias = self.param('bias', (self.features,), self.bias_init)
        return jnp.dot(x, kernel) + bias
    

    Then, an MLP would look like this:

    class MLP(Module):
      def __call__(self, x):
        x = nn.Dense(features=16)(x)
        x = nn.relu(x)
        x = nn.Dense(features=16)(x)
    

    I believe that this proposals keeps the conciseness of current Flax, while having the potential to significantly reduce both implementation complexity and mental model complexity. The mental model in Flax now reduces to the same one as Keras (other than the fact that parameters are immutable)

    For example, in this case re-using a module is trivial -- keep a reference to nn.Dense(features=16) and re-use that. (NOTE: We may choose to keep the safe-guarding behavior of .shared() that makes it hard to accidentally copy and paste code that accidentally re-uses modules. We can achieve that by having modules default to raising an error when __call__ is invoked a second time, unless .shared() was called on the module instance first)

    With this proposal, there's also no need for module.partial -- you can just use functools.partial(module.__call__) or functools.partial(module). (Though this is a bit different than in current Flax because the return value of functools.partial in itself isn't a module, rather it's a function. But maybe it was always confusing to understand module.partial -- does it override kwargs for all module methods? Just apply?)

    Possible transition plan

    Given the non-trivial amount of code written using Flax, and the fact that this proposal would change every module written with Flax, we need an upgrade plan.

    I propose adding, alongside every new module in flax.nn, a function with the same name but lower-cased, that operates the same as in current Flax. These functions would be deprecated-on-arrival. E.g., alongside Dense as shown above we would also have

    def dense(x, features, kernel_init, bias_init):
      """DEPRECATED. Use the new Module API: http://link/to/upgrade/guide."""
      return Dense(features, kernel_init, bias_init)(x)
    

    Then the first part of the upgrade process is mainly search and replace "Dense" -> "dense", etc.. In addition, some more manual changes will possible be necessary for uses of .partial and .shared. Later, users can transition into the new API at a time they see fit.

    Current State

    @avital has a messy work-in-progress branch checking the viability of using dataclasses in this settings. Results so far seem cautiously promising, but more work is needed before this proposal is ready to be acted on.

    opened by avital 32
  • Set field on dataclass transform decorator

    Set field on dataclass transform decorator

    What does this PR do?

    • Set field on dataclass transform decorator so that type checkers know about the field descriptor.
    • Use the dataclass transform on the class rather than the metaclass.
    • Use typing_extensions.dataclass_transform

    Fixes # (issue)

    Checklist

    • [x] This PR fixes a minor issue (e.g.: typo or small bug) or improves the docs (you can dismiss the other checks if that's the case).
    pull ready 
    opened by NeilGirdhar 27
  • Adds extract_field method

    Adds extract_field method

    What does this PR do?

    Fixes #2609. Adds the Module.extract_field method, useful to extract submodules that are defined inside .setup and their corresponding variables.

    Sample Usage

    class MyModule(nn.Module):
      def setup(self):
        self.submodule = nn.Dense(4)
    
      def __call__(self, x):
        return self.submodule(x)
    
    module = MyModule()
    variables = module.init(jax.random.PRNGKey(0), jnp.ones((1, 2)))
    
    submodule, submodule_vars = module.extract_field('submodule', variables)
    assert submodule.features == 4
    

    Also works for nested object by passing a lambda:

    nested_module, nested_vars = module.extract_field(
      lambda m: m.some['deeply']['nested'].submodule, variables)
    opened by cgarciae 23
  • Inconsistent network behaviour when using different batch sizes for `model.apply` on CPU

    Inconsistent network behaviour when using different batch sizes for `model.apply` on CPU

    Hey, thanks for the great work!

    I'm using BatchNorm in my network, but have set the use_running_average parameter of BatchNorm layers to true, which means it will not compute any running mean/stds using the input data that is passing through the network and it will use the pre-computed parameters. Thus, the network's behaviour doesn't change among different batches (Ideally, I guess, but it should be true).

    I've provided a simple reproducible Colab notebook that reproduces the example. The colab needs two files to run properly which are:

    • wide_resnet_jax.py: The python file containing the shallow WideResNet module implemented using Flax. You can download it from this gist: https://gist.github.com/mohamad-amin/5334109dba81b9c26e7b4d1ded7fd9ad psd_data.pkl, which can be downloaded from: https://drive.google.com/file/d/18eb93M34vaWjFyNzIq-vnOfll0T6HCjT/view?usp=sharing

    psd_data.pkl is the pickled version of a dict containing three things:

    • data: The train and test data used for training the model.
    • params: The trained parameters of the WideResNet module that we're using, such that it will achieve 1.0 train accuracy and 0.89 test accuracy.
    • labels: The labels of the datapoints in data, to double check the accuracies.

    The problem that I have is:

    ys = []
    for i in range(10):
      ys.append(apply_fn(params, X_train[i:i+1]))
    ys = jnp.stack(ys).squeeze()
    vs = apply_fn(params, X_train[:10])
    np.allclose(ys, vs)
    # Outputs False!
    

    which shows that the network's behaviour varies for different outputs. I expect this to output true, as I have fixed the parameters and the BatchNorm layers. Am I doing something wrong?

    https://colab.research.google.com/drive/1a_SheAt9RH9tPRJ1DC60yccsbYaFssDx?usp=sharing

    Priority: P2 - eventual 
    opened by mohamad-amin 22
  • Remove float32 dtype assumption

    Remove float32 dtype assumption

    Fixes #1777

    Checklist

    • [x] This change is discussed in a Github issue/discussion (please add a link).
    • [x] The documentation and docstrings adhere to the documentation guidelines.
    • [x] This change includes necessary high-coverage tests. (No quality testing = no merge!)
    Priority: P1 - soon pull ready 
    opened by NeilGirdhar 20
  • Guidelines for Examples

    Guidelines for Examples

    Following some discussion I wrote down a set of guidelines for Flax examples inside the google/flax repository. It will take some work to update the current set of examples but I believe that the examples will be even better afterwards and easier to maintain.

    Guidelines

    • Each example should have at least 2 owners . Creating examples cost time and we want to recognize this. We also would like to have a point of contact in case the example needs to be updated to a new API.
    • Every example should have a README.md which specifies at least:
      • Command for running the example with a link to http://tensorboard.dev for a successful run. Optionally one can also include some output logged on the command line.
      • Link to related references (paper for the model or a great blogpost).
    • Keep the example focused.
      • Each example should support a single model on a single dataset.
      • Keep the number of configurable hyper-parameters basic (<10) and use absl.flags. We do not want to overwhelm readers.
    • Default hyperparameter configuration should be provided as ConfigDict in configs/default.py. An additional configs/test.py for a CPU friendly unit test case are ok.
    • To make it easier to test, maintain and reuse code structure the code as follows:
      • main.py contains the flags and calls a method from train.py to run the training loop. This should be the only file defining flags! This can be almost identical for all examples, please copy from linen_examples/wmt/main.py. If possible stick with the 2 flags (--workdir (not model_dir) and --config).
      • train.py contains classes and methods for training and evaluating the model.
      • train_test.py for test cases of the training code. At a minimum the test should run a single training step but more fine grained unit tests are a bonus. You can use tfds.testing.mock_data to avoid real data from disk.
      • Additional files for more complex architectures or input_pipeline.py.
    • Only use public datasets available in TensorFlow Datasets. TensorFlow Datasets standardizes how datasets downloaded and prepared. Synthetic data is also allowed.
    • Follow the Google Python Style Guide (no relative imports, use type annotation)
    • requirements.txt if running the example requires additional Python packages that are not installed by Flax or it's dependencies.
    • Use the clu package where appropriate. This helps us to de-duplicate common functionality and make the code agnostic to the platform. We would like to see examples that work equally well on Google's internal systems as on Cloud or locally.

    Future Work

    Examples should also be covered by regression tests, see #144.

    We should also add general instructions for running on examples (on GPU and TPU) using Google Cloud.

    Priority: P2 - eventual 
    opened by Marvin182 19
  • Documentation for recurrent

    Documentation for recurrent

    I'm studying RNN's using jax so I'm currently investigating flax. I think the documentation in the RNN module is incorrect or out of date.

    https://github.com/google-research/flax/blob/e7247d58e4f3460c03da5f935cb83d9c0883a97c/flax/nn/recurrent.py#L21-L23

    Results in TypeError: apply() missing 1 required positional argument: 'inputs'

    Also create builds and evaluates the model and returns a (y, model), so I feel like the design has changed and the recurrent examples should either initialise the state before calling create (but that wouldn't scan), or call create_by_shape?

    Edit: I found a test which seems to confirm that the docstring is incorrect i.e. the code below creates an initial carry and passes to create

    https://github.com/google-research/flax/blob/e7247d58e4f3460c03da5f935cb83d9c0883a97c/tests/nn_test.py#L461-L468

    2nd Edit:

    Also, I'm slightly confused by LSTMCell.initialize_carry() - it requires a batch_dim, and returns an initialised (zero) state for each batch. I might be missing something but this seems to preclude using lax.scan() to process each batch sequentially using the state from the previous batch as the initial state for the next batch (or some other state estimator which is specifically what I'm attempting). For example I have 365 "trajectories" (timeseries) each consisting of 24 samples and 5 features. So the state should be size 5 and I want to scan each trajectory from some initial state I provide (the intent is to use another net to estimate the state).

    opened by david-waterworth 18
  • Add functionality for capturing intermediates

    Add functionality for capturing intermediates

    • Module.sow tracks values in a collection if it's mutable and acts as a no-op otherwise
    • Module.apply(capture_intermediates=...) can be used to track return values of submodules (by default from the call method) but a custom filter can be passed if necessary.

    See https://flax.readthedocs.io/en/latest/howtos/extracting_intermediates.html for how this is used.

    cla: yes Priority: P1 - soon pull ready 
    opened by jheek 17
  • 4x slowdown in evaluation of RBM (Flax.linen vs jax.experimental.stax)

    4x slowdown in evaluation of RBM (Flax.linen vs jax.experimental.stax)

    Provide as much information as possible. At least, this should include a description of your issue and steps to reproduce the problem. If possible also provide a summary of what steps or workarounds you have already tried.

    Problem you have encountered:

    I compare a simple implementation of an RBM with flax against a similar implementation with Jax.experimental.stax. See this gist notebook .

    The two produce the same jaxpr code when traced, so I would expect comparable performance (minus dispatch cost and time taken to flatten/unflattne the inputs and outputs), but that is not the case, and flax has a 4x disadvantage.

    Essentially, the two implementations are

    stax.serial(stax.Dense(alpha * L), stax.Sigmoid, SumLayer)
    

    and

    class FlaxRBM(nn.Module):
        dtype: Any = np.float32
        alpha: int = 1
        use_bias: bool = True
    
        @nn.compact
        def __call__(self, x):
            x = nn.Dense(
                name="Dense",
                features=self.alpha * x.shape[-1],
                dtype=self.dtype,
                use_bias=self.use_bias,
            )(x)
            x = nn.activation.sigmoid(x)
            return jnp.sum(x, axis=-1)
    

    What I observe is peculiar:

    # alpha=1
    # Input shape (1,1) 
    
    # jax
    63.3 µs ± 247 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    
    # flax
    252 µs ± 7.65 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    

    which would suggest that flax has a 4x times the dispatch cost of jax (weird... but ok).

    Still, if I increase the size:

    # alpha=3
    # Input shape (32,20) 
    
    # jax
    69.5 µs ± 4.92 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    
    # flax
    280 µs ± 31.2 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    

    We finally overcome the dispatch cost, but flax runtime increases too?

    # alpha=6
    # Input shape (128,30) 
    
    # jax
    116 µs ± 8.51 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    
    # flax
    407 µs ± 83.7 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    

    ???

    I also checked that the two produced the same jaxpr, which is indeed the case

    jax.make_jaxpr(j_ma.apply)(j_w, x)
    { lambda  ; a b c.
      let d = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
                           precision=None ] c a
          e = broadcast_in_dim[ broadcast_dimensions=(1,)
                                shape=(1, 180) ] b
          f = add d e
          g = sign f
          h = mul f g
          i = mul h -2.0
          j = exp i
          k = add j 1.0
          l = log k
          m = add h l
          n = log 2.0
          o = sub m n
          p = reduce_sum[ axes=(1,) ] o
      in (p,) }
    
    jax.make_jaxpr(f_ma.apply)(f_w, x)
    { lambda  ; a b c.
      let d = dot_general[ dimension_numbers=(((1,), (0,)), ((), ()))
                           precision=None ] c b
          e = broadcast_in_dim[ broadcast_dimensions=(1,)
                                shape=(1, 180) ] a
          f = add d e
          g = sign f
          h = mul f g
          i = mul h -2.0
          j = exp i
          k = add j 1.0
          l = log k
          m = add h l
          n = log 2.0
          o = sub m n
          p = reduce_sum[ axes=(1,) ] o
      in (p,) }
    

    I have jax==v0.2.8 and flax==v0.3.0

    Status: blocked Priority: P2 - eventual 
    opened by PhilipVinc 17
  • Run unit tests for each affected example in a howto

    Run unit tests for each affected example in a howto

    From an offline discussion with @avital:

    HOWTOs can fail for two reasons -- the diff can be stale and no longer apply without conflicts, or it can apply and unit tests won't pass. Ideally, both of these should be tested on every commit in CI.

    It looks like the first is already tested for during the apply-howto-branches.yml workflow. This line checks which files were touched by the howto and runs pytest on those directories.

    It seems that the changes in this PR very narrowly resolve the issue @avital raised; the check assumes:

    1. Files that are touched are accompanied by tests in the same directory (examples fall into this category)
    2. The only changes a howto can make are testable via the tests in the same directory (example tests should fall into this category)

    Let me know if/how I misjudged the scope of the issue and I'll be happy to address!

    cla: yes 
    opened by danielsuo 17
  • Rename ConvLSTM

    Rename ConvLSTM

    All cell type names end with the Cell suffix except ConvLSTM. Since we will add full layers like RNN and CudnnLSTM in the near future, this naming disparity might get confusing so we should fix it. @marcvanzee did some digging into some internal code and its not that many cases so the fix should be easy.

    Priority: P1 - soon 
    opened by cgarciae 0
  • linen.scan/remat_scan create params with shape containing 0 when scan length=1

    linen.scan/remat_scan create params with shape containing 0 when scan length=1

    Hello!

    I am not sure if this is done by necessity or just by chance, but nn.scan and remat_scan output params with leading dim shape = 0 when scan length = 1. I noticed but didn't care until I ran into problems later on as it breaks the convention of counting params using prod(shape), which is used fairly often in deep learning code including distributed shampoo.

    for example:

    nn.remat_scan(
        ResBlockV2,
        lengths=(1,),
    )()(x)
    

    can output shape (0, 2, 3, 4, 4), so the workaround for now would be

    if self.n_layers > 1:
        return nn.remat_scan(
            ResBlockV2,
            lengths=(self.n_layers,),
        )()(x)
    else:
        return ResBlockV2()(x)
    

    Is it possible for a single loop scan to output leading dim shape = 1 or would there be negative consequences that I don't know about? If It's possible I think it'd let code be concise without having to do the workaround above while minimizing surprises from a param shape containing a 0.

    System information

    • Flax, jax, jaxlib versions: 0.6.2 flax, 0.3.25 jax/jaxlib
    • Python version: 3.9
    Priority: P3 - no schedule 
    opened by evanatyourservice 1
  • Flax docs restructuring: introduce Developer notes (ex-Advanced), move Contributing and Philosophy to main ToC, and other changes

    Flax docs restructuring: introduce Developer notes (ex-Advanced), move Contributing and Philosophy to main ToC, and other changes

    Partially addresses https://github.com/google/flax/issues/2627 by @marcvanzee

    1. Move 6 docs out of Advanced Topics:
    • Move CONTRIBUTING.md to the main ToC --> Make the Contributing guide visible to (new) users/future contributors.
    • Move The Flax philosophy to the main ToC (@marcvanzee maybe we can move it to Developer notes - see below) --> Make the Philosophy visible.
    • Move Dealing with Flax Module arguments to Guides.
    • Move Convert PyTorch to Flax to Guides.
    • Move Upgrading my codebase to Optax to Guides.
    • Move Upgrading my codebase to Linen to Guides.
    1. Rename Advanced Topics to Developer notes.
    2. Move The Sharp Bits to Developer notes.
    3. Minor updates to/linting of various docs, including:
    • CONTRIBUTING.md
    • The INDEX pages
    • The Flax philosophy
    • Convert PyTorch to Flax
    Priority: P2 - eventual 
    opened by 8bitmp3 5
  • Improve pjit guide

    Improve pjit guide

    The pjit guide #2730 is out, and just to summarize a few items to visit in the future:

    • Change the pip3 install line when v0.6.4 is out.
    • Visit new JAX APIs like jax.sharding, device_put and jit after the APIs are more finalized.
    • Try out the use cases of pjit or jit without input/output sharding specifications.
    Priority: P2 - eventual 
    opened by IvyZX 0
  • Initialization of Submodules Lifted with `flax.nn.scan`

    Initialization of Submodules Lifted with `flax.nn.scan`

    One more issue 😄 . Promise this is the last one. There are a lot of question about flax.nn.scan and RTD and existing GitHub issues do not solve them.

    With very deep model compilation times become insane and it takes about 1 hour to compile model for Nvidia runtime. So, I decided to prevent loop unrolling with jax.lax.scan and its lifting counterpart flax.nn.scan. However, I faced multiple issues. Incomplete list of issues follows.

    1. There is no clear way to initialize scanned submodules. I came up with solution to pass everything as args and kwargs to __call__ of submodule (in this case MLP).
    2. There is no keyword argument of flax.nn.scan as RTD says.
    3. Func flax.nn.scan always returns (carry, args) even if there is only carry and no args.
    4. RTD says that target should be either a type of nn.Module or a function which accepts nn.Module (type?) as its first position argument.
    5. If one specified name of modules in MLP then an exception is thrown. It is a bit strange because all parameter trees merged to a single parameter tree.
    import flax.linen as nn
    import jax
    import jax.numpy as jnp
    
    
    def initializer(val):
        def init(key, shape, dtype):
            return jnp.full(shape, val, dtype)
    
        return init
    
    
    class MLP(nn.Module):
    
        @nn.compact
        def __call__(self, xs, var):
            h = nn.Dense(features=2, kernel_init=initializer(var))(xs)
            h = nn.relu(h)
            h = nn.Dense(features=2, kernel_init=initializer(var))(xs)
            return xs + h, None
    
    
    class Transformer(nn.Module):
    
        length: int = 3
    
        def setup(self):
            def fn(self, *args, **kwargs):
                return MLP(self, *args, **kwargs)
    
            # FAIL: Function instead of derived type from nn.Module does not work.
            #
            #   ScanMLP = nn.scan(target=fn, ...)
            #
            #   jax._src.traceback_util.UnfilteredStackTrace: TypeError:
            #   Transformer.setup.<locals>.fn() missing 1 required positional
            #   argument: 'self'
    
            # OK: No problems.
            ScanMLP = nn.scan(target=fn,
                              variable_axes={'params': 0},
                              variable_broadcast=False,
                              split_rngs={'params': True},
                              length=self.length)
    
            self.vars = jnp.arange(self.length)  # e.g. [0, 1, 2]
            self.mlp = ScanMLP()  # FAIL: ScanMLP(self.vars)
    
        @nn.compact  # OK: This decorator does nothing. Why?
        def __call__(self, xs):
            carry, out = self.mlp(xs, self.vars)  # OK: Axis 0 (implicitely).
            assert out is None
            return carry
    
    
    model = Transformer(length=1250)
    ys, state = jax.jit(model.init_with_output)(jax.random.PRNGKey(42),
                                                jnp.ones((3, 2)))
    kernel = state['params']['mlp']['Dense_0']['kernel']
    assert (kernel[0, ...] == jnp.zeros((2, 2))).all()
    assert (kernel[1, ...] == jnp.ones((2, 2))).all()
    

    In this experiments flax v0.6.3 and jax v0.4.1 are used.

    opened by daskol 0
Releases(v0.6.3)
  • v0.6.3(Dec 9, 2022)

    What's Changed

    • Add gfile api shim to remove tensorflow dependency for basic IO. by @chiamp in https://github.com/google/flax/pull/2586
    • Remove Mypy type errors by @zaxtax in https://github.com/google/flax/pull/2594
    • Attempt to fix pytype issue. by @levskaya in https://github.com/google/flax/pull/2628
    • example/imagenet: use absolute path to locate the Flax root dir by @yhtang in https://github.com/google/flax/pull/2630
    • Move Flax - The Sharp Bits ToC location by @8bitmp3 in https://github.com/google/flax/pull/2633
    • Speed up Github Actions CI by @levskaya in https://github.com/google/flax/pull/2635
    • Update requirements by @marcvanzee in https://github.com/google/flax/pull/2652
    • Switch to Orbax for Flax single-checkpoint support under the hood. by @copybara-service in https://github.com/google/flax/pull/2637
    • Added path discrepancy details in serialization errors by @chiamp in https://github.com/google/flax/pull/2632
    • Fix get_type_hints again by @cgarciae in https://github.com/google/flax/pull/2654
    • BatchNorm guide by @cgarciae in https://github.com/google/flax/pull/2536
    • Generalize pool to handle multiple batch dimensions by @cgarciae in https://github.com/google/flax/pull/2591
    • Improve transfer learning guide by @cgarciae in https://github.com/google/flax/pull/2595
    • update docstrings and error messages in traverse_util.py. by @copybara-service in https://github.com/google/flax/pull/2666
    • Use a different rng key for each batch element in DenseGeneral init by @j-towns in https://github.com/google/flax/pull/2665
    • Added check for Mac M1 chip, when conducting serialization test. by @chiamp in https://github.com/google/flax/pull/2657
    • Added explicit warning if flax.io is using default Python I/O operations by @chiamp in https://github.com/google/flax/pull/2625
    • Check file or directory before removing checkpoints. by @IvyZX in https://github.com/google/flax/pull/2676
    • updated extracting_intermediates in flax docs by @chiamp in https://github.com/google/flax/pull/2685
    • updated model_surgery in flax_docs by @chiamp in https://github.com/google/flax/pull/2687
    • updated getting_started in flax docs by @chiamp in https://github.com/google/flax/pull/2684
    • Updated flax docs. by @chiamp in https://github.com/google/flax/pull/2667
    • updated flax_basics in flax docs by @chiamp in https://github.com/google/flax/pull/2686
    • Update python version support by @cgarciae in https://github.com/google/flax/pull/2682
    • add mypy.ini placeholder. by @copybara-service in https://github.com/google/flax/pull/2693
    • Release 0.6.3 by @IvyZX in https://github.com/google/flax/pull/2705

    New Contributors

    • @yhtang made their first contribution in https://github.com/google/flax/pull/2630

    Full Changelog: https://github.com/google/flax/compare/v0.6.2...v0.6.3

    Source code(tar.gz)
    Source code(zip)
  • v0.6.2(Nov 15, 2022)

    What's Changed

    • Refactor out dataclass transform that allows parent and name to be moved to the end of the argument list into a more general "kw_only_dataclasses" module. by @copybara-service in https://github.com/google/flax/pull/2468
    • Don't create reference cycles among Modules. by @levskaya in https://github.com/google/flax/pull/2499
    • adds perturb to docs by @cgarciae in https://github.com/google/flax/pull/2511
    • Add pre-commit hook to remove trailing white spaces by @cgarciae in https://github.com/google/flax/pull/2513
    • Add extracting gradients section to capture intermediates guide by @cgarciae in https://github.com/google/flax/pull/2515
    • Make is_fully_replicated and is_fully_addressble a property rather than a method. by @copybara-service in https://github.com/google/flax/pull/2516
    • Adds CallSetupUnboundModuleError by @cgarciae in https://github.com/google/flax/pull/2496
    • Adding documentation to Dropout around rng use by @zaxtax in https://github.com/google/flax/pull/2492
    • no-op when double wrapping with struct.dataclass by @cgarciae in https://github.com/google/flax/pull/2505
    • Add epub and pdf RDT formats by @cgarciae in https://github.com/google/flax/pull/2517
    • Update landing page example by @cgarciae in https://github.com/google/flax/pull/2366
    • Use gfile.remove for files because it doesn't work on GCS files. by @IvyZX in https://github.com/google/flax/pull/2518
    • Add save_checkpoint_multiprocess to api reference. by @IvyZX in https://github.com/google/flax/pull/2519
    • Add Tensorstore back to required dependencies. by @IvyZX in https://github.com/google/flax/pull/2520
    • Delete flax.optim.rst by @ppwwyyxx in https://github.com/google/flax/pull/2522
    • Added an error for when we call init, init_with_output and apply on Module class. by @chiamp in https://github.com/google/flax/pull/2529
    • Adding link to Bayesian Inference example that uses Flax by @zaxtax in https://github.com/google/flax/pull/2521
    • Added IncorrectPostInitOverrideError to capture incorrect post init overrides. by @copybara-service in https://github.com/google/flax/pull/2535
    • Fix to use is_initializing for init-detection. by @yotarok in https://github.com/google/flax/pull/2486
    • add rng_collection argument to Dropout by @cgarciae in https://github.com/google/flax/pull/2540
    • Cancel tests if other jobs fail by @cgarciae in https://github.com/google/flax/pull/2507
    • Update Guides link in Flax README by @8bitmp3 in https://github.com/google/flax/pull/2544
    • Pin jupytext version in requirements.txt by @IvyZX in https://github.com/google/flax/pull/2545
    • Fix flax.linen.stochastic.Dropout by @dslisleedh in https://github.com/google/flax/pull/2510
    • Transfer Learning Guide by @cgarciae in https://github.com/google/flax/pull/2394
    • Update Flax Contributing.md by @8bitmp3 in https://github.com/google/flax/pull/2546
    • cap dynamic scale to float32 max by @jheek in https://github.com/google/flax/pull/2553
    • Remove optional import of jax.experimental.array for older JAX versions. by @copybara-service in https://github.com/google/flax/pull/2552
    • Update examples.rst with denoising-diffusion-flax by @yiyixuxu in https://github.com/google/flax/pull/2487
    • return None if no _parent_ref is set by @cgarciae in https://github.com/google/flax/pull/2548
    • Add a documentation page on checkpointing by @IvyZX in https://github.com/google/flax/pull/2530
    • Lint Flax Contributing guide by @8bitmp3 in https://github.com/google/flax/pull/2560
    • Remove extra metadata in Checkpointing guide by @8bitmp3 in https://github.com/google/flax/pull/2559
    • Update getting_started.ipynb and getting_started.md by @chiamp in https://github.com/google/flax/pull/2563
    • Remove unused svn dependency by @8bitmp3 in https://github.com/google/flax/pull/2574
    • Fix pytype check in checkpoints.py by @IvyZX in https://github.com/google/flax/pull/2592
    • Add new 🔪 Flax - The Sharp Bits 🔪 Dropout and randomness by @8bitmp3 in https://github.com/google/flax/pull/2593
    • Fixes launch_gce.sh with imagenet example. by @andsteing in https://github.com/google/flax/pull/2598
    • Added test to check for Variable warning. by @chiamp in https://github.com/google/flax/pull/2610
    • Release version 0.6.2 by @IvyZX in https://github.com/google/flax/pull/2613

    New Contributors

    • @zaxtax made their first contribution in https://github.com/google/flax/pull/2492
    • @ppwwyyxx made their first contribution in https://github.com/google/flax/pull/2522
    • @chiamp made their first contribution in https://github.com/google/flax/pull/2529
    • @yotarok made their first contribution in https://github.com/google/flax/pull/2486
    • @yiyixuxu made their first contribution in https://github.com/google/flax/pull/2487

    Full Changelog: https://github.com/google/flax/compare/v0.6.1...v0.6.2

    Source code(tar.gz)
    Source code(zip)
  • v0.6.1(Oct 4, 2022)

    What's Changed

    • Updates examples/{imagenet,wmt} requirements. by @andsteing in https://github.com/google/flax/pull/2405
    • Bump rich dependency version by @yklcs in https://github.com/google/flax/pull/2407
    • Adds axis_name and axis_index_groups to LayerNorm and GroupNorm. by @copybara-service in https://github.com/google/flax/pull/2402
    • Plumb spmd_axis_name through transforms.vmap through to JAX vmap by @copybara-service in https://github.com/google/flax/pull/2398
    • Support multiple inputs in flax lifted vjp/custom_vjp by @copybara-service in https://github.com/google/flax/pull/2399
    • Explicit reexport initializers from jax by @lkhphuc in https://github.com/google/flax/pull/2409
    • Improve tabulate by @cgarciae in https://github.com/google/flax/pull/2316
    • Add path_aware_map function by @cgarciae in https://github.com/google/flax/pull/2371
    • PIL does not accept DeviceArray, so needed to use numpy. by @villeh1 in https://github.com/google/flax/pull/2427
    • Move examples to RTD by @cgarciae in https://github.com/google/flax/pull/2367
    • Simplify dynamic context by @cgarciae in https://github.com/google/flax/pull/2388
    • Remove pytype generic workaround by @jheek in https://github.com/google/flax/pull/2446
    • Add static_argnums to nn.checkpoint by @cgarciae in https://github.com/google/flax/pull/2457
    • ignore tf deprecation warning. by @copybara-service in https://github.com/google/flax/pull/2466
    • Fix Managing Parameters and State docs by @cgarciae in https://github.com/google/flax/pull/2473
    • Use gfile.listdir instead of gfile.glob in checkpointing by @IvyZX in https://github.com/google/flax/pull/2470
    • Create test matrix to speedup tests by @cgarciae in https://github.com/google/flax/pull/2458
    • Fix Conv docstrings by @cgarciae in https://github.com/google/flax/pull/2425
    • Use proper scikit-learn dependency by @cgarciae in https://github.com/google/flax/pull/2465
    • Improve attribute error msg for unbounded modules by @cgarciae in https://github.com/google/flax/pull/2440
    • Adding "count_include_pad" argument to flax.linen.pooling.avg_pool by @dslisleedh in https://github.com/google/flax/pull/2451
    • Add perturb() to allow capturing intermediate gradients by @IvyZX in https://github.com/google/flax/pull/2476
    • fix DynamicScale docstring by @cgarciae in https://github.com/google/flax/pull/2491
    • test against python 3.8 and 3.9 by @cgarciae in https://github.com/google/flax/pull/2490
    • Update version to 0.6.1 by @cgarciae in https://github.com/google/flax/pull/2494
    • Adoption cache should use WeakValueDictionary. by @levskaya in https://github.com/google/flax/pull/2495
    • FLIP: General metadata by @jheek in https://github.com/google/flax/pull/2435

    New Contributors

    • @yklcs made their first contribution in https://github.com/google/flax/pull/2407
    • @villeh1 made their first contribution in https://github.com/google/flax/pull/2427
    • @dslisleedh made their first contribution in https://github.com/google/flax/pull/2451

    Full Changelog: https://github.com/google/flax/compare/v0.6.0...v0.6.1

    Source code(tar.gz)
    Source code(zip)
  • v0.6.0(Aug 17, 2022)

    What's Changed

    • Add on_commit_callback to put the responsibility of renaming the directories on the users of the serialization library. This will also fix the GCS atomic rename issue where the users can write a success file when the commit is successful and check the existence of that file before deserialization. by @copybara-service in https://github.com/google/flax/pull/2328
    • RDT Redesign by @cgarciae in https://github.com/google/flax/pull/2177
    • Fix stale URLs to read the docs site. by @levskaya in https://github.com/google/flax/pull/2338
    • Replace all jax.tree_* calls with jax.tree_util.tree_* by @levskaya in https://github.com/google/flax/pull/2325
    • Further fix the singular-leaf checkpointing, and add tests. by @copybara-service in https://github.com/google/flax/pull/2336
    • Forward unroll argument in scan_with_axes by @sanchit-gandhi in https://github.com/google/flax/pull/2339
    • Document repo analytics by @cgarciae in https://github.com/google/flax/pull/2317
    • Improve flax basics by @cgarciae in https://github.com/google/flax/pull/2291
    • Split build into multiple jobs by @cgarciae in https://github.com/google/flax/pull/2277
    • Fix type annotation for step in training.checkpoints by @Chuxiaof in https://github.com/google/flax/pull/2343
    • Add test for writing and restoring empty checkpoints. by @copybara-service in https://github.com/google/flax/pull/2345
    • Allow all processes to checkpoint when not using GDA by @copybara-service in https://github.com/google/flax/pull/2350
    • Make importing tensorstore optional and move related type hints to comments. by @copybara-service in https://github.com/google/flax/pull/2348
    • Use jax.named_scope for name stack rather than named_call. by @copybara-service in https://github.com/google/flax/pull/2349
    • Internal change by @copybara-service in https://github.com/google/flax/pull/2356
    • Fix sphinx CI errors by @cgarciae in https://github.com/google/flax/pull/2361
    • Forward path to rewound Scope by @jheek in https://github.com/google/flax/pull/2360
    • Make link a link on the getting started by @Davidnet in https://github.com/google/flax/pull/2340
    • Fix colab & github links by @cgarciae in https://github.com/google/flax/pull/2363
    • Fix ConvTranspose with circular padding by @cgarciae in https://github.com/google/flax/pull/2364
    • Add some docstrings to the old flax.training.common_utils module. by @levskaya in https://github.com/google/flax/pull/2373
    • Correct state variable name by @Jeevesh8 in https://github.com/google/flax/pull/2369
    • Allow linen's Conv layer to operate on arbitrary-rank inputs. by @copybara-service in https://github.com/google/flax/pull/2308
    • Copies dynamic_scale.py from optim/ to training/. by @copybara-service in https://github.com/google/flax/pull/2375
    • Add option auto_flush to flax.metrics.tensorboard.SummaryWriter by @copybara-service in https://github.com/google/flax/pull/2376
    • updated supported transforms in lifting docs. by @levskaya in https://github.com/google/flax/pull/2374
    • Test docstrings with autodoc on CI by @cgarciae in https://github.com/google/flax/pull/2372
    • skip remat test that fails with autodiff by @mattjj in https://github.com/google/flax/pull/2389
    • Removes flax.optim.dynamic_scale. by @copybara-service in https://github.com/google/flax/pull/2314
    • Plumb spmd_axis_name from vmap_with_axes through to JAX vmap by @copybara-service in https://github.com/google/flax/pull/2390
    • fixed math expressions by @banda-larga in https://github.com/google/flax/pull/2392

    New Contributors

    • @sanchit-gandhi made their first contribution in https://github.com/google/flax/pull/2339
    • @Chuxiaof made their first contribution in https://github.com/google/flax/pull/2343
    • @Davidnet made their first contribution in https://github.com/google/flax/pull/2340
    • @Jeevesh8 made their first contribution in https://github.com/google/flax/pull/2369
    • @banda-larga made their first contribution in https://github.com/google/flax/pull/2392

    Full Changelog: https://github.com/google/flax/compare/v0.5.3...v0.6.0

    Source code(tar.gz)
    Source code(zip)
  • v0.5.3(Jul 25, 2022)

    What's Changed

    • Adds .pre-commit-config.yaml by @copybara-service in https://github.com/google/flax/pull/2212
    • Fix missing passthrough of nn.scan unroll arg by @jheek in https://github.com/google/flax/pull/2213
    • Test Notebooks on CI by @cgarciae in https://github.com/google/flax/pull/2166
    • Bump numpy from 1.21.4 to 1.22.0 in examples by @marcvanzee in https://github.com/google/flax/pull/2228
    • Add nn.switch by @cgarciae in https://github.com/google/flax/pull/2205
    • Fix notebooks by @cgarciae in https://github.com/google/flax/pull/2231
    • Add launch section with colab button by @cgarciae in https://github.com/google/flax/pull/2235
    • Enabling the dollarmath extension of MyST to render correctly math expresions by @WaterKnight1998 in https://github.com/google/flax/pull/2238
    • Update codediff to use sphinx-design tabs by @cgarciae in https://github.com/google/flax/pull/2204
    • Fix tests by @cgarciae in https://github.com/google/flax/pull/2253
    • Add single-host async save to save_checkpoint. by @IvyZX in https://github.com/google/flax/pull/2233
    • Add a method for detecting the use of "init" functions. by @levskaya in https://github.com/google/flax/pull/2234
    • Small fix in MNIST example by @marcvanzee in https://github.com/google/flax/pull/2258
    • Fix typos in the doc of flax.linen.Module.bind by @nalzok in https://github.com/google/flax/pull/2269
    • Add colab button to flax_basics by @cgarciae in https://github.com/google/flax/pull/2276
    • Fix type annotations by @cgarciae in https://github.com/google/flax/pull/2281
    • Exclude pseudo-fields of dataclass by @YouJiacheng in https://github.com/google/flax/pull/2199
    • Fix variable aliasing in put_variable by @jheek in https://github.com/google/flax/pull/2296
    • Update reference to tree_map to avoid deprecation warning. by @copybara-service in https://github.com/google/flax/pull/2298
    • Fix nondeterministic bug arising from sharing logic during module adoption. by @copybara-service in https://github.com/google/flax/pull/2302
    • fix ppo example typo by @fuyw in https://github.com/google/flax/pull/2306
    • Forward axis_size tot jax.vmap by @jheek in https://github.com/google/flax/pull/2310
    • cleanup: replace deprecated jax.tree_map with jax.tree_util.tree_map by @copybara-service in https://github.com/google/flax/pull/2311
    • Add GlobalDeviceArray/multihost checkpoint support to Flax. by @copybara-service in https://github.com/google/flax/pull/2287
    • 0.5.3 update version & changelog by @IvyZX in https://github.com/google/flax/pull/2330
    • Replace use of id() with global counter-based id. by @levskaya in https://github.com/google/flax/pull/2313

    New Contributors

    • @WaterKnight1998 made their first contribution in https://github.com/google/flax/pull/2238
    • @nalzok made their first contribution in https://github.com/google/flax/pull/2269
    • @YouJiacheng made their first contribution in https://github.com/google/flax/pull/2199
    • @fuyw made their first contribution in https://github.com/google/flax/pull/2306

    Full Changelog: https://github.com/google/flax/compare/v0.5.2...v0.5.3

    Source code(tar.gz)
    Source code(zip)
  • v0.5.2(Jun 21, 2022)

    What's Changed

    • Flax Basics docs: Add missing @jax.jit to mse by @rsokl in https://github.com/google/flax/pull/2181
    • add missing colon in example code by @PWhiddy in https://github.com/google/flax/pull/2188
    • New-sphinx-theme by @cgarciae in https://github.com/google/flax/pull/2171
    • Add missing PyYAML dependency by @cgarciae in https://github.com/google/flax/pull/2193
    • Improve module docs by @cgarciae in https://github.com/google/flax/pull/2167
    • Changed optimizer to optax by @berndbohnet in https://github.com/google/flax/pull/1916
    • Show repository button by @PhilipVinc in https://github.com/google/flax/pull/2206
    • Updates filterwarning in pytest.ini by @marcvanzee in https://github.com/google/flax/pull/2209
    • v0.5.2 by @cgarciae in https://github.com/google/flax/pull/2203

    New Contributors

    • @rsokl made their first contribution in https://github.com/google/flax/pull/2181
    • @PWhiddy made their first contribution in https://github.com/google/flax/pull/2188
    • @berndbohnet made their first contribution in https://github.com/google/flax/pull/1916

    Full Changelog: https://github.com/google/flax/compare/v0.5.1...v0.5.2

    Source code(tar.gz)
    Source code(zip)
  • v0.5.1(Jun 10, 2022)

    What's Changed

    • Adds flax import to summary.py by @marcvanzee in https://github.com/google/flax/pull/2138
    • Add options for fallback behavior. by @copybara-service in https://github.com/google/flax/pull/2130
    • Upgrade to modern python idioms using pyupgrade. by @levskaya in https://github.com/google/flax/pull/2132
    • Update download_dataset_metadata.sh by @mattiasmar in https://github.com/google/flax/pull/1801
    • Mark correct minimum jax version requirement by @PhilipVinc in https://github.com/google/flax/pull/2136
    • Edited contributing.md by @IvyZX in https://github.com/google/flax/pull/2151
    • Bump tensorflow from 2.8.0 to 2.8.1 in /examples/imagenet by @dependabot in https://github.com/google/flax/pull/2143
    • Bump tensorflow from 2.8.0 to 2.8.1 in /examples/wmt by @dependabot in https://github.com/google/flax/pull/2142
    • Add typehint to Module.scope by @cgarciae in https://github.com/google/flax/pull/2106
    • Correcting Mistakes In Flip Docs by @saiteja13427 in https://github.com/google/flax/pull/2140
    • Add CAUSAL padding for 1D convolution. by @copybara-service in https://github.com/google/flax/pull/2141
    • Calculate cumulative number or issues and prs by @cgarciae in https://github.com/google/flax/pull/2154
    • Improve setup instructions in contributing guide by @cgarciae in https://github.com/google/flax/pull/2155
    • Forward unroll argument in lifted scan by @jheek in https://github.com/google/flax/pull/2158
    • Improve tabulate by @cgarciae in https://github.com/google/flax/pull/2162
    • Remove unused variable from nlp_seq example by @marcvanzee in https://github.com/google/flax/pull/2163
    • Allow nn.cond, nn.while to act on bound methods. by @levskaya in https://github.com/google/flax/pull/2172
    • 0.5.1 by @cgarciae in https://github.com/google/flax/pull/2180
    • Update normalization.py by @yechengxi in https://github.com/google/flax/pull/2182

    New Contributors

    • @mattiasmar made their first contribution in https://github.com/google/flax/pull/1801
    • @PhilipVinc made their first contribution in https://github.com/google/flax/pull/2136
    • @IvyZX made their first contribution in https://github.com/google/flax/pull/2151
    • @saiteja13427 made their first contribution in https://github.com/google/flax/pull/2140
    • @yechengxi made their first contribution in https://github.com/google/flax/pull/2182

    Full Changelog: https://github.com/google/flax/compare/v0.5.0...v0.5.1

    Source code(tar.gz)
    Source code(zip)
  • v0.5.0(May 23, 2022)

    New features:

    • Added flax.jax_utils.ad_shard_unpad() by @lucasb-eyer
    • Implemented default dtype FLIP. This means the default dtype is now inferred from inputs and params rather than being hard-coded to float32. This is especially useful for dealing with complex numbers because the standard Modules will no longer truncate complex numbers to their real component by default. Instead the complex dtype is preserved by default.

    Bug fixes:

    • Fix support for JAX's experimental_name_stack.

    Breaking changes:

    • In rare cases the dtype of a layer can change due to default dtype FLIP. See the "Backward compatibility" section of the proposal for more information.
    Source code(tar.gz)
    Source code(zip)
  • v0.4.3(May 5, 2022)

  • v0.4.2(May 5, 2022)

    What's Changed

    • Canonicalize conv padding by @jheek in https://github.com/google/flax/pull/2009
    • Update ScopeParamNotFoundError message. by @melissatan in https://github.com/google/flax/pull/2013
    • Set field on dataclass transform decorator by @NeilGirdhar in https://github.com/google/flax/pull/1927
    • Don't recommend mixing setup and compact in docs. by @levskaya in https://github.com/google/flax/pull/2018
    • Clarifies optim.Adam(weight_decay) parameter. by @copybara-service in https://github.com/google/flax/pull/2016
    • Update linear regression example in Jax intro and Flax intro. by @melissatan in https://github.com/google/flax/pull/2015
    • Lifted cond by @jheek in https://github.com/google/flax/pull/2020
    • Use tree_map instead of deprecated tree_multimap by @jheek in https://github.com/google/flax/pull/2024
    • Remove tree_multimap from docs, examples, and tests by @jheek in https://github.com/google/flax/pull/2026
    • Fix bug where the linen Module state is reused. by @jheek in https://github.com/google/flax/pull/2025
    • Add getattribute with lazy setup trigger. by @levskaya in https://github.com/google/flax/pull/2028
    • Better error messages for loading checkpoints. by @copybara-service in https://github.com/google/flax/pull/2035
    • Add filterwarning for jax.tree_multimap by @marcvanzee in https://github.com/google/flax/pull/2038
    • Adds Flax logo to README by @marcvanzee in https://github.com/google/flax/pull/2036
    • Module lifecycle note by @jheek in https://github.com/google/flax/pull/1964
    • Fix linter errors in core/scope.py and core/tracers.py. by @copybara-service in https://github.com/google/flax/pull/2004
    • Handle edge-case of rate==1.0 in Dropout layer. by @levskaya in https://github.com/google/flax/pull/2055
    • Bug fixes and generalizations of nn.partitioning api. by @copybara-service in https://github.com/google/flax/pull/2062
    • Add support for JAX dynamic stack-based named_call. by @copybara-service in https://github.com/google/flax/pull/2063
    • Updates pooling docstrings by @marcvanzee in https://github.com/google/flax/pull/2064
    • Makes annotated_mnist use Optax's xent loss. by @andsteing in https://github.com/google/flax/pull/2071

    Full Changelog: https://github.com/google/flax/compare/v0.4.1...v0.4.2

    Source code(tar.gz)
    Source code(zip)
  • v0.4.1(Mar 23, 2022)

    What's Changed

    • Added locally-connected (unshared CNN) layer flax.linen.ConvLocal.
    • Improved seq2seq example: Factored our model and input pipeline code.
    • Added Optax update guide and deprecated flax.optim.
    • Added sep argument to flax.traverse_util.flatten_dict().
    • Implemented Sequential module, in flax.linen.combinators.
    Source code(tar.gz)
    Source code(zip)
  • v0.4.0(Jan 27, 2022)

    What's Changed

    • Add PReLU Activation by @isaaccorley in https://github.com/google/flax/pull/1570
    • Fix GroupNorm type hint for param num_groups. by @lkhphuc in https://github.com/google/flax/pull/1657
    • Add named_call overrides to docs by @jheek in https://github.com/google/flax/pull/1649
    • mission statement by @jheek in https://github.com/google/flax/pull/1668
    • Improves Flax Modules for RTD by @marcvanzee in https://github.com/google/flax/pull/1416
    • Add clarifying docstring for 'size' argument to prefetch_to_device's by @avital in https://github.com/google/flax/pull/1574
    • Add circular padding to flax.linen.Conv and flax.linen.ConvTranspose by @sgrigory in https://github.com/google/flax/pull/1661
    • Fix child scope rng reuse. by @jheek in https://github.com/google/flax/pull/1692
    • Numerically stable weight norm by @jheek in https://github.com/google/flax/pull/1693
    • Remove cyclic refs from scope by @jheek in https://github.com/google/flax/pull/1696
    • Add unroll to jax_utils.scan_in_dim by @ptigwe in https://github.com/google/flax/pull/1691
    • Removes rng arguments from Dropout's __call__. by @copybara-service in https://github.com/google/flax/pull/1689
    • Add error for empty scopes. by @jheek in https://github.com/google/flax/pull/1698
    • correct axis resolution in case of repeated axis in the logica axis r… by @ultrons in https://github.com/google/flax/pull/1703
    • Fix lost mutation bug in transforms on nested scopes. by @levskaya in https://github.com/google/flax/pull/1716
    • Expose put_variable function to Module. by @levskaya in https://github.com/google/flax/pull/1710
    • add eq and hash for scopes by @jheek in https://github.com/google/flax/pull/1720
    • Fixes a bug in DenseGeneral. by @copybara-service in https://github.com/google/flax/pull/1722
    • Add param_dtype argument to linen Modules by @jheek in https://github.com/google/flax/pull/1739
    • Implement custom vjp by @jheek in https://github.com/google/flax/pull/1738
    • Handle setup with transformed methods taking submodules of self. by @levskaya in https://github.com/google/flax/pull/1745
    • validate RNG key shape against jax's default by @copybara-service in https://github.com/google/flax/pull/1780
    • Adds optax update guide. by @andsteing in https://github.com/google/flax/pull/1774
    • Implement LazyRNG by @jheek in https://github.com/google/flax/pull/1723
    • make params_with_axes() work when params_axes is not mutable by @copybara-service in https://github.com/google/flax/pull/1811
    • Updates the ensembling HOWTO to Optax. by @andsteing in https://github.com/google/flax/pull/1806
    • Adds prominent scenic link to examples/README.md by @copybara-service in https://github.com/google/flax/pull/1809
    • Removes PixelCNN++ example. @copybara-service in https://github.com/google/flax/pull/1819
    • Add support for non-float32 normalization for linen normalization layers by @jheek in https://github.com/google/flax/pull/1804
    • Make Filter a Collection instead of a Container by @NeilGirdhar in https://github.com/google/flax/pull/1815
    • Removes deprecated API from RTD by @marcvanzee in https://github.com/google/flax/pull/1824

    New Contributors

    • @isaaccorley made their first contribution in https://github.com/google/flax/pull/1570
    • @lkhphuc made their first contribution in https://github.com/google/flax/pull/1657
    • @sgrigory made their first contribution in https://github.com/google/flax/pull/1661
    • @ptigwe made their first contribution in https://github.com/google/flax/pull/1691
    • @ultrons made their first contribution in https://github.com/google/flax/pull/1703
    • @dependabot made their first contribution in https://github.com/google/flax/pull/1749
    • @NeilGirdhar made their first contribution in https://github.com/google/flax/pull/1699
    • @saeta made their first contribution in https://github.com/google/flax/pull/1784
    • @melissatan made their first contribution in https://github.com/google/flax/pull/1793

    Full Changelog: https://github.com/google/flax/compare/v0.3.6...v0.4.0

    Source code(tar.gz)
    Source code(zip)
  • v0.3.6(Oct 27, 2021)

    Breaking changes:

    • Move flax.nn to flax.deprecated.nn.

    New features:

    • Add experimental checkpoint policy argument. See flax.linen.checkpoint
    • Add lifted versions of jvp and vjp.
    • Add lifted transformation for mapping variables. See flax.linen.map_variables.
    Source code(tar.gz)
    Source code(zip)
  • v0.3.5(Sep 21, 2021)

    Breaking changes:

    • You can no longer pass an int as the kernel_size for a flax.linen.Conv. Instead a type error is raised stating that a tuple/list should be provided. Stride and dilation arguments do support broadcasting a single int value now because this is not ambiguous when the kernel rank is known.
    • flax.linen.enable_named_call and flax.linen.disable_named_call now work anywhere instead of only affecting Modules constructed after the enable/disable call. Additionally, there is now flax.linen.override_named_call that provided a context manager to locally disable/enable named_call.
    • NamedTuples are no longer converted to tuples on assignment to a linen.Module. New features:
    • Flax internal stack frames are now removed from exception state traces.
    • Added flax.linen.nowrap to decorate method that should not be transformed because they are stateful.
    • Flax no longer uses implicit rank broadcasting. Thus, you can now use Flax with --jax_numpy_rank_promotion=raise.

    Bugfixes:

    • linen Modules and dataclasses made with flax.struct.dataclass or flax.struct.PyTreeNode are now correctly recognized as dataclasses by static analysis tools like PyLance. Autocomplete of constructors has been verified to work with VSCode.
    • Fixed a bug in FrozenDict which didn't allow copying dicts with reserved names.
    • Fix the serialization of named tuples. Tuple fields are no longer stored in the state dict and the named tuple class is no longer recreated (bug).
    • Mixed precision training with float16 now works correctly with the attention layers.
    • auto-generated linen Module hash, eq, repr no longer fail by default on non-init attributes.
    Source code(tar.gz)
    Source code(zip)
  • v0.3.4(May 18, 2021)

    Possibly breaking changes:

    • When calling init the 'intermediates' collection is no longer mutable. Therefore, intermediates will no longer be returned from initialization by default.
    • Don't update batch statistics during initialization.
    • When not using any non-determinism (e.g., dropout), it is not longer necessary to specify the deterministic argument in MultiHeadDotProductAttention.

    Other changes:

    • Rewrote various examples to use Optax instead of Flax optimizers (e.g., Imagenet, SST2).
    • Added an NLP text classification example (on the SST-2 dataset) to examples/sst2. that uses a bidirectional LSTM (BiLSTM) to encode the input text.
    • Added flax.training.train_state to simplify using Optax optimizers.
    • mutable argument is now available on Module.init and Module.init_with_outputs
    • Bug fix: Correctly handle non-default parameters of Linen Modules with nested inheritance.
    • Expose dot_product_attention_weights, allowing access to attention weights.
    • BatchNorm instances will behave correctly during init when called multiple times.
    • Added a more extensive "how to contribute" guide in contributing.md.
    • Add proper cache behavior for lift.jit, fixing cache misses.
    • Fix bug in Embed layer: make sure it behaves correctly when embedding is np.array.
    • Fix linen.Module for deep inheritance chains.
    • Fix bug in DenseGeneral: correctly expand bias to account for batch & noncontracting dimensions.
    • Allow Flax lifted transforms to work on partially applied Modules.
    • Make MultiOptimizer use apply_gradient instead of apply_param_gradient.
    Source code(tar.gz)
    Source code(zip)
  • v0.3.3(Mar 31, 2021)

    Possible breaking changes:

    • Bug Fix: Disallow modifying attributes in Modules after they are initialized.
    • Raise an error when saving a checkpoint which has a smaller step than the latest checkpoint already saved.
    • MultiOptimizer now rejects the case where multiple sub optimizers update the same parameter.

    Other changes:

    • Added custom error classes to many Linen errors. See: https://flax.readthedocs.io/en/latest/flax.errors.html
    • Adds Module.bind for binding variables and RNGs to an interactive Module.
    • Adds nn.apply and nn.init for transforming arbitrary functions that take a linen.Module as their first argument.
    • Add option to overwrite existing checkpoints in save_checkpoint.
    • Remove JAX omnistaging check for forward compatibility.
    • Pathlib compatibility for checkpoint paths.
    • is_leaf argument in traverse_util.flatten_dict
    Source code(tar.gz)
    Source code(zip)
  • v0.3.2(Mar 1, 2021)

  • 0.3.1(Feb 26, 2021)

    Many improvements to Linen, and the old flax.nn is officially reprecated!

    Notably, there's a clean API for extracting intermediates from modules defined using @nn.compact, a more ergonomic API for using Batch Norm and Dropout in modules defined using setup, support for MultiOptimizer with Linen, and multiple safety, performance and error message improvements.

    See the CHANGELOG for more details

    Source code(tar.gz)
    Source code(zip)
  • 0.3.0(Dec 8, 2020)

  • v0.2.2(Oct 1, 2020)

  • v0.1.0rc2(Mar 18, 2020)

Owner
Google
Google ❤️ Open Source
Google
Very deep VAEs in JAX/Flax

Very Deep VAEs in JAX/Flax Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on I

Jamie Townsend 42 Dec 12, 2022
Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.

Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.

Matthias Wright 169 Dec 26, 2022
Standalone pre-training recipe with JAX+Flax

Sabertooth Sabertooth is standalone pre-training recipe based on JAX+Flax, with data pipelines implemented in Rust. It runs on CPU, GPU, and/or TPU, b

Nikita Kitaev 26 Nov 28, 2022
Local Attention - Flax module for Jax

Local Attention - Flax Autoregressive Local Attention - Flax module for Jax Install $ pip install local-attention-flax Usage from jax import random fr

Phil Wang 16 Jun 16, 2022
Implementation of FitVid video prediction model in JAX/Flax.

FitVid Video Prediction Model Implementation of FitVid video prediction model in JAX/Flax. If you find this code useful, please cite it in your paper:

Google Research 62 Nov 25, 2022
Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax

Clockwork VAEs in JAX/Flax Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax, ported

Julius Kunze 26 Oct 5, 2022
Reimplementation of the paper "Attention, Learn to Solve Routing Problems!" in jax/flax.

JAX + Attention Learn To Solve Routing Problems Reinplementation of the paper Attention, Learn to Solve Routing Problems! using Jax and Flax. Fully su

Gabriela Surita 7 Dec 1, 2022
JAXDL: JAX (Flax) Deep Learning Library

JAXDL: JAX (Flax) Deep Learning Library Simple and clean JAX/Flax deep learning algorithm implementations: Soft-Actor-Critic (arXiv:1812.05905) Transf

Patrick Hart 4 Nov 27, 2022
Advantage Actor Critic (A2C): jax + flax implementation

Advantage Actor Critic (A2C): jax + flax implementation Current version supports only environments with continious action spaces and was tested on muj

Andrey 3 Jan 23, 2022
A Neural Net Training Interface on TensorFlow, with focus on speed + flexibility

Tensorpack is a neural network training interface based on TensorFlow. Features: It's Yet Another TF high-level API, with speed, and flexibility built

Tensorpack 6.2k Jan 1, 2023
A Neural Net Training Interface on TensorFlow, with focus on speed + flexibility

Tensorpack is a neural network training interface based on TensorFlow. Features: It's Yet Another TF high-level API, with speed, and flexibility built

Tensorpack 6.2k Jan 9, 2023
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
JAX-based neural network library

Haiku: Sonnet for JAX Overview | Why Haiku? | Quickstart | Installation | Examples | User manual | Documentation | Citing Haiku What is Haiku? Haiku i

DeepMind 2.3k Jan 4, 2023
Evolving neural network parameters in JAX.

Evolving Neural Networks in JAX This repository holds code displaying techniques for applying evolutionary network training strategies in JAX. Each sc

Trevor Thackston 6 Feb 12, 2022
TorchIO is a Medical image preprocessing and augmentation toolkit for deep learning. Part of the PyTorch Ecosystem.

Medical image preprocessing and augmentation toolkit for deep learning. Part of the PyTorch Ecosystem.

Fernando Pérez-García 1.6k Jan 6, 2023
RoBERTa Marathi Language model trained from scratch during huggingface 🤗 x flax community week

RoBERTa base model for Marathi Language (मराठी भाषा) Pretrained model on Marathi language using a masked language modeling (MLM) objective. RoBERTa wa

Nipun Sadvilkar 23 Oct 19, 2022
Pacman-AI - AI project designed by UC Berkeley. Designed reflex and minimax agents for the game Pacman.

Pacman AI Jussi Doherty CAP 4601 - Introduction to Artificial Intelligence - Fall 2020 Python version 3.0+ Source of this project This repo contains a

Jussi Doherty 1 Jan 3, 2022