Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two

Overview

512x512 flowers after 12 hours of training, 1 gpu

256x256 flowers after 12 hours of training, 1 gpu

Pizza

'Lightweight' GAN

PyPI version

Implementation of 'lightweight' GAN proposed in ICLR 2021, in Pytorch. The main contributions of the paper is a skip-layer excitation in the generator, paired with autoencoding self-supervised learning in the discriminator. Quoting the one-line summary "converge on single gpu with few hours' training, on 1024 resolution sub-hundred images".

Install

$ pip install lightweight-gan

Use

One command

$ lightweight_gan --data ./path/to/images --image-size 512

Model will be saved to ./models/{name} every 1000 iterations, and samples from the model saved to ./results/{name}. name will be default, by default.

Training settings

Pretty self explanatory for deep learning practitioners

$ lightweight_gan \
    --data ./path/to/images \
    --name {name of run} \
    --batch-size 16 \
    --gradient-accumulate-every 4 \
    --num-train-steps 200000

Augmentation

Augmentation is essential for Lightweight GAN to work effectively in a low data setting

By default, the augmentation types is set to translation and cutout, with color omitted. You can include color as well with the following.

$ lightweight_gan --data ./path/to/images --aug-prob 0.25 --aug-types [translation,cutout,color]

Test augmentation

You can test and see how your images will be augmented before it pass into a neural network (if you use augmentation). Let's see how it works on this image:

Basic usage

Base code to augment your image, define --aug-test and put path to your image into --data:

lightweight_gan \
    --aug-test \
    --data ./path/to/lena.jpg

After this will be created the file lena_augs.jpg that will be look something like this:

Options

You can use some options to change result:

  • --image-size 256 to change size of image tiles in the result. Default: 256.
  • --aug-type [color,cutout,translation] to combine several augmentations. Default: [cutout,translation].
  • --batch-size 10 to change count of images in the result image. Default: 10.
  • --num-image-tiles 5 to change count of tiles in the result image. Default: 5.

Try this command:

lightweight_gan \
    --aug-test \
    --data ./path/to/lena.jpg \
    --batch-size 16 \
    --num-image-tiles 4 \
    --aug-types [color,translation]

result wil be something like that:

Types of augmentations

This library contains several types of embedded augmentations.
Some of these works by default, some of these can be controlled from a command as options in the --aug-types:

  • Horizontal flip (work by default, not under control, runs in the AugWrapper class);
  • color randomly change brightness, saturation and contrast;
  • cutout creates random black boxes on the image;
  • offset randomly moves image by x and y-axis with repeating image;
    • offset_h only by an x-axis;
    • offset_v only by a y-axis;
  • translation randomly moves image on the canvas with black background;

Full setup of augmentations is --aug-types [color,cutout,offset,translation].
General recommendation is using suitable augs for your data and as many as possible, then after sometime of training disable most destructive (for image) augs.

Color

Cutout

Offset

Only x-axis:

Only y-axis:

Translation

Mixed precision

You can turn on automatic mixed precision with one flag --amp

You should expect it to be 33% faster and save up to 40% memory

Multiple GPUs

Also one flag to use --multi-gpus

Generating

Once you have finished training, you can generate samples with one command. You can select which checkpoint number to load from. If --load-from is not specified, will default to the latest.

$ lightweight_gan \
  --name {name of run} \
  --load-from {checkpoint num} \
  --generate \
  --generate-types {types of result, default: [default,ema]} \
  --num-image-tiles {count of image result}

After run this command you will get folder near results image folder with postfix "-generated-{checkpoint num}".

You can also generate interpolations

$ lightweight_gan --name {name of run} --generate-interpolation

Show progress

After creating several checkpoints of model you can generate progress as sequence images by command:

$ lightweight_gan \
  --name {name of run} \
  --show-progress \
  --generate-types {types of result, default: [default,ema]} \
  --num-image-tiles {count of image result}

After running this command you will get a new folder in the results folder, with postfix "-progress". You can convert the images to a video with ffmpeg using the command "ffmpeg -framerate 10 -pattern_type glob -i '*-ema.jpg' out.mp4".

Show progress gif demonstration

Show progress video demonstration

Discriminator output size

The author has kindly let me know that the discriminator output size (5x5 vs 1x1) leads to different results on different datasets. (5x5 works better for art than for faces, as an example). You can toggle this with a single flag

# disc output size is by default 1x1
$ lightweight_gan --data ./path/to/art --image-size 512 --disc-output-size 5

Attention

You can add linear + axial attention to specific resolution layers with the following

# make sure there are no spaces between the values within the brackets []
$ lightweight_gan --data ./path/to/images --image-size 512 --attn-res-layers [32,64] --aug-prob 0.25

Bonus

You can also train with transparent images

$ lightweight_gan --data ./path/to/images --transparent

Or greyscale

$ lightweight_gan --data ./path/to/images --greyscale

Alternatives

If you want the current state of the art GAN, you can find it at https://github.com/lucidrains/stylegan2-pytorch

Citations

@inproceedings{
    anonymous2021towards,
    title={Towards Faster and Stabilized {\{}GAN{\}} Training for High-fidelity Few-shot Image Synthesis},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=1Fqg133qRaI},
    note={under review}
}
@inproceedings{
    anonymous2021global,
    title={Global Self-Attention Networks},
    author={Anonymous},
    booktitle={Submitted to International Conference on Learning Representations},
    year={2021},
    url={https://openreview.net/forum?id=KiFeuZu24k},
    note={under review}
}
@misc{cao2020global,
    title={Global Context Networks},
    author={Yue Cao and Jiarui Xu and Stephen Lin and Fangyun Wei and Han Hu},
    year={2020},
    eprint={2012.13375},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}
@misc{qin2020fcanet,
    title={FcaNet: Frequency Channel Attention Networks},
    author={Zequn Qin and Pengyi Zhang and Fei Wu and Xi Li},
    year={2020},
    eprint={2012.11879},
    archivePrefix={arXiv},
    primaryClass={cs.CV}
}
@misc{sinha2020topk,
    title={Top-k Training of GANs: Improving GAN Performance by Throwing Away Bad Samples},
    author={Samarth Sinha and Zhengli Zhao and Anirudh Goyal and Colin Raffel and Augustus Odena},
    year={2020},
    eprint={2002.06224},
    archivePrefix={arXiv},
    primaryClass={stat.ML}
}

What I cannot create, I do not understand - Richard Feynman

Comments
  • Troubles with global context module in 0.15.0

    Troubles with global context module in 0.15.0

    @lucidrains

    After update to this version https://github.com/lucidrains/lightweight-gan/releases/tag/0.15.0 I cant continue train my network and did start in from zero. Previous version was in state 117k batches by 4 (468k images, around 66 hours of trainig) image and was pretty good. In new version 0.15.0 on same dataset with same parameters (--image-size 1024 --aug-types [color,offset_h] --aug-prob 1 --amp --batch-size 7) after 77k batches by 7 (539k images, around 49 hours of training) I see some bugs like oil puddle. Did you meet this or do you know how avoid this?

    image

    In previous version with sle-spatial I didnt meet something like this.

    opened by Dok11 9
  • What is sle_spatial?

    What is sle_spatial?

    I have seen this function argument mentioned in this issue:

    https://github.com/lucidrains/lightweight-gan/issues/14#issuecomment-733432989

    What is sle_spatial?

    opened by woctezuma 8
  • unable to load save model. please try downgrading the package to the version specified by the saved model

    unable to load save model. please try downgrading the package to the version specified by the saved model

    I have the following problem since today. How to do/solve this?

    continuing from previous epoch - 118 loading from version 0.21.4 unable to load save model. please try downgrading the package to the version specified by the saved model Traceback (most recent call last): File "/opt/conda/bin/lightweight_gan", line 8, in sys.exit(main()) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 193, in main fire.Fire(train_from_folder) File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 141, in Fire component_trace = _Fire(component, args, parsed_flag_args, context, name) File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 466, in _Fire component, remaining_args = _CallAndUpdateTrace( File "/opt/conda/lib/python3.8/site-packages/fire/core.py", line 681, in _CallAndUpdateTrace component = fn(*varargs, **kwargs) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 184, in train_from_folder run_training(0, 1, model_args, data, load_from, new, num_train_steps, name, seed, use_aim, aim_repo, aim_run_hash) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/cli.py", line 59, in run_training model.load(load_from) File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/lightweight_gan.py", line 1603, in load raise e File "/opt/conda/lib/python3.8/site-packages/lightweight_gan/lightweight_gan.py", line 1600, in load self.GAN.load_state_dict(load_data['GAN']) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1497, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for LightweightGAN: Missing key(s) in state_dict: "G.layers.0.0.2.1.weight", "G.layers.0.0.2.1.bias", "G.layers.0.0.4.weight", "G.layers.0.0.4.bias", "G.layers.0.0.4.running_mean", "G.layers.0.0.4.running_var", "G.layers.1.0.2.1.weight", "G.layers.1.0.2.1.bias", "G.layers.1.0.4.weight", "G.layers.1.0.4.bias", "G.layers.1.0.4.running_mean", "G.layers.1.0.4.running_var", "G.layers.2.0.2.1.weight", "G.layers.2.0.2.1.bias", "G.layers.2.0.4.weight", "G.layers.2.0.4.bias", "G.layers.2.0.4.running_mean", "G.layers.2.0.4.running_var", "G.layers.3.0.2.1.weight", "G.layers.3.0.2.1.bias", "G.layers.3.0.4.weight", "G.layers.3.0.4.bias", "G.layers.3.0.4.running_mean", "G.layers.3.0.4.running_var", "G.layers.3.2.fn.to_lin_q.weight", "G.layers.3.2.fn.to_lin_kv.net.0.weight", "G.layers.3.2.fn.to_lin_kv.net.1.weight", "G.layers.3.2.fn.to_kv.weight", "G.layers.4.0.2.1.weight", "G.layers.4.0.2.1.bias", "G.layers.4.0.4.weight", "G.layers.4.0.4.bias", "G.layers.4.0.4.running_mean", "G.layers.4.0.4.running_var", "G.layers.5.0.2.1.weight", "G.layers.5.0.2.1.bias", "G.layers.5.0.4.weight", "G.layers.5.0.4.bias", "G.layers.5.0.4.running_mean", "G.layers.5.0.4.running_var", "D.residual_layers.3.1.fn.to_lin_q.weight", "D.residual_layers.3.1.fn.to_lin_kv.net.0.weight", "D.residual_layers.3.1.fn.to_lin_kv.net.1.weight", "D.residual_layers.3.1.fn.to_kv.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_q.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.0.weight", "D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.1.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_q.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.0.weight", "D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.1.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.weight", "GE.layers.0.0.2.1.weight", "GE.layers.0.0.2.1.bias", "GE.layers.0.0.4.weight", "GE.layers.0.0.4.bias", "GE.layers.0.0.4.running_mean", "GE.layers.0.0.4.running_var", "GE.layers.1.0.2.1.weight", "GE.layers.1.0.2.1.bias", "GE.layers.1.0.4.weight", "GE.layers.1.0.4.bias", "GE.layers.1.0.4.running_mean", "GE.layers.1.0.4.running_var", "GE.layers.2.0.2.1.weight", "GE.layers.2.0.2.1.bias", "GE.layers.2.0.4.weight", "GE.layers.2.0.4.bias", "GE.layers.2.0.4.running_mean", "GE.layers.2.0.4.running_var", "GE.layers.3.0.2.1.weight", "GE.layers.3.0.2.1.bias", "GE.layers.3.0.4.weight", "GE.layers.3.0.4.bias", "GE.layers.3.0.4.running_mean", "GE.layers.3.0.4.running_var", "GE.layers.3.2.fn.to_lin_q.weight", "GE.layers.3.2.fn.to_lin_kv.net.0.weight", "GE.layers.3.2.fn.to_lin_kv.net.1.weight", "GE.layers.3.2.fn.to_kv.weight", "GE.layers.4.0.2.1.weight", "GE.layers.4.0.2.1.bias", "GE.layers.4.0.4.weight", "GE.layers.4.0.4.bias", "GE.layers.4.0.4.running_mean", "GE.layers.4.0.4.running_var", "GE.layers.5.0.2.1.weight", "GE.layers.5.0.2.1.bias", "GE.layers.5.0.4.weight", "GE.layers.5.0.4.bias", "GE.layers.5.0.4.running_mean", "GE.layers.5.0.4.running_var", "D_aug.D.residual_layers.3.1.fn.to_lin_q.weight", "D_aug.D.residual_layers.3.1.fn.to_lin_kv.net.0.weight", "D_aug.D.residual_layers.3.1.fn.to_lin_kv.net.1.weight", "D_aug.D.residual_layers.3.1.fn.to_kv.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_q.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.0.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_lin_kv.net.1.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_q.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.0.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_lin_kv.net.1.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.weight". Unexpected key(s) in state_dict: "G.layers.0.0.2.weight", "G.layers.0.0.2.bias", "G.layers.0.0.3.bias", "G.layers.0.0.3.running_mean", "G.layers.0.0.3.running_var", "G.layers.0.0.3.num_batches_tracked", "G.layers.1.0.2.weight", "G.layers.1.0.2.bias", "G.layers.1.0.3.bias", "G.layers.1.0.3.running_mean", "G.layers.1.0.3.running_var", "G.layers.1.0.3.num_batches_tracked", "G.layers.2.0.2.weight", "G.layers.2.0.2.bias", "G.layers.2.0.3.bias", "G.layers.2.0.3.running_mean", "G.layers.2.0.3.running_var", "G.layers.2.0.3.num_batches_tracked", "G.layers.3.0.2.weight", "G.layers.3.0.2.bias", "G.layers.3.0.3.bias", "G.layers.3.0.3.running_mean", "G.layers.3.0.3.running_var", "G.layers.3.0.3.num_batches_tracked", "G.layers.3.2.fn.to_kv.net.0.weight", "G.layers.3.2.fn.to_kv.net.1.weight", "G.layers.4.0.2.weight", "G.layers.4.0.2.bias", "G.layers.4.0.3.bias", "G.layers.4.0.3.running_mean", "G.layers.4.0.3.running_var", "G.layers.4.0.3.num_batches_tracked", "G.layers.5.0.2.weight", "G.layers.5.0.2.bias", "G.layers.5.0.3.bias", "G.layers.5.0.3.running_mean", "G.layers.5.0.3.running_var", "G.layers.5.0.3.num_batches_tracked", "D.residual_layers.3.1.fn.to_kv.net.0.weight", "D.residual_layers.3.1.fn.to_kv.net.1.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.net.0.weight", "D.to_shape_disc_out.1.fn.fn.to_kv.net.1.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.net.0.weight", "D.to_shape_disc_out.3.fn.fn.to_kv.net.1.weight", "GE.layers.0.0.2.weight", "GE.layers.0.0.2.bias", "GE.layers.0.0.3.bias", "GE.layers.0.0.3.running_mean", "GE.layers.0.0.3.running_var", "GE.layers.0.0.3.num_batches_tracked", "GE.layers.1.0.2.weight", "GE.layers.1.0.2.bias", "GE.layers.1.0.3.bias", "GE.layers.1.0.3.running_mean", "GE.layers.1.0.3.running_var", "GE.layers.1.0.3.num_batches_tracked", "GE.layers.2.0.2.weight", "GE.layers.2.0.2.bias", "GE.layers.2.0.3.bias", "GE.layers.2.0.3.running_mean", "GE.layers.2.0.3.running_var", "GE.layers.2.0.3.num_batches_tracked", "GE.layers.3.0.2.weight", "GE.layers.3.0.2.bias", "GE.layers.3.0.3.bias", "GE.layers.3.0.3.running_mean", "GE.layers.3.0.3.running_var", "GE.layers.3.0.3.num_batches_tracked", "GE.layers.3.2.fn.to_kv.net.0.weight", "GE.layers.3.2.fn.to_kv.net.1.weight", "GE.layers.4.0.2.weight", "GE.layers.4.0.2.bias", "GE.layers.4.0.3.bias", "GE.layers.4.0.3.running_mean", "GE.layers.4.0.3.running_var", "GE.layers.4.0.3.num_batches_tracked", "GE.layers.5.0.2.weight", "GE.layers.5.0.2.bias", "GE.layers.5.0.3.bias", "GE.layers.5.0.3.running_mean", "GE.layers.5.0.3.running_var", "GE.layers.5.0.3.num_batches_tracked", "D_aug.D.residual_layers.3.1.fn.to_kv.net.0.weight", "D_aug.D.residual_layers.3.1.fn.to_kv.net.1.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.net.0.weight", "D_aug.D.to_shape_disc_out.1.fn.fn.to_kv.net.1.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.net.0.weight", "D_aug.D.to_shape_disc_out.3.fn.fn.to_kv.net.1.weight". size mismatch for G.layers.0.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.1.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.2.0.3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.3.0.3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.3.2.fn.to_out.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]). size mismatch for G.layers.4.0.3.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for G.layers.5.0.3.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for D.residual_layers.3.1.fn.to_out.weight: copying a param with shape torch.Size([128, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 1024, 1, 1]). size mismatch for D.to_shape_disc_out.1.fn.fn.to_out.weight: copying a param with shape torch.Size([64, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 1024, 1, 1]). size mismatch for D.to_shape_disc_out.3.fn.fn.to_out.weight: copying a param with shape torch.Size([32, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 1024, 1, 1]). size mismatch for GE.layers.0.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.1.0.3.weight: copying a param with shape torch.Size([1024]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.2.0.3.weight: copying a param with shape torch.Size([512]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.3.0.3.weight: copying a param with shape torch.Size([256]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.3.2.fn.to_out.weight: copying a param with shape torch.Size([256, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([256, 1024, 1, 1]). size mismatch for GE.layers.4.0.3.weight: copying a param with shape torch.Size([128]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for GE.layers.5.0.3.weight: copying a param with shape torch.Size([64]) from checkpoint, the shape in current model is torch.Size([1]). size mismatch for D_aug.D.residual_layers.3.1.fn.to_out.weight: copying a param with shape torch.Size([128, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([128, 1024, 1, 1]). size mismatch for D_aug.D.to_shape_disc_out.1.fn.fn.to_out.weight: copying a param with shape torch.Size([64, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([64, 1024, 1, 1]). size mismatch for D_aug.D.to_shape_disc_out.3.fn.fn.to_out.weight: copying a param with shape torch.Size([32, 512, 1, 1]) from checkpoint, the shape in current model is torch.Size([32, 1024, 1, 1]).

    opened by sebastiantrella 7
  • Greyscale image generation

    Greyscale image generation

    Hi,

    thank you for this repo, I've been playing with it a bit and it seems very good! I am trying to generate greyscale images, so I modified the channel accordingly

    init_channel = 4 if transparent else 1

    unfortunately, this seemed to have no effect as the images generated are still RGB (even though they converge towards greyscale with time), even weirder IMO is that I can modify the number of channels for the generator and keep the original 3 for the discriminator without any issue.

    I have also changed this part to no effect

    convert_image_fn = partial(convert_image_to, 'RGBA' if transparent else 'L') num_channels = 1 if not transparent else 4

    Am I missing something here?

    opened by stefanorosss 7
  • Getting NoneType is not subscriptable when trying to start training.

    Getting NoneType is not subscriptable when trying to start training.

    I've been able to train models before but after changing my dataset I'm getting the error.

    My trace: File "/usr/local/lib/python3.6/dist-packages/lightweight_gan/lightweight_gan.py", line 1356, in load name = checkpoints[-1] TypeError: 'NoneType' object is not subscriptable

    opened by TomCallan 6
  • Optimal parameters for Google Colab

    Optimal parameters for Google Colab

    Hello,

    First of all, thank you for sharing your code and insights with the rest of us!

    As for your code, I plan to run it for 12 hours on Google Colab, similarly to the set-up for what is shown in the README.

    My datasets consists of images of 256x256 resolution, and I have started training with the following command-line:

    !lightweight_gan \
     --data {image_dir} \
     --disc-output-size 5 \
     --aug-prob 0.25 \
     --aug-types [translation,cutout,color] \
     --amp \
    

    I have noticed that the expected training time is 112.5 hours with 150k iterations (the default setting), which is consistent with the average time of 2.7 seconds per iteration shown in the log. However, it is ~ 9 times more than what is shown in the README. So I wonder if I am doing something wrong, and I see 2 solutions.

    First, I could decrease the number of iterations so that it takes 12 hours, by choosing 16k iterations instead of 150k with:

     --num-train-steps 16000 \
    

    Is it what you have done for the results shown in the README?

    Second, I have noticed that I am only using 3.8 GB of GPU memory, so I could increase the batch size, as you mentioned in https://github.com/lucidrains/lightweight-gan/issues/13#issuecomment-732486110. Edit: However, the training time increases with a larger batch size. For instance, I am using 7.2 GB of GPU memory, and it takes 8.2 seconds per iteration, with the following:

     --batch-size 32 \
     --gradient-accumulate-every 4 \
    
    opened by woctezuma 6
  • Added Experimentation Tracking.

    Added Experimentation Tracking.

    Added Experimentation Tracking using Aim.

    Now you can:

    Track all the model hyperparameters and architectural choices. Track all types of losses. Filter all the experiments with respect to hyperparameters or the architecture Group and aggregate w.r.t. all the trackables to dive into granular experimentation assessment. Track the generated images to track how the model improves.

    Screen Shot 2022-04-12 at 16 56 35 Screen Shot 2022-04-12 at 16 57 24
    opened by hnhnarek 5
  • Aim installation error

    Aim installation error

    I'm trying to run the generator after training, to generate fake samples using the following command

    lightweight_gan --generate --load-from 299

    I get this following error:

    Traceback (most recent call last):
      File "C:\anaconda3\lib\runpy.py", line 197, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "C:\anaconda3\lib\runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "C:\anaconda3\Scripts\lightweight_gan.exe\__main__.py", line 7, in <module>
      File "C:\anaconda3\lib\site-packages\lightweight_gan\cli.py", line 195, in main
        fire.Fire(train_from_folder)
      File "C:\anaconda3\lib\site-packages\fire\core.py", line 141, in Fire
        component_trace = _Fire(component, args, parsed_flag_args, context, name)
      File "C:\anaconda3\lib\site-packages\fire\core.py", line 466, in _Fire
        component, remaining_args = _CallAndUpdateTrace(
      File "C:\anaconda3\lib\site-packages\fire\core.py", line 681, in _CallAndUpdateTrace
        component = fn(*varargs, **kwargs)
      File "C:\anaconda3\lib\site-packages\lightweight_gan\cli.py", line 158, in train_from_folder
        model = Trainer(**model_args)
      File "C:\anaconda3\lib\site-packages\lightweight_gan\lightweight_gan.py", line 1057, in __init__
        self.run = self.aim.Run(run_hash=aim_run_hash, repo=aim_repo)
    AttributeError: 'Trainer' object has no attribute 'aim'
    

    and when I try to run pip install aim, I get a dependency error with aimrocks

      ERROR: Command errored out with exit status 1:
       command: 'C:\Anaconda3\envs\aerialweb\python.exe' 'C:\Anaconda3\envs\aerialweb\lib\site-packages\pip' install --ignore-installed --no-user --prefix 'C:\Users\ahmed\AppData\Local\Temp\pip-build-env-b2ysw94t\overlay' --no-warn-script-location --no-binary :none: --only-binary :none: -i https://pypi.org/simple -- setuptools 'cython >= 3.0.0a9' 'aimrocks == 0.2.1'
           cwd: None
      Complete output (12 lines):
      WARNING: Ignoring invalid distribution -pencv-python (c:\anaconda3\envs\aerialweb\lib\site-packages)
      WARNING: Ignoring invalid distribution -cipy (c:\anaconda3\envs\aerialweb\lib\site-packages)
      Collecting setuptools
        Using cached setuptools-59.6.0-py3-none-any.whl (952 kB)
      Collecting cython>=3.0.0a9
        Using cached Cython-3.0.0a10-py2.py3-none-any.whl (1.1 MB)
      ERROR: Could not find a version that satisfies the requirement aimrocks==0.2.1 (from versions: 0.1.3a14, 0.2.0.dev1, 0.2.0)
      ERROR: No matching distribution found for aimrocks==0.2.1
      WARNING: Ignoring invalid distribution -pencv-python (c:\anaconda3\envs\aerialweb\lib\site-packages)
      WARNING: Ignoring invalid distribution -cipy (c:\anaconda3\envs\aerialweb\lib\site-packages)
      WARNING: Ignoring invalid distribution -pencv-python (c:\anaconda3\envs\aerialweb\lib\site-packages)
      WARNING: Ignoring invalid distribution -cipy (c:\anaconda3\envs\aerialweb\lib\site-packages)
    

    What is aimrocks and what does it actually do? I am unable to find a matching distribution or even a wheel file to install it manually. Please help

    opened by demiahmed 4
  • Can't find

    Can't find "__main__" module (sorry if noob question)

    Hello, I hope it's not too much of a noob question, I don't have any background in coding.

    After creating the env and installing Pytorch I ran "python setup.py install" and then I ran "python lightweight_gan --data /source --image-size 512" (I filled a "source" folder with pictures of fishes) but I get the error "can't find 'main' module". More exactly, C:\Programmes perso\Logiciels\Anaconda\envs\lightweightgan\python.exe: can't find 'main' module in 'C:\Programmes perso\Logiciels\LightweightGan\lightweight_gan' I tried to copy and rename some of the other modules (init, lightweight_gan...), the code seems to start to run but stops before doing anything. So I guess some file must be missing, or did I do something wrong ?

    Thanks a lot for the repo and have a nice day

    opened by SPotaufeux 4
  • Hard cutoff straight lines/boxes of nothing in generated images

    Hard cutoff straight lines/boxes of nothing in generated images

    Hello! Training on Google Colab with

    !lightweight_gan --data my/images/ --name my-name --image-size 256 --transparent --dual-contrast-loss --num-train-steps 250000
    

    I'm at 250k iterations over the course of 5 days at 2s/it, and have gotten strange results with boxes.

    I've circled some examples of this below. image

    My training data is 22k images of 256x256 .pngs that do not contain large hard edges or boxes like this. They're video game sprites with hard edges being limited to at most 10x10px

    Are there any suggestions I can do with arguments in order to decrease the chance of the models learning that transparent boxes are good? Would converting to a white background help?

    Thank you!

    opened by timendez 4
  • Amount of training steps

    Amount of training steps

    If I bring down the number of training steps from 150 000 to 30 000, will the trained model be overall bad? Does it really need the 100 000 or 150 000 training steps?

    opened by MartimQS 4
  • Executing with a trailing \ in the arguments sets the variable new to the truthy value '\\' and deletes all progress

    Executing with a trailing \ in the arguments sets the variable new to the truthy value '\\' and deletes all progress

    A rather frustrating issue:

    calling it with a trailing \ like lightweight_gan --data full_cleaned256/ --name c256 --models_dir models --results_dir results \

    sets the variable new to the truthy value '\' and deletes all progress.

    This might well be an issue with Fire but might be mitigated or fixed here too, I am unsure about that.

    Thanks. Jonas

    opened by deklesen 0
  • Projecting generated Images to Latent Space

    Projecting generated Images to Latent Space

    Is there any way to reverse engineer the generated images into the latent space?

    I am trying to embed fresh RGB as well as ones generated by the Generator into the latent space so I can find its nearest neighbour, pretty much like AI image editing tools.

    I plan to convert my RGB image into tensor embeddings based on my trained model and tweak the feature vectors.

    How can I achieve this with lightweight-gan?

    opened by demiahmed 0
  • Discriminator Loss converges to 0 while Generator loss pretty high

    Discriminator Loss converges to 0 while Generator loss pretty high

    I am trying to train with a custom image dataset for about 600,000 epochs. At about halfway, my D_loss converges to 0 while my G_loss stays put at 2.5

    My evaluation outputs are slowly starting to fade out to either black or white.

    Is there any thing that I could to tweak my model? Either by increasing the threshold for the Discriminator or by training the Generator only?

    opened by demiahmed 3
  • loss implementation differs from paper

    loss implementation differs from paper

    Hi,

    Thanks for this amazing implementation! I have a question concerning the loss implementation, as it seems to differ from the original equations. The screenshot below shows the GAN loss as presented in the paper :

    paper_losses

    • in red, the discriminator loss (D loss) on the true labels,
    • in green the D loss on labels for fake generated images,
    • and in blue, the generator loss (G loss) on labels for fake images.

    This makes sense to me. Since it is assumed that D outputs values between 0 and 1 (0 = fake, 1 = real) :

    • in red, we want D to output 1 for true images → let's assume D indeed outputs 1 for true images : -min(0, -1 + D(x)) = 0, which is indeed the minimum achievable,
    • in green, we want D to output 0 (from the discriminator perspective) for fake images → let's assume D indeed outputs 0 for fake images : -min(0, -1 - D(x^)) = 1, which is the minimum achievable if D outputs values only between 0 and 1,
    • in blue, we want D to output 1 (from the generator perspective) for fake images : the equation follows directly.

    Now, the way the authors implement this in the code provided in the supplementary materials of the paper is as follows (the colors match the ones in the above picture)

    og_code_loss_d_real og_code_loss_d_fake og_code_loss_g

    Except for the strange involved randomness (already explained in https://github.com/lucidrains/lightweight-gan/issues/11), their implementation is a one to one match with the paper equations.


    The way it is implemented in this repo however is quite different, and I do not understand why..

    lighweight_gan_losses

    Let's start with the discriminator loss :

    • in red, you want D to output small values (negative if allowed), to set this term as small as possible (0 if D can output negative values)
    • in green, you want D to output values as large as possible (larger or equal to 1) to cancel this term out as well

    For the generator loss :

    • in blue, you want the opposite of green, that is for D to output values as small as possible

    This implementation seems to be meaningful, and yields coherent results (as proven in examples). It also seems to me that D is not limited to output values between 0 and 1, but any real value (I might be wrong). I am just wondering why this choice? Could you perhaps elaborate why you decided to implement the loss differently from the original paper?

    opened by maximeraafat 1
  • showing results while training ?

    showing results while training ?

    how to show generator results after every epoch during training ?

    this is my current configuration

     lightweight_gan \
      --data "/content/dataset/Dataset/" \
      --num-train-steps 100000 \
      --image-size 128 \
      --name GAN2DBlood5k \
      --batch-size 32 \
      --gradient-accumulate-every 5 \
      --disc-output-size 1 \
      --dual-contrast-loss \
      --attn-res-layers [] \
      --calculate_fid_every 1000\
      --greyscale \
      --amp
    

    using --show-progress only works after training. Also it seems that there is no longer checkpoints per epoch

    opened by galaelized 2
Releases(1.1.1)
Owner
Phil Wang
Working with Attention. It's all we need.
Phil Wang
AOT-GAN for High-Resolution Image Inpainting (codebase for image inpainting)

AOT-GAN for High-Resolution Image Inpainting Arxiv Paper | AOT-GAN: Aggregated Contextual Transformations for High-Resolution Image Inpainting Yanhong

Multimedia Research 214 Jan 3, 2023
A fast poisson image editing implementation that can utilize multi-core CPU or GPU to handle a high-resolution image input.

Poisson Image Editing - A Parallel Implementation Jiayi Weng (jiayiwen), Zixu Chen (zixuc) Poisson Image Editing is a technique that can fuse two imag

Jiayi Weng 110 Dec 27, 2022
This is an official pytorch implementation of Lite-HRNet: A Lightweight High-Resolution Network.

Lite-HRNet: A Lightweight High-Resolution Network Introduction This is an official pytorch implementation of Lite-HRNet: A Lightweight High-Resolution

HRNet 675 Dec 25, 2022
A Fast and Stable GAN for Small and High Resolution Imagesets - pytorch

A Fast and Stable GAN for Small and High Resolution Imagesets - pytorch The official pytorch implementation of the paper "Towards Faster and Stabilize

Bingchen Liu 455 Jan 8, 2023
A collection of pre-trained StyleGAN2 models trained on different datasets at different resolution.

Awesome Pretrained StyleGAN2 A collection of pre-trained StyleGAN2 models trained on different datasets at different resolution. Note the readme is a

Justin 1.1k Dec 24, 2022
FuseDream: Training-Free Text-to-Image Generationwith Improved CLIP+GAN Space OptimizationFuseDream: Training-Free Text-to-Image Generationwith Improved CLIP+GAN Space Optimization

FuseDream This repo contains code for our paper (paper link): FuseDream: Training-Free Text-to-Image Generation with Improved CLIP+GAN Space Optimizat

XCL 191 Dec 31, 2022
TransGAN: Two Transformers Can Make One Strong GAN

[Preprint] "TransGAN: Two Transformers Can Make One Strong GAN", Yifan Jiang, Shiyu Chang, Zhangyang Wang

VITA 1.5k Jan 7, 2023
Lite-HRNet: A Lightweight High-Resolution Network

LiteHRNet Benchmark ?? ?? Based on MMsegmentation ?? ?? Cityscapes FCN resize concat config mIoU last mAcc last eval last mIoU best mAcc best eval bes

null 16 Dec 12, 2022
Boosting Monocular Depth Estimation Models to High-Resolution via Content-Adaptive Multi-Resolution Merging

Boosting Monocular Depth Estimation Models to High-Resolution via Content-Adaptive Multi-Resolution Merging This repository contains an implementation

Computational Photography Lab @ SFU 1.1k Jan 2, 2023
Source code, datasets and trained models for the paper Learning Advanced Mathematical Computations from Examples (ICLR 2021), by François Charton, Amaury Hayat (ENPC-Rutgers) and Guillaume Lample

Maths from examples - Learning advanced mathematical computations from examples This is the source code and data sets relevant to the paper Learning a

Facebook Research 171 Nov 23, 2022
Unofficial pytorch implementation of the paper "Dynamic High-Pass Filtering and Multi-Spectral Attention for Image Super-Resolution"

DFSA Unofficial pytorch implementation of the ICCV 2021 paper "Dynamic High-Pass Filtering and Multi-Spectral Attention for Image Super-Resolution" (p

null 2 Nov 15, 2021
Official PyTorch implementation of "VITON-HD: High-Resolution Virtual Try-On via Misalignment-Aware Normalization" (CVPR 2021)

VITON-HD — Official PyTorch Implementation VITON-HD: High-Resolution Virtual Try-On via Misalignment-Aware Normalization Seunghwan Choi*1, Sunghyun Pa

Seunghwan Choi 250 Jan 6, 2023
Implementation for HFGI: High-Fidelity GAN Inversion for Image Attribute Editing

HFGI: High-Fidelity GAN Inversion for Image Attribute Editing High-Fidelity GAN Inversion for Image Attribute Editing Update: We released the inferenc

Tengfei Wang 371 Dec 30, 2022
Official repository for "Restormer: Efficient Transformer for High-Resolution Image Restoration". SOTA for motion deblurring, image deraining, denoising (Gaussian/real data), and defocus deblurring.

Restormer: Efficient Transformer for High-Resolution Image Restoration Syed Waqas Zamir, Aditya Arora, Salman Khan, Munawar Hayat, Fahad Shahbaz Khan,

Syed Waqas Zamir 906 Dec 30, 2022
Code for Two-stage Identifier: "Locate and Label: A Two-stage Identifier for Nested Named Entity Recognition"

Code for Two-stage Identifier: "Locate and Label: A Two-stage Identifier for Nested Named Entity Recognition", accepted at ACL 2021. For details of the model and experiments, please see our paper.

tricktreat 87 Dec 16, 2022
This repository contains the implementation of Deep Detail Enhancment for Any Garment proposed in Eurographics 2021

Deep-Detail-Enhancement-for-Any-Garment Introduction This repository contains the implementation of Deep Detail Enhancment for Any Garment proposed in

null 40 Dec 13, 2022
This repo contains the implementation of the algorithm proposed in Off-Belief Learning, ICML 2021.

Off-Belief Learning Introduction This repo contains the implementation of the algorithm proposed in Off-Belief Learning, ICML 2021. Environment Setup

Facebook Research 32 Jan 5, 2023
A PyTorch Reimplementation of TecoGAN: Temporally Coherent GAN for Video Super-Resolution

TecoGAN-PyTorch Introduction This is a PyTorch reimplementation of TecoGAN: Temporally Coherent GAN for Video Super-Resolution (VSR). Please refer to

null 165 Dec 17, 2022