TTS-GAN: A Transformer-based Time-Series Generative Adversarial Network
This repository contains code from the paper "TTS-GAN: A Transformer-based Time-Series Generative Adversarial Network"
Abstract: Time-series datasets used in machine learning applications often are small in size, making the training of deep neural network architectures ineffective. For time series, the suite of data augmentation tricks we can use to expand the size of the dataset is limited by the need to maintain the basic properties of the signal. Data generated by a Generative Adversarial Network (GAN) can be utilized as another data augmentation tool. RNN-based GANs suffer from the fact that they cannot effectively model long sequences of data points with irregular temporal relations. To tackle these problems, we introduce TTS-GAN, a transformer-based GAN which can successfully generate realistic synthetic time series data sequences of arbitrary length, similar to the original ones. Both the generator and discriminator networks of the GAN model are built using a pure transformer encoder architecture. We use visualizations to demonstrate the similarity of real and generated time series and a simple classification task that shows how we can use synthetically generated data to augment real data and improve classification accuracy.
Transformer GAN generate synthetic time-series data
The TTS-GAN Architecture
The TTS-GAN model architecture is shown in the upper figure. It contains two main parts, a generator, and a discriminator. Both of them are built based on the transformer encoder architecture. An encoder is a composition of two compound blocks. A multi-head self-attention module constructs the first block and the second block is a feed-forward MLP with GELU activation function. The normalization layer is applied before both of the two blocks and the dropout layer is added after each block. Both blocks employ residual connections.
The time series data processing step
We view a time-series data sequence like an image with a height equal to 1. The number of time-steps is the width of an image, W. A time-series sequence can have a single channel or multiple channels, and those can be viewed as the number of channels (RGB) of an image, C. So an input sequence can be represented with the matrix of size (Batch Size, C, 1, W). Then we choose a patch size N to divide a sequence into W / N patches. We then add a soft positional encoding value by the end of each patch, the positional value is learned during model training. Each patch will then have the data shape (Batch Size, C, 1, (W/N) + 1) This process is shown in the upper figure.
Several images of the TTS-GAN project
Saved pre-trained GAN model checkpoints
The UniMiB dataset dataLoader used for loading GAN model training/testing data
Load real running and jumping data from UniMiB dataset
Load Synthetic running and jumping data from the pre-trained GAN models
The GAN model training and evaluation functions
The major GAN model training file
The help functions to draw T-SNE and PCA plots
The adamw function file
The parse function used for reading parameters to train_GAN.py file
Run this file to start training the Jumping GAN model
Run this file to start training the Running GAN model
To train the Running data GAN model:
To train the Jumping data GAN model:
A simple example of visualizing the similarity between the synthetic running&jumping data and the real running&jumping data: