Flaxformer: transformer architectures in JAX/Flax

Overview

Flaxformer: transformer architectures in JAX/Flax

Flaxformer is a transformer library for primarily NLP and multimodal research at Google. It is used for many NLP research use cases, providing both off-the-shelf BERT and T5 models, and several research projects built on shared components.

General library goals

The Flaxformer library aims to provide transformer models that are:

  • High performance: Models are annotated for use with the PJIT API, enabling them to be used for training the largest models.
  • Reusable: Components have self-contained configuration, and high-level modules like encoders, decoders, etc. don't make too many assumptions about what their sub-modules look like.
  • Tested: We aim to employ a reasonable amount of unit testing, and write tests whenever bugs are encountered. However no guarantees are provided.
  • Maintainble: We have created a versioning strategy for our modules so code refactors can take place which alter the module structure. This is tricky in Flax, because Flax generates a tree of parameters based on the exact module structure. Our approach lets us maintain compatibility with previously trained model checkpoints.

Code locations

Modeling components such as dense attention, layer norms, and MLP blocks can be found in the components/ directory.

Higher-level classes which combine these components can be found in the architectures/ directory. The current architecture file for the T5 family of models is architectures/t5/t5_architecture.py; this is a mid-level API requiring sub-components to be configured. A high-level starting point, exposing fewer parameters, is architectures/t5/t5_1_1.py.

Relationship to other codebases

Flaxformer is primarily used by other research projects, in particular T5X. We hope to release examples demonstrating the integration of these codebases soon.

If you would like to use Flaxformer independently of T5X, please see the unit tests for examples instantiating the models. In the medium-term future, we hope to provide more stand-alone examples of Flaxformer use.

Contributions

Unfortunately, we cannot accept contributions to the Flaxformer repo at this time, so any pull requests will be automatically closed - but please file issues as needed!

Installing dependencies and running tests

After checking out this repository, in its root directory, you can install it along with test dependencies by running,

pip3 install '.[testing]'

If you like, you can run the tests from pytest with the following invocation,

python3 -m pytest

Uninstalling

If you need to uninstall Flaxformer, please run,

pip3 uninstall flaxformer

Troubleshooting

Flax deps

Flaxformer is developed in close collaboration with the Flax team. There may be bugs if your Flax version is not up to date. To install the latest version from GitHub, please run,

pip3 uninstall flax
pip3 install git+https://github.com/google/flax

Note

Flaxformer is a project maintained by a team in Google Research. It is not an official Google product.

You might also like...
Learning Spatio-Temporal Transformer for Visual Tracking
Learning Spatio-Temporal Transformer for Visual Tracking

STARK The official implementation of the paper Learning Spatio-Temporal Transformer for Visual Tracking Highlights The strongest performances Tracker

Transformer related optimization, including BERT, GPT
Transformer related optimization, including BERT, GPT

This repository provides a script and recipe to run the highly optimized transformer-based encoder and decoder component, and it is tested and maintained by NVIDIA.

A fast and easy implementation of Transformer with PyTorch.

FasySeq FasySeq is a shorthand as a Fast and easy sequential modeling toolkit. It aims to provide a seq2seq model to researchers and developers, which

Code release for
Code release for "COTR: Correspondence Transformer for Matching Across Images"

COTR: Correspondence Transformer for Matching Across Images This repository contains the inference code for COTR. We plan to release the training code

Reformer, the efficient Transformer, in Pytorch
Reformer, the efficient Transformer, in Pytorch

Reformer, the Efficient Transformer, in Pytorch This is a Pytorch implementation of Reformer https://openreview.net/pdf?id=rkgNKkHtvB It includes LSH

 Segmenter - Transformer for Semantic Segmentation
Segmenter - Transformer for Semantic Segmentation

Segmenter - Transformer for Semantic Segmentation

A look-ahead multi-entity Transformer for modeling coordinated agents.
A look-ahead multi-entity Transformer for modeling coordinated agents.

baller2vec++ This is the repository for the paper: Michael A. Alcorn and Anh Nguyen. baller2vec++: A Look-Ahead Multi-Entity Transformer For Modeling

PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer
PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer

Cross-Covariance Image Transformer (XCiT) PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer L

Speech Recognition for Uyghur using Speech transformer

Speech Recognition for Uyghur using Speech transformer Training: this model using CTC loss and Cross Entropy loss for training. Download pretrained mo

Comments
  • How to run a simple inference on Switch base

    How to run a simple inference on Switch base

    Hi there!

    First of all, awesome work on Switch transformers 🔥 I was wondering if there is a simple example script / commands to do a simple inference using switch_base model? Thanks !

    opened by younesbelkada 1
  • ModuleNotFoundError: No module named 'flaxformer.architectures.longt5'

    ModuleNotFoundError: No module named 'flaxformer.architectures.longt5'

    Installing the flaxformer repo using pip results in the above error when importing flaxformer.architectures.longt5. I believe the issue is caused by the missing __init__.py file in the respective directory.

    opened by nradwan 1
  • Failed to map logical axes for target/decoder/logits...

    Failed to map logical axes for target/decoder/logits...

    I am getting the following error when fine-tuning longT5 model:

    ` ValueError Traceback (most recent call last) Input In [16], in <cell line: 21>() 14 gin_utils.parse_gin_flags( 15 # User-provided gin paths take precedence if relative paths conflict. 16 FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, 17 FLAGS.gin_file, 18 FLAGS.gin_bindings) 19 train_using_gin() ---> 21 gin_utils.run(main_train)

    File ~/Downloads/t5x/t5x/gin_utils.py:105, in run(main) 103 def run(main): 104 """Wrapper for app.run that rewrites gin args before parsing.""" --> 105 app.run( 106 main, 107 flags_parser=lambda a: app.parse_flags_with_usage(rewrite_gin_args(a)))

    File ~/opt/miniconda3/lib/python3.9/site-packages/absl/app.py:312, in run(main, argv, flags_parser) 310 callback() 311 try: --> 312 _run_main(main, args) 313 except UsageError as error: 314 usage(shorthelp=True, detailed_error=error, exitcode=error.exitcode)

    File ~/opt/miniconda3/lib/python3.9/site-packages/absl/app.py:258, in _run_main(main, argv) 256 sys.exit(retval) 257 else: --> 258 sys.exit(main(argv))

    Input In [15], in main_train(argv) 1 def main_train(argv: Sequence[str]): 2 """Wrapper for pdb post mortems.""" ----> 3 _main(argv)

    Input In [16], in _main(argv) 12 train_using_gin = gin.configurable(train) 14 gin_utils.parse_gin_flags( 15 # User-provided gin paths take precedence if relative paths conflict. 16 FLAGS.gin_search_paths + _DEFAULT_GIN_SEARCH_PATHS, 17 FLAGS.gin_file, 18 FLAGS.gin_bindings) ---> 19 train_using_gin()

    File ~/opt/miniconda3/lib/python3.9/site-packages/gin/config.py:1605, in _make_gin_wrapper..gin_wrapper(*args, **kwargs) 1603 scope_info = " in scope '{}'".format(scope_str) if scope_str else '' 1604 err_str = err_str.format(name, fn_or_cls, scope_info) -> 1605 utils.augment_exception_message_and_reraise(e, err_str)

    File ~/opt/miniconda3/lib/python3.9/site-packages/gin/utils.py:41, in augment_exception_message_and_reraise(exception, message) 39 proxy = ExceptionProxy() 40 ExceptionProxy.qualname = type(exception).qualname ---> 41 raise proxy.with_traceback(exception.traceback) from None

    File ~/opt/miniconda3/lib/python3.9/site-packages/gin/config.py:1582, in _make_gin_wrapper..gin_wrapper(*args, **kwargs) 1579 new_kwargs.update(kwargs) 1581 try: -> 1582 return fn(*new_args, **new_kwargs) 1583 except Exception as e: # pylint: disable=broad-except 1584 err_str = ''

    Input In [7], in train(model, train_dataset_cfg, train_eval_dataset_cfg, infer_eval_dataset_cfg, checkpoint_cfg, partitioner, trainer_cls, model_dir, total_steps, eval_steps, eval_period, stats_period, random_seed, use_hardware_rng, summarize_config_fn, inference_evaluator_cls, get_dataset_fn, concurrent_metrics, actions, train_eval_get_dataset_fn, run_eval_before_training, use_gda) 224 input_types = { 225 k: v.dtype.as_numpy_dtype() for k, v in train_ds.element_spec.items() 226 } 227 init_or_restore_tick = time.time() --> 228 train_state_initializer = utils.TrainStateInitializer( 229 optimizer_def=model.optimizer_def, 230 init_fn=model.get_initial_variables, 231 input_shapes=input_shapes, 232 input_types=input_types, 233 partitioner=partitioner) 234 # 3. From scratch using init_fn. 235 train_state = train_state_initializer.from_checkpoint_or_scratch( 236 restore_cfgs, init_rng=init_rng, ds_iter=checkpointable_train_iter)

    File ~/Downloads/t5x/t5x/utils.py:368, in TrainStateInitializer.init(self, optimizer_def, init_fn, input_shapes, partitioner, input_types) 365 self._partitioner = partitioner 366 self.global_train_state_shape = jax.eval_shape( 367 initialize_train_state, rng=jax.random.PRNGKey(0)) --> 368 self.train_state_axes = partitioner.get_mesh_axes( 369 self.global_train_state_shape) 370 self._initialize_train_state = initialize_train_state 372 # Currently scanned layers require passing annotations through to the 373 # point of the scan transformation to resolve an XLA SPMD issue. 374 375 # init_fn is always(?) equal to model.get_initial_variables, fetch the model 376 # instance from the bound method.

    File ~/Downloads/t5x/t5x/partitioning.py:892, in PjitPartitioner.get_mesh_axes(self, train_state) 888 raise ValueError(f'Failed to map logical axes for {param_name}') from e 890 flat_logical_axes = traverse_util.flatten_dict( 891 logical_axes.state_dict(), keep_empty_nodes=True, sep='/') --> 892 flat_mesh_axes = { 893 k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items() 894 } 896 return logical_axes.restore_state( 897 traverse_util.unflatten_dict(flat_mesh_axes, sep='/'))

    File ~/Downloads/t5x/t5x/partitioning.py:893, in (.0) 888 raise ValueError(f'Failed to map logical axes for {param_name}') from e 890 flat_logical_axes = traverse_util.flatten_dict( 891 logical_axes.state_dict(), keep_empty_nodes=True, sep='/') 892 flat_mesh_axes = { --> 893 k: _logical_to_mesh_axes(k, v) for k, v in flat_logical_axes.items() 894 } 896 return logical_axes.restore_state( 897 traverse_util.unflatten_dict(flat_mesh_axes, sep='/'))

    File ~/Downloads/t5x/t5x/partitioning.py:888, in PjitPartitioner.get_mesh_axes.._logical_to_mesh_axes(param_name, logical_axes) 885 return flax_partitioning.logical_to_mesh_axes(logical_axes, 886 self._logical_axis_rules) 887 except ValueError as e: --> 888 raise ValueError(f'Failed to map logical axes for {param_name}') from e

    ValueError: Failed to map logical axes for target/decoder/logits_dense/kernel In call to configurable 'train' (<function train at 0x2b751e160>)

    `

    opened by ibulu 0
  • BERT Pre-Training

    BERT Pre-Training

    Hi,

    I would like to test this flaxformer library to pre-train a BERT from scratch.

    What is necessary to create the pre-training data (mlm with duplication factor) on an own corpus with an own created wordpiece-based vocab.

    How can the pre-training started.

    I'm really excited to test it, any help is highly appreciated!

    opened by stefan-it 0
Owner
Google
Google ❤️ Open Source
Google
xFormers is a modular and field agnostic library to flexibly generate transformer architectures by interoperable and optimized building blocks.

Description xFormers is a modular and field agnostic library to flexibly generate transformer architectures by interoperable and optimized building bl

Facebook Research 2.3k Jan 8, 2023
Tutorial to pretrain & fine-tune a 🤗 Flax T5 model on a TPUv3-8 with GCP

Pretrain and Fine-tune a T5 model with Flax on GCP This tutorial details how pretrain and fine-tune a FlaxT5 model from HuggingFace using a TPU VM ava

Gabriele Sarti 41 Nov 18, 2022
Sequence model architectures from scratch in PyTorch

This repository implements a variety of sequence model architectures from scratch in PyTorch. Effort has been put to make the code well structured so that it can serve as learning material. The training loop implements the learner design pattern from fast.ai in pure PyTorch, with access to the loop provided through callbacks. Detailed logging and graphs are also provided with python logging and wandb. Additional implementations will be added.

Brando Koch 11 Mar 28, 2022
🤗 Transformers: State-of-the-art Natural Language Processing for Pytorch, TensorFlow, and JAX.

English | 简体中文 | 繁體中文 State-of-the-art Natural Language Processing for Jax, PyTorch and TensorFlow ?? Transformers provides thousands of pretrained mo

Hugging Face 77.2k Jan 3, 2023
Trankit is a Light-Weight Transformer-based Python Toolkit for Multilingual Natural Language Processing

Trankit: A Light-Weight Transformer-based Python Toolkit for Multilingual Natural Language Processing Trankit is a light-weight Transformer-based Pyth

null 652 Jan 6, 2023
Transformer-based Text Auto-encoder (T-TA) using TensorFlow 2.

T-TA (Transformer-based Text Auto-encoder) This repository contains codes for Transformer-based Text Auto-encoder (T-TA, paper: Fast and Accurate Deep

Jeong Ukjae 13 Dec 13, 2022
Code for the paper "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer"

T5: Text-To-Text Transfer Transformer The t5 library serves primarily as code for reproducing the experiments in Exploring the Limits of Transfer Lear

Google Research 4.6k Jan 1, 2023
Code for the paper "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer"

T5: Text-To-Text Transfer Transformer The t5 library serves primarily as code for reproducing the experiments in Exploring the Limits of Transfer Lear

Google Research 3.2k Feb 17, 2021
Code associated with the "Data Augmentation using Pre-trained Transformer Models" paper

Data Augmentation using Pre-trained Transformer Models Code associated with the Data Augmentation using Pre-trained Transformer Models paper Code cont

null 44 Dec 31, 2022
Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

Phil Wang 5k Jan 2, 2023