Model parallel transformers in JAX and Haiku

Overview

Table of contents

  1. Mesh Transformer JAX
    1. Updates
  2. Pretrained Models
    1. GPT-J-6B
      1. Links
      2. Acknowledgments
      3. License
      4. Model Details
      5. Zero-Shot Evaluations
  3. Architecture and Usage
    1. Fine-tuning
    2. JAX Dependency
  4. TODO

Mesh Transformer JAX

A haiku library using the xmap/pjit operators in JAX for model parallelism of transformers.

The parallelism scheme is similar to the original Megatron-LM, which is efficient on TPUs due to the high speed 2d mesh network. There is also an experimental model version which implements ZeRo style sharding.

This library is designed for scalability up to approximately 40B parameters on TPUv3s, beyond which different parallelism strategies should be used. See other implementations such as GPT-NeoX or DeepSpeed for that.

One future direction for research is integrating this codebase with swarm-jax, to achieve further scalability with pipeline parallelism.

Updates

12-07-21: Added guide to fine tuning

Pretrained Models

GPT-J-6B

A 6 billion parameter, autoregressive text generation model trained on The Pile.

Links

Slim weights (bf16 weights only, for inference, 9GB)

Full weights (including optimizer params, 61GB)

Colab demo

Web demo

Aran's blog post

Acknowledgments

This project would not have been possible without compute generously provided by the TPU Research Cloud with assistance from EleutherAI.

Thanks to the Cloud TPU team at Google for providing early access to the Cloud TPU VM alpha (now publicly available!)

Thanks to everyone who have helped out one way or another (listed alphabetically):

  • Aran Komatsuzaki for advice with experiment design and writing the blog posts.
  • James Bradbury for valuable assistance with debugging JAX issues.
  • Janko Prester for creating the web demo frontend.
  • Laurence Golding for adding some features to the web demo.
  • Leo Gao for running zero shot evaluations for the baseline models for the table.

License

The weights of GPT-J-6B are licensed under version 2.0 of the Apache License.

Model Details

Hyperparameter Value
n_parameters 6,053,381,344
n_layers 28*
d_model 4,096
d_ff 16,384
n_heads 16
d_head 256
n_ctx 2,048
n_vocab 50,257 (same tokenizer as GPT-2/3)
position encoding Rotary position encodings (RoPE)
RoPE dimensions 64

* each layer consists of one feedforward block and one self attention block

The model consists of 28 layers with a model dimension of 4096, and a feedforward dimension of 16384. The model dimension is split into 16 heads, each with a dimension of 256. Rotary position encodings (RoPE) was applied to 64 dimensions of each head. The model is trained with a tokenization vocabulary of 50257, using the same set of BPEs as GPT-2/GPT-3.

Zero-Shot Evaluations

Models roughly sorted by performance, or by FLOPs if not available.

Model Weights Training FLOPs LAMBADA PPL ↓ LAMBADA Acc ↑ Winogrande ↑ Hellaswag ↑ PIQA ↑ Dataset Size (GB)
Chance 0 ~a lot ~0% 50% 25% 25% 0
GPT-3-Ada‡ ----- 9.95 51.6% 52.9% 43.4% 70.5% -----
GPT-2-1.5B ----- 10.63 51.21% 59.4% 50.9% 70.8% 40
GPTNeo-1.3B‡ 3.0e21 7.50 57.2% 55.0% 48.9% 71.1% 825
Megatron-2.5B* 2.4e21 ----- 61.7% ----- ----- ----- 174
GPTNeo-2.7B‡ 6.8e21 5.63 62.2% 56.5% 55.8% 73.0% 825
GPT-3-1.3B*‡ 2.4e21 5.44 63.6% 58.7% 54.7% 75.1% ~800
GPT-3-Babbage‡ ----- 5.58 62.4% 59.0% 54.5% 75.5% -----
Megatron-8.3B* 7.8e21 ----- 66.5% ----- ----- ----- 174
GPT-3-2.7B*‡ 4.8e21 4.60 67.1% 62.3% 62.8% 75.6% ~800
Megatron-11B† 1.0e22 ----- ----- ----- ----- ----- 161
GPT-J-6B 1.5e22 3.99 69.7% 65.3% 66.1% 76.5% 825
GPT-3-6.7B*‡ 1.2e22 4.00 70.3% 64.5% 67.4% 78.0% ~800
GPT-3-Curie‡ ----- 4.00 69.3% 65.6% 68.5% 77.9% -----
GPT-3-13B*‡ 2.3e22 3.56 72.5% 67.9% 70.9% 78.5% ~800
GPT-3-175B*‡ 3.1e23 3.00 76.2% 70.2% 78.9% 81.0% ~800
GPT-3-Davinci‡ ----- 3.0 75% 72% 78% 80% -----
Gopher 230B* 6.31E+23 ----- 74.50% 70.10% 79.20% 81.80% 1344
MT-NLG 530B*‡ ----- ----- 76.6% 73.0% 80.2% 82.0% -----

* represents evaluation numbers reported by their respective authors, all other numbers are provided by running the lm-evaluation-harness either with the released weights or with API access. Due to subtle implementation differences as well as different zero shot task framing, these might not be directly comparable. See this blog post for more details.

The Megatron-11B model provides no comparable metrics, and several implementations using the released weights do not reproduce the generation quality and evaluations. (see 1 2 3) Thus, evaluation was not attempted.

These models have been trained with data which contains possible test set contamination. The OpenAI GPT-3 models failed to deduplicate training data for certain test sets, while the GPT-Neo models as well as this one is trained on The Pile, which has not been deduplicated against any test sets.

Architecture and Usage

Most scripts in this repository are designed to be run on TPUs, which under the TPU-VM architecture are virtual machines which can run arbitrary code. Most scripts are designed to spin up a TPU, SSH into it to set up the dependencies and copy code over from the local directory, and then start a Ray worker which can accept RPC calls.

The TPUVMs handles running model training steps and evaluation, checkpoint save and loading, while the driver python program handles data loading and general orchestration (such as when to save checkpoints etc).

This means that most scripts (train.py, eval_harness.py etc) expect to be running on a GCE virtual machine in the same region as the TPUs, to minimize RPC latency and data transfer cost. Other scripts (usually ones which don't take a --tpu argument, such as device_sample.py, device_serve.py or device_train.py) expect to be run directly on a TPUVM. The device_* scripts only work on a v3-8 and not on larger pods.

Furthermore, there is an example (resharding_example.py) of how to convert the provided checkpoints (which have 8 shards in the case of GPT-J-6B) down to a smaller number, such as for when running on GPU(s).

Fine-tuning

To fine-tune the model, run device_train.py on a TPU VM. Using a TPU v3-8, you can fine-tune at a rate of ~5000 tokens/second, which should be sufficient for small-to-medium-size datasets.

Please read the step by step guide for thorough fine-tuning instructions.

JAX Dependency

Note this library has some specific requirements for JAX version. Specifically, to use the v1 models (including GPT-J 6B), jax==0.2.12 is required. This in turn depends on jaxlib==0.1.68. If this is not done, you will get cryptic xmap errors

However, to use the v2 model code (no publicly released weights), the newest JAX version can be used.

Citation

To cite this repository:

@misc{mesh-transformer-jax,
  author = {Wang, Ben},
  title = {{Mesh-Transformer-JAX: Model-Parallel Implementation of Transformer Language Model with JAX}},
  howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}},
  year = 2021,
  month = May
}

To cite the weights of GPT-J-6B:

@misc{gpt-j,
  author = {Wang, Ben and Komatsuzaki, Aran},
  title = {{GPT-J-6B: A 6 Billion Parameter Autoregressive Language Model}},
  howpublished = {\url{https://github.com/kingoflolz/mesh-transformer-jax}},
  year = 2021,
  month = May
}

If you use this repository or any of the pretrained weights to do something cool, we would love to hear about it. Feel free to open a github issue or reach out over email (in profile).

TODO

  • disentangle heads and shards
  • test/benchmark on TPU
  • implement gradient checkpointing
  • fix initialization
  • mixed precision
  • deal with preemptible TPUs
  • test and validate generation
  • shard activations instead of replicating for memory efficiency (in v2)
  • support ZeRO style sharding (in v2)
Comments
  • AttributeError: module 'jax.random' has no attribute 'KeyArray' while fine tuning.

    AttributeError: module 'jax.random' has no attribute 'KeyArray' while fine tuning.

    I am following your rep to fine tune GPT-J on TPU. When I run "python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/" with my bucket name and the config file I have created, I get an error as "AttributeError: module 'jax.random' has no attribute 'KeyArray'". These are some of the specs:

    OS: Ubuntu 20.04 jax version = 0.2.12 TPU : V3-8 Zone : us-central1-b

    The error is caused by line 7 in the device_train.py where optax is being imported:- "import optax".

    This is the error stack:

    WARNING: Logging before InitGoogle() is written to STDERR I0420 11:47:44.856002 10240 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process. Traceback (most recent call last): File "device_train.py", line 7, in import optax File "/usr/local/lib/python3.8/dist-packages/optax/init.py", line 17, in from optax._src.alias import adabelief File "/usr/local/lib/python3.8/dist-packages/optax/_src/alias.py", line 21, in from optax._src import base File "/usr/local/lib/python3.8/dist-packages/optax/_src/base.py", line 18, in import chex File "/usr/local/lib/python3.8/dist-packages/chex/init.py", line 17, in from chex._src.asserts import assert_axis_dimension File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts.py", line 26, in from chex._src import asserts_internal as _ai File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts_internal.py", line 32, in from chex._src import pytypes File "/usr/local/lib/python3.8/dist-packages/chex/_src/pytypes.py", line 36, in PRNGKey = jax.random.KeyArray AttributeError: module 'jax.random' has no attribute 'KeyArray'

    Any help is appreciated!

    opened by samyakai 13
  • the-eye.eu down - alternative access to GPT-J-6B/step_383500_slim.tar.zstd ?

    the-eye.eu down - alternative access to GPT-J-6B/step_383500_slim.tar.zstd ?

    Thank you for your hard work and dedication to the project. I've been trying to load the slim checkpoint from the-eye.eu. Apparently, the side is down. I've checked any sort of access, but it probably is an outage on their end (https://downforeveryoneorjustme.com/the-eye.eu).

    They might resolve it shortly - or maybe not. Is there any alternative address to wget GPT-J-6B/step_383500_slim.tar.zstd ? Thank you.

    opened by PhilWicke 13
  • Return probabilities from CausalTransformer.generate()

    Return probabilities from CausalTransformer.generate()

    Convert logits to probabilities for all classes at every generation step and return them as single tensor with shape (1, gen_length, 1, vocab) in the output of CausalTransformer.generate()

    resolves #41

    opened by narphorium 13
  • TpuEmbeddingEngine_WriteParameters not available in this library.

    TpuEmbeddingEngine_WriteParameters not available in this library.

    I followed all of the instructions in the training guide but when I run the device_train script, I get this error:

    2022-02-23 07:56:56.271731: F external/org_tensorflow/tensorflow/core/tpu/tpu_library_init_fns.inc:104] TpuEmbeddingEngine_WriteParameters not available in this library.
    

    This is my exact command for the training process:

    python3 device_train.py --config=configs/6B.json --tune-model-path=gs://nnrap/step_383500
    
    opened by nikhilanayak 11
  • JAX or Ray Library issue

    JAX or Ray Library issue

    I am using this repo to train on a small custom dataset and Jax - 0.2.16 to train this model. However requirements.txt says jax 0.2.12. I really don't know how ray and jax have been working internally but my assumptions is thos code shoud work fine on v3-8 TPU. When I execute train.py then generates following error

    (pid=9454, ip=10.164.0.9) jax runtime initialization starting 2021-07-09 12:11:59,794 ERROR worker.py:78 -- Unhandled error (suppress with RAY_IGNORE_UNHANDLED_ERRORS=1): ray::NetworkRunner.run() (pid=9454, ip=10.164.0.9) File "python/ray/_raylet.pyx", line 501, in ray._raylet.execute_task File "python/ray/_raylet.pyx", line 451, in ray._raylet.execute_task.function_executor File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/ray/_private/function_manager.py", line 563, in actor_method_executor return method(__ray_actor, *args, **kwargs) File "/home/paramjeetsingh80/mesh-transformer-jax/mesh_transformer/train_actor.py", line 24, in run TypeError: new() missing 1 required positional argument: 'loops'

    I am not sure, what this error is? Can someone help be debug this error? Is this error a library versioning issue or something else.

    opened by paramjeet2021 11
  • Colab GPT-J-6B Inference Demo.ipynb error

    Colab GPT-J-6B Inference Demo.ipynb error


    ImportError Traceback (most recent call last) in () 28 maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp'))) 29 ---> 30 tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

    29 frames /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/layers/init.py in () 144 145 # Normalization layers. --> 146 from tensorflow.python.keras.layers.normalization import LayerNormalization 147 from tensorflow.python.keras.layers.normalization_v2 import SyncBatchNormalization 148

    ImportError: cannot import name 'LayerNormalization' from 'tensorflow.python.keras.layers.normalization' (/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/layers/normalization/init.py)


    NOTE: If your import is failing due to a missing package, you can manually install dependencies using either !pip or !apt.

    To view examples of installing some common dependencies, click the "Open Examples" button below.

    opened by panweigood 9
  • Illegal Instruction Core dumped (resharding_example.py)

    Illegal Instruction Core dumped (resharding_example.py)

    Whenever I try to run resharding_example.py I keep this error here. Illegal instruction (core dumped)

    I made sure to tar the weights and throw them in the root folder of the project and I installed the dependencies needed. Is there still something that i am missing?

    opened by Metroman123 8
  • division by zero

    division by zero

    if gradient_accumulation_steps == 1 there is division by zero when calculating G_noise and S_noise

    https://github.com/kingoflolz/mesh-transformer-jax/blob/22de86e3cabb995ad1005cd90b6a407c0a5f954f/device_train.py#L351

    opened by mgrankin 7
  • Error in prediction code

    Error in prediction code

    Any update on the issue #65

    Even after following @kingoflolz Recommendations in the post #65 I am not able to move ahead as library itself breaks and import jax itself abort.

    opened by paramjeet2021 7
  • Jax TPU Issue

    Jax TPU Issue

    Hi @kingoflolz, amazing work!!

    I am trying to test the model on TPU VM using the step_383500/ data.

    On following the steps mentioned in jax-quickstart-tpu-vm and then running jax.devices() returns correct output which is

    [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
    

    After that when I clone your repo and run pip install -r requirements.txt and then run pip install jax==0.2.12 as you mentioned in your fine tune docs, but It gives this error,

    2021-07-18 05:15:21.807145: F external/org_tensorflow/tensorflow/core/tpu/tpu_executor_init_fns.inc:110] TpuTransferManager_ReadDynamicShapes not available in this library.
    Aborted (core dumped)
    

    so I run this from jax-quickstart-tpu-vm docs,

    pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    

    after which jax.devices() gives this output

    >>> import jax
    >>> jax.devices()
    WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
    [CpuDevice(id=0)]
    

    Somehow jax fails to detect the TPU, don't know what's the issue, I am running Google Cloud's TPU VM v3-8 with v2-alpha software version. Would really appreciate your help.

    opened by Alexmhack 6
  • out_axes specification issue when you run the script with sudo, works fine without sudo

    out_axes specification issue when you run the script with sudo, works fine without sudo

    out_axes specification issue when you run the script with sudo, works fine without sudo, But i want to run the script through a systemd service, so it gives this error. even i run it as sudo python3 device_serve.py it gives same error, t works fine with python3 device_serve.py

    here's some stack trace.

    key shape (8, 2) in shape (1, 2048) dp 1 mp 8 read from disk/gcs in 6.40554s Traceback (most recent call last): File "simple.py", line 149, in <module> output = network.generate(batched_tokens, length, gen_length, {"top_p": np.ones(total_batch) * 0.9, File "/home/ahmedjawed/mesh-transformer-jax/mesh_transformer/transformer_shard.py", line 309, in generate return self.generate_xmap(self.state, File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 615, in fun_mapped out_flat = xmap_p.bind( File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 818, in bind return core.call_bind(self, fun, *args, **params) # type: ignore File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 1551, in call_bind outs = primitive.process(top_trace, fun, tracers, params) File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 821, in process return trace.process_xmap(self, fun, tracers, params) File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 606, in process_call return primitive.impl(f, *tracers, **params) File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 646, in xmap_impl xmap_callable = make_xmap_callable( File "/usr/local/lib/python3.8/dist-packages/jax/linear_util.py", line 262, in memoized_fun ans = call(fun, *args) File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 673, in make_xmap_callable _check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes) File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 1454, in _check_out_avals_vs_out_axes raise TypeError(f"One of xmap results has an out_axes specification of " TypeError: One of xmap results has an out_axes specification of ['batch', ...], but is actually mapped along more axes defined by this xmap call: shard

    opened by ahmedjawedaj 6
  • Could not find a version that satisfies the requirement ray[default]==1.4.1

    Could not find a version that satisfies the requirement ray[default]==1.4.1

    I'm running pip install -r mesh-transformer-jax/requirements.txt on a fresh clone of this repo and getting this:

    ERROR: Could not find a version that satisfies the requirement ray[default]==1.4.1 (from versions: 1.13.0, 2.0.0rc0, 2.0.0rc1, 2.0.0, 2.0.1, 2.1.0, 2.2.0)
    ERROR: No matching distribution found for ray[default]==1.4.1
    

    I can see this version available on pip tho, so I'm not sure why it can't find it: https://pypi.org/project/ray/1.4.1/

    Don't usually use python, thanks for assistance!

    opened by Maxim-Mazurok 0
  • Can we please get a quickstart guide?

    Can we please get a quickstart guide?

    TPU? Proxy? I just want to docker up and don't know what any of these script will or won't setup automatically. I don't like spending hours of setup to find something doesn't work.

    opened by tswallen 0
  • TPU not found on VM (jax version 0.2.16)

    TPU not found on VM (jax version 0.2.16)

    Hello

    I'm running a TPU v3-8 VM on Google. On the VM I installed jax with pip install "jax[tpu]==0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html.

    Unfortunately, I'm getting the message No GPU/TPU found, falling back to CPU. when issuing jax.device_count(). The same holds for pip install jax==0.2.12. Only when I'm using pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html (newest jax version), it works. As far as I can see, for fine-tuning we need jax version 0.2.12 or 0.2.16.

    How can I get it running with these versions?

    opened by Eichhof 0
  • Project dependencies may have API risk issues

    Project dependencies may have API risk issues

    Hi, In mesh-transformer-jax, inappropriate dependency versioning constraints can cause risks.

    Below are the dependencies and version constraints that the project is using

    numpy~=1.19.5
    tqdm>=4.45.0
    wandb>=0.11.2
    einops~=0.3.0
    requests~=2.25.1
    fabric~=2.6.0
    optax==0.0.9
    dm-haiku==0.0.5
    git+https://github.com/EleutherAI/lm-evaluation-harness/
    ray[default]==1.4.1
    jax~=0.2.12
    Flask~=1.1.2
    cloudpickle~=1.3.0
    tensorflow-cpu~=2.6.0
    google-cloud-storage~=1.36.2
    transformers
    smart_open[gcs]
    func_timeout
    ftfy
    fastapi
    uvicorn
    lm_dataformat
    pathy
    

    The version constraint == will introduce the risk of dependency conflicts because the scope of dependencies is too strict. The version constraint No Upper Bound and * will introduce the risk of the missing API Error because the latest version of the dependencies may remove some APIs.

    After further analysis, in this project, The version constraint of dependency tqdm can be changed to >=4.36.0,<=4.64.0. The version constraint of dependency wandb can be changed to >=0.5.16,<=0.9.7. The version constraint of dependency dm-haiku can be changed to >=0.0.1,<=0.0.8. The version constraint of dependency google-cloud-storage can be changed to >=1.17.0,<=2.4.0. The version constraint of dependency fastapi can be changed to >=0.1.2,<=0.78.0.

    The above modification suggestions can reduce the dependency conflicts as much as possible, and introduce the latest version as much as possible without calling Error in the projects.

    The invocation of the current project includes all the following methods.

    The calling methods from the tqdm
    tqdm.tqdm.set_postfix
    tqdm.tqdm.update
    tqdm.tqdm
    
    The calling methods from the wandb
    wandb.init
    wandb.log
    
    The calling methods from the dm-haiku
    haiku.PRNGSequence.take
    haiku.without_apply_rng
    haiku.experimental.optimize_rng_use
    haiku.data_structures.tree_size
    
    The calling methods from the google-cloud-storage
    google.cloud.storage.Client.list_blobs
    
    The calling methods from the fastapi
    fastapi.FastAPI.post
    fastapi.FastAPI
    
    The calling methods from the all methods
    optax.chain
    rotate_every_two
    mesh_transformer.checkpoint.write_ckpt
    x.reshape
    jax.eval_shape
    flask.Flask
    multiprocessing.pool.ThreadPool
    haiku.Flatten
    get_bearer
    jax.numpy.exp
    queue.Queue
    mesh_transformer.TPU_cluster.TPUCluster
    mesh_transformer.layers.TransformerLayerShardV2
    json.load
    time.sleep
    n.run.remote
    flask.make_response
    config.config.TransformerLayerShardV2.get_init_decode_state
    super
    self.proj.loss.hk.remat
    self.move_weights_pjit
    shard_strategy.count
    vocab.example_shape.key.next.jax.random.uniform.astype
    numpy.sum
    jax.experimental.pjit.pjit
    val_loss.np.array.mean.append
    eval_apply_fn
    self._relative_position_bucket
    f.read
    to_id
    jax.lax.dot_general
    ops.get_gptj_model
    jax.nn.softmax
    numpy.logical_and.tolist
    jax.numpy.stack
    self.queue.put
    mesh_transformer.layers.ProjectionShard
    jax.lax.sort_key_val
    self.config.get
    response.headers.add
    l.decode_once
    jax.local_device_count
    lm_eval.evaluator.evaluate.items
    rotate_every_two_v2
    threading.Lock
    jax.numpy.linalg.norm
    x.jnp.transpose.reshape
    argparse.ArgumentParser.add_argument
    mesh_transformer.layers.TransformerLayerShard
    ckpt_step.str.meta.get
    p.map
    val_set.reset
    mesh_transformer.layers.Projection
    projection_apply_fn
    jax.numpy.array_equal
    tqdm.tqdm.update
    numpy.zeros_like
    qids.append
    jax.tree_unflatten
    tensorflow.data.TFRecordDataset
    param_init_fn
    self.rpe
    _unshard
    self.queue_ids.get
    compile_model
    mesh_transformer.util.head_print
    n.eval.remote
    jax.numpy.log
    ops.get_gptj_model.add_to_queue
    f_psum
    mesh_transformer.build_model.build_model.eval
    haiku.transform
    jax.random.uniform
    jax.host_count
    haiku.get_parameter
    numpy.array
    self.convert_requests
    n.train.remote
    jax.numpy.var
    config.config.TransformerLayerShardV2
    grad_norm.np.array.mean
    x.self.k.reshape
    batch_items.append
    x.reshape.reshape
    getattr
    tfrecord_loader.TFRecordNewInputs.get_state
    self.proj
    self.output_q.get
    threading.Thread
    data.items
    json.load.append
    numpy.prod
    jax.lax.stop_gradient
    x.numpy
    node.load_ckpt.remote
    input
    numpy.sqrt
    numpy.where
    jax.numpy.sort
    min
    queue.Queue.put
    transformers.GPT2TokenizerFast.from_pretrained.decode
    fastapi.FastAPI.on_event
    self.tokenizer.encode
    numpy.arange
    self.input
    jax.tree_map
    f_psum.defvjp
    _corsify_actual_response
    _unshard.append
    global_norm
    self.eval
    config.get
    tree_flatten_with_names
    numpy.empty
    states.append
    self.train_pjit
    flask.jsonify
    state.items
    is_leaf
    json.load.items
    jax.numpy.transpose
    jax.numpy.array
    self.input_q.get
    given_length.astype
    numpy.log
    jax.random.split
    repr
    ClipByGlobalNormState
    exit
    int
    i.decode
    apply_fns
    new_states.append
    iter
    all_array_equal
    z_loss.sum_exp_logits.jnp.log.jnp.square.mean
    GPTJ
    self.to_data
    l.hk.remat
    read_remote.remote
    mesh_transformer.transformer_shard.CausalTransformer.generate
    getnorm
    requests.post
    mesh_transformer.transformer_shard.CausalTransformer
    x.reshape.astype
    batch_flattened.append
    wandb.log
    loss.np.array.mean
    functools.partial
    jax.lax.pmax
    jax.lax.rsqrt
    jax.lax.broadcasted_iota
    tokenizer
    self.train_xmap
    generate_fn
    numpy.finfo
    jax.nn.one_hot
    str
    self.network_builder.generate
    numpy.tril
    self.ff
    mesh_transformer.transformer_shard.CausalTransformer.eval
    payloads.QueueResponse
    _build_cors_prelight_response
    mesh_transformer.layers.EmbeddingShard
    self.proj.max
    attention_vec.reshape.reshape
    ray.get.append
    google.cloud.storage.Client
    jax.random.categorical
    itertools.cycle
    numpy.zeros
    embed_apply_fn
    x.self.q.reshape
    self.self_attn
    x.jnp.zeros_like.astype
    itertools.zip_longest
    transformers.GPT2TokenizerFast.from_pretrained.encode
    jax.lax.all_gather
    unstacked.append
    jax.tree_multimap
    jax.experimental.maps.ResourceEnv
    g_psum
    jax.numpy.sqrt
    jax.lax.pmean
    train_step
    jax.lax.axis_index
    loss.append
    mesh_transformer.util.additive_weight_decay
    output.append
    mesh_transformer.util.clip_by_global_norm
    qid.self.queue_ids.get
    numpy.array.mean
    NotImplementedError
    ValueError
    get_project
    self.dense_proj_o
    self.prepare_item.get
    ops.get_gptj_model.load_model
    tfrecord_loader.TFRecordNewInputs
    json.dumps
    jax.experimental.PartitionSpec
    mesh_transformer.util.global_norm
    numpy.ones
    Exception
    zip
    mesh_transformer.util.to_f32.to_bf16.early_cast
    id
    process_request
    next
    jax.experimental.maps.Mesh
    file.prefetch.apply
    init_decode_apply
    file_index.dir.np.load.keys
    self.input_proj
    jax.numpy.argsort
    lm_eval.tasks.get_task_dict
    batch_items.items
    mesh_transformer.build_model.build_model.move
    last.astype
    app.run.threading.Thread.start
    fabric.Connection
    config.EmbeddingShardV2
    requests.delete.json
    self.generate_xmap
    ftfy.fix_text
    sum
    jax.numpy.multiply
    timer
    haiku.data_structures.tree_size
    self.init_pjit
    config.config.TransformerLayerShardV2.decode_once
    CausalTransformerShard.generate_initial
    infer
    CausalTransformerShard
    mesh_transformer.util.to_f32.to_bf16.bf16_optimizer
    jax.host_id
    self.reset
    file.prefetch.prefetch
    config.Projection.loss
    open
    apply_rotary_pos_emb
    nucleaus_filter
    jax.random.PRNGKey
    config.Projection
    json.dump
    p.imap
    self.output_proj
    val_grad_fn
    mesh_transformer.build_model.build_model.train
    mesh_transformer.transformer_shard.CausalTransformer.move_xmap
    io.BytesIO
    x.x.all
    x.self.v.reshape
    self.network_builder.train
    val_set.get_samples
    optax.scale
    all_top_p.append
    setuptools.find_packages
    self.output
    iter_decode_apply
    jax.experimental.pjit.with_sharding_constraint
    smart_open.open
    random.randint
    residual.hk.remat
    pytree.items
    tasks.util.shrink_seq
    pad_amount.tokens.np.pad.astype
    decode
    tensorflow.io.VarLenFeature
    self.network_builder.write_ckpt
    fastapi.FastAPI
    node.write_ckpt.remote
    qid.self.queue_ids.put
    last_loss.np.array.mean
    all_ctx.append
    check_tpu
    conn.run
    list
    func_timeout.func_set_timeout
    numpy.maximum.astype
    jax.numpy.concatenate
    conn.put
    tasks.eval_harness.EvalHarnessAdaptor
    self.infer_batch
    jax.lax.scan
    delete_tpu
    ops.get_gptj_model.start_background
    argparse.ArgumentParser
    lm_eval.evaluator.evaluate
    fastapi.FastAPI.post
    tensorflow.sparse.to_dense
    threading.Thread.start
    numpy.minimum
    numpy.sum.tolist
    jax.value_and_grad
    mesh_transformer.train_actor.NetworkRunner.options
    range
    new_vals.append
    reshard.all
    optax.apply_updates
    split
    self.network.generate
    tensorflow.sparse.reorder
    self.dense_proj
    optax.scale_by_schedule
    ReplicatedLayerNorm
    self.q
    jax.numpy.where
    all_temp.append
    google.cloud.storage.Client.list_blobs
    f_pmean.defvjp
    eval_step
    optax.GradientTransformation
    jax.lax.psum
    parallel_write
    init_fns
    blob.delete
    params.append
    eval_loss_fn
    tasks.EvalHarnessAdaptor
    projection_init_fn
    jax.tree_leaves
    reshard
    params.get
    res.ray.get.i.i.np.array.mean
    numpy.savez
    wandb.init
    grouper
    numpy.max
    compression.i.tf.data.TFRecordDataset.map
    optax.scale_by_adam
    haiku.PRNGSequence
    jax.numpy.zeros_like
    time.time
    logging.getLogger
    self.embed.hk.remat
    multiprocessing.Pool
    x.reshape.append
    jax.numpy.square
    last_loss.append
    numpy.pad
    mesh_transformer.build_model.build_model
    mesh_transformer.transformer_shard.CausalTransformer.train
    self.o
    mesh_transformer.checkpoint.load_ckpt_v2
    self.network_builder
    self.glu
    jax.experimental.maps.mesh
    glob.glob
    ray.init
    tensorflow.io.FixedLenFeature
    self.tokenizer.add_special_tokens
    self.norm
    sch
    self.pool.imap
    reshaped.reshape.reshape
    numpy.cos
    jax.tree_flatten
    tqdm.tqdm
    step.step.all
    self.sample_once
    self.map_fn
    batch.append
    noise_scale_stats.update
    ray_tpu.wait_til
    optimizer.update
    super.__init__
    read_sharded_v2
    val_set.sample_once
    max
    logging.getLogger.debug
    save
    queue.Queue.get
    jax.devices.np.array.reshape
    self.v
    haiku.remat
    self.qvk_proj
    fastapi.FastAPI.add_middleware
    self.network_builder.load_ckpt
    haiku.initializers.TruncatedNormal
    all_length.append
    haiku.experimental.optimize_rng_use
    json.load.pop
    temp.logits.key.jax.random.categorical.astype
    mesh_transformer.util.to_f32.to_bf16.config
    val_sets.values
    numpy.sin
    jax.numpy.zeros
    tasks.util.sample_batch
    requests.delete
    self.eval_xmap
    haiku.Linear
    self.network_builder.eval
    tfrecord_loader.TFRecordNewInputs.get_samples
    optax.additive_weight_decay
    task_res.items
    jax.numpy.zeros.block_until_ready
    haiku.LayerNorm
    self.used.append
    jax.numpy.sum
    fixed_pos_embedding
    self.input_q.put
    setuptools.setup
    jax.numpy.arange
    jax.numpy.asarray
    apply_rotary_pos_emb_v2
    warnings.filterwarnings
    functools.lru_cache
    softmax_sample
    format
    traceback.print_exc
    numpy.dtype
    transformer_init_fn
    n.get_params.remote
    v.attention_weights.jnp.einsum.reshape
    payloads.CompletionResponse
    itertools.chain
    jax.numpy.exp.sum
    mesh_transformer.build_model.build_model.save
    ray.is_initialized
    aux.get
    reshard.reshape
    jax.devices
    ray.remote
    self.prepare_item
    scheduler
    mesh_transformer.util.maybe_shard
    ray_tpu.get_connection
    jax.numpy.clip
    ray.get
    node.move_params.remote
    json.load.get
    divmod
    p.imap_unordered
    jax.numpy.einsum
    grad_norm_micro.np.array.mean
    CausalTransformerShard.loss
    gpt3_schedule
    index_weights
    parallel_read
    fix_dtype
    jax.lax.pmean.mean
    subprocess.check_output.decode
    optax.AdditiveWeightDecayState
    numpy.load
    jax.numpy.mean
    numpy.concatenate
    TFRecordWIT
    n.generate.remote
    jax.experimental.maps.xmap
    outputs.append
    index_fname.open.read
    float
    CausalTransformerShard.generate_once
    ray_tpu.create_tpu
    val_sets.items
    mesh_transformer.checkpoint.read_ckpt
    bcast_iota.Ellipsis.jnp.newaxis.rp_bucket.jnp.array.astype
    self.nodes.append
    mesh_transformer.util.g_psum
    self.dim_per_head.np.sqrt.astype
    transformers.GPT2TokenizerFast.from_pretrained
    samples.append
    dp.mp.key.next.jax.random.uniform.astype
    tree_leaves_with_names
    subprocess.check_output
    get_inital
    queue.Queue.qsize
    conn.sudo
    q.put
    transformer_apply_fn
    jax.numpy.cumsum
    l.get_init_decode_state
    mesh_transformer.util.gpt3_schedule
    self.tpu.eval
    optimizer.init
    haiku.next_rng_key
    jax.numpy.split
    parse_args
    all
    print
    ops.get_gptj_model.wait_for_queue
    val_loss.np.array.mean
    argparse.ArgumentParser.parse_args
    self.queue.get
    self.embed
    RMSNorm
    mesh_transformer.layers.RelativePositionEmbs
    numpy.stack
    self.head_split
    requests.get
    g_psum.defvjp
    process_init
    self.network_builder.move_xmap
    mesh_transformer.checkpoint.write_ckpt_v2
    logging.getLogger.info
    map
    filter
    jax.numpy.broadcast_to
    self.eval_pjit
    uvicorn.run
    einops.repeat
    mesh_transformer.train_actor.NetworkRunner.options.remote
    jax.numpy.reshape
    self.d_head.np.sqrt.astype
    len
    self.get_samples
    gc.collect
    mesh_transformer.util.to_f32
    self.norm.reshape
    TFRecordWIT.sample_once
    haiku.PRNGSequence.take
    einops.rearrange
    sampler
    float.update
    index_fname.open.read.splitlines
    numpy.einsum
    out.reshape.reshape
    haiku.without_apply_rng
    tensorflow.io.parse_single_example
    self.output_q.put
    all_tokenized.append
    shrink_seq
    self.transformer_layers.append
    network.state.count.item
    bool
    embed_init_fn
    self.tokenizer.decode
    RuntimeError
    all_q.append
    multiprocessing.set_start_method
    haiku.initializers.Constant
    tensorflow.data.experimental.dense_to_ragged_batch
    tensorflow.cast
    jax.nn.gelu
    mesh_transformer.build_model.build_model.load
    jax.device_put
    max_exact.num_buckets.max_exact.max_distance.np.log.np.float32.np.finfo.eps.max_exact.np.float32.n.astype.np.log.astype
    mesh_transformer.util.to_bf16
    convert_fn
    max_lengths.append
    tqdm.tqdm.set_postfix
    subprocess.check_output.decode.strip
    isinstance
    numpy.array_split
    numpy.logical_and
    jax.numpy.cos
    mesh_transformer.util.f_psum
    enumerate
    flask.Flask.route
    mesh_transformer.layers.EmbeddingShardV2
    self.init_xmap
    seq.np.zeros.astype
    os.path.expanduser
    jax.device_count
    contexts.append
    self.k
    numpy.maximum
    

    @developer Could please help me check this issue? May I pull a request to fix it? Thank you very much.

    opened by PyDeps 0
Owner
Ben Wang
Ben Wang
An implementation of model parallel GPT-3-like models on GPUs, based on the DeepSpeed library. Designed to be able to train models in the hundreds of billions of parameters or larger.

GPT-NeoX An implementation of model parallel GPT-3-like models on GPUs, based on the DeepSpeed library. Designed to be able to train models in the hun

EleutherAI 3.1k Jan 8, 2023
Yuqing Xie 2 Feb 17, 2022
Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra. What is Lightning Tran

Pytorch Lightning 581 Dec 21, 2022
Ray-based parallel data preprocessing for NLP and ML.

Wrangl Ray-based parallel data preprocessing for NLP and ML. pip install wrangl # for latest pip install git+https://github.com/vzhong/wrangl See exa

Victor Zhong 33 Dec 27, 2022
ReCoin - Restoring our environment and businesses in parallel

Shashank Ojha, Sabrina Button, Abdellah Ghassel, Joshua Gonzales "Reduce Reuse R

sabrina button 1 Mar 14, 2022
Official implementation of MLP Singer: Towards Rapid Parallel Korean Singing Voice Synthesis

MLP Singer Official implementation of MLP Singer: Towards Rapid Parallel Korean Singing Voice Synthesis. Audio samples are available on our demo page.

Neosapience 103 Dec 23, 2022
This is a project of data parallel that running on NLP tasks.

This is a project of data parallel that running on NLP tasks.

null 2 Dec 12, 2021
Unofficial Parallel WaveGAN (+ MelGAN & Multi-band MelGAN & HiFi-GAN & StyleMelGAN) with Pytorch

Parallel WaveGAN implementation with Pytorch This repository provides UNOFFICIAL pytorch implementations of the following models: Parallel WaveGAN Mel

Tomoki Hayashi 1.2k Dec 23, 2022
Code for "Parallel Instance Query Network for Named Entity Recognition", accepted at ACL 2022.

README Code for Two-stage Identifier: "Parallel Instance Query Network for Named Entity Recognition", accepted at ACL 2022. For details of the model a

Yongliang Shen 45 Nov 29, 2022
Shared code for training sentence embeddings with Flax / JAX

flax-sentence-embeddings This repository will be used to share code for the Flax / JAX community event to train sentence embeddings on 1B+ training pa

Nils Reimers 23 Dec 30, 2022
Flaxformer: transformer architectures in JAX/Flax

Flaxformer: transformer architectures in JAX/Flax Flaxformer is a transformer library for primarily NLP and multimodal research at Google. It is used

Google 114 Dec 29, 2022
KoBART model on huggingface transformers

KoBART-Transformers SKT에서 공개한 KoBART를 편리하게 사용할 수 있게 transformers로 포팅하였습니다. Install (Optional) BartModel과 PreTrainedTokenizerFast를 이용하면 설치하실 필요 없습니다. p

Hyunwoong Ko 58 Dec 7, 2022
Train 🤗-transformers model with Poutyne.

poutyne-transformers Train ?? -transformers models with Poutyne. Installation pip install poutyne-transformers Example import torch from transformers

Lennart Keller 2 Dec 18, 2022
Utilize Korean BERT model in sentence-transformers library

ko-sentence-transformers 이 프로젝트는 KoBERT 모델을 sentence-transformers 에서 보다 쉽게 사용하기 위해 만들어졌습니다. Ko-Sentence-BERT-SKTBERT 프로젝트에서는 KoBERT 모델을 sentence-trans

Junghyun 40 Dec 20, 2022
An ultra fast tiny model for lane detection, using onnx_parser, TensorRTAPI, torch2trt to accelerate. our model support for int8, dynamic input and profiling. (Nvidia-Alibaba-TensoRT-hackathon2021)

Ultra_Fast_Lane_Detection_TensorRT An ultra fast tiny model for lane detection, using onnx_parser, TensorRTAPI to accelerate. our model support for in

steven.yan 121 Dec 27, 2022
Explore different way to mix speech model(wav2vec2, hubert) and nlp model(BART,T5,GPT) together

SpeechMix Explore different way to mix speech model(wav2vec2, hubert) and nlp model(BART,T5,GPT) together. Introduction For the same input: from datas

Eric Lam 31 Nov 7, 2022
:mag: End-to-End Framework for building natural language search interfaces to data by utilizing Transformers and the State-of-the-Art of NLP. Supporting DPR, Elasticsearch, HuggingFace’s Modelhub and much more!

Haystack is an end-to-end framework that enables you to build powerful and production-ready pipelines for different search use cases. Whether you want

deepset 1.4k Feb 18, 2021
Incorporating KenLM language model with HuggingFace implementation of Wav2Vec2CTC Model using beam search decoding

Wav2Vec2CTC With KenLM Using KenLM ARPA language model with beam search to decode audio files and show the most probable transcription. Assuming you'v

farisalasmary 65 Sep 21, 2022