PyTorch implementation of Glow

Overview

glow-pytorch

PyTorch implementation of Glow, Generative Flow with Invertible 1x1 Convolutions (https://arxiv.org/abs/1807.03039)

Usage:

python train.py PATH

as trainer uses ImageFolder of torchvision, input directory should be structured like this even when there are only 1 classes. (Currently this implementation does not incorporate class classification loss.)

PATH/class1
PATH/class2
...

Notes

Sample

I have trained model on vanilla celebA dataset. Seems like works well. I found that learning rate (I have used 1e-4 without scheduling), learnt prior, number of bits (in this cases, 5), and using sigmoid function at the affine coupling layer instead of exponential function is beneficial to training a model.

In my cases, LU decomposed invertible convolution was much faster than plain version. So I made it default to use LU decomposed version.

Progression of samples

Progression of samples during training. Sampled once per 100 iterations during training.

Comments
  • Question on how to reconstruct images

    Question on how to reconstruct images

    Hi!

    Thanks a lot for this repo, it's really great! I am wondering how to reconstruct images, since the current reverse method seems to take a list of images with different sizes.

    I tried writing a method:

        def reverse_data(self,z):
            for i, block in enumerate(self.blocks[::-1]):
                if i == 0:
                    eps = torch.randn_like(z)
                    input = block.reverse(z,eps)
    
                else:
                    eps = torch.zeros_like(input)
                    input = block.reverse(input,eps)
    
            return input
    

    But I am unsure whether this is correct. I would appreciate any help!

    Thank you!

    opened by MrHuff 6
  • Bits per dimension

    Bits per dimension

    Do you know how to map your loss values to bits per dimension results (see Table 2 in the paper)? I'm having a hard time trying to come up with a formula for the correspondence? Some reddit post mentions subtracting math.log(128) to take into account scaling, but it still doesn't seem right.

    I looked at the original implementation in Tensorflow but couldn't figure that out. Would you mind letting me know what you think about it? Also, do you know how close your implementation is compared to the original code in terms of bits per dimension? Thank you.

    opened by tangbinh 5
  • Data Parallel

    Data Parallel

    Thank you for your code. It looks like you have tried to use nn.DataParallel but didn't quite include it in there. Can you tell me your experience with it?

    For some reason, the loss kept increasing when I used nn.DataParallel with 2 GPUs regardless of batch size. To make it run with your code, I changed your calc_loss a little bit by expanding logdet to have same size as log_p. I also tried logdet.mean(), but it didn't work either. Here, I'm not really sure why logdet values are different for the 2 GPUs, as it seems to depend on shared weights only.

    opened by tangbinh 5
  • cusolver error: 7

    cusolver error: 7

    What I run

    python train.py ../train/ --img_size 128 --batch 8
    

    SPEC

    2 x RTX2080 ti Two cards or one card does not change the error.

    CODE

    Latest.

    ENVIRONMENT

    # packages in environment at /opt/conda:
    #
    # Name                    Version                   Build  Channel
    _libgcc_mutex             0.1                        main  
    async-generator           1.10                     pypi_0    pypi
    attrs                     20.3.0                   pypi_0    pypi
    backcall                  0.2.0                      py_0  
    bash-kernel               0.7.2                    pypi_0    pypi
    beautifulsoup4            4.9.3              pyhb0f4dca_0  
    blas                      1.0                         mkl  
    bleach                    3.2.1                    pypi_0    pypi
    bzip2                     1.0.8                h7b6447c_0  
    ca-certificates           2020.10.14                    0  
    certifi                   2020.6.20          pyhd3eb1b0_3  
    cffi                      1.14.0           py38he30daa8_1  
    chardet                   3.0.4                 py38_1003  
    conda                     4.9.2            py38h06a4308_0  
    conda-build               3.20.5                   py38_1  
    conda-package-handling    1.6.1            py38h7b6447c_0  
    cryptography              2.9.2            py38h1ba5d50_0  
    cudatoolkit               11.0.221             h6bb024c_0  
    dataclasses               0.6                      pypi_0    pypi
    decorator                 4.4.2                      py_0  
    defusedxml                0.6.0                    pypi_0    pypi
    dnspython                 2.0.0                    pypi_0    pypi
    entrypoints               0.3                      pypi_0    pypi
    filelock                  3.0.12                     py_0  
    freetype                  2.10.4               h5ab3b9f_0  
    future                    0.18.2                   pypi_0    pypi
    gdown                     3.12.2                   pypi_0    pypi
    glob2                     0.7                        py_0  
    icu                       58.2                 he6710b0_3  
    idna                      2.9                        py_1  
    intel-openmp              2020.2                      254  
    ipykernel                 5.3.4                    pypi_0    pypi
    ipython                   7.19.0                   pypi_0    pypi
    ipython_genutils          0.2.0                    py38_0  
    ipywidgets                7.5.1                    pypi_0    pypi
    jedi                      0.17.2                   py38_0  
    jinja2                    2.11.2                     py_0  
    jpeg                      9b                   h024ee3a_2  
    json5                     0.9.5                    pypi_0    pypi
    jsonschema                3.2.0                    pypi_0    pypi
    jupyter                   1.0.0                    pypi_0    pypi
    jupyter-client            6.1.7                    pypi_0    pypi
    jupyter-console           6.2.0                    pypi_0    pypi
    jupyter-core              4.7.0                    pypi_0    pypi
    jupyterlab                2.2.9                    pypi_0    pypi
    jupyterlab-pygments       0.1.2                    pypi_0    pypi
    jupyterlab-server         1.2.0                    pypi_0    pypi
    lcms2                     2.11                 h396b838_0  
    ld_impl_linux-64          2.33.1               h53a641e_7  
    libarchive                3.4.2                h62408e4_0  
    libedit                   3.1.20181209         hc058e9b_0  
    libffi                    3.3                  he6710b0_1  
    libgcc-ng                 9.1.0                hdf63c60_0  
    libgfortran-ng            7.3.0                hdf63c60_0  
    liblief                   0.10.1               he6710b0_0  
    libpng                    1.6.37               hbc83047_0  
    libstdcxx-ng              9.1.0                hdf63c60_0  
    libtiff                   4.1.0                h2733197_1  
    libuv                     1.40.0               h7b6447c_0  
    libxml2                   2.9.10               hb55368b_3  
    lz4-c                     1.9.2                heb0550a_3  
    markupsafe                1.1.1            py38h7b6447c_0  
    mistune                   0.8.4                    pypi_0    pypi
    mkl                       2020.2                      256  
    mkl-service               2.3.0            py38he904b0f_0  
    mkl_fft                   1.2.0            py38h23d657b_0  
    mkl_random                1.1.1            py38h0573a6f_0  
    nbclient                  0.5.1                    pypi_0    pypi
    nbconvert                 6.0.7                    pypi_0    pypi
    nbformat                  5.0.8                    pypi_0    pypi
    ncurses                   6.2                  he6710b0_1  
    nest-asyncio              1.4.3                    pypi_0    pypi
    ninja                     1.10.1           py38hfd86e86_0  
    notebook                  5.7.5                    pypi_0    pypi
    numpy                     1.19.2           py38h54aff64_0  
    numpy-base                1.19.2           py38hfa32c7d_0  
    olefile                   0.46                       py_0  
    openssl                   1.1.1h               h7b6447c_0  
    packaging                 20.4                     pypi_0    pypi
    pandocfilters             1.4.3                    pypi_0    pypi
    parso                     0.7.0                      py_0  
    patchelf                  0.12                 he6710b0_0  
    pexpect                   4.8.0                    py38_0  
    pickleshare               0.7.5                 py38_1000  
    pillow                    8.0.0            py38h9a89aac_0  
    pip                       20.0.2                   py38_3  
    pkginfo                   1.6.0                    py38_0  
    prometheus-client         0.9.0                    pypi_0    pypi
    prompt-toolkit            3.0.8                      py_0  
    psutil                    5.7.2            py38h7b6447c_0  
    ptyprocess                0.6.0                    py38_0  
    py-lief                   0.10.1           py38h403a769_0  
    pycosat                   0.6.3            py38h7b6447c_1  
    pycparser                 2.20                       py_0  
    pygments                  2.7.1                      py_0  
    pyopenssl                 19.1.0                   py38_0  
    pyparsing                 2.4.7                    pypi_0    pypi
    pyrsistent                0.17.3                   pypi_0    pypi
    pysocks                   1.7.1                    py38_0  
    python                    3.8.3                hcff3b4d_0  
    python-dateutil           2.8.1                    pypi_0    pypi
    python-etcd               0.4.5                    pypi_0    pypi
    python-libarchive-c       2.9                        py_0  
    pytorch                   1.7.0           py3.8_cuda11.0.221_cudnn8.0.3_0    pytorch
    pytz                      2020.1                     py_0  
    pyyaml                    5.3.1            py38h7b6447c_0  
    pyzmq                     20.0.0                   pypi_0    pypi
    qtconsole                 4.7.7                    pypi_0    pypi
    qtpy                      1.9.0                    pypi_0    pypi
    readline                  8.0                  h7b6447c_0  
    requests                  2.23.0                   py38_0  
    ripgrep                   12.1.1                        0  
    ruamel_yaml               0.15.87          py38h7b6447c_0  
    scipy                     1.5.2            py38h0b6359f_0  
    send2trash                1.5.0                    pypi_0    pypi
    setuptools                46.4.0                   py38_0  
    six                       1.14.0                   py38_0  
    soupsieve                 2.0.1                      py_0  
    sqlite                    3.31.1               h62c20be_1  
    terminado                 0.9.1                    pypi_0    pypi
    testpath                  0.4.4                    pypi_0    pypi
    tk                        8.6.8                hbc83047_0  
    torchelastic              0.2.1                    pypi_0    pypi
    torchvision               0.8.0                py38_cu110    pytorch
    tornado                   5.1.1                    pypi_0    pypi
    tqdm                      4.46.0                     py_0  
    traitlets                 5.0.5                      py_0  
    typing_extensions         3.7.4.3                    py_0  
    urllib3                   1.25.8                   py38_0  
    wcwidth                   0.2.5                      py_0  
    webencodings              0.5.1                    pypi_0    pypi
    wheel                     0.34.2                   py38_0  
    widgetsnbextension        3.5.1                    pypi_0    pypi
    xz                        5.2.5                h7b6447c_0  
    yaml                      0.1.7                had09818_2  
    zlib                      1.2.11               h7b6447c_3  
    zstd                      1.4.5                h9ceee32_0 
    

    ERROR MESSAGE:

    Namespace(affine=False, batch=8, img_size=128, iter=200000, lr=0.0001, n_bits=5, n_block=4, n_flow=32, n_sample=20, no_lu=False, path='../train/', temp=0.7)
    /workspace/glow-pytorch/model.py:102: UserWarning: The given NumPy array is not writeable, and PyTorch does not support non-writeable tensors. This means you can write to the underlying (supposedly non-writeable) NumPy array using the tensor. You may want to copy the array to protect its data or make it writeable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at  /opt/conda/conda-bld/pytorch_1603729096996/work/torch/csrc/utils/tensor_numpy.cpp:141.)
      w_s = torch.from_numpy(w_s)
    Loss: 2.15042; logP: -2.13823; logdet: 4.98781; lr: 0.0001000:   0%| | 1/200000 
    Traceback (most recent call last):
      File "train.py", line 177, in <module>
        train(args, model, optimizer)
      File "train.py", line 148, in train
        model_single.reverse(z_sample).cpu().data,
      File "/workspace/glow-pytorch/model.py", line 367, in reverse
        input = block.reverse(z_list[-1], z_list[-1], reconstruct=reconstruct)
      File "/workspace/glow-pytorch/model.py", line 322, in reverse
        input = flow.reverse(input)
      File "/workspace/glow-pytorch/model.py", line 239, in reverse
        input = self.invconv.reverse(input)
      File "/workspace/glow-pytorch/model.py", line 136, in reverse
        return F.conv2d(output, weight.squeeze().inverse().unsqueeze(2).unsqueeze(3))
    RuntimeError: cusolver error: 7, when calling `cusolverDnCreate(handle)`
    

    EXTRA:

    This bug happens when doing reverse calculation when i % 100 == 0. I changed it to i == 1 to faster the bug reproduction.

    And, changing w_s = torch.from_numpy(w_s) to w_s = torch.from_numpy(w_s.copy()) turn offs all warnings above. But the error still occurs.

    opened by LyWangPX 3
  • Conditional Gaussian prior parameters produce unnormalized likelihoods

    Conditional Gaussian prior parameters produce unnormalized likelihoods

    Hi,

    first off, thank you for you implementation!

    I found, that you've deviated from the original OpenAI implementation and enabled to produce the prior parameters (mean, logsd) from the intermediate flow splits via an additional convolution https://github.com/rosinality/glow-pytorch/blob/97081ff115a694cf04aeedbd58447f33d242c879/model.py#L285

    Unfortunately, this little change leads to the Gauss distribution being unnormalized. If you think about it: let's say the convolution that produces the prior parameters (mean, logsd) will learn to produce mean = z_new, logsd = 0. Then the gaussian prior likelihood for z_new: N(z_new; mean=z_new, sd=1) will always have maximum likelihood, since the query values z_new are equal to the distribution mode. So, if you integrate this term for all z_new in R (lets keep it simple one-dimensional), you will end up with a value > 1 (should be infinity), showing that this way of defining the prior leads to an unnormalized distribution.

    To fix this you need to (as done in the original OpenAI implementation) remove the condition of the mean and logsd on the split variables and just learn them from a fixed input. Relevant pieces of the original implementation:

    Prior is created unconditionally https://github.com/openai/glow/blob/master/model.py#L180

    Fixed zero input https://github.com/openai/glow/blob/master/model.py#L109

    If parameters are learned, apply convolution on fixed input https://github.com/openai/glow/blob/master/model.py#L111

    Create prior from mean and logsd https://github.com/openai/glow/blob/master/model.py#L116

    opened by braun-steven 2
  • something confusing in calculating the loss.

    something confusing in calculating the loss.

    In the paper, when x is continuous, the loss is calculated by feeding $x+U(0,a)$ into the model, but in the codes, it seems you feed $x+N(0,a)$ into the model?

    opened by SiyuWang15 2
  • Why are we sampling z from the standard normal distribution instead of the learned p(z)?

    Why are we sampling z from the standard normal distribution instead of the learned p(z)?

    Hi rosinality,

    Your code is clean and easy to read. Thank you for your effort.

    I have one question: During the sampling process, why are we sampling z from the standard normal distribution (with temperature)? Shouldn't we sample from the learned p(z)? Is it because p(z) is dependant on the data so that we cannot sample from it? (In the implementation, if I'm understanding it correctly, p(z) has four components, three of them are dependent on both the data and the model, while the last one is only dependant on the model.)

    Thanks.

    opened by icbcbicc 2
  •  (why use torch_rand_noise as input?) input 에서의 torch_rand_noise 의 이유

    (why use torch_rand_noise as input?) input 에서의 torch_rand_noise 의 이유

    안녕하세요. 좋은 코드 감사합니다.

    사용 중 질문이 생겨서 이슈에 남깁니다.

    어째서 input에 torch_rand 를 넣어서 사용하시는 건가요?

    제가 레퍼런스나 다른 구현된 코드 등을 보았을 때는 이에 대한 설명이 없어서 문의 드립니다.

    감사합니다.

    (한글이나 영어답변 무엇이든 괜찮습니다. 제가 영어가 부족해서 제목에 핵심 문장만 넣었습니다.)

    opened by Meric3 2
  • Act Norm Output issue

    Act Norm Output issue

    image

    the actnorm function is

    === s * input + b

    But saw the implementation,

     if self.logdet:
             return self.scale * (input + self.loc), logdet
    

    which means that it has become

    === s * (input + b)

    Is it doing on purpose? Or did I get something wrong with the paper?

    opened by kelvinaaa2 1
  • Maybe something wrong with affine paramter in argparse?

    Maybe something wrong with affine paramter in argparse?

    question

    First thanks for this work. I just generated some samples using this project with celebA dataset. However, I felt somehow confused when i tried to debug it. As you can see in the first pic, it shows that the default paramter for affine coupling is not correct. It seems that affine paramter keeps false even the default value is true( in variable, watch mini window and also for the cursor). affine I think maybe something wrong with argparse, which leads to the default false value for coupling layer.

    opened by hjumper 1
  • Log determinant has wrong dimensions?

    Log determinant has wrong dimensions?

    Issue I've printed log_p and logdet returned by the model. It seems that log_p_sum has the size equal to the batch size (which is fine) but the logdet contains a single value. Should logdet have the same dimensions as log_p?

    Code for reproducing Add the following code to train.py line 120. print(log_p.size(), logdet.size())

    opened by matejgrcic 1
  • why  with torch.no_grad() when i == 0:

    why with torch.no_grad() when i == 0:

     if i == 0:
        with torch.no_grad():
            log_p, logdet, _ = model.module(
                image + torch.rand_like(image) / n_bins
            )
    
            continue
    
    else:
        log_p, logdet, _ = model(image + torch.rand_like(image) / n_bins)
    
    opened by fido20160817 0
  • Flow not perfectly invertible

    Flow not perfectly invertible

    Hi, given an input image tensor x and extracting the glow model i tried the following:

    latent = glow(x)[2] x_reconstructed = glow.reverse(latent)

    Since it is a normalizing flow one would expect, that x_reconstructed is very similar to x, since the only source of errors should be rounding errors. However, I observe very big differences. Does anybody has an explanation for that?

    opened by sidney1505 1
  • too smalll value of logP

    too smalll value of logP

    I designed an coupling layer by mysefl. It is a bit sophisticated.

    During training, the value of loss and logP are not desirable. How it happens? maybe due to the data itself or the structure of coupling layer. Anybody know about this?

    7IGBiFOQiy

    opened by fido20160817 0
  • z_list

    z_list

    Hi, Thanks for your nice work! I am new to the glow model, so I have some stupid questions, and I don't solve them even if I try to google.

    image The flow model can translate the input $x$ to latent space code $z$ by a sequence of $h$ models. In my understanding, we only need the last output $z$ to reconstruct the input $x$. Why we don't use the learned $z$ and use a random z_list? I appreciate your answer and hope you have a good day!

    opened by ZhouCX117 9
Owner
Kim Seonghyeon
no side-effects
Kim Seonghyeon
ALBERT-pytorch-implementation - ALBERT pytorch implementation

ALBERT-pytorch-implementation developing... 모델의 개념이해를 돕기 위한 구현물로 현재 변수명을 상세히 적었고

BG Kim 3 Oct 6, 2022
An essential implementation of BYOL in PyTorch + PyTorch Lightning

Essential BYOL A simple and complete implementation of Bootstrap your own latent: A new approach to self-supervised Learning in PyTorch + PyTorch Ligh

Enrico Fini 48 Sep 27, 2022
RealFormer-Pytorch Implementation of RealFormer using pytorch

RealFormer-Pytorch Implementation of RealFormer using pytorch. Includes comparison with classical Transformer on image classification task (ViT) wrt C

Simo Ryu 90 Dec 8, 2022
A PyTorch implementation of the paper Mixup: Beyond Empirical Risk Minimization in PyTorch

Mixup: Beyond Empirical Risk Minimization in PyTorch This is an unofficial PyTorch implementation of mixup: Beyond Empirical Risk Minimization. The co

Harry Yang 121 Dec 17, 2022
A pytorch implementation of Pytorch-Sketch-RNN

Pytorch-Sketch-RNN A pytorch implementation of https://arxiv.org/abs/1704.03477 In order to draw other things than cats, you will find more drawing da

Alexis David Jacq 172 Dec 12, 2022
PyTorch implementation of Advantage async actor-critic Algorithms (A3C) in PyTorch

Advantage async actor-critic Algorithms (A3C) in PyTorch @inproceedings{mnih2016asynchronous, title={Asynchronous methods for deep reinforcement lea

LEI TAI 111 Dec 8, 2022
Pytorch-diffusion - A basic PyTorch implementation of 'Denoising Diffusion Probabilistic Models'

PyTorch implementation of 'Denoising Diffusion Probabilistic Models' This reposi

Arthur Juliani 76 Jan 7, 2023
Fang Zhonghao 13 Nov 19, 2022
RETRO-pytorch - Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch

RETRO - Pytorch (wip) Implementation of RETRO, Deepmind's Retrieval based Attent

Phil Wang 556 Jan 4, 2023
HashNeRF-pytorch - Pure PyTorch Implementation of NVIDIA paper on Instant Training of Neural Graphics primitives

HashNeRF-pytorch Instant-NGP recently introduced a Multi-resolution Hash Encodin

Yash Sanjay Bhalgat 616 Jan 6, 2023
Generic template to bootstrap your PyTorch project with PyTorch Lightning, Hydra, W&B, and DVC.

NN Template Generic template to bootstrap your PyTorch project. Click on Use this Template and avoid writing boilerplate code for: PyTorch Lightning,

Luca Moschella 520 Dec 30, 2022
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intention of Apex is to make up-to-date utilities available to users as quickly as possible.

NVIDIA Corporation 6.9k Jan 3, 2023
Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch

30 Days Of Machine Learning Using Pytorch Objective of the repository is to learn and build machine learning models using Pytorch. List of Algorithms

Mayur 119 Nov 24, 2022
Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pytorch Lightning 1.4k Jan 1, 2023
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

Mamy Ratsimbazafy 360 Dec 10, 2022
The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.

This is a curated list of tutorials, projects, libraries, videos, papers, books and anything related to the incredible PyTorch. Feel free to make a pu

Ritchie Ng 9.2k Jan 2, 2023
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

Mamy Ratsimbazafy 359 Jan 5, 2023
A bunch of random PyTorch models using PyTorch's C++ frontend

PyTorch Deep Learning Models using the C++ frontend Gettting started Clone the repo 1. https://github.com/mrdvince/pytorchcpp 2. cd fashionmnist or

Vince 0 Jul 13, 2021
PyTorch Autoencoders - Implementing a Variational Autoencoder (VAE) Series in Pytorch.

PyTorch Autoencoders Implementing a Variational Autoencoder (VAE) Series in Pytorch. Inspired by this repository Model List check model paper conferen

Subin An 8 Nov 21, 2022