PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions

Overview

glow-pytorch

PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions (https://arxiv.org/abs/1807.03039)

Usage:

python train.py PATH

as trainer uses ImageFolder of torchvision, input directory should be structured like this even when there are only 1 classes. (Currently this implementation does not incorporate class classification loss.)

PATH/class1
PATH/class2
...

Notes

Sample

I have trained model on vanilla celebA dataset. Seems like works well. I found that learning rate (I have used 1e-4 without scheduling), learnt prior, number of bits (in this cases, 5), and using sigmoid function at the affine coupling layer instead of exponential function is beneficial to training a model.

In my cases, LU decomposed invertible convolution was much faster than plain version. So I made it default to use LU decomposed version.

Progression of samples

Progression of samples during training. Sampled once per 100 iterations during training.

Comments
  • Question on how to reconstruct images

    Question on how to reconstruct images

    Hi!

    Thanks a lot for this repo, it's really great! I am wondering how to reconstruct images, since the current reverse method seems to take a list of images with different sizes.

    I tried writing a method:

        def reverse_data(self,z):
            for i, block in enumerate(self.blocks[::-1]):
                if i == 0:
                    eps = torch.randn_like(z)
                    input = block.reverse(z,eps)
    
                else:
                    eps = torch.zeros_like(input)
                    input = block.reverse(input,eps)
    
            return input
    

    But I am unsure whether this is correct. I would appreciate any help!

    Thank you!

    opened by MrHuff 6
  • Bits per dimension

    Bits per dimension

    Do you know how to map your loss values to bits per dimension results (see Table 2 in the paper)? I'm having a hard time trying to come up with a formula for the correspondence? Some reddit post mentions subtracting math.log(128) to take into account scaling, but it still doesn't seem right.

    I looked at the original implementation in Tensorflow but couldn't figure that out. Would you mind letting me know what you think about it? Also, do you know how close your implementation is compared to the original code in terms of bits per dimension? Thank you.

    opened by tangbinh 5
  • Data Parallel

    Data Parallel

    Thank you for your code. It looks like you have tried to use nn.DataParallel but didn't quite include it in there. Can you tell me your experience with it?

    For some reason, the loss kept increasing when I used nn.DataParallel with 2 GPUs regardless of batch size. To make it run with your code, I changed your calc_loss a little bit by expanding logdet to have same size as log_p. I also tried logdet.mean(), but it didn't work either. Here, I'm not really sure why logdet values are different for the 2 GPUs, as it seems to depend on shared weights only.

    opened by tangbinh 5
  • cusolver error: 7

    cusolver error: 7

    What I run

    python train.py ../train/ --img_size 128 --batch 8
    

    SPEC

    2 x RTX2080 ti Two cards or one card does not change the error.

    CODE

    Latest.

    ENVIRONMENT

    # packages in environment at /opt/conda:
    #
    # Name                    Version                   Build  Channel
    _libgcc_mutex             0.1                        main  
    async-generator           1.10                     pypi_0    pypi
    attrs                     20.3.0                   pypi_0    pypi
    backcall                  0.2.0                      py_0  
    bash-kernel               0.7.2                    pypi_0    pypi
    beautifulsoup4            4.9.3              pyhb0f4dca_0  
    blas                      1.0                         mkl  
    bleach                    3.2.1                    pypi_0    pypi
    bzip2                     1.0.8                h7b6447c_0  
    ca-certificates           2020.10.14                    0  
    certifi                   2020.6.20          pyhd3eb1b0_3  
    cffi                      1.14.0           py38he30daa8_1  
    chardet                   3.0.4                 py38_1003  
    conda                     4.9.2            py38h06a4308_0  
    conda-build               3.20.5                   py38_1  
    conda-package-handling    1.6.1            py38h7b6447c_0  
    cryptography              2.9.2            py38h1ba5d50_0  
    cudatoolkit               11.0.221             h6bb024c_0  
    dataclasses               0.6                      pypi_0    pypi
    decorator                 4.4.2                      py_0  
    defusedxml                0.6.0                    pypi_0    pypi
    dnspython                 2.0.0                    pypi_0    pypi
    entrypoints               0.3                      pypi_0    pypi
    filelock                  3.0.12                     py_0  
    freetype                  2.10.4               h5ab3b9f_0  
    future                    0.18.2                   pypi_0    pypi
    gdown                     3.12.2                   pypi_0    pypi
    glob2                     0.7                        py_0  
    icu                       58.2                 he6710b0_3  
    idna                      2.9                        py_1  
    intel-openmp              2020.2                      254  
    ipykernel                 5.3.4                    pypi_0    pypi
    ipython                   7.19.0                   pypi_0    pypi
    ipython_genutils          0.2.0                    py38_0  
    ipywidgets                7.5.1                    pypi_0    pypi
    jedi                      0.17.2                   py38_0  
    jinja2                    2.11.2                     py_0  
    jpeg                      9b                   h024ee3a_2  
    json5                     0.9.5                    pypi_0    pypi
    jsonschema                3.2.0                    pypi_0    pypi
    jupyter                   1.0.0                    pypi_0    pypi
    jupyter-client            6.1.7                    pypi_0    pypi
    jupyter-console           6.2.0                    pypi_0    pypi
    jupyter-core              4.7.0                    pypi_0    pypi
    jupyterlab                2.2.9                    pypi_0    pypi
    jupyterlab-pygments       0.1.2                    pypi_0    pypi
    jupyterlab-server         1.2.0                    pypi_0    pypi
    lcms2                     2.11                 h396b838_0  
    ld_impl_linux-64          2.33.1               h53a641e_7  
    libarchive                3.4.2                h62408e4_0  
    libedit                   3.1.20181209         hc058e9b_0  
    libffi                    3.3                  he6710b0_1  
    libgcc-ng                 9.1.0                hdf63c60_0  
    libgfortran-ng            7.3.0                hdf63c60_0  
    liblief                   0.10.1               he6710b0_0  
    libpng                    1.6.37               hbc83047_0  
    libstdcxx-ng              9.1.0                hdf63c60_0  
    libtiff                   4.1.0                h2733197_1  
    libuv                     1.40.0               h7b6447c_0  
    libxml2                   2.9.10               hb55368b_3  
    lz4-c                     1.9.2                heb0550a_3  
    markupsafe                1.1.1            py38h7b6447c_0  
    mistune                   0.8.4                    pypi_0    pypi
    mkl                       2020.2                      256  
    mkl-service               2.3.0            py38he904b0f_0  
    mkl_fft                   1.2.0            py38h23d657b_0  
    mkl_random                1.1.1            py38h0573a6f_0  
    nbclient                  0.5.1                    pypi_0    pypi
    nbconvert                 6.0.7                    pypi_0    pypi
    nbformat                  5.0.8                    pypi_0    pypi
    ncurses                   6.2                  he6710b0_1  
    nest-asyncio              1.4.3                    pypi_0    pypi
    ninja                     1.10.1           py38hfd86e86_0  
    notebook                  5.7.5                    pypi_0    pypi
    numpy                     1.19.2           py38h54aff64_0  
    numpy-base                1.19.2           py38hfa32c7d_0  
    olefile                   0.46                       py_0  
    openssl                   1.1.1h               h7b6447c_0  
    packaging                 20.4                     pypi_0    pypi
    pandocfilters             1.4.3                    pypi_0    pypi
    parso                     0.7.0                      py_0  
    patchelf                  0.12                 he6710b0_0  
    pexpect                   4.8.0                    py38_0  
    pickleshare               0.7.5                 py38_1000  
    pillow                    8.0.0            py38h9a89aac_0  
    pip                       20.0.2                   py38_3  
    pkginfo                   1.6.0                    py38_0  
    prometheus-client         0.9.0                    pypi_0    pypi
    prompt-toolkit            3.0.8                      py_0  
    psutil                    5.7.2            py38h7b6447c_0  
    ptyprocess                0.6.0                    py38_0  
    py-lief                   0.10.1           py38h403a769_0  
    pycosat                   0.6.3            py38h7b6447c_1  
    pycparser                 2.20                       py_0  
    pygments                  2.7.1                      py_0  
    pyopenssl                 19.1.0                   py38_0  
    pyparsing                 2.4.7                    pypi_0    pypi
    pyrsistent                0.17.3                   pypi_0    pypi
    pysocks                   1.7.1                    py38_0  
    python                    3.8.3                hcff3b4d_0  
    python-dateutil           2.8.1                    pypi_0    pypi
    python-etcd               0.4.5                    pypi_0    pypi
    python-libarchive-c       2.9                        py_0  
    pytorch                   1.7.0           py3.8_cuda11.0.221_cudnn8.0.3_0    pytorch
    pytz                      2020.1                     py_0  
    pyyaml                    5.3.1            py38h7b6447c_0  
    pyzmq                     20.0.0                   pypi_0    pypi
    qtconsole                 4.7.7                    pypi_0    pypi
    qtpy                      1.9.0                    pypi_0    pypi
    readline                  8.0                  h7b6447c_0  
    requests                  2.23.0                   py38_0  
    ripgrep                   12.1.1                        0  
    ruamel_yaml               0.15.87          py38h7b6447c_0  
    scipy                     1.5.2            py38h0b6359f_0  
    send2trash                1.5.0                    pypi_0    pypi
    setuptools                46.4.0                   py38_0  
    six                       1.14.0                   py38_0  
    soupsieve                 2.0.1                      py_0  
    sqlite                    3.31.1               h62c20be_1  
    terminado                 0.9.1                    pypi_0    pypi
    testpath                  0.4.4                    pypi_0    pypi
    tk                        8.6.8                hbc83047_0  
    torchelastic              0.2.1                    pypi_0    pypi
    torchvision               0.8.0                py38_cu110    pytorch
    tornado                   5.1.1                    pypi_0    pypi
    tqdm                      4.46.0                     py_0  
    traitlets                 5.0.5                      py_0  
    typing_extensions         3.7.4.3                    py_0  
    urllib3                   1.25.8                   py38_0  
    wcwidth                   0.2.5                      py_0  
    webencodings              0.5.1                    pypi_0    pypi
    wheel                     0.34.2                   py38_0  
    widgetsnbextension        3.5.1                    pypi_0    pypi
    xz                        5.2.5                h7b6447c_0  
    yaml                      0.1.7                had09818_2  
    zlib                      1.2.11               h7b6447c_3  
    zstd                      1.4.5                h9ceee32_0 
    

    ERROR MESSAGE:

    Namespace(affine=False, batch=8, img_size=128, iter=200000, lr=0.0001, n_bits=5, n_block=4, n_flow=32, n_sample=20, no_lu=False, path='../train/', temp=0.7)
    /workspace/glow-pytorch/model.py:102: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /opt/conda/conda-bld/pytorch_1603729096996/work/torch/csrc/utils/tensor_numpy.cpp:141.)
      w_s = torch.from_numpy(w_s)
    Loss: 2.15042; logP: -2.13823; logdet: 4.98781; lr: 0.0001000:   0%| | 1/200000 
    Traceback (most recent call last):
      File "train.py", line 177, in <module>
        train(args, model, optimizer)
      File "train.py", line 148, in train
        model_single.reverse(z_sample).cpu().data,
      File "/workspace/glow-pytorch/model.py", line 367, in reverse
        input = block.reverse(z_list[-1], z_list[-1], reconstruct=reconstruct)
      File "/workspace/glow-pytorch/model.py", line 322, in reverse
        input = flow.reverse(input)
      File "/workspace/glow-pytorch/model.py", line 239, in reverse
        input = self.invconv.reverse(input)
      File "/workspace/glow-pytorch/model.py", line 136, in reverse
        return F.conv2d(output, weight.squeeze().inverse().unsqueeze(2).unsqueeze(3))
    RuntimeError: cusolver error: 7, when calling `cusolverDnCreate(handle)`
    

    EXTRA:

    This bug happens when doing reverse calculation when i % 100 == 0. I changed it to i == 1 to faster the bug reproduction.

    And, changing w_s = torch.from_numpy(w_s) to w_s = torch.from_numpy(w_s.copy()) turn offs all warnings above. But the error still occurs.

    opened by LyWangPX 3
  • Conditional Gaussian prior parameters produce unnormalized likelihoods

    Conditional Gaussian prior parameters produce unnormalized likelihoods

    Hi,

    first off, thank you for you implementation!

    I found, that you've deviated from the original OpenAI implementation and enabled to produce the prior parameters (mean, logsd) from the intermediate flow splits via an additional convolution https://github.com/rosinality/glow-pytorch/blob/97081ff115a694cf04aeedbd58447f33d242c879/model.py#L285

    Unfortunately, this little change leads to the Gauss distribution being unnormalized. If you think about it: let's say the convolution that produces the prior parameters (mean, logsd) will learn to produce mean = z_new, logsd = 0. Then the gaussian prior likelihood for z_new: N(z_new; mean=z_new, sd=1) will always have maximum likelihood, since the query values z_new are equal to the distribution mode. So, if you integrate this term for all z_new in R (lets keep it simple one-dimensional), you will end up with a value > 1 (should be infinity), showing that this way of defining the prior leads to an unnormalized distribution.

    To fix this you need to (as done in the original OpenAI implementation) remove the condition of the mean and logsd on the split variables and just learn them from a fixed input. Relevant pieces of the original implementation:

    Prior is created unconditionally https://github.com/openai/glow/blob/master/model.py#L180

    Fixed zero input https://github.com/openai/glow/blob/master/model.py#L109

    If parameters are learned, apply convolution on fixed input https://github.com/openai/glow/blob/master/model.py#L111

    Create prior from mean and logsd https://github.com/openai/glow/blob/master/model.py#L116

    opened by braun-steven 2
  • something confusing in calculating the loss.

    something confusing in calculating the loss.

    In the paper, when x is continuous, the loss is calculated by feeding $x+U(0,a)$ into the model, but in the codes, it seems you feed $x+N(0,a)$ into the model?

    opened by SiyuWang15 2
  • Why are we sampling z from the standard normal distribution instead of the learned p(z)?

    Why are we sampling z from the standard normal distribution instead of the learned p(z)?

    Hi rosinality,

    Your code is clean and easy to read. Thank you for your effort.

    I have one question: During the sampling process, why are we sampling z from the standard normal distribution (with temperature)? Shouldn't we sample from the learned p(z)? Is it because p(z) is dependant on the data so that we cannot sample from it? (In the implementation, if I'm understanding it correctly, p(z) has four components, three of them are dependent on both the data and the model, while the last one is only dependant on the model.)

    Thanks.

    opened by icbcbicc 2
  •  (why use torch_rand_noise as input?) input 에서의 torch_rand_noise 의 이유

    (why use torch_rand_noise as input?) input 에서의 torch_rand_noise 의 이유

    안녕하세요. 좋은 코드 감사합니다.

    사용 중 질문이 생겨서 이슈에 남깁니다.

    어째서 input에 torch_rand 를 넣어서 사용하시는 건가요?

    제가 레퍼런스나 다른 구현된 코드 등을 보았을 때는 이에 대한 설명이 없어서 문의 드립니다.

    감사합니다.

    (한글이나 영어답변 무엇이든 괜찮습니다. 제가 영어가 부족해서 제목에 핵심 문장만 넣었습니다.)

    opened by Meric3 2
  • Act Norm Output issue

    Act Norm Output issue

    image

    the actnorm function is

    === s * input + b

    But saw the implementation,

     if self.logdet:
             return self.scale * (input + self.loc), logdet
    

    which means that it has become

    === s * (input + b)

    Is it doing on purpose? Or did I get something wrong with the paper?

    opened by kelvinaaa2 1
  • Maybe something wrong with affine paramter in argparse?

    Maybe something wrong with affine paramter in argparse?

    question

    First thanks for this work. I just generated some samples using this project with celebA dataset. However, I felt somehow confused when i tried to debug it. As you can see in the first pic, it shows that the default paramter for affine coupling is not correct. It seems that affine paramter keeps false even the default value is true( in variable, watch mini window and also for the cursor). affine I think maybe something wrong with argparse, which leads to the default false value for coupling layer.

    opened by hjumper 1
  • Log determinant has wrong dimensions?

    Log determinant has wrong dimensions?

    Issue I've printed log_p and logdet returned by the model. It seems that log_p_sum has the size equal to the batch size (which is fine) but the logdet contains a single value. Should logdet have the same dimensions as log_p?

    Code for reproducing Add the following code to train.py line 120. print(log_p.size(), logdet.size())

    opened by matejgrcic 1
  • why  with torch.no_grad() when i == 0:

    why with torch.no_grad() when i == 0:

     if i == 0:
        with torch.no_grad():
            log_p, logdet, _ = model.module(
                image + torch.rand_like(image) / n_bins
            )
    
            continue
    
    else:
        log_p, logdet, _ = model(image + torch.rand_like(image) / n_bins)
    
    opened by fido20160817 0
  • Flow not perfectly invertible

    Flow not perfectly invertible

    Hi, given an input image tensor x and extracting the glow model i tried the following:

    latent = glow(x)[2] x_reconstructed = glow.reverse(latent)

    Since it is a normalizing flow one would expect, that x_reconstructed is very similar to x, since the only source of errors should be rounding errors. However, I observe very big differences. Does anybody has an explanation for that?

    opened by sidney1505 1
  • too smalll value of logP

    too smalll value of logP

    I designed an coupling layer by mysefl. It is a bit sophisticated.

    During training, the value of loss and logP are not desirable. How it happens? maybe due to the data itself or the structure of coupling layer. Anybody know about this?

    7IGBiFOQiy

    opened by fido20160817 0
  • z_list

    z_list

    Hi, Thanks for your nice work! I am new to the glow model, so I have some stupid questions, and I don't solve them even if I try to google.

    image The flow model can translate the input $x$ to latent space code $z$ by a sequence of $h$ models. In my understanding, we only need the last output $z$ to reconstruct the input $x$. Why we don't use the learned $z$ and use a random z_list? I appreciate your answer and hope you have a good day!

    opened by ZhouCX117 9
Owner
Kim Seonghyeon
no side-effects
Kim Seonghyeon
PyTorch implementations of normalizing flow and its variants.

PyTorch implementations of normalizing flow and its variants.

Tatsuya Yatagawa 55 Dec 1, 2022
Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Unofficial PyTorch implementation of DeepMind's Perceiver IO with PyTorch Lightning scripts for distributed training

Martin Krasser 251 Dec 25, 2022
Tez is a super-simple and lightweight Trainer for PyTorch. It also comes with many utils that you can use to tackle over 90% of deep learning projects in PyTorch.

Tez: a simple pytorch trainer NOTE: Currently, we are not accepting any pull requests! All PRs will be closed. If you want a feature or something does

abhishek thakur 1.1k Jan 4, 2023
null 270 Dec 24, 2022
A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

A lightweight wrapper for PyTorch that provides a simple declarative API for context switching between devices, distributed modes, mixed-precision, and PyTorch extensions.

Fidelity Investments 56 Sep 13, 2022
A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.

A PyTorch repo for data loading and utilities to be shared by the PyTorch domain libraries.

null 878 Dec 30, 2022
PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

Cong Cai 12 Dec 19, 2021
A PyTorch implementation of EfficientNet

EfficientNet PyTorch Quickstart Install with pip install efficientnet_pytorch and load a pretrained EfficientNet with: from efficientnet_pytorch impor

Luke Melas-Kyriazi 7.2k Jan 6, 2023
PyTorch implementation of TabNet paper : https://arxiv.org/pdf/1908.07442.pdf

README TabNet : Attentive Interpretable Tabular Learning This is a pyTorch implementation of Tabnet (Arik, S. O., & Pfister, T. (2019). TabNet: Attent

DreamQuark 2k Dec 27, 2022
An implementation of Performer, a linear attention-based transformer, in Pytorch

Performer - Pytorch An implementation of Performer, a linear attention-based transformer variant with a Fast Attention Via positive Orthogonal Random

Phil Wang 900 Dec 22, 2022
This is an differentiable pytorch implementation of SIFT patch descriptor.

This is an differentiable pytorch implementation of SIFT patch descriptor. It is very slow for describing one patch, but quite fast for batch. It can

Dmytro Mishkin 150 Dec 24, 2022
GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks

GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks This repository implements a capsule model Inten

Joel Huang 15 Dec 24, 2022
Tacotron 2 - PyTorch implementation with faster-than-realtime inference

Tacotron 2 (without wavenet) PyTorch implementation of Natural TTS Synthesis By Conditioning Wavenet On Mel Spectrogram Predictions. This implementati

NVIDIA Corporation 4.1k Jan 3, 2023
A Pytorch Implementation for Compact Bilinear Pooling.

CompactBilinearPooling-Pytorch A Pytorch Implementation for Compact Bilinear Pooling. Adapted from tensorflow_compact_bilinear_pooling Prerequisites I

null 169 Dec 23, 2022
A pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch.

Compact Bilinear Pooling for PyTorch. This repository has a pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch. This

Grégoire Payen de La Garanderie 234 Dec 7, 2022
Pytorch implementation of Distributed Proximal Policy Optimization

Pytorch-DPPO Pytorch implementation of Distributed Proximal Policy Optimization: https://arxiv.org/abs/1707.02286 Using PPO with clip loss (from https

Alexis David Jacq 164 Jan 5, 2023
A PyTorch implementation of L-BFGS.

PyTorch-LBFGS: A PyTorch Implementation of L-BFGS Authors: Hao-Jun Michael Shi (Northwestern University) and Dheevatsa Mudigere (Facebook) What is it?

Hao-Jun Michael Shi 478 Dec 27, 2022
A PyTorch implementation of Learning to learn by gradient descent by gradient descent

Intro PyTorch implementation of Learning to learn by gradient descent by gradient descent. Run python main.py TODO Initial implementation Toy data LST

Ilya Kostrikov 300 Dec 11, 2022
PyTorch Implementation of [1611.06440] Pruning Convolutional Neural Networks for Resource Efficient Inference

PyTorch implementation of [1611.06440 Pruning Convolutional Neural Networks for Resource Efficient Inference] This demonstrates pruning a VGG16 based

Jacob Gildenblat 836 Dec 26, 2022