Autoregressive Models in PyTorch.

Overview

Autoregressive

This repository contains all the necessary PyTorch code, tailored to my presentation, to train and generate data from WaveNet-like autoregressive models.

For presentation purposes, the WaveNet-like models are applied to randomized Fourier series (1D) and MNIST (2D). In the figure below, two WaveNet-like models with different training settings make an n-step prediction on a periodic time-series from the validation dataset.

Advanced functions show how to generate MNIST images and how to estimate the MNIST digit class (progressively) p(y=class|x) from observed pixels using a conditional WaveNet p(x|y=class) and Bayes rule. Left: sampled MNIST digits, right: progressive class estimates as more pixels are observed.

Note, this library does not implement (Gated) PixelCNNs, but unrolls images for the purpose of processing in WaveNet architectures. This works surprisingly well.

Features

Currently the following features are implemented

  • WaveNet architecture and training as proposed in (oord2016wavenet)
  • Conditioning support (oord2016wavenet)
  • Fast generation based on (paine2016fast)
  • Fully differentiable n-step unrolling in training (heindl2021autoreg)
  • 2D image generation, completion, classification, and progressive classification support based on MNIST dataset
  • A randomized Fourier dataset

Presentation

A detailed presentation with theoretical background, architectural considerations and experiments can be found below.

The presentation source as well as all generated images are public domain. In case you find them useful, please leave a citation (see References below). All presentation sources can be found in etc/presentation. The presentation is written in markdown using Marp, graph diagrams are created using yEd.

If you spot errors or if case you have suggestions for improvements, please let me know by opening an issue.

Installation

To install run,

pip install https://github.com/cheind/autoregressive.git#egg=autoregressive[dev]

which requires Python 3.9 and a recent PyTorch > 1.9

Usage

The library comes with a set of pre-trained models in models/. The following commands use those models to make various predictions. Many listed commands come with additional parameters; use --help to get additional information.

1D Fourier series

Sample new signals from scratch

python -m autoregressive.scripts.wavenet_signals sample --config "models/fseries_q127/config.yaml" --ckpt "models/fseries_q127/xxxxxx.ckpt" --condition 4 --horizon 1000

The default models conditions on the periodicity of the signal. For the pre-trained model the value range is int: [0..4], corresponding to periods of 5-10secs.


Predict the shape of partially observable curves.

python -m autoregressive.scripts.wavenet_signals predict --config "models/fseries_q127/config.yaml" --ckpt "models/fseries_q127/xxxxxx.ckpt" --horizon 1500 --num_observed 50 --num_trajectories 20 --num_curves 1 --show_confidence true

2D MNIST

To sample from the class-conditional model

python -m autoregressive.scripts.wavenet_mnist sample --config "models/mnist_q2/config.yaml" --ckpt "models/mnist_q2/xxxxxx.ckpt"

Generate images conditioned on the digit class and observed pixels.

python -m autoregressive.scripts.wavenet_mnist predict --config "models/mnist_q2/config.yaml" --ckpt "models/mnist_q2/xxxxxx.ckpt" 

To perform classification

python -m autoregressive.scripts.wavenet_mnist classify --config "models/mnist_q2/config.yaml" --ckpt "models/mnist_q2/xxxxxx.ckpt"

Train

To train / reproduce a model

python -m autoregressive.scripts.train fit --config "models/mnist_q2/config.yaml"

Progress is logged to Tensorboard

tensorboard --logdir lightning_logs

To generate a training configuration file for a specific dataset use

python -m autoregressive.scripts.train fit --data autoregressive.datasets.FSeriesDataModule --print_config > fseries_config.yaml

Test

To run the tests

pytest

References

@misc{heindl2021autoreg, 
  title={Autoregressive Models}, 
  journal={PROFACTOR Journal Club}, 
  author={Heindl, Christoph},
  year={2021},
  howpublished={\url{https://github.com/cheind/autoregressive}}
}

@article{oord2016wavenet,
  title={Wavenet: A generative model for raw audio},
  author={Oord, Aaron van den and Dieleman, Sander and Zen, Heiga and Simonyan, Karen and Vinyals, Oriol and Graves, Alex and Kalchbrenner, Nal and Senior, Andrew and Kavukcuoglu, Koray},
  journal={arXiv preprint arXiv:1609.03499},
  year={2016}
}

@article{paine2016fast,
  title={Fast wavenet generation algorithm},
  author={Paine, Tom Le and Khorrami, Pooya and Chang, Shiyu and Zhang, Yang and Ramachandran, Prajit and Hasegawa-Johnson, Mark A and Huang, Thomas S},
  journal={arXiv preprint arXiv:1611.09482},
  year={2016}
}

@article{oord2016conditional,
  title={Conditional image generation with pixelcnn decoders},
  author={Oord, Aaron van den and Kalchbrenner, Nal and Vinyals, Oriol and Espeholt, Lasse and Graves, Alex and Kavukcuoglu, Koray},
  journal={arXiv preprint arXiv:1606.05328},
  year={2016}
}
Comments
  • Switch to quantizated encoding and one hot

    Switch to quantizated encoding and one hot

    discussion https://github.com/ibab/tensorflow-wavenet/issues/83 https://github.com/ibab/tensorflow-wavenet/issues/219

    also we could then use gumpel-trick to reparametrize categorical dist and allow unrolling during training.

    in issue 83 above, note that they speak about a larger kernel for the initial conv.

    opened by cheind 4
  • add image support

    add image support

    early support in image-support branch. v48 on gpus. seems to be a quite nice model so far. R=347, global condition=class,

    top row: observed 392 px bottom row: generated 392px wavenet-mnist-r347-o392-g392

    top row: observed 200 px bottom row: generated 584 px wavenet-mnist-r347-o200-g584

    top row: observed 1 px bottom row: generated 783 px wavenet-mnist-r347-o1-g783b

    top row: observed 200 px bottom row: generated 584 px wavenet-mnist-r347-o200-g584b

    opened by cheind 3
  • add positional encoding

    add positional encoding

    an attempt to tell the model where it is in the sequence. Good start would be transformers positional encoding procedure.

    See e.g https://kazemnejad.com/blog/transformer_architecture_positional_encoding/

    We could at the positional embedding as local conditioning, which would be required first.

    opened by cheind 3
  • validation should use an accuracy measure instead of loss

    validation should use an accuracy measure instead of loss

    consider DTW /FastDTW to account for phase-shifts

    https://medium.com/walmartglobaltech/time-series-similarity-using-dynamic-time-warping-explained-9d09119e48ec https://github.com/slaypni/fastdtw/blob/master/fastdtw/fastdtw.py

    opened by cheind 3
  • add gumbel-noise based sampler

    add gumbel-noise based sampler

    allows us to employ reparametrization trick on categoricals and hence backprop through multiple steps that require sampling. see https://www.youtube.com/watch?v=JFgXEbgcT7g

    opened by cheind 2
  • Add conditioning

    Add conditioning

    Paper also describes local and global conditioning. In time series prediction we can treat these similar to positional encoding, i.e. season, time, etc. This would be similar to feeding in time as second channel input?

    https://stats.stackexchange.com/questions/388130/wavenet-global-and-local-conditioning

    opened by cheind 2
  • plot weight histograms

    plot weight histograms

    if self.histograms: layer = 'layer{}'.format(layer_index) tf.histogram_summary(layer + '_filter', weights_filter) tf.histogram_summary(layer + '_gate', weights_gate) tf.histogram_summary(layer + '_dense', weights_dense) tf.histogram_summary(layer + '_skip', weights_skip)

    opened by cheind 1
  • config.yaml files contain a `fit:` root element

    config.yaml files contain a `fit:` root element

    this seems to cause problems when trying to actually perform a fit.

    python -m autoregressive.scripts.train fit --config models\fseries_q127\config.yaml
    usage: train.py [-h] [--config CONFIG] [--print_config [={comments,skip_null}+]] {fit,validate,test,predict,tune} ...
    train.py: error: Configuration check failed :: Key "fit.data" is required but not included in config object or its value is None.```
    
    However
    

    python -m autoregressive.scripts.train --config models\fseries_q127\config.yaml fit

    works
    
    With the most current version 
    

    python -m autoregressive.scripts.train fit --data autoregressive.datasets.MNISTDataModule --print_config > config.yaml

     does not generate a `fit:`.
    opened by cheind 0
  • quantizer theory

    quantizer theory

    https://colab.research.google.com/github/spatialaudio/digital-signal-processing-lecture/blob/master/quantization/linear_uniform_characteristic.ipynb#scrollTo=Oc-w6mUgXiiP

    opened by cheind 0
  • What is meant by global conditioning only.

    What is meant by global conditioning only.

    I am attempting to use your wavenet implementation to model some climate data, where my condition vector changes with time. The code mentions only global conditioning is currently supported. What exactly does this mean from an architecture perspective?

    opened by btickell 1
  • Reconsider the softmax distribution

    Reconsider the softmax distribution

    PixelCNN++ argues that neighboring intensity usually correlate is not captured by the softmax distribution. Instead, they propose a mixture model consisting of the logitic distribution (like normal but with heavier tails). https://arxiv.org/pdf/1701.05517.pdf

    opened by cheind 2
  • Decouple input and output representations

    Decouple input and output representations

    currently we assume that what we get as input is what we will predict as output (just shifted). However, thinking towards other research areas it might make sense that we rework that more generally:

    model
      input: BxIxT
      output: BxQxT
    

    where I might match Q but does not have to. In the training we would then have code like the following

    def training_step(batch): 
      inputs = batch['x']
      if 't' in batch:
        targets = batch['t'] # allows us to provide alternative targets
      elif I == Q:
        targets = inputs[..., 1:]
        inputs = inputs[..., :-1]
      else:
        raise ValueError(...)
    
      logits = self.forward(inputs)
      loss = ce(logits, targets)
    

    what's more is that we need to think about input transformers. Currently we use one-hot encoding hardwired into the model. We might instead consider a differentiable input_transform that is given to the model upon initialization. This would allow us to use differentiable embedding strategies.

    opened by cheind 2
Owner
Christoph Heindl
I am a scientist at PROFACTOR/JKU working at the interface between computer vision, robotics and deep learning.
Christoph Heindl
Adversarial Attacks on Probabilistic Autoregressive Forecasting Models.

Attack-Probabilistic-Models This is the source code for Adversarial Attacks on Probabilistic Autoregressive Forecasting Models. This repository contai

SRI Lab, ETH Zurich 25 Sep 14, 2022
Generative Autoregressive, Normalized Flows, VAEs, Score-based models (GANVAS)

GANVAS-models This is an implementation of various generative models. It contains implementations of the following: Autoregressive Models: PixelCNN, G

MRSAIL (Mini Robotics, Software & AI Lab) 6 Nov 26, 2022
RITA is a family of autoregressive protein models, developed by LightOn in collaboration with the OATML group at Oxford and the Debora Marks Lab at Harvard.

RITA: a Study on Scaling Up Generative Protein Sequence Models RITA is a family of autoregressive protein models, developed by a collaboration of Ligh

LightOn 69 Dec 22, 2022
Pytorch Implementation of Google's Parallel Tacotron 2: A Non-Autoregressive Neural TTS Model with Differentiable Duration Modeling

Parallel Tacotron2 Pytorch Implementation of Google's Parallel Tacotron 2: A Non-Autoregressive Neural TTS Model with Differentiable Duration Modeling

Keon Lee 170 Dec 27, 2022
This is a template for the Non-autoregressive Deep Learning-Based TTS model (in PyTorch).

Non-autoregressive Deep Learning-Based TTS Template This is a template for the Non-autoregressive TTS model. It contains Data Preprocessing Pipeline D

Keon Lee 13 Dec 5, 2022
Pytorch implementation of “Recursive Non-Autoregressive Graph-to-Graph Transformer for Dependency Parsing with Iterative Refinement”

Graph-to-Graph Transformers Self-attention models, such as Transformer, have been hugely successful in a wide range of natural language processing (NL

Idiap Research Institute 40 Aug 14, 2022
PyTorch Implementation of VAENAR-TTS: Variational Auto-Encoder based Non-AutoRegressive Text-to-Speech Synthesis.

VAENAR-TTS - PyTorch Implementation PyTorch Implementation of VAENAR-TTS: Variational Auto-Encoder based Non-AutoRegressive Text-to-Speech Synthesis.

Keon Lee 67 Nov 14, 2022
PyTorch Implementation of "Non-Autoregressive Neural Machine Translation"

Non-Autoregressive Transformer Code release for Non-Autoregressive Neural Machine Translation by Jiatao Gu, James Bradbury, Caiming Xiong, Victor O.K.

Salesforce 261 Nov 12, 2022
Implementation of the paper NAST: Non-Autoregressive Spatial-Temporal Transformer for Time Series Forecasting.

Non-AR Spatial-Temporal Transformer Introduction Implementation of the paper NAST: Non-Autoregressive Spatial-Temporal Transformer for Time Series For

Chen Kai 66 Nov 28, 2022
TalkNet 2: Non-Autoregressive Depth-Wise Separable Convolutional Model for Speech Synthesis with Explicit Pitch and Duration Prediction.

TalkNet 2 [WIP] TalkNet 2: Non-Autoregressive Depth-Wise Separable Convolutional Model for Speech Synthesis with Explicit Pitch and Duration Predictio

Rishikesh (ऋषिकेश) 69 Dec 17, 2022
Semi-Autoregressive Transformer for Image Captioning

Semi-Autoregressive Transformer for Image Captioning Requirements Python 3.6 Pytorch 1.6 Prepare data Please use git clone --recurse-submodules to clo

YE Zhou 23 Dec 9, 2022
The official implementation of VAENAR-TTS, a VAE based non-autoregressive TTS model.

VAENAR-TTS This repo contains code accompanying the paper "VAENAR-TTS: Variational Auto-Encoder based Non-AutoRegressive Text-to-Speech Synthesis". Sa

THUHCSI 138 Oct 28, 2022
Implementation of "Glancing Transformer for Non-Autoregressive Neural Machine Translation"

GLAT Implementation for the ACL2021 paper "Glancing Transformer for Non-Autoregressive Neural Machine Translation" Requirements Python >= 3.7 Pytorch

null 117 Jan 9, 2023
Code for "Causal autoregressive flows" - AISTATS, 2021

Code for "Causal Autoregressive Flow" This repository contains code to run and reproduce experiments presented in Causal Autoregressive Flows, present

Ricardo Pio Monti 35 Dec 16, 2022
SlotRefine: A Fast Non-Autoregressive Model forJoint Intent Detection and Slot Filling

SlotRefine: A Fast Non-Autoregressive Model for Joint Intent Detection and Slot Filling Reference Main paper to be cited (Di Wu et al., 2020) @article

Moore 34 Nov 3, 2022
[ACM MM 2021] Diverse Image Inpainting with Bidirectional and Autoregressive Transformers

Diverse Image Inpainting with Bidirectional and Autoregressive Transformers Installation pip install -r requirements.txt Dataset Preparation Given the

Yingchen Yu 25 Nov 9, 2022
Official implementation of the paper Chunked Autoregressive GAN for Conditional Waveform Synthesis

Chunked Autoregressive GAN (CARGAN) Official implementation of the paper Chunked Autoregressive GAN for Conditional Waveform Synthesis [paper] [compan

Descript 150 Dec 6, 2022
Facebook Research 605 Jan 2, 2023