Can We Find Neurons that Cause Unrealistic Images in Deep Generative Networks?
Artifact Detection/Correction - Offcial PyTorch Implementation
This repo provides the official PyTorch implementation of the following paper:
Can We Find Neurons that Cause Unrealistic Images in Deep Generative Networks?
Hwanil Choi, Wonjoon Chang, Jaesik Choi*
Korea Advanced Institute of Science and Technology, KAIST
Abstract
Even though image generation with Generative Adversarial Networks (GANs) has been showing remarkable ability to generate high-quality images, GANs do not always guarantee photorealistic images will be generated. Sometimes they generate images that have defective or unnatural objects, which are referred to as 'artifacts'. Research to determine why the artifacts emerge and how they can be detected and removed has not been sufficiently carried out. To analyze this, we first hypothesize that rarely activated neurons and frequently activated neurons have different purposes and responsibilities for the progress of generating images. By analyzing the statistics and the roles for those neurons, we empirically show that rarely activated neurons are related to failed results of making diverse objects and lead to artifacts. In addition, we suggest a correction method, called 'sequential ablation', to repair the defective part of the generated images without complex computational cost and manual efforts.
https://arxiv.org/abs/1812.04948
Dependencies
- PyTorch 1.4.0
- python 3.6
- cuda 10.0.x
- cudnn 7.6.3
GenForce
Pre-Trained Models (Official) -Dataset \ Model | PGGAN | StyleGAN2 |
---|---|---|
CelebA-HQ (Official) | 1024 x 1024 | X |
FFHQ (Official) | X | 1024 X 1024 |
LSUN-Church (Official) | 256 x 256 | 256 x 256 |
LSUN-CAT (Official) | 256 x 256 | 256 x 256 |
For following implementation, download StyleGAN2 FFHQ weights in current directory. Otherwise, you should change the '--weight_path' options to your directory.
More pre-trained weights are available in genforce-model-zoo
- optional : StyleGAN3
Implementation
- Options
optional arguments:
-h, --help show this help message and exit
--gpu GPU gpu index numper
--batch_size BATCH_SIZE
batch size for pre processing and generating process
--sample_size SAMPLE_SIZE
sample size for statistics
--freq_path FREQ_PATH
loading saved frequencies of neurons
--model MODEL pggan, styelgan2
--dataset DATASET ffhq, cat, church, etc
--resolution RESOLUTION
dataset resolution
--weight_path WEIGHT_PATH
pre-trained weight path
--detection DETECTION
implement normal/artifact detection
--correction CORRECTION
implement correction task
- Usage
python main.py --gpu 0 --batch_size 30 --sample_size 30000 --freq_pth ./stats \
--model stylegan2 --dataset ffhq --resolution 1024 --weight_path ./ \
--detection True --correction True
If you are on remote server, then to show the results, you should do the following. (X11 forwarding).
You can also implement our codes in 'Jupyter Notebook' that has more degree of freedom. Use the 'notebook.ipynb' file.