Pytorch port of Google Research's LEAF Audio paper

Overview

leaf-audio-pytorch

Pytorch port of Google Research's LEAF Audio paper published at ICLR 2021.

This port is not completely finished, but the Leaf() frontend is fully ported over, functional and validated to have similar outputs to the original tensorflow implementation. A few small things are missing, such as the SincNet and SincNet+ implementations, a few different pooling layers, etc.

PLEASE leave issues, pull requests, comments, or anything you find in using this repository that may be of value to others who will try to use this.

Installation

From the root directory of this repo, run:

pip install -e .

Usage

leaf_audio_pytorch mirrors it's original respository; imports and arguments are the same.

import leaf_audio_pytorch.frontend as frontend

leaf = frontend.Leaf()

Installation for Developing

If you are looking to develop on this repo, the requirements.txt contains everything needed to run the torch and tf implementations of leaf audio simultaneously.

NOTE: There is some weird dependency stuff going on with the original leaf-audio repo. Seems like its a dependency issue with lingvo and waymo-open-dataset. These below commands are a workaround.

Install the packages required:

pip install -r requirements.txt --no-deps

Install the leaf-audio repo from Git SSH:

pip install git+ssh://[email protected]/google-research/leaf-audio.git --no-deps

Then add the leaf_audio_pytorch package as well

python setup.py develop

At this point everything should be good to go! The scripts in test/ contains some testing code to validate the torch implementation mirrors tf.

Some Things to Keep in Mind (PLEASE READ)

  • When writing this port, I ran a debugger of the torch and tf implementations side by side and validated that each layer and operation mirrors the tensorflow implementation (to within a few significant digits, i.e. a tensor's values may variate by 0.001). There is one notable exception: The depthwise convolution within the GaussianLowpass pooling layer has a larger variation in tensor values, but the ported operation still produces similar outputs. I'm not sure why this operation is producing different values, but i'm currently looking into it. Please do your own due diligence in using this port and making sure this works as expected.

  • As of March 29, I finished the initial version of the port, but I have not tested Leaf() in a traning setting yet. Calling .backward() on Leaf() throws no errors, meaning backprop works as expected. However, I do not yet know how this will function during training.

  • As PyTorch and Tensorflow follow different tensor ordering conventions, Leaf() does all of its operations and outputs tensors with channels first.

Reference

All credit and attribution goes to Neil Zeghidour and the Google Research team who wrote the paper and created the Tensorflow implementation.

Please visit their GitHub repository and review their ICLR publication.

You might also like...
Source code for models described in the paper "AudioCLIP: Extending CLIP to Image, Text and Audio" (https://arxiv.org/abs/2106.13043)

AudioCLIP Extending CLIP to Image, Text and Audio This repository contains implementation of the models described in the paper arXiv:2106.13043. This

Code for the Interspeech 2021 paper
Code for the Interspeech 2021 paper "AST: Audio Spectrogram Transformer".

AST: Audio Spectrogram Transformer Introduction Citing Getting Started ESC-50 Recipe Speechcommands Recipe AudioSet Recipe Pretrained Models Contact I

The official implementation of the Interspeech 2021 paper WSRGlow: A Glow-based Waveform Generative Model for Audio Super-Resolution.

WSRGlow The official implementation of the Interspeech 2021 paper WSRGlow: A Glow-based Waveform Generative Model for Audio Super-Resolution. Audio sa

Official implementation of the paper WAV2CLIP: LEARNING ROBUST AUDIO REPRESENTATIONS FROM CLIP

Wav2CLIP 🚧 WIP 🚧 Official implementation of the paper WAV2CLIP: LEARNING ROBUST AUDIO REPRESENTATIONS FROM CLIP 📄 🔗 Ho-Hsiang Wu, Prem Seetharaman

Code for the TASLP paper
Code for the TASLP paper "PSLA: Improving Audio Tagging With Pretraining, Sampling, Labeling, and Aggregation".

PSLA: Improving Audio Tagging with Pretraining, Sampling, Labeling, and Aggregation Introduction Getting Started FSD50K Recipe AudioSet Recipe Label E

Data manipulation and transformation for audio signal processing, powered by PyTorch

torchaudio: an audio library for PyTorch The aim of torchaudio is to apply PyTorch to the audio domain. By supporting PyTorch, torchaudio follows the

This repository contains a PyTorch implementation of
This repository contains a PyTorch implementation of "AD-NeRF: Audio Driven Neural Radiance Fields for Talking Head Synthesis".

AD-NeRF: Audio Driven Neural Radiance Fields for Talking Head Synthesis | Project Page | Paper | PyTorch implementation for the paper "AD-NeRF: Audio

Easy to use Audio Tagging in PyTorch
Easy to use Audio Tagging in PyTorch

Audio Classification, Tagging & Sound Event Detection in PyTorch Progress: Fine-tune on audio classification Fine-tune on audio tagging Fine-tune on s

PyTorch implementation of SampleRNN: An Unconditional End-to-End Neural Audio Generation Model

samplernn-pytorch A PyTorch implementation of SampleRNN: An Unconditional End-to-End Neural Audio Generation Model. It's based on the reference implem

Comments
  • Backpropagation doesn't change weights of the convolutional layer

    Backpropagation doesn't change weights of the convolutional layer

    Hi,

    Thank you for creating this Github! I was wondering why does forward call of GaborConstraint class has "with torch.no_grad()" (line 28 in leaf_audio_pytorch/convolution.py)? It appears the convolutional layers weight doesn't change while backpropagation because of this. However, removing it I see that weights update. Is there a reason we need torch.no_grad()? Attached a sample code to compare the Gabor layer parameters before and after training a toy code.

    `

    import torch
    from leaf_audio_pytorch import frontend
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    a = frontend.Leaf().to(device)
    ## A random input that is like an audio waveform
    x = torch.rand(1,1,97339).cuda(device, non_blocking=True)
    a.train()
    print(a._complex_conv._kernel.data)
    opt = torch.optim.Adam(a.parameters(), lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-6, amsgrad=True)
    for _ in range(100):
        b = a(x)
        ## Ground truth for the output for leaf module
        b_gt = torch.ones_like(b)
        loss = torch.norm((b-b_gt).view(1,-1),p=2,dim=-1)
        opt.zero_grad()
        loss.backward()
        opt.step()
    print('====================================================')
    print(a._complex_conv._kernel.data)
    

    `

    opened by sahu-ji 3
  • Can't run inference - RuntimeError: unsupported input to tensordot, got dims=0

    Can't run inference - RuntimeError: unsupported input to tensordot, got dims=0

    When I run the inference (like for example the one provided in the tests)

    import numpy as np
    import torch
    import leaf_audio_pytorch.frontend as torch_frontend
    py_leaf = torch_frontend.Leaf().cuda()
    # (batch_size, num_samples, 1)
    test_audio = np.random.random((8,15000,1)).astype(np.float32)
    
    # convert to channel first for pytorch
    t_audio = torch.Tensor(test_audio).permute(0,2,1).cuda()
    print(py_leaf(t_audio))
    

    I get this error:

    RuntimeError                              Traceback (most recent call last)
    <ipython-input-27-eeba248942f0> in <module>
    ----> 1 print(py_leaf(t_audio))
    
    ~/Envs/leaf-audio-pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
       1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
       1050                 or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1051             return forward_call(*input, **kwargs)
       1052         # Do not call functions when jit is used
       1053         full_backward_hooks, non_full_backward_hooks = [], []
    
    ~/src/leaf-audio-pytorch/leaf_audio_pytorch/frontend.py in forward(self, x)
        131             outputs = outputs[:,:,1:]
        132 
    --> 133         outputs = self._complex_conv(outputs)
        134         outputs = self._activation(outputs)
        135         outputs = self._pooling(outputs)
    
    ~/Envs/leaf-audio-pytorch/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
       1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
       1050                 or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1051             return forward_call(*input, **kwargs)
       1052         # Do not call functions when jit is used
       1053         full_backward_hooks, non_full_backward_hooks = [], []
    
    ~/src/leaf-audio-pytorch/leaf_audio_pytorch/convolution.py in forward(self, x)
         74             kernel = torch.gather(kernel, dim=0, index=filter_order)
         75 
    ---> 76         filters = impulse_responses.gabor_filters(kernel, self._kernel_size, self.gabor_filter_init_t)
         77         real_filters = torch.real(filters)
         78         img_filters = torch.imag(filters)
    
    ~/src/leaf-audio-pytorch/leaf_audio_pytorch/impulse_responses.py in gabor_filters(kernel, size, t_tensor)
         33     return gabor_impulse_response(
         34         t_tensor,
    ---> 35         center=kernel[:, 0], fwhm=kernel[:, 1])
         36 
         37 
    
    ~/src/leaf-audio-pytorch/leaf_audio_pytorch/impulse_responses.py in gabor_impulse_response(t, center, fwhm)
          9     """Computes the gabor impulse response."""
         10     denominator = 1.0 / (np.sqrt(2.0 * math.pi) * fwhm)
    ---> 11     gaussian = torch.exp(torch.tensordot(1.0 / (2. * fwhm**2), -t**2, dims=0)) # TODO: validate the dims here
         12     center_frequency_complex = center.type(torch.complex64)
         13     t_complex = t.type(torch.complex64)
    
    ~/Envs/leaf-audio-pytorch/lib/python3.7/site-packages/torch/functional.py in tensordot(a, b, dims, out)
        927 
        928     if len(dims_a) == 0 or len(dims_b) == 0:
    --> 929         raise RuntimeError(f"unsupported input to tensordot, got dims={dims}")
        930 
        931     if out is None:
    
    RuntimeError: unsupported input to tensordot, got dims=0
    
    opened by dariocazzani 2
  • Why use complex number in  convolution

    Why use complex number in convolution

    Hi @denfed The paper says "we compute instead the convolution with 2N real-valued filters φ ̃ , n = 1, . . . , 2N , and perform squared l2-pooling". I am wandering why in the code we still use complex number when initializing the filter.

    Thanks for the great work! Junjie

    opened by jjjjohnson 1
  • Remove torch.no_grad call that breaks backprop on the complex conv

    Remove torch.no_grad call that breaks backprop on the complex conv

    Thanks to @sahu-ji, this is a bug fix that allows the complex conv to be backpropagatable. I'm not sure exactly why this is a bug, because all of the operations inside of the torch.no_grad() inside the GaborConstraint are not learnable, it is only a clipping (i.e. constraint) operation. But nonetheless it does break backprop. Not anymore :)

    opened by denfed 0
Owner
Dennis Fedorishin
UB | Computer Science PhD Candidate
Dennis Fedorishin
A PyTorch port of the Neural 3D Mesh Renderer

Neural 3D Mesh Renderer (CVPR 2018) This repo contains a PyTorch implementation of the paper Neural 3D Mesh Renderer by Hiroharu Kato, Yoshitaka Ushik

Daniilidis Group University of Pennsylvania 1k Jan 9, 2023
Red Team tool for exfiltrating files from a target's Google Drive that you have access to, via Google's API.

GD-Thief Red Team tool for exfiltrating files from a target's Google Drive that you(the attacker) has access to, via the Google Drive API. This includ

Antonio Piazza 39 Dec 27, 2022
A large dataset of 100k Google Satellite and matching Map images, resembling pix2pix's Google Maps dataset.

Larger Google Sat2Map dataset This dataset extends the aerial ⟷ Maps dataset used in pix2pix (Isola et al., CVPR17). The provide script download_sat2m

null 34 Dec 28, 2022
Google-drive-to-sqlite - Create a SQLite database containing metadata from Google Drive

google-drive-to-sqlite Create a SQLite database containing metadata from Google

Simon Willison 140 Dec 4, 2022
TensorFlow ROCm port

Documentation TensorFlow is an end-to-end open source platform for machine learning. It has a comprehensive, flexible ecosystem of tools, libraries, a

ROCm Software Platform 622 Jan 9, 2023
hipCaffe: the HIP port of Caffe

Caffe Caffe is a deep learning framework made with expression, speed, and modularity in mind. It is developed by the Berkeley Vision and Learning Cent

ROCm Software Platform 126 Dec 5, 2022
A data-driven maritime port simulator

PySeidon - A Data-Driven Maritime Port Simulator ?? Extendable and modular software for maritime port simulation. This software uses entity-component

null 6 Apr 10, 2022
Tensorflow port of a full NetVLAD network

netvlad_tf The main intention of this repo is deployment of a full NetVLAD network, which was originally implemented in Matlab, in Python. We provide

Robotics and Perception Group 225 Nov 8, 2022
A big endian Gentoo port developed on a Pine64.org RockPro64

Gentoo-aarch64_be A big endian Gentoo port developed on a Pine64.org RockPro64 The endian wars are over... little endian won. As a result, it is incre

Rory Bolt 6 Dec 7, 2022
A port of muP to JAX/Haiku

MUP for Haiku This is a (very preliminary) port of Yang and Hu et al.'s μP repo to Haiku and JAX. It's not feature complete, and I'm very open to sugg

null 18 Dec 30, 2022