Pytorch port of Google Research's LEAF Audio paper



Pytorch port of Google Research's LEAF Audio paper published at ICLR 2021.

This port is not completely finished, but the Leaf() frontend is fully ported over, functional and validated to have similar outputs to the original tensorflow implementation. A few small things are missing, such as the SincNet and SincNet+ implementations, a few different pooling layers, etc.

PLEASE leave issues, pull requests, comments, or anything you find in using this repository that may be of value to others who will try to use this.


From the root directory of this repo, run:

pip install -e .


leaf_audio_pytorch mirrors it's original respository; imports and arguments are the same.

import leaf_audio_pytorch.frontend as frontend

leaf = frontend.Leaf()

Installation for Developing

If you are looking to develop on this repo, the requirements.txt contains everything needed to run the torch and tf implementations of leaf audio simultaneously.

NOTE: There is some weird dependency stuff going on with the original leaf-audio repo. Seems like its a dependency issue with lingvo and waymo-open-dataset. These below commands are a workaround.

Install the packages required:

pip install -r requirements.txt --no-deps

Install the leaf-audio repo from Git SSH:

pip install git+ssh:// --no-deps

Then add the leaf_audio_pytorch package as well

python develop

At this point everything should be good to go! The scripts in test/ contains some testing code to validate the torch implementation mirrors tf.

Some Things to Keep in Mind (PLEASE READ)

  • When writing this port, I ran a debugger of the torch and tf implementations side by side and validated that each layer and operation mirrors the tensorflow implementation (to within a few significant digits, i.e. a tensor's values may variate by 0.001). There is one notable exception: The depthwise convolution within the GaussianLowpass pooling layer has a larger variation in tensor values, but the ported operation still produces similar outputs. I'm not sure why this operation is producing different values, but i'm currently looking into it. Please do your own due diligence in using this port and making sure this works as expected.

  • As of March 29, I finished the initial version of the port, but I have not tested Leaf() in a traning setting yet. Calling .backward() on Leaf() throws no errors, meaning backprop works as expected. However, I do not yet know how this will function during training.

  • As PyTorch and Tensorflow follow different tensor ordering conventions, Leaf() does all of its operations and outputs tensors with channels first.


All credit and attribution goes to Neil Zeghidour and the Google Research team who wrote the paper and created the Tensorflow implementation.

Please visit their GitHub repository and review their ICLR publication.

  • Backpropagation doesn't change weights of the convolutional layer

    Backpropagation doesn't change weights of the convolutional layer


    Thank you for creating this Github! I was wondering why does forward call of GaborConstraint class has "with torch.no_grad()" (line 28 in leaf_audio_pytorch/ It appears the convolutional layers weight doesn't change while backpropagation because of this. However, removing it I see that weights update. Is there a reason we need torch.no_grad()? Attached a sample code to compare the Gabor layer parameters before and after training a toy code.


    import torch
    from leaf_audio_pytorch import frontend
    import os
    os.environ["CUDA_VISIBLE_DEVICES"] = "0"
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    a = frontend.Leaf().to(device)
    ## A random input that is like an audio waveform
    x = torch.rand(1,1,97339).cuda(device, non_blocking=True)
    opt = torch.optim.Adam(a.parameters(), lr=0.002, betas=(0.9, 0.999), eps=1e-08, weight_decay=1e-6, amsgrad=True)
    for _ in range(100):
        b = a(x)
        ## Ground truth for the output for leaf module
        b_gt = torch.ones_like(b)
        loss = torch.norm((b-b_gt).view(1,-1),p=2,dim=-1)


    opened by sahu-ji 3
  • Can't run inference - RuntimeError: unsupported input to tensordot, got dims=0

    Can't run inference - RuntimeError: unsupported input to tensordot, got dims=0

    When I run the inference (like for example the one provided in the tests)

    import numpy as np
    import torch
    import leaf_audio_pytorch.frontend as torch_frontend
    py_leaf = torch_frontend.Leaf().cuda()
    # (batch_size, num_samples, 1)
    test_audio = np.random.random((8,15000,1)).astype(np.float32)
    # convert to channel first for pytorch
    t_audio = torch.Tensor(test_audio).permute(0,2,1).cuda()

    I get this error:

    RuntimeError                              Traceback (most recent call last)
    <ipython-input-27-eeba248942f0> in <module>
    ----> 1 print(py_leaf(t_audio))
    ~/Envs/leaf-audio-pytorch/lib/python3.7/site-packages/torch/nn/modules/ in _call_impl(self, *input, **kwargs)
       1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
       1050                 or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1051             return forward_call(*input, **kwargs)
       1052         # Do not call functions when jit is used
       1053         full_backward_hooks, non_full_backward_hooks = [], []
    ~/src/leaf-audio-pytorch/leaf_audio_pytorch/ in forward(self, x)
        131             outputs = outputs[:,:,1:]
    --> 133         outputs = self._complex_conv(outputs)
        134         outputs = self._activation(outputs)
        135         outputs = self._pooling(outputs)
    ~/Envs/leaf-audio-pytorch/lib/python3.7/site-packages/torch/nn/modules/ in _call_impl(self, *input, **kwargs)
       1049         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
       1050                 or _global_forward_hooks or _global_forward_pre_hooks):
    -> 1051             return forward_call(*input, **kwargs)
       1052         # Do not call functions when jit is used
       1053         full_backward_hooks, non_full_backward_hooks = [], []
    ~/src/leaf-audio-pytorch/leaf_audio_pytorch/ in forward(self, x)
         74             kernel = torch.gather(kernel, dim=0, index=filter_order)
    ---> 76         filters = impulse_responses.gabor_filters(kernel, self._kernel_size, self.gabor_filter_init_t)
         77         real_filters = torch.real(filters)
         78         img_filters = torch.imag(filters)
    ~/src/leaf-audio-pytorch/leaf_audio_pytorch/ in gabor_filters(kernel, size, t_tensor)
         33     return gabor_impulse_response(
         34         t_tensor,
    ---> 35         center=kernel[:, 0], fwhm=kernel[:, 1])
    ~/src/leaf-audio-pytorch/leaf_audio_pytorch/ in gabor_impulse_response(t, center, fwhm)
          9     """Computes the gabor impulse response."""
         10     denominator = 1.0 / (np.sqrt(2.0 * math.pi) * fwhm)
    ---> 11     gaussian = torch.exp(torch.tensordot(1.0 / (2. * fwhm**2), -t**2, dims=0)) # TODO: validate the dims here
         12     center_frequency_complex = center.type(torch.complex64)
         13     t_complex = t.type(torch.complex64)
    ~/Envs/leaf-audio-pytorch/lib/python3.7/site-packages/torch/ in tensordot(a, b, dims, out)
        928     if len(dims_a) == 0 or len(dims_b) == 0:
    --> 929         raise RuntimeError(f"unsupported input to tensordot, got dims={dims}")
        931     if out is None:
    RuntimeError: unsupported input to tensordot, got dims=0
    opened by dariocazzani 2
  • Why use complex number in  convolution

    Why use complex number in convolution

    Hi @denfed The paper says "we compute instead the convolution with 2N real-valued filters φ ̃ , n = 1, . . . , 2N , and perform squared l2-pooling". I am wandering why in the code we still use complex number when initializing the filter.

    Thanks for the great work! Junjie

    opened by jjjjohnson 1
  • Remove torch.no_grad call that breaks backprop on the complex conv

    Remove torch.no_grad call that breaks backprop on the complex conv

    Thanks to @sahu-ji, this is a bug fix that allows the complex conv to be backpropagatable. I'm not sure exactly why this is a bug, because all of the operations inside of the torch.no_grad() inside the GaborConstraint are not learnable, it is only a clipping (i.e. constraint) operation. But nonetheless it does break backprop. Not anymore :)

    opened by denfed 0
Dennis Fedorishin
UB | Computer Science PhD Candidate
Dennis Fedorishin
