A modular, research-friendly framework for high-performance and inference of sequence models at many scales

Related tags

Deep Learning t5x
Overview

T5X

T5X is a modular, composable, research-friendly framework for high-performance, configurable, self-service training, evaluation, and inference of sequence models (starting with language) at many scales.

It is essentially a new and improved implementation of the T5 codebase (based on Mesh TensorFlow) in JAX and Flax.

Installation

Note that all the commands in this document should be run in the commandline of the TPU VM instance unless otherwise stated.

  1. Follow the instructions to set up a Google Cloud Platform (GCP) account and enable the Cloud TPU API.

    Note: While T5X works with GPU as well, we haven't heavily tested the GPU usage.

  2. Create a Cloud TPU VM instance following this instruction. We recommend that you develop your workflow in a single v3-8 TPU (i.e., --accelerator-type=v3-8) and scale up to pod slices once the pipeline is ready. In this README, we focus on using a single v3-8 TPU. See here to learn more about TPU architectures.

  3. With Cloud TPU VMs, you ssh directly into the host machine of the TPU VM. You can install packages, run your code run, etc. in the host machine. Once the TPU instance is created, ssh into it with

    gcloud alpha compute tpus tpu-vm ssh ${TPU_NAME} --zone=${ZONE}

    where TPU_NAME and ZONE are the name and the zone used in step 2.

  4. Install T5X and the dependencies. JAX and Gin-config need to be installed from the source.

    git clone --branch=main https://github.com/google-research/t5x
    cd t5x
    
    python3 -m pip install -e . -f \
      https://storage.googleapis.com/jax-releases/libtpu_releases.html
    
  5. Create toogle Cloud Storage (GCS) bucket to store the dataset and model checkpoints. To create a GCS bucket, see these instructions.

Example: English to German translation

As a running example, we use the WMT14 En-De translation. The raw dataset is available in TensorFlow Datasets as "wmt_t2t_translate".

T5 casts the translation task such as the following

{'en': 'That is good.', 'de': 'Das ist gut.'}

to the form called "text-to-text":

{'inputs': 'translate English to German: That is good.', 'targets': 'Das ist gut.'}

This formulation allows many different classes of language tasks to be expressed in a uniform manner and a single encoder-decoder architecture can handle them without any task-specific parameters. For more detail, refer to the T5 paper (Raffel et al. 2019).

For a scalable data pipeline and an evaluation framework, we use SeqIO, which was factored out of the T5 library. A seqio.Task packages together the raw dataset, vocabulary, preprocessing such as tokenization and evaluation metrics such as BLEU and provides a tf.data instance.

The T5 library provides a number of seqio.Tasks that were used in the T5 paper. In this example, we use wmt_t2t_ende_v003.

Training

To run a training job, we use the t5x/train.py script.

# Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
MODEL_DIR="..."

# Data dir to save the processed dataset in "gs://data_dir" format.
TFDS_DATA_DIR="..."
T5X_DIR="..."  # directory where the T5X repo is cloned.

python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/t5_1_1_base_wmt_from_scratch.gin" \
  --gin.MODEL_DIR="'${MODEL_DIR}'" \
  --tfds_data_dir=${TFDS_DATA_DIR}

The configuration for this training run is defined in the Gin file t5_1_1_base_wmt_from_scratch.gin. Gin-config is a library to handle configurations based on dependency injection. Among many benefits, Gin allows users to pass custom components such as a custom model to the T5X library without having to modify the core library. The custom components section shows how this is done.

While the core library is independent of Gin, it is central to the examples we provide. Therefore, we provide a short introduction to Gin in the context of T5X. All the configurations are written to a file "config.gin" in MODEL_DIR. This makes debugging as well as reproducing the experiment much easier.

In addition to the config.json, model-info.txt file summarizes the model parameters (shape, names of the axes, partitioning info) as well as the optimizer states.

TensorBoard

To monitor the training in TensorBoard, it is much easier (due to authentification issues) to launch the TensorBoard on your own machine and not in the TPU VM. So in the commandline where you ssh'ed into the TPU VM, launch the TensorBoard with the logdir pointing to the MODEL_DIR.

# NB: run this on your machine not TPU VM!
MODEL_DIR="..."  # Copy from the TPU VM.
tensorboard --logdir=${MODEL_DIR}

Or you can launch the TensorBoard inside a Colab. In a Colab cell, run

from google.colab import auth
auth.authenticate_user()

to authorize the Colab to access the GCS bucket and launch the TensorBoard.

%load_ext tensorboard
model_dir = "..."  # Copy from the TPU VM.
%tensorboard --logdir=model_dir

TODO(hwchung): Add tfds preparation instruction

Fine-tuning

We can leverage the benefits of self-supervised pre-training by initializing from one of our pre-trained models. Here we use the T5.1.1 Base checkpoint.

# Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
MODEL_DIR="..."

# Data dir to save the processed dataset in "gs://data_dir" format.
TFDS_DATA_DIR="..."
T5X_DIR="..."  # directory where the T5X repo is cloned.

python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/t5_1_1_base_wmt_finetune.gin" \
  --gin.MODEL_DIR="'${MODEL_DIR}'" \
  --tfds_data_dir=${TFDS_DATA_DIR}

Note: when supplying a string, dict, list, tuple value, or a bash variable via a flag, you must put it in quotes. In the case of strings, it requires "triple quotes" ("' '" ). For example: --gin.utils.DatasetConfig.split="'validation'" or --gin.MODEL_DIR="'${MODEL_DIR}'".

Gin makes it easy to change a number of configurations. For example, you can change the partitioning.ModelBasedPjitPartitioner.num_partitions (overriding the value in t5_1_1_base_wmt_from_scratch.gin) to chanage the parallelism strategy and pass it as a commandline arg.

--gin.partitioning.ModelBasedPjitPartitioner.num_partitions=8

Evaluation

To run the offline (i.e. without training) evaluation, you can use t5x/eval.py script.

EVAL_OUTPUT_DIR="..."  # directory to write eval output
T5X_DIR="..."  # directory where the t5x is cloned, e.g., ${HOME}"/t5x".
TFDS_DATA_DIR="..."
CHECKPOINT_PATH="..."

python3 ${T5X_DIR}/t5x/eval.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/t5_1_1_base_wmt_eval.gin" \
  --gin.CHECKPOINT_PATH="'${CHECKPOINT_PATH}'" \
  --gin.EVAL_OUTPUT_DIR="'${EVAL_OUTPUT_DIR}'" \
  --tfds_data_dir=${TFDS_DATA_DIR}

Inference

To run inference, you can use t5x/infer.py script. Here we use the same seqio.Task, but for inference we do not use the targets features other than logging them alongside the prediction in a JSON file.

INFER_OUTPUT_DIR="..."  # directory to write infer output
T5X_DIR="..."  # directory where the t5x is cloned, e.g., ${HOME}"/t5x".
TFDS_DATA_DIR="..."
CHECKPOINT_PATH="..."

python3 ${T5X_DIR}/t5x/infer.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/t5_1_1_base_wmt_infer.gin" \
  --gin.CHECKPOINT_PATH="'${CHECKPOINT_PATH}'" \
  --gin.INFER_OUTPUT_DIR="'${INFER_OUTPUT_DIR}'" \
  --tfds_data_dir=${TFDS_DATA_DIR}

Custom components

The translation example uses the encoder-decoder model that T5X provides as well as the dataset from the T5 library. This section shows how you can use your own dataset and a model and pass via Gin.

Example: custom dataset in a user directory

For this example, we have the following directory structure with ${HOME}/dir1/user_dir representing a user directory with custom components.

${HOME}
└── dir1
    └── user_dir
        ├── t5_1_1_base_de_en.gin
        └── tasks.py

As an example, let's define a new dataset. Here we use the same Translation dataset but we define the translation task in the opposite direction, i.e., German to English intead of English to German. We define this task in tasks.py

# ${HOME}/dir1/user_dir/tasks.py

import functools
import seqio
import tensorflow_datasets as tfds
from t5.evaluation import metrics
from t5.data import preprocessors

vocabulary = seqio.SentencePieceVocabulary(
    'gs://t5-data/vocabs/cc_all.32000/sentencepiece.model', extra_ids=100)
output_features = {
    'inputs': seqio.Feature(vocabulary=vocabulary),
    'targets': seqio.Feature(vocabulary=vocabulary)
}

seqio.TaskRegistry.add(
    'wmt_t2t_de_en_v003',
    source=seqio.TfdsDataSource(tfds_name='wmt_t2t_translate/de-en:1.0.0'),
    preprocessors=[
        functools.partial(
            preprocessors.translate,
            source_language='de', target_language='en'),
        seqio.preprocessors.tokenize,
        seqio.CacheDatasetPlaceholder(),
        seqio.preprocessors.append_eos_after_trim,
    ],
    metric_fns=[metrics.bleu],
    output_features=output_features)

In the Gin file, most of the settings are equivalent to those used in the En->De example. So we include the Gin file from that example. To use "wmt_t2t_de_en_v003" task we just defined, we need to import the task module "tasks.py". Note that we use a relative path defined with respect to the user directory. This will be specified as a flag.

# ${HOME}/dir1/user_dir/t5_1_1_base_de_en.gin
from __gin__ import dynamic_registration
import tasks  # This imports the task defined in dir1/user_dir/tasks.py.

include "t5x-tmp/t5x/examples/t5/t5_1_1/examples/t5_1_1_base_wmt_from_scratch.gin"
MIXTURE_OR_TASK_NAME = "wmt_t2t_de_en_v003"

Finally, we launch training passing the user directory as a flag gin_search_paths such that the Gin file and python modules can be specified with relative paths.

PROJECT_DIR=${HOME}"/dir1/user_dir"
T5X_DIR="..."  # directory where the t5x is cloned.
TFDS_DATA_DIR="..."
MODEL_DIR="..."
export PYTHONPATH=${PROJECT_DIR}

python3 ${T5X_DIR}/t5x/train.py \
  --gin_search_paths=${PROJECT_DIR} \
  --gin_file="t5_1_1_base_de_en.gin" \
  --gin.MODEL_DIR="'${MODEL_DIR}'" \
  --tfds_data_dir=${TFDS_DATA_DIR}

Released Checkpoints

We release the checkpoints for the T5.1.1 models in a native T5X format.

These are converted from the public Mesh TensorFlow checkpoints .

Compatibility with the Mesh TensorFlow checkpoints

The Mesh TensorFlow checkpoints trained using the T5 library can be directly loaded into T5X. For example, we can rerun the fine-tuning example initializing from the MTF checkpoint by changing the INIT_CHECKPOINT Gin macro.

# Model dir to save logs, ckpts, etc. in "gs://model_dir" format.
MODEL_DIR="..."

# Data dir to save the processed dataset in "gs://data_dir" format.
TFDS_DATA_DIR="..."
T5X_DIR="..."  # directory where the T5X repo is cloned.

python3 ${T5X_DIR}/t5x/train.py \
  --gin_file="t5x/examples/t5/t5_1_1/examples/wmt19_ende_from_scratch.gin" \
  --gin.MODEL_DIR="'${MODEL_DIR}'" \
  --gin.MIXTURE_OR_TASK_NAME="'wmt_t2t_ende_v003'" \
  --gin.INIT_CHECKPOINT="'gs://t5-data/pretrained_models/t5.1.1.base/model.ckpt-1000000'" \
  --tfds_data_dir=${TFDS_DATA_DIR}

Note that restoring directly from the Mesh TensorFlow checkpoints can be inefficient if heavy model parallelism is used for large models. This is because each host loads the entire copy of the model first and then keep only the relevant slices dictated by the model parallelism specification. If you have Mesh TensorFlow checkpoints that you run often, we recommend converting the checkpoints to T5X native format using Checkpointer.convert_from_tf_checkpoint.

TODO(hwchung): Add a conversion script.

Note

This is not an officially supported Google product

Comments
  • Seg Fault after saving checkpoints

    Seg Fault after saving checkpoints

    Hi,

    I am getting a seg fault sometimes after the model has saved the checkpoint. It is not every checkpoint and seems to be random which checkpoints it crashes after. I am not sure if it is related to issue #340

    For example, I am running prompt_tuning/scripts/sst2-demo-xxl.sh, and the output is below.

    317 18:14:56.525280 140415323761728 utils.py:138] Saved Numpy Arrays for step 1104000 to gs://nicl/checkpoint_models/sst/full_dataset/prompt-tuning/t5-11b/numpy_checkpoints/checkpoint_1104000
    I0317 18:14:56.604028 140415323761728 checkpoints.py:600] Saving checkpoint for step 1104000 to gs://nicl/checkpoint_models/sst/full_dataset/prompt-tuning/t5-11b/checkpoint_1104000.tmp-1647540896
    I0317 18:14:56.614308 140622481194048 checkpoints.py:600] Saving checkpoint for step 1104000 to gs://nicl/checkpoint_models/sst/full_dataset/prompt-tuning/t5-11b/checkpoint_1104000.tmp-1647540896
    I0317 18:14:56.624289 140590966570048 checkpoints.py:600] Saving checkpoint for step 1104000 to gs://nicl/checkpoint_models/sst/full_dataset/prompt-tuning/t5-11b/checkpoint_1104000.tmp-1647540896
    I0317 18:14:56.653718 140272509271104 checkpoints.py:600] Saving checkpoint for step 1104000 to gs://nicl/checkpoint_models/sst/full_dataset/prompt-tuning/t5-11b/checkpoint_1104000.tmp-1647540896
    Fatal Python error: Segmentation fault
    
    
    Thread 0x00007fdb1dc01700 (most recent call first):
      File "/home/dptam/.local/lib/python3.8/site-packages/jax/interpreters/pxla.py", line 664 in _sda_value
      File "/home/dptam/.local/lib/python3.8/site-packages/jax/_src/device_array.py", line 266 in __array__
      File "/home/dptam/.local/lib/python3.8/site-packages/t5x/checkpoints.py", line 447 in <lambda>
      File "/home/dptam/.local/lib/python3.8/site-packages/t5x/checkpoint_importer.py", line 84 in get
      File "/usr/lib/python3.8/concurrent/futures/thread.py", line 57 in run
      File "/usr/lib/python3.8/concurrent/futures/thread.py", line 80 in _worker
      File "/usr/lib/python3.8/threading.py", line 870 in run
      File "/usr/lib/python3.8/threading.py", line 932 in _bootstrap_inner
      File "/usr/lib/python3.8/threading.py", line 890 in _bootstrap
    
    Thread 0x00007f56809df700 (most recent call first):
      File "/usr/lib/python3.8/concurrent/futures/thread.py", line 78 in _worker
      File "/usr/lib/python3.8/threading.py", line 870 in run
      File "/usr/lib/python3.8/threading.py", line 932 in _bootstrap_inner
      File "/usr/lib/python3.8/threading.py", line 890 in _bootstrap
    
      Thread 0x00007f56c7aad700 (most recent call first):
      File "/usr/lib/python3.8/concurrent/futures/thread.py", line 78 in _worker
      File "/usr/lib/python3.8/threading.py", line 870 in run
      File "/usr/lib/python3.8/threading.py", line 932 in _bootstrap_inner
      File "/usr/lib/python3.8/threading.py", line 890 in _bootstrap
    Thread 0x00007fdde29efc40 (most recent call first):
      File "/home/dptam/.local/lib/python3.8/site-packages/t5x/checkpoints.py", line 693 in _write_array
    https://symbolize.stripped_domain/r/?trace=7fdde2e4203b,7fdde2e420bf,e,5ef27540f,e,26f7c5aff,f,b15f59df&map= 
    E0317 18:14:57.770066  341059 process_state.cc:1062] RAW: Signal 11 raised at PC: 0x7fdde2e4203b while already in FailureSignalHandler!
    E0317 18:14:57.770096  341059 process_state.cc:1065] RAW: tid: 341059 raised new signal
        @                0xf       1440  (unknown)
        @        0x25ed159b0  (unknown)  (unknown)
        @               0x10   76231216  (unknown)
        @        0x261cdc840  (unknown)  (unknown)
        @        0x2dfdd4780  (unknown)  (unknown)
        @        0x5f1f8a120  (unknown)  (unknown)
    https://symbolize.stripped_domain/r/?trace=7fdde301ffd3,7fddd98d57f9,7fdde2e420bf,7,e,25ed159af,f,261cdc83f,2dfdd477f,5f1f8a11f&map=7a511a57244151c993b16b37978e7ed7:7fddcaefd000-7fddd9c3fd50 
    E0317 18:14:57.818885  341068 coredump_hook.cc:365] RAW: Remote crash data gathering hook invoked.
    E0317 18:14:57.818900  341068 coredump_hook.cc:411] RAW: Skipping coredump since rlimit was 0 at process start.
    E0317 18:14:57.818919  341068 client.cc:221] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
    E0317 18:14:57.818922  341068 coredump_hook.cc:473] RAW: Sending fingerprint to remote end.
    E0317 18:14:57.818928  341068 coredump_socket.cc:124] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
    E0317 18:14:57.818933  341068 coredump_hook.cc:477] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
    E0317 18:14:57.818938  341068 coredump_hook.cc:550] RAW: Discarding core.
    prompt_tuning/scripts/sst2-demo-xxl.sh: line 37: 337643 Segmentation fault      (core dumped) python3 -m t5x.train --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" --gin_file="prompt_tuning/configs/models/t5_1_1_xxl_prompt.gin" --gin_file="prompt_tuning/configs/prompts/from_class_labels.gin" --gin_file="prompt_tuning/configs/runs/prompt_finetune.gin" --gin.CLASS_LABELS="['positive', 'negative']" --gin.MODEL_DIR="'${MODEL_DIR}'" --gin.MIXTURE_OR_TASK_NAME="'taskless_glue_sst2_v200_examples'" --gin.MIXTURE_OR_TASK_MODULE="'prompt_tuning.data.glue'" --gin.TASK_FEATURE_LENGTHS="{'inputs': 512, 'targets': 8}" --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" --gin.TRAIN_STEPS="1_212_000" --gin.USE_CACHED_TASKS="False" --gin.BATCH_SIZE="16" --gin.partitioning.PjitPartitioner.model_parallel_submesh="(4, 4, 1, 2)" --tfds_data_dir=${TFDS_DATA_DIR}
    ##### Command execution on worker 3 failed with return code 139. Continuing.
    prompt_tuning/scripts/sst2-demo-xxl.sh: line 37: 334750 Aborted                 (core dumped) python3 -m t5x.train --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" --gin_file="prompt_tuning/configs/models/t5_1_1_xxl_prompt.gin" --gin_file="prompt_tuning/configs/prompts/from_class_labels.gin" --gin_file="prompt_tuning/configs/runs/prompt_finetune.gin" --gin.CLASS_LABELS="['positive', 'negative']" --gin.MODEL_DIR="'${MODEL_DIR}'" --gin.MIXTURE_OR_TASK_NAME="'taskless_glue_sst2_v200_examples'" --gin.MIXTURE_OR_TASK_MODULE="'prompt_tuning.data.glue'" --gin.TASK_FEATURE_LENGTHS="{'inputs': 512, 'targets': 8}" --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" --gin.TRAIN_STEPS="1_212_000" --gin.USE_CACHED_TASKS="False" --gin.BATCH_SIZE="16" --gin.partitioning.PjitPartitioner.model_parallel_submesh="(4, 4, 1, 2)" --tfds_data_dir=${TFDS_DATA_DIR}
    ##### Command execution on worker 1 failed with return code 134. Continuing.
    prompt_tuning/scripts/sst2-demo-xxl.sh: line 37: 335504 Aborted                 (core dumped) python3 -m t5x.train --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" --gin_file="prompt_tuning/configs/models/t5_1_1_xxl_prompt.gin" --gin_file="prompt_tuning/configs/prompts/from_class_labels.gin" --gin_file="prompt_tuning/configs/runs/prompt_finetune.gin" --gin.CLASS_LABELS="['positive', 'negative']" --gin.MODEL_DIR="'${MODEL_DIR}'" --gin.MIXTURE_OR_TASK_NAME="'taskless_glue_sst2_v200_examples'" --gin.MIXTURE_OR_TASK_MODULE="'prompt_tuning.data.glue'" --gin.TASK_FEATURE_LENGTHS="{'inputs': 512, 'targets': 8}" --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" --gin.TRAIN_STEPS="1_212_000" --gin.USE_CACHED_TASKS="False" --gin.BATCH_SIZE="16" --gin.partitioning.PjitPartitioner.model_parallel_submesh="(4, 4, 1, 2)" --tfds_data_dir=${TFDS_DATA_DIR}
    ##### Command execution on worker 0 failed with return code 134. Continuing.
    

    Thanks

    opened by dptam 13
  • Fatal Python error: Segmentation fault, when training t5x-XXL on a TPU Pod v3-32

    Fatal Python error: Segmentation fault, when training t5x-XXL on a TPU Pod v3-32

    Hi,

    I was able to train and infer prompt tuning with t5x-XXL on a TPU Pod v3-32 for my custom task defined from a TSV file, but I am seeing now an error and can't understand it.

    I follow the instructions from prompt tuning to train and infer Prompt on a Pod Slice, except that the last libtpu_release gives an error TPUEmbeddingEngineState_Create not available in this library. so I install the release from February 15, 2022.

    I run the following script

    
    MODEL_DIR=${1:-${MODEL_DIR}}
    TFDS_DATA_DIR=${2:-${TFDS_DATA_DIR}}
    
    if [ -z ${MODEL_DIR} ] || [ -z ${TFDS_DATA_DIR} ]; then
                      echo "usage: ./rec_sys.sh gs://your-bucket/path/to/model_dir gs://your-bucket/path/to/tfds/cache"
                                  exit 1
    fi
    
    T5X_DIR="`python3 -m prompt_tuning.scripts.find_module t5x`/.."
    FLAXFORMER_DIR="`python3 -m prompt_tuning.scripts.find_module flaxformer`/.."
    PROMPT_DIR="`python3 -m prompt_tuning.scripts.find_module prompt_tuning`/.."
    echo "Searching for gin configs in:"
    echo "- ${T5X_DIR}"
    echo "- ${FLAXFORMER_DIR}"
    echo "- ${PROMPT_DIR}"
    echo "============================="
    PRETRAINED_MODEL="gs://t5-data/pretrained_models/t5x/t5_1_1_lm100k_xxl/checkpoint_1100000"
    
    python3 -m t5x.train \
                      --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" \
                      --gin_file="prompt_tuning/configs/models/t5_1_1_xxl_prompt.gin" \
                      --gin_file="prompt_tuning/configs/runs/prompt_finetune.gin" \
                      --gin.MODEL_DIR="'${MODEL_DIR}'" \
                      --gin.BATCH_SIZE="16" \
                      --gin.MIXTURE_OR_TASK_NAME="'yelp'" \
                      --gin.MIXTURE_OR_TASK_MODULE="'task_dir.mytasks'" \
                      --gin.TASK_FEATURE_LENGTHS="{'inputs': 512, 'targets': 8}" \
                      --gin.USE_CACHED_TASKS="False" \
                      --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" \
    		  --gin.partitioning.PjitPartitioner.model_parallel_submesh="(4, 4, 1, 2)" \
                      --gin.TRAIN_STEPS="1_100_010" \
    

    and get the following errors: First:

    
    tensorstore/internal/oauth2/google_auth_provider.cc:163: Credentials file not found. NOT_FOUND: $GOOGLE_APPLICATION_CREDENTIALS is not set or corrupt. 
    
    tensorstore/internal/oauth2/google_auth_provider.cc:168: Credentials file not found. NOT_FOUND: Could not find the credentials file in the standard gcloud location [/home/leojlaugier/.config/gcloud/application_default_credentials.json] 
    
    tensorstore/internal/oauth2/google_auth_provider.cc:203: Running on GCE, using GCE Auth Provider 
    
    Fatal Python error: Segmentation fault 
    
     
    
     
    
    Thread 0x00007f3304539c40 (most recent call first): 
    
      File "/usr/lib/python3.8/selectors.py", line 468 in select 
    
      File "/usr/lib/python3.8/asyncio/base_events.py", line 1823 in _run_once 
    
      File "/usr/lib/python3.8/asyncio/base_events.py", line 570 in run_forever 
    
      File "/usr/lib/python3.8/asyncio/base_events.py", line 603 in run_until_complete 
    
      File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/checkpoints.py", line 160 in _run_future_tree 
    
      File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/checkpoints.py", line 913 in _read_state_from_tensorstore 
    
      File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/checkpoints.py", line 860 in restore 
    
      File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/utils.py", line 455 in _restore_path 
    
      File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/utils.py", line 466 in from_checkpoints 
    
      File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/utils.py", line 507 in from_checkpoint 
    
      File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/utils.py", line 522 in from_checkpoint_or_scratch 
    
      File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/train.py", line 320 in train 
    
      File "/home/leojlaugier/.local/lib/python3.8/site-packages/gin/config.py", line 1582 in gin_wrapper 
    
      File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/train.py", line 623 in _main 
    
      File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/train.py", line 605 in main 
    
      File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 251 in _run_main 
    
      File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 303 in run 
    
      File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/gin_utils.py", line 105 in run 
    
      File "/home/leojlaugier/.local/lib/python3.8/site-packages/t5x/train.py", line 625 in <module> 
    
      File "/usr/lib/python3.8/runpy.py", line 87 in _run_code 
    
      File "/usr/lib/python3.8/runpy.py", line 194 in _run_module_as_main 
    
    https://symbolize.stripped_domain/r/?trace=7f330498e18b,7f330498e20f,6&map= 
    
    *** SIGSEGV (@0x7d100002a60), see gl__________41#s15 received by PID 10848 (TID 12157) on cpu 19; stack trace: *** 
    
    PC: @     0x7f330498e18b  (unknown)  raise 
    
        @     0x7f32fb6ea1fa        992  (unknown) 
    
        @     0x7f330498e210  (unknown)  (unknown) 
    
        @                0x7  (unknown)  (unknown) 
    
    https://symbolize.stripped_domain/r/?trace=7f330498e18b,7f32fb6ea1f9,7f330498e20f,6&map=55976a7e1de583f3a9544af1c86ac940:7f32ed01c000-7f32fba50d80 
    
    E0310 16:51:25.580514   12157 coredump_hook.cc:365] RAW: Remote crash data gathering hook invoked. 
    
    E0310 16:51:25.580525   12157 coredump_hook.cc:411] RAW: Skipping coredump since rlimit was 0 at process start. 
    
    E0310 16:51:25.580535   12157 client.cc:221] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec. 
    
    E0310 16:51:25.580557   12157 coredump_hook.cc:473] RAW: Sending fingerprint to remote end. 
    
    E0310 16:51:25.580562   12157 coredump_socket.cc:124] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket 
    
    E0310 16:51:25.580565   12157 coredump_hook.cc:477] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running? 
    
    E0310 16:51:25.580569   12157 coredump_hook.cc:550] RAW: Discarding core. 
    
    

    And later

    ^[[6~^[[6~I0310 16:54:01.224571 140510105828416 train.py:456] Epoch 1100 of 1101
    I0310 16:54:01.224764 140510105828416 train.py:462] BEGIN Train loop.
    I0310 16:54:01.224818 140510105828416 train.py:467] Training for 10 steps.
    I0310 16:54:01.226046 140497868457728 logging_writer.py:48] [1100000] collection=train timing/compilation_seconds=87.301567
    I0310 16:54:01.230673 140510105828416 trainer.py:491] Training: step 1100000
    I0310 16:54:01.635557 140510105828416 train.py:490] END Train loop.
    ./train_yelp_xxl.sh: line 34: 12024 Aborted                 (core dumped) python3 -m t5x.train --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" --gin_file="prompt_tuning/configs/models/t5_1_1_xxl_prompt.gin" --gin_file="prompt_tuning/configs/runs/prompt_finetune.gin" --gin.MODEL_DIR="'${MODEL_DIR}'" --gin.partitioning.PjitPartitioner.model_parallel_submesh="(4, 4, 1, 2)" --gin.BATCH_SIZE="16" --gin.MIXTURE_OR_TASK_NAME="'yelp'" --gin.MIXTURE_OR_TASK_MODULE="'task_dir.mytasks'" --gin.TASK_FEATURE_LENGTHS="{'inputs': 512, 'targets': 8}" --gin.USE_CACHED_TASKS="False" --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" --gin.TRAIN_STEPS="1_100_010"
    ##### Command execution on worker 1 failed with return code 134. Continuing.
    ./train_yelp_xxl.sh: line 34: 10848 Aborted                 (core dumped) python3 -m t5x.train --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" --gin_file="prompt_tuning/configs/models/t5_1_1_xxl_prompt.gin" --gin_file="prompt_tuning/configs/runs/prompt_finetune.gin" --gin.MODEL_DIR="'${MODEL_DIR}'" --gin.partitioning.PjitPartitioner.model_parallel_submesh="(4, 4, 1, 2)" --gin.BATCH_SIZE="16" --gin.MIXTURE_OR_TASK_NAME="'yelp'" --gin.MIXTURE_OR_TASK_MODULE="'task_dir.mytasks'" --gin.TASK_FEATURE_LENGTHS="{'inputs': 512, 'targets': 8}" --gin.USE_CACHED_TASKS="False" --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" --gin.TRAIN_STEPS="1_100_010"
    ##### Command execution on worker 0 failed with return code 134. Continuing.
    ./train_yelp_xxl.sh: line 34: 11607 Aborted                 (core dumped) python3 -m t5x.train --gin_search_paths="${T5X_DIR},${FLAXFORMER_DIR},${PROMPT_DIR}" --gin_file="prompt_tuning/configs/models/t5_1_1_xxl_prompt.gin" --gin_file="prompt_tuning/configs/runs/prompt_finetune.gin" --gin.MODEL_DIR="'${MODEL_DIR}'" --gin.partitioning.PjitPartitioner.model_parallel_submesh="(4, 4, 1, 2)" --gin.BATCH_SIZE="16" --gin.MIXTURE_OR_TASK_NAME="'yelp'" --gin.MIXTURE_OR_TASK_MODULE="'task_dir.mytasks'" --gin.TASK_FEATURE_LENGTHS="{'inputs': 512, 'targets': 8}" --gin.USE_CACHED_TASKS="False" --gin.INITIAL_CHECKPOINT_PATH="'${PRETRAINED_MODEL}'" --gin.TRAIN_STEPS="1_100_010"
    ##### Command execution on worker 3 failed with return code 134. Continuing.
    

    Then the run freezes. I might be missing something obvious but I think I haven't changed anything but the data since the last time I was able to train and infer with prompt tuning. Moreover, I was able to train on the same train data but problems arose when I tried to infer. Therefore, I'm asking if you could help me understand the issue.

    Thanks in advance for your time.

    opened by LeoLaugier 11
  • Error loading model from checkpoint on Apple M1

    Error loading model from checkpoint on Apple M1

    I am trying to load longT5 model from checkpoint and getting the following error. Any help is much appreciated.

    `

    RuntimeError Traceback (most recent call last) Input In [9], in <cell line: 1>() ----> 1 t5x_checkpoint = t5x.checkpoints.load_t5x_checkpoint(checkpoint_dir)

    File ~/t5x/t5x/checkpoints.py:1594, in load_t5x_checkpoint(path, step, state_transformation_fns, remap, restore_dtype, lazy_parameters) 1592 if not lazy_parameters: 1593 future_state_dict = jax.tree_map(lambda x: x.get_async(), state_dict) -> 1594 state_dict = _run_future_tree(future_state_dict) 1596 if restore_dtype is not None: 1597 state_dict['target'] = _cast(state_dict['target'], restore_dtype)

    File ~/t5x/t5x/checkpoints.py:167, in _run_future_tree(future_tree) 165 # TODO(adarob): Use asyncio.run in py3.7+. 166 loop = asyncio.get_event_loop() --> 167 leaves = loop.run_until_complete(asyncio.gather(*future_leaves)) 168 return jax.tree_unflatten(treedef, leaves)

    File ~/opt/miniconda3/lib/python3.9/asyncio/base_events.py:623, in BaseEventLoop.run_until_complete(self, future) 612 """Run until the Future is done. 613 614 If the argument is a coroutine, it is wrapped in a Task. (...) 620 Return the Future's result, or raise its exception. 621 """ 622 self._check_closed() --> 623 self._check_running() 625 new_task = not futures.isfuture(future) 626 future = tasks.ensure_future(future, loop=self)

    File ~/opt/miniconda3/lib/python3.9/asyncio/base_events.py:583, in BaseEventLoop._check_running(self) 581 def _check_running(self): 582 if self.is_running(): --> 583 raise RuntimeError('This event loop is already running') 584 if events._get_running_loop() is not None: 585 raise RuntimeError( 586 'Cannot run the event loop while another loop is running')

    RuntimeError: This event loop is already running `

    opened by ibulu 9
  • `num_partitions` does not work for GPU

    `num_partitions` does not work for GPU

    Hi thanks for the great work. I were already carefully read the docs of the partitioning, but I am still confused about how it works and what did the partitioning rules means. I tried to run the pertaining code on a single node with 8-A100 GPU. When I pretrain the T5 with the huggingface trainer and deepspeed Zero-2, it works well. However I tried to run the pretrain with the scripts provided in the examples with

    partitioning.PjitPartitioner:
      num_partitions = 1
      logical_axis_rules= @partitioning.standard_logical_axis_rules()
    
    partitioning.standard_logical_axis_rules:
      activation_partitioning_dims = 2
      parameter_partitioning_dims = 2
    

    ,

    I get the following errors:

    56   │ Traceback (most recent call last):
      57   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/runpy.py", line 196, in _run_module_as_main
      58   │     return _run_code(code, main_globals, None,
      59   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/runpy.py", line 86, in _run_code
      60   │     exec(code, run_globals)
      61   │   File "/mnt/cache/namco/t5x/t5x/train.py", line 659, in <module>
      62   │     gin_utils.run(main)
      63   │   File "/mnt/cache/namco/t5x/t5x/gin_utils.py", line 105, in run
      64   │     app.run(
      65   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/absl/app.py", line 312, in run
      66   │     _run_main(main, args)
      67   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/absl/app.py", line 258, in _run_main
      68   │     sys.exit(main(argv))
      69   │   File "/mnt/cache/namco/t5x/t5x/train.py", line 637, in main
      70   │     _main(argv)
      71   │   File "/mnt/cache/namco/t5x/t5x/train.py", line 657, in _main
      72   │     train_using_gin()
      73   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/gin/config.py", line 1605, in gin_wrapper
      74   │     utils.augment_exception_message_and_reraise(e, err_str)
      75   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
      76   │     raise proxy.with_traceback(exception.__traceback__) from None
      77   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/gin/config.py", line 1582, in gin_wrapper
      78   │     return fn(*new_args, **new_kwargs)
      79   │   File "/mnt/cache/namco/t5x/t5x/train.py", line 321, in train
      80   │     train_state = train_state_initializer.from_checkpoint_or_scratch(
      81   │   File "/mnt/cache/namco/t5x/t5x/utils.py", line 523, in from_checkpoint_or_scratch
      82   │     or self.from_scratch(init_rng))
      83   │   File "/mnt/cache/namco/t5x/t5x/utils.py", line 395, in from_scratch
      84   │     return p_initialize_train_state_fn(init_rng)
      85   │   File "/mnt/cache/namco/t5x/t5x/partitioning.py", line 729, in __call__
      86   │     return self._pjitted_fn(*args)
      87   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/experimental/pjit.py", line 267, in wrapped
      88   │     args_flat, _, params, _, out_tree, _ = infer_params(*args, **kwargs)
      89   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/experimental/pjit.py", line 246, in infer_params
      90   │     jaxpr, canonicalized_out_axis_resources_flat = _pjit_jaxpr(
      91   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/linear_util.py", line 272, in memoized_fun
      92   │     ans = call(fun, *args)
      93   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/experimental/pjit.py", line 411, in _pjit_jaxpr
      94   │     _check_shapes_against_resources("pjit outputs", mesh.is_multi_process, mesh.shape,
      95   │   File "/mnt/cache/namco/anaconda3/envs/jax/lib/python3.10/site-packages/jax/experimental/pjit.py", line 588, in _check_shapes_against_resources
      96   │     raise ValueError(f"One of {what} was given the resource assignment "
      97   │ ValueError: One of pjit outputs was given the resource assignment of PartitionSpec('model', None), which implies that the size of its dimension 0 should be divisib
           │ le by 8, but it is equal to 12
      98   │   In call to configurable 'train' (<function train at 0x7f598523c790>)
    

    Could you please help me to fix this error?

    opened by Namco0816 8
  • Out of RAM when training T5x on WMT14 En-De with Colab's TPU

    Out of RAM when training T5x on WMT14 En-De with Colab's TPU

    Hi,

    I wanted to check if I can train and evaluate T5x Base on Colab's TPU.

    When I try it in this Colab notebook, the session crashes after requesting more RAM than available (12 GB).

    The last information log I receive is

    trainer.py:472] Training: step 0

    And the last warning I get is:

    /usr/local/lib/python3.7/dist-packages/jax/interpreters/xla.py:800: UserWarning: Some donated buffers were not usable: s32[], f32[768]{0}, f32[768]{0}, f32[768]{0}, [...]", ".join(unused_donations)))

    Do you know if this is some Jax / TPU related issue, and whether one can use Colab's TPU / GPU to train and eval T5x?

    Thank you for your time,

    Leo

    opened by LeoLaugier 8
  • partitioning issues during inference on v3-32

    partitioning issues during inference on v3-32

    Hi,

    I was running inference on prompt-tuning which I think this calls this codebase and I ran into an issue when doing inference on a v3-32 with the partitioning with TypeError: 'ShapeDtypeStruct' object is not iterable. Training works fine on a v3-32, and training and inference work fine on a v3-8.

    Here is the traceback.

    Traceback (most recent call last):
      File "/usr/lib/python3.8/runpy.py", line 194, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/usr/lib/python3.8/runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "/home/dptam/.local/lib/python3.8/site-packages/t5x/eval.py", line 234, in <module>
        gin_utils.run(main)
      File "/home/dptam/.local/lib/python3.8/site-packages/t5x/gin_utils.py", line 103, in run
        app.run(
      File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 303, in run
        _run_main(main, args)
      File "/usr/local/lib/python3.8/dist-packages/absl/app.py", line 251, in _run_main
        sys.exit(main(argv))
      File "/home/dptam/.local/lib/python3.8/site-packages/t5x/eval.py", line 213, in main
        _main(argv)
      File "/home/dptam/.local/lib/python3.8/site-packages/t5x/eval.py", line 231, in _main
        evaluate_using_gin()
      File "/home/dptam/.local/lib/python3.8/site-packages/gin/config.py", line 1605, in gin_wrapper
        utils.augment_exception_message_and_reraise(e, err_str)
      File "/home/dptam/.local/lib/python3.8/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
        raise proxy.with_traceback(exception.__traceback__) from None
      File "/home/dptam/.local/lib/python3.8/site-packages/gin/config.py", line 1582, in gin_wrapper
        return fn(*new_args, **new_kwargs)
      File "/home/dptam/.local/lib/python3.8/site-packages/t5x/eval.py", line 127, in evaluate
        train_state_initializer = utils.TrainStateInitializer(
      File "/home/dptam/.local/lib/python3.8/site-packages/t5x/utils.py", line 365, in __init__
        self.train_state_axes = partitioner.get_mesh_axes(
      File "/home/dptam/.local/lib/python3.8/site-packages/t5x/partitioning.py", line 826, in get_mesh_axes
        mesh_axes_dict = jax.tree_map(flax_partitioning.logical_to_mesh_axes,
      File "/home/dptam/.local/lib/python3.8/site-packages/jax/_src/tree_util.py", line 178, in tree_map
        return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
      File "/home/dptam/.local/lib/python3.8/site-packages/jax/_src/tree_util.py", line 178, in <genexpr>
        return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
      File "/home/dptam/.local/lib/python3.8/site-packages/flax/linen/partitioning.py", line 154, in logical_to_mesh_axes
        axis_name_counts = collections.Counter(array_dim_names)
      File "/usr/lib/python3.8/collections/__init__.py", line 552, in __init__
        self.update(iterable, **kwds)
      File "/usr/lib/python3.8/collections/__init__.py", line 637, in update
        _count_elements(self, iterable)
    TypeError: 'ShapeDtypeStruct' object is not iterable
      In call to configurable 'evaluate' (<function evaluate at 0x7f784d161700>)
    Rewritten gin arg: --gin_bindings=MIXTURE_OR_TASK_NAME = 'glue_rte_32_shot_32_seed'
    Rewritten gin arg: --gin_bindings=MIXTURE_OR_TASK_MODULE = 'prompt_tuning.data.few_glue'
    Rewritten gin arg: --gin_bindings=TASK_FEATURE_LENGTHS = {'inputs': 512, 'targets': 8}
    Rewritten gin arg: --gin_bindings=CHECKPOINT_PATH = 'gs://nicl/pretrained_models/t5x_checkpoints/t0_3b/checkpoint_1112000'
    Rewritten gin arg: --gin_bindings=EVAL_OUTPUT_DIR = 'gs://nicl/checkpoint_models/rte/32_shot/32_seed/prompt-tuning/t0-3b/eval'
    Rewritten gin arg: --gin_bindings=utils.DatasetConfig.split = 'validation'
    Rewritten gin arg: --gin_bindings=utils.DatasetConfig.batch_size = 128
    Rewritten gin arg: --gin_bindings=USE_CACHED_TASKS = False
    Rewritten gin arg: --gin_bindings=partitioning.ModelBasedPjitPartitioner.model_parallel_submesh = (4, 4, 1, 2)
    Rewritten gin arg: --gin_bindings=PROMPT_FILE = 'gs://nicl/checkpoint_models/rte/32_shot/32_seed/prompt-tuning/t0-3b/numpy_checkpoints/checkpoint_1112300/encoder.prompt.prompt.prompt'
    ##### Command execution on worker 0 failed with return code 1. Continuing.
    ##### Command execution on worker 3 failed with return code 1. Continuing.
    ##### Command execution on worker 1 failed with return code 1. Continuing.
    ##### Command execution on worker 2 failed with return code 1. Continuing.
    
    opened by dptam 7
  • T5X: Introduce eval_fn in the abstract BaseModel class, to be able to customize execution when training VS evaluating a model. By default, eval_fn behaves like loss_fn so this change is backward compatible.

    T5X: Introduce eval_fn in the abstract BaseModel class, to be able to customize execution when training VS evaluating a model. By default, eval_fn behaves like loss_fn so this change is backward compatible.

    T5X: Introduce eval_fn in the abstract BaseModel class, to be able to customize execution when training VS evaluating a model. By default, eval_fn behaves like loss_fn so this change is backward compatible.

    opened by copybara-service[bot] 7
  • ValueError: None values not supported

    ValueError: None values not supported

    upon running a seqio mixture on mT5 and ByT5 i get and error stating: ValueError: None values not supported

    I currently am using a seqio mixture that i define in my task.py file and use the default mt5 tokenizer gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model with extra_ids=0

    This is how my task.py file looks

    import functools
    import seqio
    import tensorflow as tf
    import t5.data
    from datasets import load_from_disk, load_dataset
    from t5.data import postprocessors
    from t5.data import preprocessors
    from t5.evaluation import metrics
    from seqio import FunctionDataSource, utils
    
    TaskRegistry = seqio.TaskRegistry
    vocabulary = seqio.SentencePieceVocabulary('gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)
    
    
    DEFAULT_OUTPUT_FEATURES = {
        "inputs": seqio.Feature(
            vocabulary=vocabulary, add_eos=True,
            required=False),
        "targets": seqio.Feature(
            vocabulary=vocabulary, add_eos=True)
    }
    
    
    
    def gen_dataset(split, shuffle=False, seed=None, column="text", dataset_path=None):
        dataset = load_dataset(dataset_path, streaming=True, use_auth_token=True)
        if shuffle:
            if seed:
                dataset = dataset.shuffle(seed=seed)
            else:
                dataset = dataset.shuffle()
        while True:
            for item in dataset[str(split)]:
                yield item[column]
    
    
    def dataset_fn(split, shuffle_files, seed=None, dataset_path=None):
        return tf.data.Dataset.from_generator(
            functools.partial(gen_dataset, split, shuffle_files, seed, dataset_path=dataset_path),
            output_signature=tf.TensorSpec(shape=(), dtype=tf.string, name=dataset_path)
        )
    
    
    @utils.map_over_dataset
    def target_to_key(x, key_map, target_key):
        """Assign the value from the dataset to target_key in key_map"""
        return {**key_map, target_key: x}
    
    # link to the mt5 sentencepiece tokenizer vocabulary
    vocabulary = seqio.SentencePieceVocabulary('gs://t5-data/vocabs/mc4.250000.100extra/sentencepiece.model', extra_ids=0)
    
    TaskRegistry.add(
        "hindi_span_curruption",
        source=seqio.FunctionDataSource(
            dataset_fn=functools.partial(dataset_fn, dataset_path='StephennFernandes/ciil_mega_corpus_hindi'),
            splits=("train", "validation"),
            caching_permitted=False),
        preprocessors=[
            functools.partial(
                target_to_key, key_map={
                    "inputs": None,
                    "targets": None,
                }, target_key="targets"),
            seqio.preprocessors.tokenize,
            # seqio.CacheDatasetPlaceholder(),
            preprocessors.span_corruption, 
            seqio.preprocessors.append_eos_after_trim,
        ],
        output_features={"targets": DEFAULT_OUTPUT_FEATURES["targets"],"inputs": seqio.Feature(vocabulary=vocabulary,add_eos=True)},
        metric_fns=[]
    )
    ### similar multiple tasks exist for multiple languages. ### 
    
    seqio.MixtureRegistry.add(
      "ciil_mix_3",
      ["assamese_span_curruption", "bengali_span_curruption", 
      "bhisnupuriya_span_curruption", "bodo_span_curruption", 
      "divehi_span_curruption", "dogri_span_curruption", 
      "english_span_curruption", "gujarati_span_curruption",
      "hindi_span_curruption", "kannada_span_curruption", 
      "kashmiri_span_curruption", "konkani_span_curruption", 
      "maithili_span_curruption", "malayalam_span_curruption",
      "manipuri_span_curruption", "marathi_span_curruption",
      "nepali_span_curruption", "odia_span_curruption",
      "panjabi_span_curruption", "sanskrit_span_curruption",
      "tamil_span_curruption", "telugu_span_curruption",
       "urdu_span_curruption" ],
      default_rate=3
    )
    

    i use the ciil_mix_3 mixture in my .gin file this is how my .gin file looks

    from __gin__ import dynamic_registration
    import t5.data.mixtures
    import __main__ as train_script
    
    
    include 't5x/examples/t5/mt5/base.gin'
    include 't5x/configs/runs/pretrain.gin'
    
    import task 
    
    MIXTURE_OR_TASK_NAME = "ciil_mix_3"
    TASK_FEATURE_LENGTHS = {"inputs": 512, "targets": 114}
    TRAIN_STEPS = 100000
    DROPOUT_RATE = 0.0
    BATCH_SIZE = 32
    
    
    train_script.train:
      eval_period = 2000
    

    The following is the entire stack track of the same:

    Traceback (most recent call last):
      File "/home/stephen/anaconda3/lib/python3.9/runpy.py", line 197, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/home/stephen/anaconda3/lib/python3.9/runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/train.py", line 748, in <module>
        gin_utils.run(main)
      File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/gin_utils.py", line 107, in run
        app.run(
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/absl/app.py", line 308, in run
        _run_main(main, args)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/absl/app.py", line 254, in _run_main
        sys.exit(main(argv))
      File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/train.py", line 708, in main
        _main(argv)
      File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/train.py", line 744, in _main
        train_using_gin()
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/gin/config.py", line 1605, in gin_wrapper
        utils.augment_exception_message_and_reraise(e, err_str)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/gin/utils.py", line 41, in augment_exception_message_and_reraise
        raise proxy.with_traceback(exception.__traceback__) from None
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/gin/config.py", line 1582, in gin_wrapper
        return fn(*new_args, **new_kwargs)
      File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/train.py", line 249, in train
        train_ds = get_dataset_fn(train_dataset_cfg, ds_shard_id, num_ds_shards,
      File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/utils.py", line 1366, in get_dataset
        return get_dataset_inner(cfg, shard_info, feature_converter_cls, seed,
      File "/home/stephen/Desktop/t5x_final_test/t5x/t5x/utils.py", line 1387, in get_dataset_inner
        ds = seqio.get_dataset(
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/dataset_providers.py", line 1671, in get_dataset
        ds = mixture_or_task.get_dataset(
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/dataset_providers.py", line 1457, in get_dataset
        datasets = [
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/dataset_providers.py", line 1458, in <listcomp>
        task.get_dataset(  # pylint:disable=g-complex-comprehension
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/dataset_providers.py", line 1209, in get_dataset
        ds = self.preprocess_postcache(
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/dataset_providers.py", line 1044, in preprocess_postcache
        dataset = self._preprocess_dataset(
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/dataset_providers.py", line 965, in _preprocess_dataset
        dataset = prep_fn(dataset, **kwargs)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/preprocessors.py", line 83, in tokenize
        return utils.map_over_dataset(fn=tokenize_fn)(dataset)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/utils.py", line 778, in wrapped_fn
        return ds.map(
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 2050, in map
        return ParallelMapDataset(
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/data/ops/dataset_ops.py", line 5284, in __init__
        self._map_func = structured_function.StructuredFunctionWrapper(
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/data/ops/structured_function.py", line 271, in __init__
        self._function = fn_factory()
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 2567, in get_concrete_function
        graph_function = self._get_concrete_function_garbage_collected(
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 2533, in _get_concrete_function_garbage_collected
        graph_function, _ = self._maybe_define_function(args, kwargs)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 2711, in _maybe_define_function
        graph_function = self._create_graph_function(args, kwargs)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/eager/function.py", line 2627, in _create_graph_function
        func_graph_module.func_graph_from_py_func(
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/framework/func_graph.py", line 1141, in func_graph_from_py_func
        func_outputs = python_func(*func_args, **func_kwargs)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/data/ops/structured_function.py", line 248, in wrapped_fn
        ret = wrapper_helper(*args)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/data/ops/structured_function.py", line 177, in wrapper_helper
        ret = autograph.tf_convert(self._func, ag_ctx)(*nested_args)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 692, in wrapper
        raise e.ag_error_metadata.to_exception(e)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 689, in wrapper
        return converted_call(f, args, kwargs, options=options)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
        result = converted_f(*effective_args, **kwargs)
      File "/tmp/__autograph_generated_fileu9gu1w4n.py", line 8, in <lambda>
        tf__lam = lambda arg: ag__.with_function_scope(lambda lscope: ag__.converted_call(fn, (arg,) + tuple(args), dict(**kargs), lscope), 'lscope', ag__.STD)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/core/function_wrappers.py", line 113, in with_function_scope
        return thunk(scope)
      File "/tmp/__autograph_generated_fileu9gu1w4n.py", line 8, in <lambda>
        tf__lam = lambda arg: ag__.with_function_scope(lambda lscope: ag__.converted_call(fn, (arg,) + tuple(args), dict(**kargs), lscope), 'lscope', ag__.STD)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 352, in converted_call
        return converted_call(
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 439, in converted_call
        result = converted_f(*effective_args, **kwargs)
      File "/tmp/__autograph_generated_filezbhafqmt.py", line 113, in tf__tokenize_impl
        ag__.for_stmt(ag__.converted_call(ag__.ld(features).items, (), None, fscope), None, loop_body, get_state_4, set_state_4, (), {'iterate_names': '(k, v)'})
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 449, in for_stmt
        _py_for_stmt(iter_, extra_test, body, None, None)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 498, in _py_for_stmt
        body(target)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 464, in protected_body
        original_body(protected_iter)
      File "/tmp/__autograph_generated_filezbhafqmt.py", line 105, in loop_body
        ag__.if_stmt(ag__.ld(k) in ag__.ld(output_features), if_body_3, else_body_3, get_state_3, set_state_3, ('v',), 1)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1341, in if_stmt
        _py_if_stmt(cond, body, orelse)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/operators/control_flow.py", line 1394, in _py_if_stmt
        return body() if cond else orelse()
      File "/tmp/__autograph_generated_filezbhafqmt.py", line 63, in if_body_3
        v = ag__.converted_call(ag__.ld(vocab).encode_tf, (ag__.ld(v),), None, fscope)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 441, in converted_call
        result = converted_f(*effective_args)
      File "/tmp/__autograph_generated_filef9jwq2ra.py", line 13, in tf__encode_tf
        retval_ = ag__.converted_call(ag__.ld(self)._encode_tf, (ag__.ld(s),), None, fscope)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 441, in converted_call
        result = converted_f(*effective_args)
      File "/tmp/__autograph_generated_filezpl5g8b_.py", line 21, in tf___encode_tf
        retval_ = ag__.converted_call(ag__.ld(self).tf_tokenizer.tokenize, (ag__.ld(s),), None, fscope)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 441, in converted_call
        result = converted_f(*effective_args)
      File "/tmp/__autograph_generated_filet9vre1mq.py", line 22, in tf__tokenize
        input_tensor = ag__.converted_call(ag__.ld(ragged_tensor).convert_to_tensor_or_ragged_tensor, (ag__.ld(input),), None, fscope)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 377, in converted_call
        return _call_unconverted(f, args, kwargs, options)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/autograph/impl/api.py", line 459, in _call_unconverted
        return f(*args)
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/ops/ragged/ragged_tensor.py", line 2683, in convert_to_tensor_or_ragged_tensor
        return ops.convert_to_tensor_v2_with_dispatch(
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/util/traceback_utils.py", line 153, in error_handler
        raise e.with_traceback(filtered_tb) from None
      File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow/python/framework/tensor_util.py", line 441, in make_tensor_proto
        raise ValueError("None values not supported.")
    ValueError: in user code:
    
        File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/utils.py", line 779, in None  *
            lambda arg: fn(arg, *args, **kargs)
        File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/preprocessors.py", line 116, in tokenize_impl  *
            v = vocab.encode_tf(v)
        File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/vocabularies.py", line 114, in encode_tf  *
            return self._encode_tf(s)
        File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/vocabularies.py", line 413, in _encode_tf  *
            return self.tf_tokenizer.tokenize(s)
        File "/home/stephen/anaconda3/lib/python3.9/site-packages/tensorflow_text/python/ops/sentencepiece_tokenizer.py", line 133, in tokenize  *
            input_tensor = ragged_tensor.convert_to_tensor_or_ragged_tensor(input)
    
        ValueError: None values not supported.
    
      In call to configurable 'train' (<function train at 0x7f79d8db2280>)
    

    I even further tried to work the same with byT5 and the same error occurs: the following is the error occured using byT5

    ValueError: in user code:
    
        File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/utils.py", line 779, in None  *
            lambda arg: fn(arg, *args, **kargs)
        File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/preprocessors.py", line 116, in tokenize_impl  *
            v = vocab.encode_tf(v)
        File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/vocabularies.py", line 114, in encode_tf  *
            return self._encode_tf(s)
        File "/home/stephen/anaconda3/lib/python3.9/site-packages/seqio/vocabularies.py", line 555, in _encode_tf  *
            tf_ids = tf.io.decode_raw(s, tf.uint8) + self._num_special_tokens
    
        ValueError: Tried to convert 'bytes' to a tensor and failed. Error: None values not supported.
    
    opened by StephennFernandes 5
  • Support forward mode differentiation

    Support forward mode differentiation

    Currently forward mode differentiation does not work as losses.py implements the cross_entropy_with_logits using jax.custom_vjp. If it was implemented with jax.custom_jvp one would get both forward and reverse mode supported. An example application of forward mode differentiation is inspecting the Hessian and the eigenvalues of a model.

    opened by salayatana66 5
  • Incorrect checkpoint path

    Incorrect checkpoint path

      File "/home/torinaki/src/product-description-generation/t5x/t5x/utils.py", line 472, in from_checkpoints
    tensorstore/internal/oauth2/google_auth_provider.cc:173: Using credentials at bigquery-key.json
        yield _restore_path(path, restore_cfg)
      File "/home/torinaki/src/product-description-generation/t5x/t5x/utils.py", line 461, in _restore_path
        return restore_checkpointer.restore(
      File "/home/torinaki/src/product-description-generation/t5x/t5x/checkpoints.py", line 811, in restore
        state_dict = self._read_state_from_tensorstore(
      File "/home/torinaki/src/product-description-generation/t5x/t5x/checkpoints.py", line 860, in _read_state_from_tensorstore
        state_dict = _run_future_tree(future_state_dict)
      File "/home/torinaki/src/product-description-generation/t5x/t5x/checkpoints.py", line 160, in _run_future_tree
        leaves = loop.run_until_complete(asyncio.gather(*future_leaves))
      File "/home/torinaki/.pyenv/versions/3.8.9/lib/python3.8/asyncio/base_events.py", line 616, in run_until_complete
        return future.result()
      File "/home/torinaki/src/product-description-generation/t5x/t5x/checkpoint_importer.py", line 115, in _get_and_cast
        arr = await self._get_fn()  # pytype: disable=bad-return-type
      File "/home/torinaki/src/product-description-generation/t5x/t5x/checkpoints.py", line 1190, in _read_ts
    tensorstore/internal/oauth2/google_auth_provider.cc:189: Using ServiceAccount AuthProvider
        t = await ts.open(tmp_ts_spec_dict, open=True)
    ValueError: Error opening "zarr" driver: Metadata at "gs://t5x-dummy-bucket/gs://t5-data/pretrained_models/t5x/t5_1_1_base/checkpoint_1000000/target.decoder.layers_0.self_attention.value.kernel/.zarray" does not exist
      In call to configurable 'train' (<function train at 0x7f8b444b7c10>)
    

    It seems that the problem is in: https://github.com/google-research/t5x/blob/main/t5x/checkpoints.py#L219

    opened by dbalabka 4
  • RuntimeError: UNIMPLEMENTED: Requested AllReduceStart not implemented on GPU

    RuntimeError: UNIMPLEMENTED: Requested AllReduceStart not implemented on GPU

    I am trying to run one of the fine-tuning examples on a machine with 2 GPUs and getting the following error:

    RuntimeError: UNIMPLEMENTED: Requested AllReduceStart not implemented on GPU; replica_count: 1; partition_count: 2, group_mode: kCrossReplicaAndPartition, operand_count: 49; NCCL support: 1; first operand array element-type: BF16 In call to configurable 'train' (<function train at 0x7fed3c31e3a0>)

    full error trace below:


    RuntimeError Traceback (most recent call last) Input In [19], in <cell line: 24>() 17 gin_utils.parse_gin_flags( 18 # User-provided gin paths take precedence if relative paths conflict. 19 FLAGS.gin_search_paths,# + _DEFAULT_GIN_SEARCH_PATHS, 20 FLAGS.gin_file, 21 FLAGS.gin_bindings) 22 train_using_gin() ---> 24 gin_utils.run(main_train)

    File /anaconda/envs/py39/lib/python3.9/site-packages/t5x/gin_utils.py:105, in run(main) 103 def run(main): 104 """Wrapper for app.run that rewrites gin args before parsing.""" --> 105 app.run( 106 main, 107 flags_parser=lambda a: app.parse_flags_with_usage(rewrite_gin_args(a)))

    File /anaconda/envs/py39/lib/python3.9/site-packages/absl/app.py:312, in run(main, argv, flags_parser) 310 callback() 311 try: --> 312 _run_main(main, args) 313 except UsageError as error: 314 usage(shorthelp=True, detailed_error=error, exitcode=error.exitcode)

    File /anaconda/envs/py39/lib/python3.9/site-packages/absl/app.py:258, in _run_main(main, argv) 256 sys.exit(retval) 257 else: --> 258 sys.exit(main(argv))

    Input In [18], in main_train(argv) 1 def main_train(argv: Sequence[str]): 2 """Wrapper for pdb post mortems.""" ----> 3 _main(argv)

    Input In [19], in _main(argv) 15 train_using_gin = gin.configurable(train) 17 gin_utils.parse_gin_flags( 18 # User-provided gin paths take precedence if relative paths conflict. 19 FLAGS.gin_search_paths,# + _DEFAULT_GIN_SEARCH_PATHS, 20 FLAGS.gin_file, 21 FLAGS.gin_bindings) ---> 22 train_using_gin()

    File /anaconda/envs/py39/lib/python3.9/site-packages/gin/config.py:1605, in _make_gin_wrapper..gin_wrapper(*args, **kwargs) 1603 scope_info = " in scope '{}'".format(scope_str) if scope_str else '' 1604 err_str = err_str.format(name, fn_or_cls, scope_info) -> 1605 utils.augment_exception_message_and_reraise(e, err_str)

    File /anaconda/envs/py39/lib/python3.9/site-packages/gin/utils.py:41, in augment_exception_message_and_reraise(exception, message) 39 proxy = ExceptionProxy() 40 ExceptionProxy.qualname = type(exception).qualname ---> 41 raise proxy.with_traceback(exception.traceback) from None

    File /anaconda/envs/py39/lib/python3.9/site-packages/gin/config.py:1582, in _make_gin_wrapper..gin_wrapper(*args, **kwargs) 1579 new_kwargs.update(kwargs) 1581 try: -> 1582 return fn(*new_args, **new_kwargs) 1583 except Exception as e: # pylint: disable=broad-except 1584 err_str = ''

    Input In [10], in train(model, train_dataset_cfg, train_eval_dataset_cfg, infer_eval_dataset_cfg, checkpoint_cfg, partitioner, trainer_cls, model_dir, total_steps, eval_steps, eval_period, stats_period, random_seed, use_hardware_rng, summarize_config_fn, inference_evaluator_cls, get_dataset_fn, concurrent_metrics, actions, train_eval_get_dataset_fn, run_eval_before_training, use_gda) 422 logging.info('Compiling train loop.') 423 logging.flush() --> 424 trainer.compile_train(first_batch) 426 # Main Loop over "epochs". 427 for epoch in range(first_epoch, num_epochs):

    File /anaconda/envs/py39/lib/python3.9/site-packages/t5x/trainer.py:545, in BaseTrainer.compile_train(self, batch) 532 """Pre-compiles train step (if not yet compiled). 533 534 Not required. (...) 542 shapes and dtypes. 543 """ 544 tick = time.time() --> 545 self._compiled_train_step = self._partitioner.compile( 546 self._partitioned_train_step, self.train_state, batch) 547 tock = time.time() 548 self.train_metrics_manager.write_scalar("timing/compilation_seconds", 549 tock - tick, self.train_state.step)

    File /anaconda/envs/py39/lib/python3.9/site-packages/t5x/partitioning.py:795, in BasePjitPartitioner.compile(self, partitioned_fn, *args) 793 def compile(self, partitioned_fn: PjittedFnWithContext, 794 *args) -> CompiledPartitionedCallable: --> 795 return partitioned_fn.lower(*args).compile()

    File /anaconda/envs/py39/lib/python3.9/site-packages/jax/_src/stages.py:221, in Lowered.compile(self) 220 def compile(self) -> Compiled: --> 221 return Compiled(self._lowering.compile(), self.args_info, 222 self.out_tree, no_kwargs=self._no_kwargs)

    File /anaconda/envs/py39/lib/python3.9/site-packages/jax/interpreters/pxla.py:2346, in MeshComputation.compile(self, _allow_propagation_to_outputs, _allow_compile_replicated) 2342 def compile(self, 2343 _allow_propagation_to_outputs : bool = False, 2344 _allow_compile_replicated : bool = True) -> 'MeshExecutable': 2345 if self._executable is None: -> 2346 self._executable = MeshExecutable.from_hlo( 2347 self._name, self._hlo, **self.compile_args, 2348 _allow_propagation_to_outputs=_allow_propagation_to_outputs, 2349 _allow_compile_replicated=_allow_compile_replicated) # type: ignore 2350 return self._executable

    File /anaconda/envs/py39/lib/python3.9/site-packages/jax/interpreters/pxla.py:2456, in MeshExecutable.from_hlo(name, computation, mesh, global_in_avals, global_out_avals, in_axes, out_axes, spmd_lowering, tuple_args, in_is_global, auto_spmd_lowering, _allow_propagation_to_outputs, _allow_compile_replicated) 2453 else: 2454 with dispatch.log_elapsed_time(f"Finished XLA compilation of {name} " 2455 "in {elapsed_time} sec"): -> 2456 xla_executable = dispatch.compile_or_get_cached(backend, computation, compile_options) 2458 if auto_spmd_lowering: 2459 in_axes, out_axes = _get_array_mapping_from_executable(xla_executable, mesh)

    File /anaconda/envs/py39/lib/python3.9/site-packages/jax/_src/dispatch.py:664, in compile_or_get_cached(backend, computation, compile_options) 661 ir_str = (computation if isinstance(computation, str) 662 else computation.as_hlo_text()) 663 _dump_ir_to_file(module_name, ir_str) --> 664 return backend_compile(backend, computation, compile_options)

    File /anaconda/envs/py39/lib/python3.9/site-packages/jax/_src/profiler.py:206, in annotate_function..wrapper(*args, **kwargs) 203 @wraps(func) 204 def wrapper(*args, **kwargs): 205 with TraceAnnotation(name, **decorator_kwargs): --> 206 return func(*args, **kwargs) 207 return wrapper

    File /anaconda/envs/py39/lib/python3.9/site-packages/jax/_src/dispatch.py:618, in backend_compile(backend, built_c, options) 614 @profiler.annotate_function 615 def backend_compile(backend, built_c, options): 616 # we use a separate function call to ensure that XLA compilation appears 617 # separately in Python profiling results --> 618 return backend.compile(built_c, compile_options=options)

    RuntimeError: UNIMPLEMENTED: Requested AllReduceStart not implemented on GPU; replica_count: 1; partition_count: 2, group_mode: kCrossReplicaAndPartition, operand_count: 49; NCCL support: 1; first operand array element-type: BF16 In call to configurable 'train' (<function train at 0x7fed3c31e3a0>)

    opened by ibulu 3
  • Clarification: no attention scores normalization when using Adafactor?

    Clarification: no attention scores normalization when using Adafactor?

    the natural scale for adafactor LR is markedly different from Adam, one doesn't use the 1/sqrt(hidden) correction for this optimizer with attention-based models. https://github.com/google-research/t5x/blob/540a65958a1d4d60cb779771247e7312069f56b3/t5x/adafactor.py#L219

    Does this mean that we should not normalize attention scores (before the softmax) when using Adafactor? If so, could you please explain why?

    opened by rodrigonogueira4 0
  • Training not starting on v3-128

    Training not starting on v3-128

    I'm trying to train a T5 model with C4 data. This is the command that I'm using to train the model

    python ${T5X_DIR}/t5x/train.py \
      --gin_file="t5x/examples/t5/t5_1_1/examples/small_c4_pretrain.gin" \
      --gin.MODEL_DIR=\"${MODEL_DIR}\" \
      --alsologtostderr
    

    UPDATE, even this is stuck

    import jax
    jax.device_count()
    

    It's stuck at this

    UserWarning: TPU backend initialization is taking more than 60.0 seconds. Did you run your code on all TPU hosts? See https://jax.readthedocs.io/en/latest/multi_process.html for more information.

    Can someone please point me what I'm doing wrong?

    opened by sumanthd17 0
Owner
Google Research
Google Research
A modular framework for vision & language multimodal research from Facebook AI Research (FAIR)

MMF is a modular framework for vision and language multimodal research from Facebook AI Research. MMF contains reference implementations of state-of-t

Facebook Research 5.1k Jan 4, 2023
A fast, distributed, high performance gradient boosting (GBT, GBDT, GBRT, GBM or MART) framework based on decision tree algorithms, used for ranking, classification and many other machine learning tasks.

Light Gradient Boosting Machine LightGBM is a gradient boosting framework that uses tree based learning algorithms. It is designed to be distributed a

Microsoft 14.5k Jan 8, 2023
The lightweight PyTorch wrapper for high-performance AI research. Scale your models, not the boilerplate.

The lightweight PyTorch wrapper for high-performance AI research. Scale your models, not the boilerplate. Website • Key Features • How To Use • Docs •

Pytorch Lightning 21.1k Jan 1, 2023
The lightweight PyTorch wrapper for high-performance AI research. Scale your models, not the boilerplate.

The lightweight PyTorch wrapper for high-performance AI research. Scale your models, not the boilerplate. Website • Key Features • How To Use • Docs •

Pytorch Lightning 11.9k Feb 13, 2021
The lightweight PyTorch wrapper for high-performance AI research. Scale your models, not the boilerplate.

The lightweight PyTorch wrapper for high-performance AI research. Scale your models, not the boilerplate. Website • Key Features • How To Use • Docs •

Pytorch Lightning 21.1k Jan 8, 2023
Official repository of OFA. Paper: Unifying Architectures, Tasks, and Modalities Through a Simple Sequence-to-Sequence Learning Framework

Paper | Blog OFA is a unified multimodal pretrained model that unifies modalities (i.e., cross-modality, vision, language) and tasks (e.g., image gene

OFA Sys 1.4k Jan 8, 2023
Sequence to Sequence Models with PyTorch

Sequence to Sequence models with PyTorch This repository contains implementations of Sequence to Sequence (Seq2Seq) models in PyTorch At present it ha

Sandeep Subramanian 708 Dec 19, 2022
High performance Cross-platform Inference-engine, you could run Anakin on x86-cpu,arm, nv-gpu, amd-gpu,bitmain and cambricon devices.

Anakin2.0 Welcome to the Anakin GitHub. Anakin is a cross-platform, high-performance inference engine, which is originally developed by Baidu engineer

null 514 Dec 28, 2022
PPLNN is a Primitive Library for Neural Network is a high-performance deep-learning inference engine for efficient AI inferencing

PPLNN is a Primitive Library for Neural Network is a high-performance deep-learning inference engine for efficient AI inferencing

null 943 Jan 7, 2023
PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices.

PyTorch-LIT PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices. With

Amin Rezaei 157 Dec 11, 2022
Source code and Dataset creation for the paper "Neural Symbolic Regression That Scales"

NeuralSymbolicRegressionThatScales Pytorch implementation and pretrained models for the paper "Neural Symbolic Regression That Scales", presented at I

null 35 Nov 25, 2022
This is the repository for CVPR2021 Dynamic Metric Learning: Towards a Scalable Metric Space to Accommodate Multiple Semantic Scales

Intro This is the repository for CVPR2021 Dynamic Metric Learning: Towards a Scalable Metric Space to Accommodate Multiple Semantic Scales Vehicle Sam

null 39 Jul 21, 2022
Contrastive Learning for Many-to-many Multilingual Neural Machine Translation(mCOLT/mRASP2), ACL2021

Contrastive Learning for Many-to-many Multilingual Neural Machine Translation(mCOLT/mRASP2), ACL2021 The code for training mCOLT/mRASP2, a multilingua

null 104 Jan 1, 2023
M2MRF: Many-to-Many Reassembly of Features for Tiny Lesion Segmentation in Fundus Images

M2MRF: Many-to-Many Reassembly of Features for Tiny Lesion Segmentation in Fundus Images This repo is the official implementation of paper "M2MRF: Man

null 12 Dec 14, 2022
CPU inference engine that delivers unprecedented performance for sparse models

The DeepSparse Engine is a CPU runtime that delivers unprecedented performance by taking advantage of natural sparsity within neural networks to reduce compute required as well as accelerate memory bound workloads. It is focused on model deployment and scaling machine learning pipelines, fitting seamlessly into your existing deployments as an inference backend.

Neural Magic 1.2k Jan 9, 2023
A user-friendly research and development tool built to standardize RL competency assessment for custom agents and environments.

Built with ❤️ by Sam Showalter Contents Overview Installation Dependencies Usage Scripts Standard Execution Environment Development Environment Benchm

SRI-AIC 1 Nov 18, 2021
Understanding and Improving Encoder Layer Fusion in Sequence-to-Sequence Learning (ICLR 2021)

Understanding and Improving Encoder Layer Fusion in Sequence-to-Sequence Learning (ICLR 2021) Citation Please cite as: @inproceedings{liu2020understan

Sunbow Liu 22 Nov 25, 2022
Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers

Segmentation Transformer Implementation of Segmentation Transformer in PyTorch, a new model to achieve SOTA in semantic segmentation while using trans

Abhay Gupta 161 Dec 8, 2022
Implementation of SETR model, Original paper: Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers.

SETR - Pytorch Since the original paper (Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers.) has no official

zhaohu xing 112 Dec 16, 2022