TLDR; Train custom adaptive filter optimizers without hand tuning or extra labels.

Overview

AutoDSP

TLDR; Train custom adaptive filter optimizers without hand tuning or extra labels.

autodsp

About

Adaptive filtering algorithms are commonplace in signal processing and have wide-ranging applications from single-channel denoising to multi-channel acoustic echo cancellation and adaptive beamforming. Such algorithms typically operate via specialized online, iterative optimization methods and have achieved tremendous success, but require expert knowledge, are slow to develop, and are difficult to customize. In our work, we present a new method to automatically learn adaptive filtering update rules directly from data. To do so, we frame adaptive filtering as a differentiable operator and train a learned optimizer to output a gradient descent-based update rule from data via backpropagation through time. We demonstrate our general approach on an acoustic echo cancellation task (single-talk with noise) and show that we can learn high-performing adaptive filters for a variety of common linear and non-linear multidelayed block frequency domain filter architectures. We also find that our learned update rules exhibit fast convergence, can optimize in the presence of nonlinearities, and are robust to acoustic scene changes despite never encountering any during training.

arXiv: https://arxiv.org/abs/2110.04284

pdf: https://arxiv.org/pdf/2110.04284.pdf

Short video: https://www.youtube.com/watch?v=y51hUaw2sTg

Full video: https://www.youtube.com/watch?v=oe0owGeCsqI

Table of contents

Setup

Clone repo

git clone 
   
    
cd autodsp

   

Get The Data

# Install Git LFS if needed
git lfs install

# Move into folder that is one above 
   
    
cd 
    
     /../

# Clone MS data
git clone https://github.com/microsoft/AEC-Challenge AEC-Challenge


    
   

Configure Environment

First, edit the config file to point to the dataset you downloaded.

vim ./autodsp/__config__.py

Next, setup your anaconda environment

# Create a conda environment
conda create -n autodsp python=3.7

# Activate the environment
conda activate autodsp

# Install some tools
conda install -c conda-forge cudnn pip

# Install JAX
pip install --upgrade "jax[cuda111]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

# Install Haiku
pip install git+https://github.com/deepmind/dm-haiku

# Install pytorch for the dataloader
conda install pytorch cpuonly -c pytorch

You can also check out autodsp.yaml, the export from our conda environment. We found the most common culprit for jax or CUDA errors was a CUDA/cuDNN version mismatch. You can find more details on this in the jax official repo https://github.com/google/jax.

Install AutoDSP

cd autodsp
pip install -e ./

This will automatically install the dependeicies in setup.py.

Running an Experiment

# move into the experiment directory
cd experiments

The entry point to train and test models is jax_run.py. jax_run.py pulls configuration files from jax_train_config.py. The general format for launching a training run is

python jax_run.py --cfg 
   
     --GPUS 
     

    
   

where is a config specified in jax_train_config.py, is something like 0 1. You can automatically send logs to Weights and Biases by appending --wandb. This run will automatically generate a /ckpts/ directory and log checkpoints to it. You can grab a checkpoint and run it on the test set via

python jax_run.py --cfg 
   
     --GPUS 
    
      --epochs 
     
       --eval 

     
    
   

where is the same as training and is a single epoch like 100 or a list of epochs like 100, 200, 300. Running evaluation will also automatically dump a .pkl file with metrics in the same directory as the checkpoint.

An explicit example is

# run the training
python jax_run.py --cfg v2_filt_2048_1_hop_1024_lin_1e4_log_24h_10unroll_2deep_earlystop_echo_noise 
                --GPUS 0 1 2 3

# run evaluation on the checkpoint from epoch 100
python jax_run.py --cfg v2_filt_2048_1_hop_1024_lin_1e4_log_24h_10unroll_2deep_earlystop_echo_noise 
                --GPUS 0 --eval --epochs 100

You can find all the configurations from our paper in the jax_train_config.py file. Training can take up to a couple days depending on model size but will automatically stop when it hits the max epoch count or validation performance stops improving.

Copyright and license

University of Illinois Open Source License

Copyright © 2021, University of Illinois at Urbana Champaign. All rights reserved.

Developed by: Jonah Casebeer 1, Nicholas J. Bryan 2 and Paris Smaragdis 1,2

1: University of Illinois at Urbana-Champaign

2: Adobe Research

Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal with the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimers. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimers in the documentation and/or other materials provided with the distribution. Neither the names of Computational Audio Group, University of Illinois at Urbana-Champaign, nor the names of its contributors may be used to endorse or promote products derived from this Software without specific prior written permission. THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS WITH THE SOFTWARE.

You might also like...
Hardware accelerated, batchable and differentiable optimizers in JAX.

JAXopt Installation | Examples | References Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX. Installation JAXopt can be

Repository for open research on optimizers.

Open Optimizers Repository for open research on optimizers. This is a test in sharing research/exploration as it happens. If you use anything from thi

[CVPR 2021] Official PyTorch Implementation for
[CVPR 2021] Official PyTorch Implementation for "Iterative Filter Adaptive Network for Single Image Defocus Deblurring"

IFAN: Iterative Filter Adaptive Network for Single Image Defocus Deblurring Checkout for the demo (GUI/Google Colab)! The GUI version might occasional

Ever felt tired after preprocessing the dataset, and not wanting to write any code further to train your model? Ever encountered a situation where you wanted to record the hyperparameters of the trained model and able to retrieve it afterward? Models Playground is here to help you do that. Models playground allows you to train your models right from the browser. Quickly and easily create / train a custom DeepDream model
Quickly and easily create / train a custom DeepDream model

Dream-Creator This project aims to simplify the process of creating a custom DeepDream model by using pretrained GoogleNet models and custom image dat

A repo to show how to use custom dataset to train s2anet, and change backbone to resnext101
A repo to show how to use custom dataset to train s2anet, and change backbone to resnext101

A repo to show how to use custom dataset to train s2anet, and change backbone to resnext101

a delightful machine learning tool that allows you to train, test and use models without writing code
a delightful machine learning tool that allows you to train, test and use models without writing code

igel A delightful machine learning tool that allows you to train/fit, test and use models without writing code Note I'm also working on a GUI desktop

Ludwig is a toolbox that allows to train and evaluate deep learning models without the need to write code.
Ludwig is a toolbox that allows to train and evaluate deep learning models without the need to write code.

Translated in 🇰🇷 Korean/ Ludwig is a toolbox that allows users to train and test deep learning models without the need to write code. It is built on

Ludwig is a toolbox that allows to train and evaluate deep learning models without the need to write code.
Ludwig is a toolbox that allows to train and evaluate deep learning models without the need to write code.

Translated in 🇰🇷 Korean/ Ludwig is a toolbox that allows users to train and test deep learning models without the need to write code. It is built on

Comments
  • How to optimize the usage of memory?

    How to optimize the usage of memory?

    Hi, I set a new config, which needs 50GB GPU memory. I don't think need so much. There may be a way to optimize the train function. I set XLA_PYTHON_CLIENT_PREALLOCATE=false. I found the memory growth most after 'run_sequence_parallel'.

    There is my config.

    def v2_filt_2400_30_hop_160_lin_1e4_log_48h_10unroll_2deep_earlystop_echo_noise():
        sys_length = 2400
        n_blocks = 30
        hop = 160
        batch_sz = 4
        sr = 16000
        is_double_len = False
        cfg = {
            'is_double_len': is_double_len,
            'learnable_fixed_init': jax_lopt.init_learnable_fixed_format,
            'jax_init': jax_lopt.init_jax_opt_format,
            'haiku_init': jax_lopt.init_haiku_format,
            'optimizer_loss': jax_lopt.mean_loss,
    
            'haiku_init_kwargs': {
                'h_size': 48,
                'p': 10.0,
                'mu': 0.01,
                'grad_features': 'log_clamp',
                'rnn_depth': 2,
            },
    
            'optimizee_init': jax_block_id.optimizee_init,
            'optimizee_init_kwargs': {
                'sys_length': sys_length,
                'n_blocks': n_blocks,
                'has_nonlin': True,
            },
    
            'optimizee_train_loss': jax_block_id.train_loss,
            'test_metrics': {
                'erle_signal': jax_block_id.make_erle_signal_loss(hop, sys_length, sr, use_target=False),
                'erle_scalar': jax_block_id.make_erle_scalar_loss(hop, sys_length, sr, use_target=False),
            },
    
            'train_data_gen': jax_aec_dset.get_msft_data_gen(mode='train', sr=sr, batch_sz=batch_sz, num_workers=6),
            'test_data_gen': jax_aec_dset.get_msft_data_gen(mode='test', sr=sr, batch_sz=1, num_workers=1, is_double_len=is_double_len, is_iter=False),
            'val_data_gen': jax_aec_dset.get_msft_data_gen(mode='val', sr=sr, batch_sz=1, num_workers=1, is_iter=False),
            'sr': sr,
            'sys_length': sys_length,
            'unroll': 10,
            'hop': hop,
            'batch_sz': batch_sz,
            'lr': 1e-4,
            'early_stop_wait': 25,
            'half_lr_wait': 10,
            'n_epochs': 501,
            'n_training_per_epoch': 100,
            'n_val_per_epoch': 100,
            'n_epochs_between_baseline': 50,
            'n_tests': 20,
            'n_tune': 20,
            'ckpt_save_dir': '../ckpts',
        }
    
        return cfg
    
    opened by Chen1399 2
Owner
Jonah Casebeer
CS Ph.D. student at UIUC
Jonah Casebeer
TLDR: Twin Learning for Dimensionality Reduction

TLDR (Twin Learning for Dimensionality Reduction) is an unsupervised dimensionality reduction method that combines neighborhood embedding learning with the simplicity and effectiveness of recent self-supervised learning losses.

NAVER 105 Dec 28, 2022
noisy labels; missing labels; semi-supervised learning; entropy; uncertainty; robustness and generalisation.

ProSelfLC: CVPR 2021 ProSelfLC: Progressive Self Label Correction for Training Robust Deep Neural Networks For any specific discussion or potential fu

amos_xwang 57 Dec 4, 2022
A hand tracking demo made with mediapipe where you can control lights with pinching your fingers and moving your hand up/down.

HandTrackingBrightnessControl A hand tracking demo made with mediapipe where you can control lights with pinching your fingers and moving your hand up

Teemu Laurila 19 Feb 12, 2022
Hand-distance-measurement-game - Hand Distance Measurement Game

Hand Distance Measurement Game This is program is made to calculate the distance

Priyansh 2 Jan 12, 2022
MohammadReza Sharifi 27 Dec 13, 2022
Ensemble Knowledge Guided Sub-network Search and Fine-tuning for Filter Pruning

Ensemble Knowledge Guided Sub-network Search and Fine-tuning for Filter Pruning This repository is official Tensorflow implementation of paper: Ensemb

Seunghyun Lee 12 Oct 18, 2022
Black-Box-Tuning - Black-Box Tuning for Language-Model-as-a-Service

Black-Box-Tuning Source code for paper "Black-Box Tuning for Language-Model-as-a

Tianxiang Sun 149 Jan 4, 2023
Saeed Lotfi 28 Dec 12, 2022
Differentiable Optimizers with Perturbations in Pytorch

Differentiable Optimizers with Perturbations in PyTorch This contains a PyTorch implementation of Differentiable Optimizers with Perturbations in Tens

Jake Tuero 54 Jun 22, 2022