Generative Query Network (GQN) in PyTorch as described in "Neural Scene Representation and Rendering"

Overview

Update 2019/06/24: A model trained on 10% of the Shepard-Metzler dataset has been added, the following notebook explains the main features of this model: nbviewer

Generative Query Network

This is a PyTorch implementation of the Generative Query Network (GQN) described in the DeepMind paper "Neural scene representation and rendering" by Eslami et al. For an introduction to the model and problem described in the paper look at the article by DeepMind.

The current implementation generalises to any of the datasets described in the paper. However, currently, only the Shepard-Metzler dataset has been implemented. To use this dataset you can use the provided script in

sh scripts/data.sh data-dir batch-size

The model can be trained in full by in accordance to the paper by running the file run-gqn.py or by using the provided training script

sh scripts/gpu.sh data-dir

Implementation

The implementation shown in this repository consists of all of the representation architectures described in the paper along with the generative model that is similar to the one described in "Towards conceptual compression" by Gregor et al.

Additionally, this repository also contains implementations of the DRAW model and the ConvolutionalDRAW model both described by Gregor et al.

Comments
  • Training time and testing demo

    Training time and testing demo

    Hi Jesper,

    Thank you for your great code of gqn in real image, I am a little curious about the following issues: How many epochs it use to train a model on real image? How many training data do you use (percentage of full training dataset)? Can you show a testing demo?

    Thank you very much!

    Best wishes, Mingjia Chen

    opened by mjchen611 22
  • ConvLSTM did not concat hidden from last round

    ConvLSTM did not concat hidden from last round

    In the structure presented in the paper, the hidden from last round is concat with input and then proceed for other operation. But it seems your LSTM did not use the hidden information from previous round.

    opened by Tom-the-Cat 7
  • Bad images in training

    Bad images in training

    While playing around with the sm5 dataset, I noticed some of them are badly rendered. individualimage Not sure if this will pose any problem for training, just wanted to point this out.

    opened by versatran01 7
  • Question about generator

    Question about generator

    In the top docstring of generator.py, you mentioned that

    The inference-generator architecture is conceptually
    similar to the encoder-decoder pair seen in variational
    autoencoders.
    

    I don't quite understand this part and I would really appreciate if you could explain a bit or point me at some related aritcles. For the generator I can see how it is similar to a decoder, where it takes latent z, query viewpoint v, and aggregated representation r and eventually output the image x_mu.

    But I'm a bit confused by the inference being the conterpart of encoder.

    opened by versatran01 7
  • Loss Change

    Loss Change

    Dear wohlert,

    May I consult you several questions?

    1. I tried to train this network on Mazes Data from https://github.com/deepmind/gqn-datasets. Actually it just contains 5% data, which is around 110000, instead of the full data. Is it right?

    2. I trained 30000 steps, but the elbo loss only converged to 6800 which has a big difference compared to around 7 in the supplementary. So may I ask what is the approximate value do you achieve on the data you used?

    3. From the visualisation based on Question 2, the reconstruction seems to be reasonable. But the sampling results is quite bad. Do you meet the same problem?

    Many thanks, Bing

    opened by BingCS 5
  • Questions on data preparing

    Questions on data preparing

    Hi, Wohlert:

    After the data conversion with your scripts, I visualize some of the images in the *.pt found pictures like this Figure_1-1

    What's wrong with that Also I'm confused about your batch operation , say if you batch the sequences as you convert them, does it mean that you won't batch them again when use dataloader?

    Thanks

    opened by Kyridiculous2 5
  • Training crashes at the same spot for both Shepard Metzler datasets

    Training crashes at the same spot for both Shepard Metzler datasets

    Some context:

    • I downloaded and converted the datasets via data.sh and set batch size to 12. Note that I am using TensorFlow 1.14 for reading the tfrecord files and converting them.
    • I use gpu.sh to run the training script. I set the batch size to either of [1,12,36,72] and DataParallel to True to use 4 GPUs

    But after a shrot time I get the following errors if I use any batch size higher than 1. This happens on iterations 40, 13 and 6 with batch sizes 12, 36 and 72. This happens for both Shepard Metzler datasets. Why I am getting these errors? Does batch size 1 on the training code mean reading one of the .pt.gz files? If so, setting batch size to 1 in the training script should actually mean 12. Would that be correct?

    Here's what I get for the data set with 5 parts when I set batch size to 36 for instance:

    Epoch [1/200]: [13/1856]   1%|▊                                                                                                                       , elbo=-2.1e+4, kl=827, mu=5e-6, sigma=2 [00:21<52:34]Current run is terminating due to exception: Caught RuntimeError in DataLoader worker process 13.
    Original Traceback (most recent call last):
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
        data = fetcher.fetch(index)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
        return self.collate_fn(data)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
        return torch.stack(batch, 0, out=out)
    RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 12 and 8 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689
    .
    Engine run is terminating due to exception: Caught RuntimeError in DataLoader worker process 13.
    Original Traceback (most recent call last):
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
        data = fetcher.fetch(index)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
        return self.collate_fn(data)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
        return torch.stack(batch, 0, out=out)
    RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 12 and 8 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689
    .
    Traceback (most recent call last):
      File "../run-gqn.py", line 183, in <module>
        trainer.run(train_loader, args.n_epochs)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 850, in run
        return self._internal_run()
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 952, in _internal_run
        self._handle_exception(e)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 714, in _handle_exception
        self._fire_event(Events.EXCEPTION_RAISED, e)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 607, in _fire_event
        func(self, *(event_args + args), **kwargs)
      File "../run-gqn.py", line 181, in handle_exception
        else: raise e
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 937, in _internal_run
        hours, mins, secs = self._run_once_on_dataset()
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 705, in _run_once_on_dataset
        self._handle_exception(e)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 714, in _handle_exception
        self._fire_event(Events.EXCEPTION_RAISED, e)
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 607, in _fire_event
        func(self, *(event_args + args), **kwargs)
      File "../run-gqn.py", line 181, in handle_exception
        else: raise e
      File "/usr/local/lib/python3.6/dist-packages/ignite/engine/engine.py", line 655, in _run_once_on_dataset
        batch = next(self._dataloader_iter)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 801, in __next__
        return self._process_data(data)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py", line 846, in _process_data
        data.reraise()
      File "/usr/local/lib/python3.6/dist-packages/torch/_utils.py", line 385, in reraise
        raise self.exc_type(msg)
    RuntimeError: Caught RuntimeError in DataLoader worker process 13.
    Original Traceback (most recent call last):
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 178, in _worker_loop
        data = fetcher.fetch(index)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 47, in fetch
        return self.collate_fn(data)
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in default_collate
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 79, in <listcomp>
        return [default_collate(samples) for samples in transposed]
      File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/collate.py", line 55, in default_collate
        return torch.stack(batch, 0, out=out)
    RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 12 and 8 in dimension 1 at /pytorch/aten/src/TH/generic/THTensor.cpp:689
    
    opened by Amir-Arsalan 4
  • AttributeError: 'int' object has no attribute 'size'

    AttributeError: 'int' object has no attribute 'size'

    in draw.py, I get this error at the line 118 (batch_size = z.size(0)) Sorry if this is obvious, thanks for help anyway.

    ~ % pip show torch :( Name: torch Version: 1.0.1.post2

    opened by DRM-Free 4
  • Increase dimension of viewpoint and representation

    Increase dimension of viewpoint and representation

    Thanks for this implementation. One question I have is when increasing the dimension of viewpoint and representation, you use torch.repeat. Is there any reason for this? Can one possibly use interpolate?

    In the original paper it says "when concatenating viewpoint v to an image or feature map, its values are ‘broadcast’ in the spatial dimensions to obtain the correct size. "

    The word 'broadcast' is not precisely defined, hence the question.

    opened by versatran01 4
  • Learning rate change

    Learning rate change

    Regarding line 113 of run-gqn.py. Does this change the learning rate of the Adam optimizer? This post shows something different

    https://stackoverflow.com/questions/48324152/pytorch-how-to-change-the-learning-rate-of-an-optimizer-at-any-given-moment-no

    opened by david-bernstein 4
  • Using the rooms data?

    Using the rooms data?

    I wanted to try your code on the rooms data but during conversion, I get these errors. What could I be doing wrong? Note that for the rooms data with moving camera I set the number of camera parameters to 7:

    Traceback (most recent call last):
      File "/usr/lib/python3.6/multiprocessing/pool.py", line 119, in worker
        result = (True, func(*args, **kwds))
      File "/usr/lib/python3.6/multiprocessing/pool.py", line 44, in mapstar
        return list(map(*args))
      File "tfrecord-converter.py", line 66, in convert
        for i, batch in enumerate(batch_process(record)):
      File "tfrecord-converter.py", line 29, in chunk
        for first in iterator:
      File "tfrecord-converter.py", line 40, in process
        'cameras': tf.FixedLenFeature(shape=SEQ_DIM * POSE_DIM, dtype=tf.float32)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/parsing_ops.py", line 1019, in parse_single_example
        serialized, features, example_names, name
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/parsing_ops.py", line 1063, in parse_single_example_v2_unoptimized
        return parse_single_example_v2(serialized, features, name)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/parsing_ops.py", line 2089, in parse_single_example_v2
        dense_defaults, dense_shapes, name)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/parsing_ops.py", line 2206, in _parse_single_example_v2_raw
        name=name)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_parsing_ops.py", line 1164, in parse_single_example
        ctx=_ctx)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/gen_parsing_ops.py", line 1260, in parse_single_example_eager_fallback
        attrs=_attrs, ctx=_ctx, name=name)
      File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/eager/execute.py", line 67, in quick_execute
        six.raise_from(core._status_to_exception(e.code, message), None)
      File "<string>", line 3, in raise_from
    tensorflow.python.framework.errors_impl.InvalidArgumentError: Key: frames.  Can't parse serialized Example. [Op:ParseSingleExample]
    """
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "tfrecord-converter.py", line 98, in <module>
        pool.map(f, records)
      File "/usr/lib/python3.6/multiprocessing/pool.py", line 266, in map
        return self._map_async(func, iterable, mapstar, chunksize).get()
      File "/usr/lib/python3.6/multiprocessing/pool.py", line 644, in get
        raise self._value
    tensorflow.python.framework.errors_impl.InvalidArgumentError: Key: frames.  Can't parse serialized Example. [Op:ParseSingleExample]
    
    opened by Amir-Arsalan 3
Releases(0.1)
Owner
Jesper Wohlert
Jesper Wohlert
Code for ACL 21: Generating Query Focused Summaries from Query-Free Resources

marge This repository releases the code for Generating Query Focused Summaries from Query-Free Resources. Please cite the following paper [bib] if you

Yumo Xu 28 Nov 10, 2022
PyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper "A Closer Look at Spatiotemporal Convolutions for Action Recognition"

R2Plus1D-PyTorch PyTorch implementation of the R2Plus1D convolution based ResNet architecture described in the paper "A Closer Look at Spatiotemporal

Irhum Shafkat 342 Dec 16, 2022
PyTorch implementation of the method described in the paper VoiceLoop: Voice Fitting and Synthesis via a Phonological Loop.

VoiceLoop PyTorch implementation of the method described in the paper VoiceLoop: Voice Fitting and Synthesis via a Phonological Loop. VoiceLoop is a n

Meta Archive 873 Dec 15, 2022
A pure PyTorch implementation of the loss described in "Online Segment to Segment Neural Transduction"

ssnt-loss ℹ️ This is a WIP project. the implementation is still being tested. A pure PyTorch implementation of the loss described in "Online Segment t

張致強 1 Feb 9, 2022
Minimal PyTorch implementation of Generative Latent Optimization from the paper "Optimizing the Latent Space of Generative Networks"

Minimal PyTorch implementation of Generative Latent Optimization This is a reimplementation of the paper Piotr Bojanowski, Armand Joulin, David Lopez-

Thomas Neumann 117 Nov 27, 2022
This is a package for LiDARTag, described in paper: LiDARTag: A Real-Time Fiducial Tag System for Point Clouds

LiDARTag Overview This is a package for LiDARTag, described in paper: LiDARTag: A Real-Time Fiducial Tag System for Point Clouds (PDF)(arXiv). This wo

University of Michigan Dynamic Legged Locomotion Robotics Lab 159 Dec 21, 2022
An interpreter for RASP as described in the ICML 2021 paper "Thinking Like Transformers"

RASP Setup Mac or Linux Run ./setup.sh . It will create a python3 virtual environment and install the dependencies for RASP. It will also try to insta

null 141 Jan 3, 2023
Source code for models described in the paper "AudioCLIP: Extending CLIP to Image, Text and Audio" (https://arxiv.org/abs/2106.13043)

AudioCLIP Extending CLIP to Image, Text and Audio This repository contains implementation of the models described in the paper arXiv:2106.13043. This

null 458 Jan 2, 2023
An official reimplementation of the method described in the INTERSPEECH 2021 paper - Speech Resynthesis from Discrete Disentangled Self-Supervised Representations.

Speech Resynthesis from Discrete Disentangled Self-Supervised Representations Implementation of the method described in the Speech Resynthesis from Di

Facebook Research 253 Jan 6, 2023
Python implementation of 3D facial mesh exaggeration using the techniques described in the paper: Computational Caricaturization of Surfaces.

Python implementation of 3D facial mesh exaggeration using the techniques described in the paper: Computational Caricaturization of Surfaces.

Wonjong Jang 8 Nov 1, 2022
Implementation of the method described in the Speech Resynthesis from Discrete Disentangled Self-Supervised Representations.

Speech Resynthesis from Discrete Disentangled Self-Supervised Representations Implementation of the method described in the Speech Resynthesis from Di

null 4 Mar 11, 2022
pytorch implementation for Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network arXiv:1609.04802

PyTorch SRResNet Implementation of Paper: "Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network"(https://arxiv.org/abs

Jiu XU 436 Jan 9, 2023
A PyTorch implementation of the WaveGlow: A Flow-based Generative Network for Speech Synthesis

WaveGlow A PyTorch implementation of the WaveGlow: A Flow-based Generative Network for Speech Synthesis Quick Start: Install requirements: pip install

Yuchao Zhang 204 Jul 14, 2022
A method that utilized Generative Adversarial Network (GAN) to interpret the black-box deep image classifier models by PyTorch.

A method that utilized Generative Adversarial Network (GAN) to interpret the black-box deep image classifier models by PyTorch.

Yunxia Zhao 3 Dec 29, 2022
Repo for CVPR2021 paper "QPIC: Query-Based Pairwise Human-Object Interaction Detection with Image-Wide Contextual Information"

QPIC: Query-Based Pairwise Human-Object Interaction Detection with Image-Wide Contextual Information by Masato Tamura, Hiroki Ohashi, and Tomoaki Yosh

null 105 Dec 23, 2022
QueryDet: Cascaded Sparse Query for Accelerating High-Resolution SmallObject Detection

QueryDet-PyTorch This repository is the official implementation of our paper: QueryDet: Cascaded Sparse Query for Accelerating High-Resolution Small O

Chenhongyi Yang 276 Dec 31, 2022
Python library containing BART query generation and BERT-based Siamese models for neural retrieval.

Neural Retrieval Embedding-based Zero-shot Retrieval through Query Generation leverages query synthesis over large corpuses of unlabeled text (such as

Amazon Web Services - Labs 35 Apr 14, 2022
QueryInst: Parallelly Supervised Mask Query for Instance Segmentation

QueryInst is a simple and effective query based instance segmentation method driven by parallel supervision on dynamic mask heads, which outperforms previous arts in terms of both accuracy and speed.

Hust Visual Learning Team 386 Jan 8, 2023
The implementation of CVPR2021 paper Temporal Query Networks for Fine-grained Video Understanding, by Chuhan Zhang, Ankush Gupta and Andrew Zisserman.

Temporal Query Networks for Fine-grained Video Understanding ?? This repository contains the implementation of CVPR2021 paper Temporal_Query_Networks

null 55 Dec 21, 2022