NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling @ INTERSPEECH 2021 Accepted

Overview

NU-Wave — Official PyTorch Implementation

NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling
Junhyeok Lee, Seungu Han @ MINDsLab Inc., SNU

Paper(arXiv): https://arxiv.org/abs/2104.02321 (Accepted to INTERSPEECH 2021)
Audio Samples: https://mindslab-ai.github.io/nuwave

Official Pytorch+Lightning Implementation for NU-Wave.

Update: CODE RELEASED! README is DONE.

Requirements

Preprocessing

Before running our project, you need to download and preprocess dataset to .pt files

  1. Download VCTK dataset
  2. Remove speaker p280 and p315
  3. Modify path of downloaded dataset data:dir in hparameters.yaml
  4. run utils/wav2pt.py
$ python utils/wav2pt.py

Training

  1. Adjust hparameters.yaml, especially train section.
train:
  batch_size: 18 # Dependent on GPU memory size
  lr: 0.00003
  weight_decay: 0.00
  num_workers: 64 # Dependent on CPU cores
  gpus: 2 # number of GPUs
  opt_eps: 1e-9
  beta1: 0.5
  beta2: 0.999
  • If you want to train with single speaker, use VCTKSingleSpkDataset instead of VCTKMultiSpkDataset for dataset in dataloader.py. And use batch_size=1 for validation dataloader.
  • Adjust data section in hparameters.yaml.
data:
  dir: '/DATA1/VCTK/VCTK-Corpus/wav48/p225' #dir/spk/format
  format: '*mic1.pt'
  cv_ratio: (223./231., 8./231., 0.00) #train/val/test
  1. run trainer.py.
$ python trainer.py
  • If you want to resume training from checkpoint, check parser.
    parser = argparse.ArgumentParser()
    parser.add_argument('-r', '--resume_from', type =int,\
            required = False, help = "Resume Checkpoint epoch number")
    parser.add_argument('-s', '--restart', action = "store_true",\
            required = False, help = "Significant change occured, use this")
    parser.add_argument('-e', '--ema', action = "store_true",\
            required = False, help = "Start from ema checkpoint")
    args = parser.parse_args()
  • During training, tensorboard logger is logging loss, spectrogram and audio.
$ tensorboard --logdir=./tensorboard --bind_all

Evaluation

run for_test.py or test.py

$ python test.py -r {checkpoint_number} {-e:option, if ema} {--save:option}
or
$ python for_test.py -r {checkpoint_number} {-e:option, if ema} {--save:option}

Please check parser.

    parser = argparse.ArgumentParser()
    parser.add_argument('-r', '--resume_from', type =int,
                required = True, help = "Resume Checkpoint epoch number")
    parser.add_argument('-e', '--ema', action = "store_true",
                required = False, help = "Start from ema checkpoint")
    parser.add_argument('--save', action = "store_true",
               required = False, help = "Save file")

While we provide lightning style test code test.py, it has device dependency. Thus, we recommend to use for_test.py.

References

This implementation uses code from following repositories:

This README and the webpage for the audio samples are inspired by:

The audio samples on our webpage are partially derived from:

Repository Structure

.
├── Dockerfile
├── dataloader.py           # Dataloader for train/val(=test)
├── filters.py              # Filter implementation
├── test.py                 # Test with lightning_loop.
├── for_test.py             # Test with for_loop. Recommended due to device dependency of lightning
├── hparameter.yaml         # Config
├── lightning_model.py      # NU-Wave implementation. DDPM is based on ivanvok's WaveGrad implementation
├── model.py                # NU-Wave model based on lmnt-com's DiffWave implementation
├── requirement.txt         # requirement libraries
├── sampling.py             # Sampling a file
├── trainer.py              # Lightning trainer
├── README.md           
├── LICSENSE
├── utils
│  ├── stft.py              # STFT layer
│  ├── tblogger.py          # Tensorboard Logger for lightning
│  └── wav2pt.py            # Preprocessing
└── docs                    # For github.io
   └─ ...

Citation & Contact

If this repository useful for your research, please consider citing! Bibtex will be updated after INTERSPEECH 2021 conference.

@article{lee2021nuwave,
  title={NU-Wave: A Diffusion Probabilistic Model for Neural Audio Upsampling},
  author={Lee, Junhyeok and Han, Seungu},
  journal={arXiv preprint arXiv:2104.02321},
  year={2021}
}

If you have a question or any kind of inquiries, please contact Junhyeok Lee at [email protected]

Comments
  • Upsample file has static in the background or complete silence

    Upsample file has static in the background or complete silence

    Hello-there, Junhyeok & Seungu,

    My name is David. I'm writing an article about your awesome repository for the Level Up Coding publication on Medium.

    I'm still new to deep learning so I've been stumbling a bit through your implementation.

    I think your paper mentioned that 8 epochs produced the similar results as 1000 epochs.

    I trained the model with a 1080 ti 11gb using a batch size of 3 for 7 epochs so far. It created a checkpoint file for the 5th epoch. It also created a ema checkpoint file for the 7th epoch.

    Here's the strange part...

    The regular checkpoints produce an upsample file that has constant static in the background. The ema checkpoints produce an upsample file with complete silence.

    Would either of you be able to help shed some light on how to make the most of your awesome repository?

    With appreciation,

    David

    opened by david-littlefield 2
  • MU-GAN outputs

    MU-GAN outputs

    Hi!

    Congrats for nice results! I looked at your MU-GAN outputs. Are these the outputs from your own implementation?

    Why didn't you try NU-GAN? https://arxiv.org/abs/2010.11362

    question 
    opened by listener17 2
  • _pickle.UnpicklingError running test.py script (SOLVED)

    _pickle.UnpicklingError running test.py script (SOLVED)

    Hi @junjun3518,

    First, congratulations on the work. I trained the nuwave model for r=2 and r=3, however I'm having trouble running the test.py script.

    Please if you can help me that would be great. The following error message occurs when running:

    # python test.py -r=645 -e
    checkpoint_ratio_2/nuwave_x2_01_07_22_epoch=645_EMA
    GPU available: True, used: True
    TPU available: None, using: 0 TPU cores
    LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
    full speakers 109
    2 9
    num_workers:  64
    Testing:   0%|                                                                                                                                                       | 0/3552 [00:00<?, ?it/s]Exception in thread Thread-3:
    Traceback (most recent call last):
      File "/opt/conda/lib/python3.6/threading.py", line 916, in _bootstrap_inner
        self.run()
      File "/opt/conda/lib/python3.6/site-packages/prefetch_generator/__init__.py", line 80, in run
        for item in self.generator:
      File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 388, in __next__
        data = self._next_data()
      File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1018, in _next_data
        return self._process_data(data)
      File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/dataloader.py", line 1043, in _process_data
        data.reraise()
      File "/opt/conda/lib/python3.6/site-packages/torch/_utils.py", line 420, in reraise
        raise self.exc_type(msg)
    _pickle.UnpicklingError: Caught UnpicklingError in DataLoader worker process 0.
    Original Traceback (most recent call last):
      File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/_utils/worker.py", line 185, in _worker_loop
        data = fetcher.fetch(index)
      File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
        data = [self.dataset[idx] for idx in possibly_batched_index]
      File "/opt/conda/lib/python3.6/site-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
        data = [self.dataset[idx] for idx in possibly_batched_index]
      File "/root/Documentos/Upsampling/nuwave/dataloader.py", line 104, in __getitem__
        wav = torch.load(self.data_list[index])
      File "/opt/conda/lib/python3.6/site-packages/torch/serialization.py", line 595, in load
        return _legacy_load(opened_file, map_location, pickle_module, **pickle_load_args)
      File "/opt/conda/lib/python3.6/site-packages/torch/serialization.py", line 765, in _legacy_load
        magic_number = pickle_module.load(f, **pickle_load_args)
    _pickle.UnpicklingError: unpickling stack underflow
    

    I'm using docker container. My hparameter.yaml file is as follows:

    train:
      batch_size: 16
      lr: 0.00003
      weight_decay: 0.00
      num_workers: 64
      gpus: 1 #ddp
      opt_eps: 1e-9
      beta1: 0.5
      beta2: 0.999
    
    data:
      dir: './VCTK-Corpus/wav48/' #dir/spk/format
      format: '*.wav'
      cv_ratio: (100./108., 8./108., 0.00) #train/val/test
    
    audio:
      sr: 48000
      nfft: 1024
      hop: 256
      ratio: 2 #upscale_ratio
      length: 32768 #32*1024 ~ 1sec
    
    arch:
      residual_layers: 30 #
      residual_channels: 64
      dilation_cycle_length: 10
      pos_emb_dim: 512 
    
    ddpm:
      max_step: 1000
      noise_schedule: "torch.linspace(1e-6, 0.006, hparams.ddpm.max_step)"
      pos_emb_scale: 50000
      pos_emb_channels: 128 
      infer_step: 8
      infer_schedule: "torch.tensor([1e-6,2e-6,1e-5,1e-4,1e-3,1e-2,1e-1,9e-1])"
    
    log:
      name: 'nuwave_x2'
      checkpoint_dir: 'checkpoint_ratio_2/'
      tensorboard_dir: 'tensorboard_ratio_2/'
      test_result_dir: 'test_sample/result'
    

    Thank you very much in advance.

    opened by freds0 1
  • About to be a Contribution

    About to be a Contribution

    I decided to make a colab , and it took hours to adjust the packaging of pip and all the other mismatching stuff , and the colab is supposedly working , and no packaging errors or (Module not found) errors anymore ,, but the thing is , it doesn't really upscale , i used the x3 last.ckpt from freds0 it should upscale 16khz to 48khz , and i did put a 16khz file which was provided from the demo and it did the following Annotation 2022-06-08 032011

    by the way I don't understand why it did so many outputs , but i took the "sample_0_48000.wav" and there is no upscaling improvement at all , at all :( but the interesting thing is it does say 48khz if imported to any audio software but it doesn't sound like it nor the spectrogram looks like it ,, you know the 16khz file that i experimented on was 66kb , after supposedly the upscaling is made it says it is 197kb but it doesn't sound like 48khz at ALL ,,, here comes the crazy part when i exported the same 16khz 66kb file to adobe audition and then exported it to 48khz which is wrong but i did it , it was the same size as the nuwave one 197 kb also , and it had the same spectrogram of the nuwave ,, then i knew that there is something wrong and that nu wave isn't really functioning alright ,, P.S during the inference there was also some errors but it didn't stop it ,, the code was; !python sampling.py -c last.ckpt -f sample_0.wav and it did that /usr/local/lib/python3.7/dist-packages/torch/functional.py:472: UserWarning: stft will soon require the return_complex parameter be given for real inputs, and will further require that return_complex=True in a future PyTorch release. (Triggered internally at /pytorch/aten/src/ATen/native/SpectralOps.cpp:664.) normalized, onesided, return_complex) /usr/local/lib/python3.7/dist-packages/torch/functional.py:546: UserWarning: istft will require a complex-valued input tensor in a future PyTorch release. Matching the output from stft with return_complex=True. (Triggered internally at /pytorch/aten/src/ATen/native/SpectralOps.cpp:817.) normalized, onesided, length, return_complex)

    i hope you reach as soon as possible as this problem of it saying it is 48khz and not sounding like one is really making my mind go crazy :)

    Have a nice day/night

    opened by dutchsing009 10
  • Contribution: checkpoints AVAILABLE!

    Contribution: checkpoints AVAILABLE!

    Hi guys, First I would like to thank @junjun3518 for the excellent work of developing and sharing the code. I trained the model following the paper settings for two weeks on a V100 GPU using ratio=2 and 3. I would like to contribute to the project by sharing the checkpoints. Below are the download links.

    nuwave x2: https://drive.google.com/file/d/1pegayKs-i78yWlPuLIp-BCU8KxxCpBzd/view?usp=sharing

    nuwave x3: https://drive.google.com/file/d/12RUMjEALAs0EoEw6Fqf9ZkpTm3COX6sf/view?usp=sharing

    The following are images of the training logs.

    nuwave x2:

    epoch: epoch

    loss: loss

    val loss: val_loss

    nuwave x3:

    epoch: epoch

    loss: loss

    val loss: val_loss

    b-a01c-8df0b99c9e0e.svg)

    opened by freds0 7
  • robustness issue: use of methods marked for deprecation

    robustness issue: use of methods marked for deprecation

    Hi,

    1. Do you have pretrained models - I don't see them linked in the github ? It would be great to have those

    2. So, to test your model I'm retraining, I noticed a couple easy fixes that would make this robust to current libraries. librosa 0.9 and pytorch-lightning 1.4 -- I get it that you put older libraries librosa 0.8 and pytorch-ligthning 1.1.6 in the requirements, yet the 'fixes' were already marked for deprecation and having the environmnet already built I didnt want to grab older libraries. So, for your consideration only, you may want to keep the old code but it doesnt work for me. I forked and while I don't know if all processes are being correctly run it seems to be training alright.

    file: nuwave/utils/wav2pt.py on librosa 0.9.0 effects.trim() requires kwargs for all but the first argument; minimal change

    rosa.effects.trim(y, top_db=15)   
    

    file: nuwave/trainer.py pytorch-lightning has the terrible habit of deprecating and renaming; I think these changes should work in the older version as well as they were already slated for deprecation. From the CHANGELOG (#5321) Removed deprecated checkpoint argument filepath Use dirpath + filename instead (#6162) Removed deprecated ModelCheckpoint arguments prefix, mode="auto"

        checkpoint_callback = ModelCheckpoint(dirpath=hparams.log.checkpoint_dir,
                                              filename=ckpt_path,
                                              verbose=True,
                                              save_last=True,
                                              save_top_k=3,
                                              monitor='val_loss',
                                              mode='min')
    

    Trainer() class does not accept checkpoint_callback kwarg. (#9754) Deprecate checkpoint_callback from the Trainer constructor in favour of enable_checkpointing

        trainer = Trainer(
            checkpoint_callback=True,
            gpus=hparams.train.gpus,
            accelerator='ddp' if hparams.train.gpus > 1 else None,
            #plugins='ddp_sharded',
            amp_backend='apex',  #
            amp_level='O2',  #
            #num_sanity_val_steps = -1,
            check_val_every_n_epoch=2,
            gradient_clip_val = 0.5,
            max_epochs=200000,
            logger=tblogger,
            progress_bar_refresh_rate=4,
            callbacks=[
                EMACallback(os.path.join(hparams.log.checkpoint_dir,
                            f'{hparams.name}_epoch={{epoch}}_EMA')),
                            checkpoint_callback
                      ],
            resume_from_checkpoint=None
            if args.resume_from == None or args.restart else sorted(
                glob(
                    os.path.join(hparams.log.checkpoint_dir,
                                 f'*_epoch={args.resume_from}.ckpt')))[-1])
    

    (#11578) Deprecated Callback.on_epoch_end hook in favour of Callback.on_{train/val/test}_epoch_end

        @rank_zero_only
        def on_train_epoch_end(self, trainer, pl_module):
            self.queue.append(trainer.current_epoch)
            ...
    
    opened by xvdp 0
  •  How do I upsample a wav once the model is trained?

    How do I upsample a wav once the model is trained?

    I think this is a very interesting project and I'd like to test it but I'm not a data scientist. I see how to train and test but I don't see any examples of how to use it. I looked through the code but didn't see anything that gave me a clear indication on how to use it. How do I upsample a wav once the model is trained?

    opened by go-dustin 3
  • Different sampling rates

    Different sampling rates

    Hi!

    Did you observe trainings with different sampling rates such as 8K->16K, 8K-> 22K, 16K->22K, etc.. ? (diferent from demo page)

    and what changes should we do to train with these data? (maybe hop length, n_fft, noise_schedule, pos_emb_scale, etc..)

    opened by EmreOzkose 26
Owner
MINDs Lab
MINDsLab provides AI platform and various AI engines based on deep machine learning.
MINDs Lab
The official implementation of the Interspeech 2021 paper WSRGlow: A Glow-based Waveform Generative Model for Audio Super-Resolution.

WSRGlow The official implementation of the Interspeech 2021 paper WSRGlow: A Glow-based Waveform Generative Model for Audio Super-Resolution. Audio sa

Kexun Zhang 96 Jan 3, 2023
Minimal diffusion models - Minimal code and simple experiments to play with Denoising Diffusion Probabilistic Models (DDPMs)

Minimal code and simple experiments to play with Denoising Diffusion Probabilist

Rithesh Kumar 16 Oct 6, 2022
Pytorch-diffusion - A basic PyTorch implementation of 'Denoising Diffusion Probabilistic Models'

PyTorch implementation of 'Denoising Diffusion Probabilistic Models' This reposi

Arthur Juliani 76 Jan 7, 2023
Code for the Interspeech 2021 paper "AST: Audio Spectrogram Transformer".

AST: Audio Spectrogram Transformer Introduction Citing Getting Started ESC-50 Recipe Speechcommands Recipe AudioSet Recipe Pretrained Models Contact I

Yuan Gong 603 Jan 7, 2023
SweiNet is an uncertainty-quantifying shear wave speed (SWS) estimator for ultrasound shear wave elasticity (SWE) imaging.

SweiNet SweiNet is an uncertainty-quantifying shear wave speed (SWS) estimator for ultrasound shear wave elasticity (SWE) imaging. SweiNet takes as in

Felix Jin 3 Mar 31, 2022
PyTorch implementation of "ContextNet: Improving Convolutional Neural Networks for Automatic Speech Recognition with Global Context" (INTERSPEECH 2020)

ContextNet ContextNet has CNN-RNN-transducer architecture and features a fully convolutional encoder that incorporates global context information into

Sangchun Ha 24 Nov 24, 2022
A denoising diffusion probabilistic model (DDPM) tailored for conditional generation of protein distograms

Denoising Diffusion Probabilistic Model for Proteins Implementation of Denoising Diffusion Probabilistic Model in Pytorch. It is a new approach to gen

Phil Wang 108 Nov 23, 2022
Pytorch implementation of "Grad-TTS: A Diffusion Probabilistic Model for Text-to-Speech"

GradTTS Unofficial Pytorch implementation of "Grad-TTS: A Diffusion Probabilistic Model for Text-to-Speech" (arxiv) About this repo This is an unoffic

HeyangXue1997 103 Dec 23, 2022
An official reimplementation of the method described in the INTERSPEECH 2021 paper - Speech Resynthesis from Discrete Disentangled Self-Supervised Representations.

Speech Resynthesis from Discrete Disentangled Self-Supervised Representations Implementation of the method described in the Speech Resynthesis from Di

Facebook Research 253 Jan 6, 2023
PULSE: Self-Supervised Photo Upsampling via Latent Space Exploration of Generative Models

PULSE: Self-Supervised Photo Upsampling via Latent Space Exploration of Generative Models Code accompanying CVPR'20 paper of the same title. Paper lin

Alex Damian 7k Dec 30, 2022
ILVR: Conditioning Method for Denoising Diffusion Probabilistic Models (ICCV 2021 Oral)

ILVR + ADM This is the implementation of ILVR: Conditioning Method for Denoising Diffusion Probabilistic Models (ICCV 2021 Oral). This repository is h

Jooyoung Choi 225 Dec 28, 2022
PyTorch implementation of "Conformer: Convolution-augmented Transformer for Speech Recognition" (INTERSPEECH 2020)

PyTorch implementation of Conformer: Convolution-augmented Transformer for Speech Recognition. Transformer models are good at capturing content-based

Soohwan Kim 565 Jan 4, 2023
Official PyTorch implementation for FastDPM, a fast sampling algorithm for diffusion probabilistic models

Official PyTorch implementation for "On Fast Sampling of Diffusion Probabilistic Models". FastDPM generation on CIFAR-10, CelebA, and LSUN datasets. S

Zhifeng Kong 68 Dec 26, 2022
Denoising Diffusion Probabilistic Models

Denoising Diffusion Probabilistic Models This repo contains code for DDPM training. Based on Denoising Diffusion Probabilistic Models, Improved Denois

Alexander Markov 7 Dec 15, 2022
Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in Pytorch

Retrieval-Augmented Denoising Diffusion Probabilistic Models (wip) Implementation of Retrieval-Augmented Denoising Diffusion Probabilistic Models in P

Phil Wang 55 Jan 1, 2023
HeatNet is a python package that provides tools to build, train and evaluate neural networks designed to predict extreme heat wave events globally on daily to subseasonal timescales.

HeatNet HeatNet is a python package that provides tools to build, train and evaluate neural networks designed to predict extreme heat wave events glob

Google Research 6 Jul 7, 2022
This project is the official implementation of our accepted ICLR 2021 paper BiPointNet: Binary Neural Network for Point Clouds.

BiPointNet: Binary Neural Network for Point Clouds Created by Haotong Qin, Zhongang Cai, Mingyuan Zhang, Yifu Ding, Haiyu Zhao, Shuai Yi, Xianglong Li

Haotong Qin 59 Dec 17, 2022
A Research-oriented Federated Learning Library and Benchmark Platform for Graph Neural Networks. Accepted to ICLR'2021 - DPML and MLSys'21 - GNNSys workshops.

FedGraphNN: A Federated Learning System and Benchmark for Graph Neural Networks A Research-oriented Federated Learning Library and Benchmark Platform

FedML-AI 175 Dec 1, 2022
Official PyTorch implementation of the preprint paper "Stylized Neural Painting", accepted to CVPR 2021.

Official PyTorch implementation of the preprint paper "Stylized Neural Painting", accepted to CVPR 2021.

Zhengxia Zou 1.5k Dec 28, 2022