SEW (Squeezed and Efficient Wav2vec)
The repo contains the code of the paper "Performance-Efficiency Trade-offs in Unsupervised Pre-training for Speech Recognition" by Felix Wu, Kwangyoun Kim, Jing Pan, Kyu Han, Kilian Q Weinberger, and Yoav Artzi.
Model Checkpoints
Unsupervisedly Pre-trained on LibriSpeech 960h
Model | Pre-training updates | Dataset | Model |
---|---|---|---|
W2V2-tiny | 100K | Librispeech 960h | download |
W2V2-small | 100K | Librispeech 960h | download |
W2V2-mid | 100K | Librispeech 960h | download |
W2V2-base | 100K | Librispeech 960h | download |
SEW-tiny | 100K | Librispeech 960h | download |
SEW-small | 100K | Librispeech 960h | download |
SEW-mid | 100K | Librispeech 960h | download |
SEW-D-tiny | 100K | Librispeech 960h | download |
SEW-D-small | 100K | Librispeech 960h | download |
SEW-D-mid | 100K | Librispeech 960h | download |
SEW-D-mid (k127) | 100K | Librispeech 960h | download |
SEW-D-base | 100K | Librispeech 960h | download |
SEW-D-base+ | 100K | Librispeech 960h | download |
SEW-D-mid | 400K | Librispeech 960h | download |
SEW-D-mid (k127) | 400K | Librispeech 960h | download |
SEW-D-base+ | 400K | Librispeech 960h | download |
ASR model fine-tuned on LibriSpeech train-clean 100h
Model | Pre-training updates | Finetuning split | Model |
---|---|---|---|
SEW-tiny | 100K | 100h | download |
SEW-D-tiny | 100K | 100h | download |
SEW-D-mid | 400K | 100h | download |
SEW-D-mid (k127) | 400K | 100h | download |
SEW-D-base+ | 400K | 100h | download |
Usage
Dependencies
The code is tested with fairseq commit 05255f9, deberta commit bf17ca4 and the following packages.
torch==1.8.0
torchaudio==0.8.0
tqdm==4.49.0
Hydra==2.5
hydra-core==1.0.4
fvcore==0.1.5.post20210330
omegaconf==2.0.5
einops==0.3.0
fire==0.2.1
Apex
Please install NVIDIA's apex with
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
--global-option="--deprecated_fused_adam" --global-option="--xentropy" \
--global-option="--fast_multihead_attn" ./
wav2letter decoder
Currently, we are decoding with wav2letter v0.2 python binding at commit 96f5f9d
Please install the python binding here https://github.com/flashlight/wav2letter/tree/96f5f9d3b41e01af0a031ee0d2604acd9ef3b1b0/bindings/python The newest commit d5a93f0
in v0.2 branch leads to worse WER for wav2vec 2.0 baselines.
Installation
git clone https://github.com/asappresearch/sew.git
cd sew
pip install -e .
Pre-training
Pre-training SEW models
Run the following command where $model_size
can be tiny
, small
, or mid
, and $ngpu
is tne number of GPUs you want to use.
bash scripts/pt-sew.sh $model_size $ngpu
Pre-training SEW-D models
bash scripts/pt-sew-d.sh $model_size $ngpu
where $model_size
can be tiny
, small
, mid
, mid-k127
, base
, or base+
.
Fine-tuning
Run the following script to fine-tune a model with the hyperparameters from wav2vec 2.0.
bash scripts/ft-model.sh $pre_trained_model $split $ngpu
where $pre_trained_model
can be either a W2V2, SEW, or a SEW-D model checkpoint and $split
can be 10m
, 1h
, 10h
, or 100h
.
Here we also provide a set of hyperparameters which sets all dropouts the same as the pre-training stage, and we found it to be more stable.
bash scripts/ft-model-stable.sh $pre_trained_model $split $ngpu
If you see out of GPU memory error, please scale down the dataset.max_tokens
and scale up the optimization.update_freq
in scripts/ft-model.sh
. For example modifying these lines
dataset.max_tokens=3200000 \
optimization.update_freq="[$((8 / $ngpu))]" \
to
dataset.max_tokens=1600000 \
optimization.update_freq="[$((16 / $ngpu))]" \
which reduces the batch size and increases the gradient accumulation steps in order to use less GPU memory.
Evaluation
- Please run this script to prepare the official LibriSpeech 4-gram language model.
bash scripts/prepare_librispeech_lm.sh $kenlm_build_bin
where $kenlm_build_bin
is the folder that contains the KenLM build_binary
executable file (e.g. /home/user/kenlm/build/bin
).
- Then run this script to evaluate a pre-trained ASR model
python tools/eval_w2v.py tunelm --subsets '["dev-clean", "dev-other", "test-clean", "test-other"]' --model $asr_checkpoint