Model parallel transformers in Jax and Haiku

Overview

Mesh Transformer Jax

A haiku library using the new(ly documented) xmap operator in Jax for model parallelism of transformers.

See enwik8_example.py for an example of using this to implement an autoregressive language model.

Benchmarks

On a TPU v3-8 (see tpuv38_example.py):

~2.7B model

Initialized in 121.842s
Total parameters: 2722382080
Compiled in 49.0534s
it: 0, loss: 20.311113357543945
<snip>
it: 90, loss: 3.987450361251831
100 steps in 109.385s
effective flops (not including attn): 2.4466e+14

~4.8B model

Initialized in 101.016s
Total parameters: 4836720896
Compiled in 52.7404s
it: 0, loss: 4.632925987243652
<snip>
it: 40, loss: 3.2406811714172363
50 steps in 102.559s
effective flops (not including attn): 2.31803e+14

10B model

Initialized in 152.762s
Total parameters: 10073579776
Compiled in 92.6539s
it: 0, loss: 5.3125
<snip>
it: 40, loss: 3.65625
50 steps in 100.235s
effective flops (not including attn): 2.46988e+14

TODO

  • disentangle heads and shards
  • test/benchmark on TPU
  • implement gradient checkpointing
  • fix initialization
  • mixed precision
  • shard activations instead of replicating
Comments
  • AttributeError: module 'jax.random' has no attribute 'KeyArray' while fine tuning.

    AttributeError: module 'jax.random' has no attribute 'KeyArray' while fine tuning.

    I am following your rep to fine tune GPT-J on TPU. When I run "python3 device_train.py --config=YOUR_CONFIG.json --tune-model-path=gs://YOUR-BUCKET/step_383500/" with my bucket name and the config file I have created, I get an error as "AttributeError: module 'jax.random' has no attribute 'KeyArray'". These are some of the specs:

    OS: Ubuntu 20.04 jax version = 0.2.12 TPU : V3-8 Zone : us-central1-b

    The error is caused by line 7 in the device_train.py where optax is being imported:- "import optax".

    This is the error stack:

    WARNING: Logging before InitGoogle() is written to STDERR I0420 11:47:44.856002 10240 tpu_initializer_helper.cc:94] libtpu.so already in use by another process. Run "$ sudo lsof -w /dev/accel0" to figure out which process is using the TPU. Not attempting to load libtpu.so in this process. Traceback (most recent call last): File "device_train.py", line 7, in import optax File "/usr/local/lib/python3.8/dist-packages/optax/init.py", line 17, in from optax._src.alias import adabelief File "/usr/local/lib/python3.8/dist-packages/optax/_src/alias.py", line 21, in from optax._src import base File "/usr/local/lib/python3.8/dist-packages/optax/_src/base.py", line 18, in import chex File "/usr/local/lib/python3.8/dist-packages/chex/init.py", line 17, in from chex._src.asserts import assert_axis_dimension File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts.py", line 26, in from chex._src import asserts_internal as _ai File "/usr/local/lib/python3.8/dist-packages/chex/_src/asserts_internal.py", line 32, in from chex._src import pytypes File "/usr/local/lib/python3.8/dist-packages/chex/_src/pytypes.py", line 36, in PRNGKey = jax.random.KeyArray AttributeError: module 'jax.random' has no attribute 'KeyArray'

    Any help is appreciated!

    opened by samyakai 13
  • the-eye.eu down - alternative access to GPT-J-6B/step_383500_slim.tar.zstd ?

    the-eye.eu down - alternative access to GPT-J-6B/step_383500_slim.tar.zstd ?

    Thank you for your hard work and dedication to the project. I've been trying to load the slim checkpoint from the-eye.eu. Apparently, the side is down. I've checked any sort of access, but it probably is an outage on their end (https://downforeveryoneorjustme.com/the-eye.eu).

    They might resolve it shortly - or maybe not. Is there any alternative address to wget GPT-J-6B/step_383500_slim.tar.zstd ? Thank you.

    opened by PhilWicke 13
  • Return probabilities from CausalTransformer.generate()

    Return probabilities from CausalTransformer.generate()

    Convert logits to probabilities for all classes at every generation step and return them as single tensor with shape (1, gen_length, 1, vocab) in the output of CausalTransformer.generate()

    resolves #41

    opened by narphorium 13
  • TpuEmbeddingEngine_WriteParameters not available in this library.

    TpuEmbeddingEngine_WriteParameters not available in this library.

    I followed all of the instructions in the training guide but when I run the device_train script, I get this error:

    2022-02-23 07:56:56.271731: F external/org_tensorflow/tensorflow/core/tpu/tpu_library_init_fns.inc:104] TpuEmbeddingEngine_WriteParameters not available in this library.
    

    This is my exact command for the training process:

    python3 device_train.py --config=configs/6B.json --tune-model-path=gs://nnrap/step_383500
    
    opened by nikhilanayak 11
  • JAX or Ray Library issue

    JAX or Ray Library issue

    I am using this repo to train on a small custom dataset and Jax - 0.2.16 to train this model. However requirements.txt says jax 0.2.12. I really don't know how ray and jax have been working internally but my assumptions is thos code shoud work fine on v3-8 TPU. When I execute train.py then generates following error

    (pid=9454, ip=10.164.0.9) jax runtime initialization starting 2021-07-09 12:11:59,794 ERROR worker.py:78 -- Unhandled error (suppress with RAY_IGNORE_UNHANDLED_ERRORS=1): ray::NetworkRunner.run() (pid=9454, ip=10.164.0.9) File "python/ray/_raylet.pyx", line 501, in ray._raylet.execute_task File "python/ray/_raylet.pyx", line 451, in ray._raylet.execute_task.function_executor File "/home/paramjeetsingh80/.local/lib/python3.8/site-packages/ray/_private/function_manager.py", line 563, in actor_method_executor return method(__ray_actor, *args, **kwargs) File "/home/paramjeetsingh80/mesh-transformer-jax/mesh_transformer/train_actor.py", line 24, in run TypeError: new() missing 1 required positional argument: 'loops'

    I am not sure, what this error is? Can someone help be debug this error? Is this error a library versioning issue or something else.

    opened by paramjeet2021 11
  • Colab GPT-J-6B Inference Demo.ipynb error

    Colab GPT-J-6B Inference Demo.ipynb error


    ImportError Traceback (most recent call last) in () 28 maps.thread_resources.env = maps.ResourceEnv(maps.Mesh(devices, ('dp', 'mp'))) 29 ---> 30 tokenizer = transformers.GPT2TokenizerFast.from_pretrained('gpt2')

    29 frames /usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/layers/init.py in () 144 145 # Normalization layers. --> 146 from tensorflow.python.keras.layers.normalization import LayerNormalization 147 from tensorflow.python.keras.layers.normalization_v2 import SyncBatchNormalization 148

    ImportError: cannot import name 'LayerNormalization' from 'tensorflow.python.keras.layers.normalization' (/usr/local/lib/python3.7/dist-packages/tensorflow/python/keras/layers/normalization/init.py)


    NOTE: If your import is failing due to a missing package, you can manually install dependencies using either !pip or !apt.

    To view examples of installing some common dependencies, click the "Open Examples" button below.

    opened by panweigood 9
  • Illegal Instruction Core dumped (resharding_example.py)

    Illegal Instruction Core dumped (resharding_example.py)

    Whenever I try to run resharding_example.py I keep this error here. Illegal instruction (core dumped)

    I made sure to tar the weights and throw them in the root folder of the project and I installed the dependencies needed. Is there still something that i am missing?

    opened by Metroman123 8
  • division by zero

    division by zero

    if gradient_accumulation_steps == 1 there is division by zero when calculating G_noise and S_noise

    https://github.com/kingoflolz/mesh-transformer-jax/blob/22de86e3cabb995ad1005cd90b6a407c0a5f954f/device_train.py#L351

    opened by mgrankin 7
  • Error in prediction code

    Error in prediction code

    Any update on the issue #65

    Even after following @kingoflolz Recommendations in the post #65 I am not able to move ahead as library itself breaks and import jax itself abort.

    opened by paramjeet2021 7
  • Jax TPU Issue

    Jax TPU Issue

    Hi @kingoflolz, amazing work!!

    I am trying to test the model on TPU VM using the step_383500/ data.

    On following the steps mentioned in jax-quickstart-tpu-vm and then running jax.devices() returns correct output which is

    [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0), TpuDevice(id=1, process_index=0, coords=(0,0,0), core_on_chip=1), TpuDevice(id=2, process_index=0, coords=(1,0,0), core_on_chip=0), TpuDevice(id=3, process_index=0, coords=(1,0,0), core_on_chip=1), TpuDevice(id=4, process_index=0, coords=(0,1,0), core_on_chip=0), TpuDevice(id=5, process_index=0, coords=(0,1,0), core_on_chip=1), TpuDevice(id=6, process_index=0, coords=(1,1,0), core_on_chip=0), TpuDevice(id=7, process_index=0, coords=(1,1,0), core_on_chip=1)]
    

    After that when I clone your repo and run pip install -r requirements.txt and then run pip install jax==0.2.12 as you mentioned in your fine tune docs, but It gives this error,

    2021-07-18 05:15:21.807145: F external/org_tensorflow/tensorflow/core/tpu/tpu_executor_init_fns.inc:110] TpuTransferManager_ReadDynamicShapes not available in this library.
    Aborted (core dumped)
    

    so I run this from jax-quickstart-tpu-vm docs,

    pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html
    

    after which jax.devices() gives this output

    >>> import jax
    >>> jax.devices()
    WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
    [CpuDevice(id=0)]
    

    Somehow jax fails to detect the TPU, don't know what's the issue, I am running Google Cloud's TPU VM v3-8 with v2-alpha software version. Would really appreciate your help.

    opened by Alexmhack 6
  • out_axes specification issue when you run the script with sudo, works fine without sudo

    out_axes specification issue when you run the script with sudo, works fine without sudo

    out_axes specification issue when you run the script with sudo, works fine without sudo, But i want to run the script through a systemd service, so it gives this error. even i run it as sudo python3 device_serve.py it gives same error, t works fine with python3 device_serve.py

    here's some stack trace.

    key shape (8, 2) in shape (1, 2048) dp 1 mp 8 read from disk/gcs in 6.40554s Traceback (most recent call last): File "simple.py", line 149, in <module> output = network.generate(batched_tokens, length, gen_length, {"top_p": np.ones(total_batch) * 0.9, File "/home/ahmedjawed/mesh-transformer-jax/mesh_transformer/transformer_shard.py", line 309, in generate return self.generate_xmap(self.state, File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 615, in fun_mapped out_flat = xmap_p.bind( File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 818, in bind return core.call_bind(self, fun, *args, **params) # type: ignore File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 1551, in call_bind outs = primitive.process(top_trace, fun, tracers, params) File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 821, in process return trace.process_xmap(self, fun, tracers, params) File "/usr/local/lib/python3.8/dist-packages/jax/core.py", line 606, in process_call return primitive.impl(f, *tracers, **params) File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 646, in xmap_impl xmap_callable = make_xmap_callable( File "/usr/local/lib/python3.8/dist-packages/jax/linear_util.py", line 262, in memoized_fun ans = call(fun, *args) File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 673, in make_xmap_callable _check_out_avals_vs_out_axes(out_avals, out_axes, global_axis_sizes) File "/usr/local/lib/python3.8/dist-packages/jax/experimental/maps.py", line 1454, in _check_out_avals_vs_out_axes raise TypeError(f"One of xmap results has an out_axes specification of " TypeError: One of xmap results has an out_axes specification of ['batch', ...], but is actually mapped along more axes defined by this xmap call: shard

    opened by ahmedjawedaj 6
  • Could not find a version that satisfies the requirement ray[default]==1.4.1

    Could not find a version that satisfies the requirement ray[default]==1.4.1

    I'm running pip install -r mesh-transformer-jax/requirements.txt on a fresh clone of this repo and getting this:

    ERROR: Could not find a version that satisfies the requirement ray[default]==1.4.1 (from versions: 1.13.0, 2.0.0rc0, 2.0.0rc1, 2.0.0, 2.0.1, 2.1.0, 2.2.0)
    ERROR: No matching distribution found for ray[default]==1.4.1
    

    I can see this version available on pip tho, so I'm not sure why it can't find it: https://pypi.org/project/ray/1.4.1/

    Don't usually use python, thanks for assistance!

    opened by Maxim-Mazurok 0
  • Can we please get a quickstart guide?

    Can we please get a quickstart guide?

    TPU? Proxy? I just want to docker up and don't know what any of these script will or won't setup automatically. I don't like spending hours of setup to find something doesn't work.

    opened by tswallen 0
  • TPU not found on VM (jax version 0.2.16)

    TPU not found on VM (jax version 0.2.16)

    Hello

    I'm running a TPU v3-8 VM on Google. On the VM I installed jax with pip install "jax[tpu]==0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html.

    Unfortunately, I'm getting the message No GPU/TPU found, falling back to CPU. when issuing jax.device_count(). The same holds for pip install jax==0.2.12. Only when I'm using pip install "jax[tpu]>=0.2.16" -f https://storage.googleapis.com/jax-releases/libtpu_releases.html (newest jax version), it works. As far as I can see, for fine-tuning we need jax version 0.2.12 or 0.2.16.

    How can I get it running with these versions?

    opened by Eichhof 0
  • Project dependencies may have API risk issues

    Project dependencies may have API risk issues

    Hi, In mesh-transformer-jax, inappropriate dependency versioning constraints can cause risks.

    Below are the dependencies and version constraints that the project is using

    numpy~=1.19.5
    tqdm>=4.45.0
    wandb>=0.11.2
    einops~=0.3.0
    requests~=2.25.1
    fabric~=2.6.0
    optax==0.0.9
    dm-haiku==0.0.5
    git+https://github.com/EleutherAI/lm-evaluation-harness/
    ray[default]==1.4.1
    jax~=0.2.12
    Flask~=1.1.2
    cloudpickle~=1.3.0
    tensorflow-cpu~=2.6.0
    google-cloud-storage~=1.36.2
    transformers
    smart_open[gcs]
    func_timeout
    ftfy
    fastapi
    uvicorn
    lm_dataformat
    pathy
    

    The version constraint == will introduce the risk of dependency conflicts because the scope of dependencies is too strict. The version constraint No Upper Bound and * will introduce the risk of the missing API Error because the latest version of the dependencies may remove some APIs.

    After further analysis, in this project, The version constraint of dependency tqdm can be changed to >=4.36.0,<=4.64.0. The version constraint of dependency wandb can be changed to >=0.5.16,<=0.9.7. The version constraint of dependency dm-haiku can be changed to >=0.0.1,<=0.0.8. The version constraint of dependency google-cloud-storage can be changed to >=1.17.0,<=2.4.0. The version constraint of dependency fastapi can be changed to >=0.1.2,<=0.78.0.

    The above modification suggestions can reduce the dependency conflicts as much as possible, and introduce the latest version as much as possible without calling Error in the projects.

    The invocation of the current project includes all the following methods.

    The calling methods from the tqdm
    tqdm.tqdm.set_postfix
    tqdm.tqdm.update
    tqdm.tqdm
    
    The calling methods from the wandb
    wandb.init
    wandb.log
    
    The calling methods from the dm-haiku
    haiku.PRNGSequence.take
    haiku.without_apply_rng
    haiku.experimental.optimize_rng_use
    haiku.data_structures.tree_size
    
    The calling methods from the google-cloud-storage
    google.cloud.storage.Client.list_blobs
    
    The calling methods from the fastapi
    fastapi.FastAPI.post
    fastapi.FastAPI
    
    The calling methods from the all methods
    optax.chain
    rotate_every_two
    mesh_transformer.checkpoint.write_ckpt
    x.reshape
    jax.eval_shape
    flask.Flask
    multiprocessing.pool.ThreadPool
    haiku.Flatten
    get_bearer
    jax.numpy.exp
    queue.Queue
    mesh_transformer.TPU_cluster.TPUCluster
    mesh_transformer.layers.TransformerLayerShardV2
    json.load
    time.sleep
    n.run.remote
    flask.make_response
    config.config.TransformerLayerShardV2.get_init_decode_state
    super
    self.proj.loss.hk.remat
    self.move_weights_pjit
    shard_strategy.count
    vocab.example_shape.key.next.jax.random.uniform.astype
    numpy.sum
    jax.experimental.pjit.pjit
    val_loss.np.array.mean.append
    eval_apply_fn
    self._relative_position_bucket
    f.read
    to_id
    jax.lax.dot_general
    ops.get_gptj_model
    jax.nn.softmax
    numpy.logical_and.tolist
    jax.numpy.stack
    self.queue.put
    mesh_transformer.layers.ProjectionShard
    jax.lax.sort_key_val
    self.config.get
    response.headers.add
    l.decode_once
    jax.local_device_count
    lm_eval.evaluator.evaluate.items
    rotate_every_two_v2
    threading.Lock
    jax.numpy.linalg.norm
    x.jnp.transpose.reshape
    argparse.ArgumentParser.add_argument
    mesh_transformer.layers.TransformerLayerShard
    ckpt_step.str.meta.get
    p.map
    val_set.reset
    mesh_transformer.layers.Projection
    projection_apply_fn
    jax.numpy.array_equal
    tqdm.tqdm.update
    numpy.zeros_like
    qids.append
    jax.tree_unflatten
    tensorflow.data.TFRecordDataset
    param_init_fn
    self.rpe
    _unshard
    self.queue_ids.get
    compile_model
    mesh_transformer.util.head_print
    n.eval.remote
    jax.numpy.log
    ops.get_gptj_model.add_to_queue
    f_psum
    mesh_transformer.build_model.build_model.eval
    haiku.transform
    jax.random.uniform
    jax.host_count
    haiku.get_parameter
    numpy.array
    self.convert_requests
    n.train.remote
    jax.numpy.var
    config.config.TransformerLayerShardV2
    grad_norm.np.array.mean
    x.self.k.reshape
    batch_items.append
    x.reshape.reshape
    getattr
    tfrecord_loader.TFRecordNewInputs.get_state
    self.proj
    self.output_q.get
    threading.Thread
    data.items
    json.load.append
    numpy.prod
    jax.lax.stop_gradient
    x.numpy
    node.load_ckpt.remote
    input
    numpy.sqrt
    numpy.where
    jax.numpy.sort
    min
    queue.Queue.put
    transformers.GPT2TokenizerFast.from_pretrained.decode
    fastapi.FastAPI.on_event
    self.tokenizer.encode
    numpy.arange
    self.input
    jax.tree_map
    f_psum.defvjp
    _corsify_actual_response
    _unshard.append
    global_norm
    self.eval
    config.get
    tree_flatten_with_names
    numpy.empty
    states.append
    self.train_pjit
    flask.jsonify
    state.items
    is_leaf
    json.load.items
    jax.numpy.transpose
    jax.numpy.array
    self.input_q.get
    given_length.astype
    numpy.log
    jax.random.split
    repr
    ClipByGlobalNormState
    exit
    int
    i.decode
    apply_fns
    new_states.append
    iter
    all_array_equal
    z_loss.sum_exp_logits.jnp.log.jnp.square.mean
    GPTJ
    self.to_data
    l.hk.remat
    read_remote.remote
    mesh_transformer.transformer_shard.CausalTransformer.generate
    getnorm
    requests.post
    mesh_transformer.transformer_shard.CausalTransformer
    x.reshape.astype
    batch_flattened.append
    wandb.log
    loss.np.array.mean
    functools.partial
    jax.lax.pmax
    jax.lax.rsqrt
    jax.lax.broadcasted_iota
    tokenizer
    self.train_xmap
    generate_fn
    numpy.finfo
    jax.nn.one_hot
    str
    self.network_builder.generate
    numpy.tril
    self.ff
    mesh_transformer.transformer_shard.CausalTransformer.eval
    payloads.QueueResponse
    _build_cors_prelight_response
    mesh_transformer.layers.EmbeddingShard
    self.proj.max
    attention_vec.reshape.reshape
    ray.get.append
    google.cloud.storage.Client
    jax.random.categorical
    itertools.cycle
    numpy.zeros
    embed_apply_fn
    x.self.q.reshape
    self.self_attn
    x.jnp.zeros_like.astype
    itertools.zip_longest
    transformers.GPT2TokenizerFast.from_pretrained.encode
    jax.lax.all_gather
    unstacked.append
    jax.tree_multimap
    jax.experimental.maps.ResourceEnv
    g_psum
    jax.numpy.sqrt
    jax.lax.pmean
    train_step
    jax.lax.axis_index
    loss.append
    mesh_transformer.util.additive_weight_decay
    output.append
    mesh_transformer.util.clip_by_global_norm
    qid.self.queue_ids.get
    numpy.array.mean
    NotImplementedError
    ValueError
    get_project
    self.dense_proj_o
    self.prepare_item.get
    ops.get_gptj_model.load_model
    tfrecord_loader.TFRecordNewInputs
    json.dumps
    jax.experimental.PartitionSpec
    mesh_transformer.util.global_norm
    numpy.ones
    Exception
    zip
    mesh_transformer.util.to_f32.to_bf16.early_cast
    id
    process_request
    next
    jax.experimental.maps.Mesh
    file.prefetch.apply
    init_decode_apply
    file_index.dir.np.load.keys
    self.input_proj
    jax.numpy.argsort
    lm_eval.tasks.get_task_dict
    batch_items.items
    mesh_transformer.build_model.build_model.move
    last.astype
    app.run.threading.Thread.start
    fabric.Connection
    config.EmbeddingShardV2
    requests.delete.json
    self.generate_xmap
    ftfy.fix_text
    sum
    jax.numpy.multiply
    timer
    haiku.data_structures.tree_size
    self.init_pjit
    config.config.TransformerLayerShardV2.decode_once
    CausalTransformerShard.generate_initial
    infer
    CausalTransformerShard
    mesh_transformer.util.to_f32.to_bf16.bf16_optimizer
    jax.host_id
    self.reset
    file.prefetch.prefetch
    config.Projection.loss
    open
    apply_rotary_pos_emb
    nucleaus_filter
    jax.random.PRNGKey
    config.Projection
    json.dump
    p.imap
    self.output_proj
    val_grad_fn
    mesh_transformer.build_model.build_model.train
    mesh_transformer.transformer_shard.CausalTransformer.move_xmap
    io.BytesIO
    x.x.all
    x.self.v.reshape
    self.network_builder.train
    val_set.get_samples
    optax.scale
    all_top_p.append
    setuptools.find_packages
    self.output
    iter_decode_apply
    jax.experimental.pjit.with_sharding_constraint
    smart_open.open
    random.randint
    residual.hk.remat
    pytree.items
    tasks.util.shrink_seq
    pad_amount.tokens.np.pad.astype
    decode
    tensorflow.io.VarLenFeature
    self.network_builder.write_ckpt
    fastapi.FastAPI
    node.write_ckpt.remote
    qid.self.queue_ids.put
    last_loss.np.array.mean
    all_ctx.append
    check_tpu
    conn.run
    list
    func_timeout.func_set_timeout
    numpy.maximum.astype
    jax.numpy.concatenate
    conn.put
    tasks.eval_harness.EvalHarnessAdaptor
    self.infer_batch
    jax.lax.scan
    delete_tpu
    ops.get_gptj_model.start_background
    argparse.ArgumentParser
    lm_eval.evaluator.evaluate
    fastapi.FastAPI.post
    tensorflow.sparse.to_dense
    threading.Thread.start
    numpy.minimum
    numpy.sum.tolist
    jax.value_and_grad
    mesh_transformer.train_actor.NetworkRunner.options
    range
    new_vals.append
    reshard.all
    optax.apply_updates
    split
    self.network.generate
    tensorflow.sparse.reorder
    self.dense_proj
    optax.scale_by_schedule
    ReplicatedLayerNorm
    self.q
    jax.numpy.where
    all_temp.append
    google.cloud.storage.Client.list_blobs
    f_pmean.defvjp
    eval_step
    optax.GradientTransformation
    jax.lax.psum
    parallel_write
    init_fns
    blob.delete
    params.append
    eval_loss_fn
    tasks.EvalHarnessAdaptor
    projection_init_fn
    jax.tree_leaves
    reshard
    params.get
    res.ray.get.i.i.np.array.mean
    numpy.savez
    wandb.init
    grouper
    numpy.max
    compression.i.tf.data.TFRecordDataset.map
    optax.scale_by_adam
    haiku.PRNGSequence
    jax.numpy.zeros_like
    time.time
    logging.getLogger
    self.embed.hk.remat
    multiprocessing.Pool
    x.reshape.append
    jax.numpy.square
    last_loss.append
    numpy.pad
    mesh_transformer.build_model.build_model
    mesh_transformer.transformer_shard.CausalTransformer.train
    self.o
    mesh_transformer.checkpoint.load_ckpt_v2
    self.network_builder
    self.glu
    jax.experimental.maps.mesh
    glob.glob
    ray.init
    tensorflow.io.FixedLenFeature
    self.tokenizer.add_special_tokens
    self.norm
    sch
    self.pool.imap
    reshaped.reshape.reshape
    numpy.cos
    jax.tree_flatten
    tqdm.tqdm
    step.step.all
    self.sample_once
    self.map_fn
    batch.append
    noise_scale_stats.update
    ray_tpu.wait_til
    optimizer.update
    super.__init__
    read_sharded_v2
    val_set.sample_once
    max
    logging.getLogger.debug
    save
    queue.Queue.get
    jax.devices.np.array.reshape
    self.v
    haiku.remat
    self.qvk_proj
    fastapi.FastAPI.add_middleware
    self.network_builder.load_ckpt
    haiku.initializers.TruncatedNormal
    all_length.append
    haiku.experimental.optimize_rng_use
    json.load.pop
    temp.logits.key.jax.random.categorical.astype
    mesh_transformer.util.to_f32.to_bf16.config
    val_sets.values
    numpy.sin
    jax.numpy.zeros
    tasks.util.sample_batch
    requests.delete
    self.eval_xmap
    haiku.Linear
    self.network_builder.eval
    tfrecord_loader.TFRecordNewInputs.get_samples
    optax.additive_weight_decay
    task_res.items
    jax.numpy.zeros.block_until_ready
    haiku.LayerNorm
    self.used.append
    jax.numpy.sum
    fixed_pos_embedding
    self.input_q.put
    setuptools.setup
    jax.numpy.arange
    jax.numpy.asarray
    apply_rotary_pos_emb_v2
    warnings.filterwarnings
    functools.lru_cache
    softmax_sample
    format
    traceback.print_exc
    numpy.dtype
    transformer_init_fn
    n.get_params.remote
    v.attention_weights.jnp.einsum.reshape
    payloads.CompletionResponse
    itertools.chain
    jax.numpy.exp.sum
    mesh_transformer.build_model.build_model.save
    ray.is_initialized
    aux.get
    reshard.reshape
    jax.devices
    ray.remote
    self.prepare_item
    scheduler
    mesh_transformer.util.maybe_shard
    ray_tpu.get_connection
    jax.numpy.clip
    ray.get
    node.move_params.remote
    json.load.get
    divmod
    p.imap_unordered
    jax.numpy.einsum
    grad_norm_micro.np.array.mean
    CausalTransformerShard.loss
    gpt3_schedule
    index_weights
    parallel_read
    fix_dtype
    jax.lax.pmean.mean
    subprocess.check_output.decode
    optax.AdditiveWeightDecayState
    numpy.load
    jax.numpy.mean
    numpy.concatenate
    TFRecordWIT
    n.generate.remote
    jax.experimental.maps.xmap
    outputs.append
    index_fname.open.read
    float
    CausalTransformerShard.generate_once
    ray_tpu.create_tpu
    val_sets.items
    mesh_transformer.checkpoint.read_ckpt
    bcast_iota.Ellipsis.jnp.newaxis.rp_bucket.jnp.array.astype
    self.nodes.append
    mesh_transformer.util.g_psum
    self.dim_per_head.np.sqrt.astype
    transformers.GPT2TokenizerFast.from_pretrained
    samples.append
    dp.mp.key.next.jax.random.uniform.astype
    tree_leaves_with_names
    subprocess.check_output
    get_inital
    queue.Queue.qsize
    conn.sudo
    q.put
    transformer_apply_fn
    jax.numpy.cumsum
    l.get_init_decode_state
    mesh_transformer.util.gpt3_schedule
    self.tpu.eval
    optimizer.init
    haiku.next_rng_key
    jax.numpy.split
    parse_args
    all
    print
    ops.get_gptj_model.wait_for_queue
    val_loss.np.array.mean
    argparse.ArgumentParser.parse_args
    self.queue.get
    self.embed
    RMSNorm
    mesh_transformer.layers.RelativePositionEmbs
    numpy.stack
    self.head_split
    requests.get
    g_psum.defvjp
    process_init
    self.network_builder.move_xmap
    mesh_transformer.checkpoint.write_ckpt_v2
    logging.getLogger.info
    map
    filter
    jax.numpy.broadcast_to
    self.eval_pjit
    uvicorn.run
    einops.repeat
    mesh_transformer.train_actor.NetworkRunner.options.remote
    jax.numpy.reshape
    self.d_head.np.sqrt.astype
    len
    self.get_samples
    gc.collect
    mesh_transformer.util.to_f32
    self.norm.reshape
    TFRecordWIT.sample_once
    haiku.PRNGSequence.take
    einops.rearrange
    sampler
    float.update
    index_fname.open.read.splitlines
    numpy.einsum
    out.reshape.reshape
    haiku.without_apply_rng
    tensorflow.io.parse_single_example
    self.output_q.put
    all_tokenized.append
    shrink_seq
    self.transformer_layers.append
    network.state.count.item
    bool
    embed_init_fn
    self.tokenizer.decode
    RuntimeError
    all_q.append
    multiprocessing.set_start_method
    haiku.initializers.Constant
    tensorflow.data.experimental.dense_to_ragged_batch
    tensorflow.cast
    jax.nn.gelu
    mesh_transformer.build_model.build_model.load
    jax.device_put
    max_exact.num_buckets.max_exact.max_distance.np.log.np.float32.np.finfo.eps.max_exact.np.float32.n.astype.np.log.astype
    mesh_transformer.util.to_bf16
    convert_fn
    max_lengths.append
    tqdm.tqdm.set_postfix
    subprocess.check_output.decode.strip
    isinstance
    numpy.array_split
    numpy.logical_and
    jax.numpy.cos
    mesh_transformer.util.f_psum
    enumerate
    flask.Flask.route
    mesh_transformer.layers.EmbeddingShardV2
    self.init_xmap
    seq.np.zeros.astype
    os.path.expanduser
    jax.device_count
    contexts.append
    self.k
    numpy.maximum
    

    @developer Could please help me check this issue? May I pull a request to fix it? Thank you very much.

    opened by PyDeps 0
Owner
Ben Wang
Ben Wang
A port of muP to JAX/Haiku

MUP for Haiku This is a (very preliminary) port of Yang and Hu et al.'s μP repo to Haiku and JAX. It's not feature complete, and I'm very open to sugg

null 18 Dec 30, 2022
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
Symbolic Parallel Adaptive Importance Sampling for Probabilistic Program Analysis in JAX

SYMPAIS: Symbolic Parallel Adaptive Importance Sampling for Probabilistic Program Analysis Overview | Installation | Documentation | Examples | Notebo

Yicheng Luo 4 Sep 13, 2022
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022
🤗 Transformers: State-of-the-art Natural Language Processing for Pytorch, TensorFlow, and JAX.

English | 简体中文 | 繁體中文 State-of-the-art Natural Language Processing for Jax, PyTorch and TensorFlow ?? Transformers provides thousands of pretrained mo

Hugging Face 77.2k Jan 2, 2023
Easy Parallel Library (EPL) is a general and efficient deep learning framework for distributed model training.

English | 简体中文 Easy Parallel Library Overview Easy Parallel Library (EPL) is a general and efficient library for distributed model training. Usability

Alibaba 185 Dec 21, 2022
Implementation of Kaneko et al.'s MaskCycleGAN-VC model for non-parallel voice conversion.

MaskCycleGAN-VC Unofficial PyTorch implementation of Kaneko et al.'s MaskCycleGAN-VC (2021) for non-parallel voice conversion. MaskCycleGAN-VC is the

null 86 Dec 25, 2022
Pytorch Implementation of Google's Parallel Tacotron 2: A Non-Autoregressive Neural TTS Model with Differentiable Duration Modeling

Parallel Tacotron2 Pytorch Implementation of Google's Parallel Tacotron 2: A Non-Autoregressive Neural TTS Model with Differentiable Duration Modeling

Keon Lee 170 Dec 27, 2022
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"

Optimal Model Design for Reinforcement Learning This repository contains JAX code for the paper Control-Oriented Model-Based Reinforcement Learning wi

Evgenii Nikishin 43 Sep 28, 2022
Implementation of FitVid video prediction model in JAX/Flax.

FitVid Video Prediction Model Implementation of FitVid video prediction model in JAX/Flax. If you find this code useful, please cite it in your paper:

Google Research 62 Nov 25, 2022
Multivariate Time Series Forecasting with efficient Transformers. Code for the paper "Long-Range Transformers for Dynamic Spatiotemporal Forecasting."

Spacetimeformer Multivariate Forecasting This repository contains the code for the paper, "Long-Range Transformers for Dynamic Spatiotemporal Forecast

QData 440 Jan 2, 2023
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.

Rockpool Rockpool is a Python package for developing signal processing applications with spiking neural networks. Rockpool allows you to build network

SynSense 21 Dec 14, 2022
Code and data for ACL2021 paper Cross-Lingual Abstractive Summarization with Limited Parallel Resources.

Multi-Task Framework for Cross-Lingual Abstractive Summarization (MCLAS) The code for ACL2021 paper Cross-Lingual Abstractive Summarization with Limit

Yu Bai 43 Nov 7, 2022
Code and data for ACL2021 paper Cross-Lingual Abstractive Summarization with Limited Parallel Resources.

Multi-Task Framework for Cross-Lingual Abstractive Summarization (MCLAS) The code for ACL2021 paper Cross-Lingual Abstractive Summarization with Limit

Yu Bai 43 Nov 7, 2022
Parallel and High-Fidelity Text-to-Lip Generation; AAAI 2022 ; Official code

Parallel and High-Fidelity Text-to-Lip Generation This repository is the official PyTorch implementation of our AAAI-2022 paper, in which we propose P

Zhying 77 Dec 21, 2022
In this project we investigate the performance of the SetCon model on realistic video footage. Therefore, we implemented the model in PyTorch and tested the model on two example videos.

Contrastive Learning of Object Representations Supervisor: Prof. Dr. Gemma Roig Institutions: Goethe University CVAI - Computational Vision & Artifici

Dirk Neuhäuser 6 Dec 8, 2022
Step by Step on how to create an vision recognition model using LOBE.ai, export the model and run the model in an Azure Function

Step by Step on how to create an vision recognition model using LOBE.ai, export the model and run the model in an Azure Function

El Bruno 3 Mar 30, 2022
Deep learning operations reinvented (for pytorch, tensorflow, jax and others)

This video in better quality. einops Flexible and powerful tensor operations for readable and reliable code. Supports numpy, pytorch, tensorflow, and

Alex Rogozhnikov 6.2k Jan 1, 2023