Official implementation for NIPS'17 paper: PredRNN: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal LSTMs.

Overview

PredRNN: A Recurrent Neural Network for Spatiotemporal Predictive Learning

The predictive learning of spatiotemporal sequences aims to generate future images by learning from the historical context, where the visual dynamics are believed to have modular structures that can be learned with compositional subsystems.

First version at NeurIPS 2017

This repo first contains a PyTorch implementation of PredRNN (2017) [paper], a recurrent network with a pair of memory cells that operate in nearly independent transition manners, and finally form unified representations of the complex environment.

Concretely, besides the original memory cell of LSTM, this network is featured by a zigzag memory flow that propagates in both bottom-up and top-down directions across all layers, enabling the learned visual dynamics at different levels of RNNs to communicate.

New in PredRNN-V2 (2021)

This repo also includes the implementation of PredRNN-V2 (2021) [paper], which improves PredRNN in the following two aspects.

1. Memory Decoupling

We find that the pair of memory cells in PredRNN contain undesirable, redundant features, and thus present a memory decoupling loss to encourage them to learn modular structures of visual dynamics.

decouple

2. Reverse Scheduled Sampling

Reverse scheduled sampling is a new curriculum learning strategy for seq-to-seq RNNs. As opposed to scheduled sampling, it gradually changes the training process of the PredRNN encoder from using the previously generated frame to using the previous ground truth. Benefits: (1) It makes the training converge quickly by reducing the encoder-forecaster training gap. (2) It enforces the model to learn more from long-term input context.

rss

Evaluation in LPIPS

LPIPS is more sensitive to perceptual human judgments, the lower the better.

Moving MNIST KTH action
PredRNN 0.109 0.204
PredRNN-V2 0.071 0.139

Prediction examples

mnist

kth

radar

Get Started

  1. Install Python 3.7, PyTorch 1.3, and OpenCV 3.4.
  2. Download data. This repo contains code for two datasets: the Moving Mnist dataset and the KTH action dataset.
  3. Train the model. You can use the following bash script to train the model. The learned model will be saved in the --save_dir folder. The generated future frames will be saved in the --gen_frm_dir folder.
  4. You can get pretrained models from here.
cd mnist_script/
sh predrnn_mnist_train.sh
sh predrnn_v2_mnist_train.sh

cd kth_script/
sh predrnn_kth_train.sh
sh predrnn_v2_kth_train.sh

Citation

If you find this repo useful, please cite the following papers.

@inproceedings{wang2017predrnn,
  title={{PredRNN}: Recurrent Neural Networks for Predictive Learning Using Spatiotemporal {LSTM}s},
  author={Wang, Yunbo and Long, Mingsheng and Wang, Jianmin and Gao, Zhifeng and Yu, Philip S},
  booktitle={Advances in Neural Information Processing Systems},
  pages={879--888},
  year={2017}
}

@misc{wang2021predrnn,
      title={{PredRNN}: A Recurrent Neural Network for Spatiotemporal Predictive Learning}, 
      author={Wang, Yunbo and Wu, Haixu and Zhang, Jianjin and Gao, Zhifeng and Wang, Jianmin and Yu, Philip S and Long, Mingsheng},
      year={2021},
      eprint={2103.09504},
      archivePrefix={arXiv},
}
Comments
  • Do you have the kth_action dataset?

    Do you have the kth_action dataset?

    Do you have the kth_action dataset ? I use the kth_action dataset from the official website, and preprocess the video to frames. But I cannot get the training set of 108,717 and test set of 4,086 sequences as the paper mention. I get like below:

    training set:
    there are 183861 pictures
    there are 30379 sequences
    test set:
    there are 105855 pictures
    there are 17145 sequences
    

    Can u give me the preprocess method or the preprocessed dataset ?

    opened by toddwyl 9
  • Question about reshape_patch function

    Question about reshape_patch function

    what the role of reshape_patch in core/utils/preprocess? I found that adding the patch_size parameter can significantly speed up the training process and reduce the cuda memory.

    And I test function reshape_patch using the code following:

        img = cv2.imread('cat.jpeg', 0)
        img = img[np.newaxis, np.newaxis, :, :, np.newaxis]
        img_patched = reshape_patch(img, 3)
    

    and I show the img(a gray image) and img_patched: 1624435126(1)

    So is this function used to reduce the spatial resolution of the image?

    And thanks for your good work!

    opened by Prevalenter 5
  • predicted results

    predicted results

    @wuhaixu2016 Hello, thank you for sharing your work. My test results on the KTH dataset are shown below. 微信图片_20210730163626 My understanding is to input 20 frames of images and predict 19 frames. However, why are these two sequences almost identical instead of predicting future actions.

    Looking forward to your reply, thank you so much!!!

    opened by buaa-luzhi 4
  • Prediction images don't change.

    Prediction images don't change.

    Hi,

    I ultilized predRNN and your traininig strategy (i.e., combine reverse schedule sampling and schedule sampling) to give a soil moisture forecasting. We ultilized 7 days soil moisture to predict it on future 7 days. However, I found the prediction images can't capture the evolution of soil moisture during forecasting steps, and give the same pattern of soil moisture on step 8 (see attached figure).

    Can you give me some suggestions? Thanks a lot !

    Lu

    fee08917879e8bf9e882655329e5c93

    opened by leelew 3
  • Whether the Moving_MNIST dataset volume in PredRNN++ is the same as the code repo?

    Whether the Moving_MNIST dataset volume in PredRNN++ is the same as the code repo?

    'Hi, this code repo is correct. We will rephrase the paper soon.

    Originally posted by @wuhaixu2016 in https://github.com/thuml/predrnn-pytorch/issues/26#issuecomment-968199072'

    Thanks for your reply. I couldn't reproduce the performance on Moving_MNIST in PredRNN++ anyway. So I would like to know whether the Moving_MNIST dataset volume in PredRNN++ is the same as the code repo? If not, which volume setting is the experiment in PredRNN++ based on? Looking forward to your reply. Thanks very much!

    opened by ILoveStudying 3
  • misconfigured parameter `num_action_ch` for `action_cond_predrnn_v2`

    misconfigured parameter `num_action_ch` for `action_cond_predrnn_v2`

    Hi there,

    I am having issues using the action-conditional PredRNNV2 for inference.

    The way it seems to work (action_injection=concat): Load the actions, grid-repeat them and concat the actual video data and the resulting action tensor channel-wise. Then, use reshape_patch() and pass the input to the model, resulting in a tensor of shape [batch, seq_length, height // patch_size, width // patch_size, (img_ch + action_ch) * patch_size ** 2].

    For the action-conditional PredRNNV2 model however, the parameter num_action_ch is used directly for the input channels for the conv layers instead of num_action_ch * patch_size ** 2. For me, this leads to runtime shape mismatches in forward(). Is this an error or did I get it wrong somehow?

    opened by Flunzmas 2
  • Random output images

    Random output images

    Hi, I'm using predrnn to predict future frames of experimental acquired data. To augment the data, I had cut each frame into overlapping tiles, and trained predrnn on those tiles.

    Currently I'm testing how good predrnn predicts. To do so, I'm using the trained model on data, test.npz, that had not previously seen. This data is also composed of overlapping tiles, the collection of which represents different frames.

    The problem I have is this: once predictions are made, the predicted tiles are generated in the corresponding folders, however these tiles are somewhat in a random order, or at least I'm sure they're not in the same order as in the test.npz file that was fed to the trained predrnn model. My question is, is it possible to preserve the same order? The reason for this is that I want to reconstruct frames from the predicted tiles, and if the tiles' order changes, it's very difficult to do a proper reconstruction.

    Thanks Miguel

    opened by miguel-fc 2
  • What's the validation dataset for KTH_action, TaxiBJ, and Human3.6M?

    What's the validation dataset for KTH_action, TaxiBJ, and Human3.6M?

    Hi, I see that there are only training and test datasets for above KTH_action, TaxiBJ, and Human3.6M. So, what's about the validation dataset or there is no validation sequences in these datasets? Thanks!

    opened by ILoveStudying 2
  • Moving MNIST 3 digits

    Moving MNIST 3 digits

    Hi, Thanks for the repo. I could get good results with your pretrained model. May I know how we can use the pretrained model for 3 digits Moving MNIST?

    opened by Mareeta26 2
  • Can the length of testing input sequence be different from that of the training input

    Can the length of testing input sequence be different from that of the training input

    Hi, Thank you for sharing with the great work. I wonder if the length of input sequence for training and testing can be different? To be specific, for training, a length-10 video clip is fed into the network, while for testing, only one frame is used as input.

    opened by chisam0217 2
  • Pretrained model-Unexpected key

    Pretrained model-Unexpected key

    Hi, I tried testing the pretrained model of predrnn-v2 for moving-mnist dataset. But I get the following error. Can you please check the reason?

    RuntimeError: Error(s) in loading state_dict for RNN: Unexpected key(s) in state_dict: "adapter.weight".

    Thanks!!

    opened by Mareeta26 1
  • Update import

    Update import

    "from skimage.measure import compare_ssim" is deprecated. Change it to: "from skimage.metrics import structural_similarity as compare_ssim"

    opened by GTziolas 0
  • FileNotFoundError: [Errno 2] No such file or directory

    FileNotFoundError: [Errno 2] No such file or directory

    When I try to execute the Moving MNIST script using a pretrained model (with arguments "--save_dir" and "--pretrained_model"), the error in the title appears. I believe the lines of code posted below cause the error in the title.

    https://github.com/thuml/predrnn-pytorch/blob/36ba2b63fa96da3a60c32127e48f517e04062201/run.py#L210-L212

    opened by GTziolas 1
  • Generate new Mnist dataset

    Generate new Mnist dataset

    Hi, As mentioned in previous issue here, may I know how did you include items like "clips", "dims" in the MNIST dataset?

    Thank you!

    opened by Mareeta26 0
  • The traffic4cast model

    The traffic4cast model

    In your article you wrote in the traffic4cast section:

    To cope with high-dimensional input frames, we apply the autoencoder architecture of U-Net [88] to the network backbone of PredRNN. Specifically, the decoder of U-Net contains four ST- LSTM layers, and the CNN encoder takes both traffic flow maps and spatiotemporal memory states as inputs.

    The code for this model is not available on this GitHub repo, can you make it available please.

    Thanks a lot!

    Kind regards, Sébastien de Blois

    opened by Scienceseb 0
Owner
THUML: Machine Learning Group @ THSS
Machine Learning Group, School of Software, Tsinghua University
THUML: Machine Learning Group @ THSS
PyTorch Code of "Memory In Memory: A Predictive Neural Network for Learning Higher-Order Non-Stationarity from Spatiotemporal Dynamics"

Memory In Memory Networks It is based on the paper Memory In Memory: A Predictive Neural Network for Learning Higher-Order Non-Stationarity from Spati

Yang Li 12 May 30, 2022
Implementation of Bidirectional Recurrent Independent Mechanisms (Learning to Combine Top-Down and Bottom-Up Signals in Recurrent Neural Networks with Attention over Modules)

BRIMs Bidirectional Recurrent Independent Mechanisms Implementation of the paper Learning to Combine Top-Down and Bottom-Up Signals in Recurrent Neura

Sarthak Mittal 26 May 26, 2022
This repository is an open-source implementation of the ICRA 2021 paper: Locus: LiDAR-based Place Recognition using Spatiotemporal Higher-Order Pooling.

Locus This repository is an open-source implementation of the ICRA 2021 paper: Locus: LiDAR-based Place Recognition using Spatiotemporal Higher-Order

Robotics and Autonomous Systems Group 96 Dec 15, 2022
PyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper "A Closer Look at Spatiotemporal Convolutions for Action Recognition"

R2Plus1D-PyTorch PyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper "A Closer Look at Spatiotemporal

Irhum Shafkat 342 Dec 16, 2022
Wind Speed Prediction using LSTMs in PyTorch

Implementation of Deep-Forecast using PyTorch Deep Forecast: Deep Learning-based Spatio-Temporal Forecasting Adapted from original implementation Setu

Onur Kaplan 151 Dec 14, 2022
Repo for flood prediction using LSTMs and HAND

Abstract Every year, floods cause billions of dollars’ worth of damages to life, crops, and property. With a proper early flood warning system in plac

null 1 Oct 27, 2021
Multivariate Time Series Forecasting with efficient Transformers. Code for the paper "Long-Range Transformers for Dynamic Spatiotemporal Forecasting."

Spacetimeformer Multivariate Forecasting This repository contains the code for the paper, "Long-Range Transformers for Dynamic Spatiotemporal Forecast

QData 440 Jan 2, 2023
Official repository for the paper "Can You Learn an Algorithm? Generalizing from Easy to Hard Problems with Recurrent Networks"

Easy-To-Hard The official repository for the paper "Can You Learn an Algorithm? Generalizing from Easy to Hard Problems with Recurrent Networks". Gett

Avi Schwarzschild 52 Sep 8, 2022
Code for the paper "Ordered Neurons: Integrating Tree Structures into Recurrent Neural Networks"

ON-LSTM This repository contains the code used for word-level language model and unsupervised parsing experiments in Ordered Neurons: Integrating Tree

Yikang Shen 572 Nov 21, 2022
An implementation of DeepMind's Relational Recurrent Neural Networks in PyTorch.

relational-rnn-pytorch An implementation of DeepMind's Relational Recurrent Neural Networks (Santoro et al. 2018) in PyTorch. Relational Memory Core (

Sang-gil Lee 241 Nov 18, 2022
LSTMs (Long Short Term Memory) RNN for prediction of price trends

Price Prediction with Recurrent Neural Networks LSTMs BTC-USD price prediction with deep learning algorithm. Artificial Neural Networks specifically L

null 5 Nov 12, 2021
Code and datasets for the paper "Combining Events and Frames using Recurrent Asynchronous Multimodal Networks for Monocular Depth Prediction" (RA-L, 2021)

Combining Events and Frames using Recurrent Asynchronous Multimodal Networks for Monocular Depth Prediction This is the code for the paper Combining E

Robotics and Perception Group 69 Dec 26, 2022
A multi-entity Transformer for multi-agent spatiotemporal modeling.

baller2vec This is the repository for the paper: Michael A. Alcorn and Anh Nguyen. baller2vec: A Multi-Entity Transformer For Multi-Agent Spatiotempor

Michael A. Alcorn 56 Nov 15, 2022
An attempt at the implementation of Glom, Geoffrey Hinton's new idea that integrates neural fields, predictive coding, top-down-bottom-up, and attention (consensus between columns)

GLOM - Pytorch (wip) An attempt at the implementation of Glom, Geoffrey Hinton's new idea that integrates neural fields, predictive coding,

Phil Wang 173 Dec 14, 2022
Official repository for the paper "Going Beyond Linear Transformers with Recurrent Fast Weight Programmers"

Recurrent Fast Weight Programmers This is the official repository containing the code we used to produce the experimental results reported in the pape

IDSIA 36 Nov 15, 2022
Implementation of CVPR'2022:Surface Reconstruction from Point Clouds by Learning Predictive Context Priors

Surface Reconstruction from Point Clouds by Learning Predictive Context Priors (CVPR 2022) Personal Web Pages | Paper | Project Page This repository c

null 136 Dec 12, 2022
OHLC Average Prediction of Apple Inc. Using LSTM Recurrent Neural Network

Stock Price Prediction of Apple Inc. Using Recurrent Neural Network OHLC Average Prediction of Apple Inc. Using LSTM Recurrent Neural Network Dataset:

Nouroz Rahman 410 Jan 5, 2023
Speech Separation Using an Asynchronous Fully Recurrent Convolutional Neural Network

Speech Separation Using an Asynchronous Fully Recurrent Convolutional Neural Network This repository is the official implementation of Speech Separati

Kai Li (李凯) 116 Nov 9, 2022
Pytorch implementation of the Variational Recurrent Neural Network (VRNN).

VariationalRecurrentNeuralNetwork Pytorch implementation of the Variational RNN (VRNN), from A Recurrent Latent Variable Model for Sequential Data. Th

emmanuel 251 Dec 17, 2022