Flaxformer: transformer architectures in JAX/Flax

Overview

Flaxformer: transformer architectures in JAX/Flax

Flaxformer is a transformer library for primarily NLP and multimodal research at Google. It is used for many NLP research use cases, providing both off-the-shelf BERT and T5 models, and several research projects built on shared components.

General library goals

The Flaxformer library aims to provide transformer models that are:

  • High performance: Models are annotated for use with the PJIT API, enabling them to be used for training the largest models.
  • Reusable: Components have self-contained configuration, and high-level modules like encoders, decoders, etc. don't make too many assumptions about what their sub-modules look like.
  • Tested: We aim to employ a reasonable amount of unit testing, and write tests whenever bugs are encountered. However no guarantees are provided.
  • Maintainble: We have created a versioning strategy for our modules so code refactors can take place which alter the module structure. This is tricky in Flax, because Flax generates a tree of parameters based on the exact module structure. Our approach lets us maintain compatibility with previously trained model checkpoints.

Code locations

Modeling components such as dense attention, layer norms, and MLP blocks can be found in the components/ directory.

Higher-level classes which combine these components can be found in the architectures/ directory. The current architecture file for the T5 family of models is architectures/t5/t5_architecture.py; this is a mid-level API requiring sub-components to be configured. A high-level starting point, exposing fewer parameters, is architectures/t5/t5_1_1.py.

Relationship to other codebases

Flaxformer is primarily used by other research projects, in particular T5X. We hope to release examples demonstrating the integration of these codebases soon.

If you would like to use Flaxformer independently of T5X, please see the unit tests for examples instantiating the models. In the medium-term future, we hope to provide more stand-alone examples of Flaxformer use.

Contributions

Unfortunately, we cannot accept contributions to the Flaxformer repo at this time, so any pull requests will be automatically closed - but please file issues as needed!

Installing dependencies and running tests

After checking out this repository, in its root directory, you can install it along with test dependencies by running,

pip3 install '.[testing]'

If you like, you can run the tests from pytest with the following invocation,

python3 -m pytest

Uninstalling

If you need to uninstall Flaxformer, please run,

pip3 uninstall flaxformer

Troubleshooting

Flax deps

Flaxformer is developed in close collaboration with the Flax team. There may be bugs if your Flax version is not up to date. To install the latest version from GitHub, please run,

pip3 uninstall flax
pip3 install git+https://github.com/google/flax

Note

Flaxformer is a project maintained by a team in Google Research. It is not an official Google product.

You might also like...
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

 Vision Transformer and MLP-Mixer Architectures
Vision Transformer and MLP-Mixer Architectures

Vision Transformer and MLP-Mixer Architectures Update (2.7.2021): Added the "When Vision Transformers Outperform ResNets..." paper, and SAM (Sharpness

RoBERTa Marathi Language model trained from scratch during huggingface 🤗 x  flax community week
RoBERTa Marathi Language model trained from scratch during huggingface 🤗 x flax community week

RoBERTa base model for Marathi Language (मराठी भाषा) Pretrained model on Marathi language using a masked language modeling (MLM) objective. RoBERTa wa

Aggragrating Nested Transformer Official Jax Implementation

NesT is a simple method, which aggragrates nested local transformers on image blocks. The idea makes vision transformers attain better accuracy, data efficiency, and convergence on the ImageNet benchmark. NesT can be scaled to small datasets to match convnet accuracy.

VSR-Transformer - This paper proposes a new Transformer for video super-resolution (called VSR-Transformer).
VSR-Transformer - This paper proposes a new Transformer for video super-resolution (called VSR-Transformer).

VSR-Transformer By Jiezhang Cao, Yawei Li, Kai Zhang, Luc Van Gool This paper proposes a new Transformer for video super-resolution (called VSR-Transf

An implementation of
An implementation of "MixHop: Higher-Order Graph Convolutional Architectures via Sparsified Neighborhood Mixing" (ICML 2019).

MixHop and N-GCN ⠀ A PyTorch implementation of "MixHop: Higher-Order Graph Convolutional Architectures via Sparsified Neighborhood Mixing" (ICML 2019)

Open source implementation of AceNAS: Learning to Rank Ace Neural Architectures with Weak Supervision of Weight Sharing

AceNAS This repo is the experiment code of AceNAS, and is not considered as an official release. We are working on integrating AceNAS as a built-in st

Keras like implementation of Deep Learning architectures from scratch using numpy.

Mini-Keras Keras like implementation of Deep Learning architectures from scratch using numpy. How to contribute? The project contains implementations

YOLOv5 🚀 is a family of object detection architectures and models pretrained on the COCO dataset
YOLOv5 🚀 is a family of object detection architectures and models pretrained on the COCO dataset

YOLOv5 🚀 is a family of object detection architectures and models pretrained on the COCO dataset, and represents Ultralytics open-source research int

Comments
  • How to run a simple inference on Switch base

    How to run a simple inference on Switch base

    Hi there!

    First of all, awesome work on Switch transformers 🔥 I was wondering if there is a simple example script / commands to do a simple inference using switch_base model? Thanks !

    opened by younesbelkada 1
  • ModuleNotFoundError: No module named 'flaxformer.architectures.longt5'

    ModuleNotFoundError: No module named 'flaxformer.architectures.longt5'

    Installing the flaxformer repo using pip results in the above error when importing flaxformer.architectures.longt5. I believe the issue is caused by the missing __init__.py file in the respective directory.

    opened by nradwan 1
  • Failed to map logical axes for target/decoder/logits...

    Failed to map logical axes for target/decoder/logits...

    I am getting the following error when fine-tuning longT5 model:

    ` ValueError Traceback (most recent call last) Input In [16], in <cell line: 21>() 14 gin_utils.parse_gin_flags( 15 # User-provided gin paths take precedence if relative paths conflict. 16 FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, 17 FLAGS.gin_file, 18 FLAGS.gin_bindings) 19 train_using_gin() ---> 21 gin_utils.run(main_train)

    File ~/Downloads/t5x/t5x/gin_utils.py:105, in run(main) 103 def run(main): 104 """Wrapper for app.run that rewrites gin args before parsing.""" --> 105 app.run( 106 main, 107 flags_parser=lambda a: app.parse_flags_with_usage(rewrite_gin_args(a)))

    File ~/opt/miniconda3/lib/python3.9/site-packages/absl/app.py:312, in run(main, argv, flags_parser) 310 callback() 311 try: --> 312 _run_main(main, args) 313 except UsageError as error: 314 usage(shorthelp=True, detailed_error=error, exitcode=error.exitcode)

    File ~/opt/miniconda3/lib/python3.9/site-packages/absl/app.py:258, in _run_main(main, argv) 256 sys.exit(retval) 257 else: --> 258 sys.exit(main(argv))

    Input In [15], in main_train(argv) 1 def main_train(argv: Sequence[str]): 2 """Wrapper for pdb post mortems.""" ----> 3 _main(argv)

    Input In [16], in _main(argv) 12 train_using_gin = gin.configurable(train) 14 gin_utils.parse_gin_flags( 15 # User-provided gin paths take precedence if relative paths conflict. 16 FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, 17 FLAGS.gin_file, 18 FLAGS.gin_bindings) ---> 19 train_using_gin()

    File ~/opt/miniconda3/lib/python3.9/site-packages/gin/config.py:1605, in _make_gin_wrapper..gin_wrapper(*args, **kwargs) 1603 scope_info = " in scope '{}'".format(scope_str) if scope_str else '' 1604 err_str = err_str.format(name, fn_or_cls, scope_info) -> 1605 utils.augment_exception_message_and_reraise(e, err_str)

    File ~/opt/miniconda3/lib/python3.9/site-packages/gin/utils.py:41, in augment_exception_message_and_reraise(exception, message) 39 proxy = ExceptionProxy() 40 ExceptionProxy.qualname = type(exception).qualname ---> 41 raise proxy.with_traceback(exception.traceback) from None

    File ~/opt/miniconda3/lib/python3.9/site-packages/gin/config.py:1582, in _make_gin_wrapper..gin_wrapper(*args, **kwargs) 1579 new_kwargs.update(kwargs) 1581 try: -> 1582 return fn(*new_args, **new_kwargs) 1583 except Exception as e: # pylint: disable=broad-except 1584 err_str = ''

    Input In [7], in train(model, train_dataset_cfg, train_eval_dataset_cfg, infer_eval_dataset_cfg, checkpoint_cfg, partitioner, trainer_cls, model_dir, total_steps, eval_steps, eval_period, stats_period, random_seed, use_hardware_rng, summarize_config_fn, inference_evaluator_cls, get_dataset_fn, concurrent_metrics, actions, train_eval_get_dataset_fn, run_eval_before_training, use_gda) 224 input_types = { 225 k: v.dtype.as_numpy_dtype() for k, v in train_ds.element_spec.items() 226 } 227 init_or_restore_tick = time.time() --> 228 train_state_initializer = utils.TrainStateInitializer( 229 optimizer_def=model.optimizer_def, 230 init_fn=model.get_initial_variables, 231 input_shapes=input_shapes, 232 input_types=input_types, 233 partitioner=partitioner) 234 # 3. From scratch using init_fn. 235 train_state = train_state_initializer.from_checkpoint_or_scratch( 236 restore_cfgs, init_rng=init_rng, ds_iter=checkpointable_train_iter)

    File ~/Downloads/t5x/t5x/utils.py:368, in TrainStateInitializer.init(self, optimizer_def, init_fn, input_shapes, partitioner, input_types) 365 self._partitioner = partitioner 366 self.global_train_state_shape = jax.eval_shape( 367 initialize_train_state, rng=jax.random.PRNGKey(0)) --> 368 self.train_state_axes = partitioner.get_mesh_axes( 369 self.global_train_state_shape) 370 self._initialize_train_state = initialize_train_state 372 # Currently scanned layers require passing annotations through to the 373 # point of the scan transformation to resolve an XLA SPMD issue. 374 375 # init_fn is always(?) equal to model.get_initial_variables, fetch the model 376 # instance from the bound method.

    File ~/Downloads/t5x/t5x/partitioning.py:892, in PjitPartitioner.get_mesh_axes(self, train_state) 888 raise ValueError(f'Failed to map logical axes for {param_name}') from e 890 flat_logical_axes = traverse_util.flatten_dict( 891 logical_axes.state_dict(), keep_empty_nodes=True, sep='/') --> 892 flat_mesh_axes = { 893 k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items() 894 } 896 return logical_axes.restore_state( 897 traverse_util.unflatten_dict(flat_mesh_axes, sep='/'))

    File ~/Downloads/t5x/t5x/partitioning.py:893, in (.0) 888 raise ValueError(f'Failed to map logical axes for {param_name}') from e 890 flat_logical_axes = traverse_util.flatten_dict( 891 logical_axes.state_dict(), keep_empty_nodes=True, sep='/') 892 flat_mesh_axes = { --> 893 k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items() 894 } 896 return logical_axes.restore_state( 897 traverse_util.unflatten_dict(flat_mesh_axes, sep='/'))

    File ~/Downloads/t5x/t5x/partitioning.py:888, in PjitPartitioner.get_mesh_axes.._logical_to_mesh_axes(param_name, logical_axes) 885 return flax_partitioning.logical_to_mesh_axes(logical_axes, 886 self._logical_axis_rules) 887 except ValueError as e: --> 888 raise ValueError(f'Failed to map logical axes for {param_name}') from e

    ValueError: Failed to map logical axes for target/decoder/logits_dense/kernel In call to configurable 'train' (<function train at 0x2b751e160>)

    `

    opened by ibulu 0
  • BERT Pre-Training

    BERT Pre-Training

    Hi,

    I would like to test this flaxformer library to pre-train a BERT from scratch.

    What is necessary to create the pre-training data (mlm with duplication factor) on an own corpus with an own created wordpiece-based vocab.

    How can the pre-training started.

    I'm really excited to test it, any help is highly appreciated!

    opened by stefan-it 0
Owner
Google
Google ❤️ Open Source
Google
Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.

Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.

Matthias Wright 169 Dec 26, 2022
Standalone pre-training recipe with JAX+Flax

Sabertooth Sabertooth is standalone pre-training recipe based on JAX+Flax, with data pipelines implemented in Rust. It runs on CPU, GPU, and/or TPU, b

Nikita Kitaev 26 Nov 28, 2022
Local Attention - Flax module for Jax

Local Attention - Flax Autoregressive Local Attention - Flax module for Jax Install $ pip install local-attention-flax Usage from jax import random fr

Phil Wang 16 Jun 16, 2022
Implementation of FitVid video prediction model in JAX/Flax.

FitVid Video Prediction Model Implementation of FitVid video prediction model in JAX/Flax. If you find this code useful, please cite it in your paper:

Google Research 62 Nov 25, 2022
Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax

Clockwork VAEs in JAX/Flax Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax, ported

Julius Kunze 26 Oct 5, 2022
Reimplementation of the paper "Attention, Learn to Solve Routing Problems!" in jax/flax.

JAX + Attention Learn To Solve Routing Problems Reinplementation of the paper Attention, Learn to Solve Routing Problems! using Jax and Flax. Fully su

Gabriela Surita 7 Dec 1, 2022
JAXDL: JAX (Flax) Deep Learning Library

JAXDL: JAX (Flax) Deep Learning Library Simple and clean JAX/Flax deep learning algorithm implementations: Soft-Actor-Critic (arXiv:1812.05905) Transf

Patrick Hart 4 Nov 27, 2022
Advantage Actor Critic (A2C): jax + flax implementation

Advantage Actor Critic (A2C): jax + flax implementation Current version supports only environments with continious action spaces and was tested on muj

Andrey 3 Jan 23, 2022
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022