Pytorch implementation AttnGAN: Fine-Grained Text to Image Generation with Attentional Generative Adversarial Networks

Overview

AttnGAN

Pytorch implementation for reproducing AttnGAN results in the paper AttnGAN: Fine-Grained Text to Image Generation with Attentional Generative Adversarial Networks by Tao Xu, Pengchuan Zhang, Qiuyuan Huang, Han Zhang, Zhe Gan, Xiaolei Huang, Xiaodong He. (This work was performed when Tao was an intern with Microsoft Research).

Dependencies

python 2.7

Pytorch

In addition, please add the project folder to PYTHONPATH and pip install the following packages:

  • python-dateutil
  • easydict
  • pandas
  • torchfile
  • nltk
  • scikit-image

Data

  1. Download our preprocessed metadata for birds coco and save them to data/
  2. Download the birds image data. Extract them to data/birds/
  3. Download coco dataset and extract the images to data/coco/

Training

  • Pre-train DAMSM models:

    • For bird dataset: python pretrain_DAMSM.py --cfg cfg/DAMSM/bird.yml --gpu 0
    • For coco dataset: python pretrain_DAMSM.py --cfg cfg/DAMSM/coco.yml --gpu 1
  • Train AttnGAN models:

    • For bird dataset: python main.py --cfg cfg/bird_attn2.yml --gpu 2
    • For coco dataset: python main.py --cfg cfg/coco_attn2.yml --gpu 3
  • *.yml files are example configuration files for training/evaluation our models.

Pretrained Model

Sampling

  • Run python main.py --cfg cfg/eval_bird.yml --gpu 1 to generate examples from captions in files listed in "./data/birds/example_filenames.txt". Results are saved to DAMSMencoders/.
  • Change the eval_*.yml files to generate images from other pre-trained models.
  • Input your own sentence in "./data/birds/example_captions.txt" if you wannt to generate images from customized sentences.

Validation

  • To generate images for all captions in the validation dataset, change B_VALIDATION to True in the eval_*.yml. and then run python main.py --cfg cfg/eval_bird.yml --gpu 1
  • We compute inception score for models trained on birds using StackGAN-inception-model.
  • We compute inception score for models trained on coco using improved-gan/inception_score.

Examples generated by AttnGAN [Blog]

bird example coco example

Creating an API

Evaluation code embedded into a callable containerized API is included in the eval\ folder.

Citing AttnGAN

If you find AttnGAN useful in your research, please consider citing:

@article{Tao18attngan,
  author    = {Tao Xu, Pengchuan Zhang, Qiuyuan Huang, Han Zhang, Zhe Gan, Xiaolei Huang, Xiaodong He},
  title     = {AttnGAN: Fine-Grained Text to Image Generation with Attentional Generative Adversarial Networks},
  Year = {2018},
  booktitle = {{CVPR}}
}

Reference

Comments
  • Telemetry key error

    Telemetry key error

    Hi, I'm getting the error on running python main.py --cfg cfg/eval_coco.yml --gpu 1

    Traceback (most recent call last):
      File "main.py", line 12, in <module>
        enable(os.environ["TELEMETRY"])
      File "/usr/lib/python2.7/UserDict.py", line 40, in __getitem__
        raise KeyError(key)
    KeyError: 'TELEMETRY'
    
    opened by sleebapaul 14
  • The inception score on coco dataset

    The inception score on coco dataset

    Hi,in my experiment,the inception score of the pretrained model on coco dataset is 16.16.Do you know why the IS is not stable?How can make the IS to 25.89?

    opened by liushanyuan18 4
  • External dataset

    External dataset

    Hey. First of all need to say that job and result is absolutely amazing. Thank you for your work and sharing the code with the community.

    Going through the steps to run your code with your data is quite easy and I get the same result as yours. But when i try to figure out how to test your architecture on external data i faced the issue of preprocessed metadata for each dataset you work with.

    So can you please list some steps on how to feed some external data (aka bunch of images with captions) to your model (including pretraining DAMSM and embedding vectors). Suppose this is will be very useful information to expand your research on broad spheres.

    opened by ZhukIvan 3
  • Discriminator loss function differs from paper?

    Discriminator loss function differs from paper?

    Can someone explain what the function of cond_wrong_errD is in the discriminator loss function? It seems not a part of the discriminator loss mentioned in the paper (Eq. 5). Also, it does not make sense to me. Why ignore the last entry in the batch?

    cond_wrong_logits = netD.COND_DNET(real_features[:(batch_size - 1)], conditions[1:batch_size])
    cond_wrong_errD = nn.BCELoss()(cond_wrong_logits, fake_labels[1:batch_size])
    
    opened by davidstap 2
  • fix noise is used for training

    fix noise is used for training

    Hi, thanks for sharing your brilliant work. I am a little confused when I read your code in trainer.py line 227 where a set of fixed noise is sampled before the training loop of epoch, which means the whole following training process used the fixed noise. As far as I know about GAN, the noise should not be fixed. If a set of fixed noise vectors is used for training, the model will lose generalization ability, which means when the input noise vector is randomly sampled, the model cannot generate a good image. When I use your code and trained model to generate images with randomly sampled noise, the results are very bad. Please let me know where I have missed understanding or incorrect setting. Thanks in advance!

    opened by wtliao 1
  • Problem with AttnGAN training

    Problem with AttnGAN training

    Thanks for this excellent work! I have some problem with AttnGAN training. I am not sure whether i have a correct thought of the following code.

    /code/trainer.py
    line 246 ~ 271
    imgs, captions, cap_lens, class_ids, keys = prepare_data(data)
    errD = discriminator_loss(netsD[i], imgs[i], fake_imgs[i],
                                                  sent_emb, real_labels, fake_labels)
    

    I think netsD[i] is the ith discriminator used in AttnGAN. In paper, there are three discriminators. I also think that imgs are corresponding to captions and have a size equal to batch size. So, how can parameter i represents both the netsD index and img index.

    opened by ecfish 1
  • issue with training

    issue with training "bird_attnDCGAN2.yml"

    Traceback (most recent call last): File "/AttnGAN-master3/code/main.py", line 140, in algo.train() File "\AttnGAN-master3\code\trainer.py", line 285, in train sent_emb, real_labels, fake_labels) File "\AttnGAN-master3\code\miscc\losses.py", line 143, in discriminator_loss cond_real_logits = netD.COND_DNET(real_features, conditions) File "\Anaconda3\envs\torch-env\lib\site-packages\torch\nn\modules\module.py", line 489, in call result = self.forward(*input, **kwargs) File "\AttnGAN-master3\code\model.py", line 567, in forward h_c_code = torch.cat((h_code, c_code), 1) RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 1. Got 4 and 1 in dimension 2 at c:\a\w\1\s\tmp_conda_3.6_091443\conda\conda-bld\pytorch_1544087948354\work\aten\src\thc\generic/THCTensorMath.cu:83

    Process finished with exit code 1

    Training using config file "bird_attn2.yml" worked normally, but when I tried to use DCGAN version, it failed with errors above.

    opened by weili-git 1
  • upblock

    upblock

    def upBlock(in_planes, out_planes):
        block = nn.Sequential(
            nn.Upsample(scale_factor=2, mode='nearest'),
            conv3x3(in_planes, out_planes * 2),
            nn.BatchNorm2d(out_planes * 2),
            GLU())
        return block
    

    i think the out_planes shouldn't times 2, and then the channel can decrease step by step ,because the argument feed into the function in_planes have been two times than out_planes

    opened by lsabrinax 1
  • Images for new text

    Images for new text

    Where should I place my text into, if I want to generate new images. I have successfully reproduced your images, but when i put my own images into the example_captions.txt they dont come up. So where should I put new text ???

    opened by ASH1998 1
  • Why does netG generate different data for same input?

    Why does netG generate different data for same input?

    In many places, the fake images are generated via:

    fake_imgs, _, _, _ = netG(noise, sent_emb, words_embs, mask)

    The netG is of class G_NET defined in https://github.com/taoxugit/AttnGAN/blob/master/code/model.py#L397.

    When I keep noise, sent_emb, words_embs and mask constant and rerun the generation, I get different fake images. Shouldn't the model be outputting a constant output for a constant input? Is there any stochastic behaviour of the G_NET?

    opened by arunpatro 1
  • global ignore in /data excludes example files

    global ignore in /data excludes example files

    The global git ignore file in the /data directory currently excludes example files such as /data/birds/example_filenames.txt and /data/birds/example_captions.txt.

    Preserving these files by using git ignore more selectively should help new users test this network more easily.

    opened by cyremur 1
  • OpenCL with AMD GPU, instead of CUDA, Nvidia GPU

    OpenCL with AMD GPU, instead of CUDA, Nvidia GPU

    I was exploring the codebase, and wanted to test some ideas on the same. But my machine has an AMD GPU, thus no support for CUDA. I shall have to use OpenCL instead. Not sure, if there is a way to make some tweaks and make it run here on my machine, or rewriting is the only way to achieve the same?

    opened by paxF3E 1
  • Resume the training?

    Resume the training?

    Can someone help me resume the training? I dont have GPU of my own. Currently i am training it in Colab and it only gives 12 hours which is not enough. So can someone help to to resume training from saved "netg" model. PLEASE!

    opened by AnishMachamasi 0
  • What is the use of class and masks?

    What is the use of class and masks?

     if class_ids is not None:
            scores0.data.masked_fill_(masks, -float('inf'))
    

    This code is used in some of the generator losses or in the DAMSM loss.. what is the need for using class when we want to generate images and not do classification?

    opened by rohit901 0
  • How do you implement the R precision for CUB dataset?How many images have you generated?

    How do you implement the R precision for CUB dataset?How many images have you generated?

    Please someone help me. The validation set for CUB is 2933, How many images have you generated to calculate the R precision? 30000? in code, How much have you considered the size of R? Is the condition of 30,000 considered? `R_count = 0 R = np.zeros(2928) . . .

    if R_count >= 30000: sum = np.zeros(8) np.random.shuffle(R) for i in range(8): sum[i] = np.average(R[i * 3000:(i + 1) * 3000- 1]) R_mean = np.average(sum)*100 R_std = np.std(sum)*100 print("R mean:{:.2f} std:{:.2f}".format(R_mean, R_std)) cont = False `

    opened by fm5o1 0
Owner
Tao Xu
Ph.D. Candidate in Computer Science
Tao Xu
StudioGAN is a Pytorch library providing implementations of representative Generative Adversarial Networks (GANs) for conditional/unconditional image generation.

StudioGAN is a Pytorch library providing implementations of representative Generative Adversarial Networks (GANs) for conditional/unconditional image generation.

null 3k Jan 8, 2023
The implementation of CVPR2021 paper Temporal Query Networks for Fine-grained Video Understanding, by Chuhan Zhang, Ankush Gupta and Andrew Zisserman.

Temporal Query Networks for Fine-grained Video Understanding ?? This repository contains the implementation of CVPR2021 paper Temporal_Query_Networks

null 55 Dec 21, 2022
FIRA: Fine-Grained Graph-Based Code Change Representation for Automated Commit Message Generation

FIRA is a learning-based commit message generation approach, which first represents code changes via fine-grained graphs and then learns to generate commit messages automatically.

Van 21 Dec 30, 2022
Super Pix Adv - Offical implemention of Robust Superpixel-Guided Attentional Adversarial Attack (CVPR2020)

Super_Pix_Adv Offical implemention of Robust Superpixel-Guided Attentional Adver

DLight 8 Oct 26, 2022
This is the official PyTorch implementation of the paper "TransFG: A Transformer Architecture for Fine-grained Recognition" (Ju He, Jie-Neng Chen, Shuai Liu, Adam Kortylewski, Cheng Yang, Yutong Bai, Changhu Wang, Alan Yuille).

TransFG: A Transformer Architecture for Fine-grained Recognition Official PyTorch code for the paper: TransFG: A Transformer Architecture for Fine-gra

Ju He 307 Jan 3, 2023
PyTorch implementation for Stochastic Fine-grained Labeling of Multi-state Sign Glosses for Continuous Sign Language Recognition.

Stochastic CSLR This is the PyTorch implementation for the ECCV 2020 paper: Stochastic Fine-grained Labeling of Multi-state Sign Glosses for Continuou

Zhe Niu 28 Dec 19, 2022
official Pytorch implementation of ICCV 2021 paper FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting.

FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting By Rui Liu, Hanming Deng, Yangyi Huang, Xiaoyu Shi, Lewei Lu, Wenxiu

null 77 Dec 27, 2022
PyTorch implementation of Weak-shot Fine-grained Classification via Similarity Transfer

SimTrans-Weak-Shot-Classification This repository contains the official PyTorch implementation of the following paper: Weak-shot Fine-grained Classifi

BCMI 60 Dec 2, 2022
Official PyTorch implementation of N-ImageNet: Towards Robust, Fine-Grained Object Recognition with Event Cameras (ICCV 2021)

N-ImageNet: Towards Robust, Fine-Grained Object Recognition with Event Cameras Official PyTorch implementation of N-ImageNet: Towards Robust, Fine-Gra

null 32 Dec 26, 2022
Code release for The Devil is in the Channels: Mutual-Channel Loss for Fine-Grained Image Classification (TIP 2020)

The Devil is in the Channels: Mutual-Channel Loss for Fine-Grained Image Classification Code release for The Devil is in the Channels: Mutual-Channel

PRIS-CV: Computer Vision Group 230 Dec 31, 2022
Pytorch implementation for reproducing StackGAN_v2 results in the paper StackGAN++: Realistic Image Synthesis with Stacked Generative Adversarial Networks

StackGAN-v2 StackGAN-v1: Tensorflow implementation StackGAN-v1: Pytorch implementation Inception score evaluation Pytorch implementation for reproduci

Han Zhang 809 Dec 16, 2022
PyTorch implementation of Lip to Speech Synthesis with Visual Context Attentional GAN (NeurIPS2021)

Lip to Speech Synthesis with Visual Context Attentional GAN This repository contains the PyTorch implementation of the following paper: Lip to Speech

null 6 Nov 2, 2022
CoaT: Co-Scale Conv-Attentional Image Transformers

CoaT: Co-Scale Conv-Attentional Image Transformers Introduction This repository contains the official code and pretrained models for CoaT: Co-Scale Co

mlpc-ucsd 191 Dec 3, 2022
WHENet: Real-time Fine-Grained Estimation for Wide Range Head Pose

WHENet: Real-time Fine-Grained Estimation for Wide Range Head Pose Yijun Zhou and James Gregson - BMVC2020 Abstract: We present an end-to-end head-pos

null 368 Dec 26, 2022
Code and data of the Fine-Grained R2R Dataset proposed in paper Sub-Instruction Aware Vision-and-Language Navigation

Fine-Grained R2R Code and data of the Fine-Grained R2R Dataset proposed in the EMNLP2020 paper Sub-Instruction Aware Vision-and-Language Navigation. C

YicongHong 34 Nov 15, 2022
The coda and data for "Measuring Fine-Grained Domain Relevance of Terms: A Hierarchical Core-Fringe Approach" (ACL '21)

We propose a hierarchical core-fringe learning framework to measure fine-grained domain relevance of terms – the degree that a term is relevant to a broad (e.g., computer science) or narrow (e.g., deep learning) domain.

Jie Huang 14 Oct 21, 2022
Code for Talk-to-Edit (ICCV2021). Paper: Talk-to-Edit: Fine-Grained Facial Editing via Dialog.

Talk-to-Edit (ICCV2021) This repository contains the implementation of the following paper: Talk-to-Edit: Fine-Grained Facial Editing via Dialog Yumin

Yuming Jiang 221 Jan 7, 2023
[ICCV 2021] Counterfactual Attention Learning for Fine-Grained Visual Categorization and Re-identification

Counterfactual Attention Learning Created by Yongming Rao*, Guangyi Chen*, Jiwen Lu, Jie Zhou This repository contains PyTorch implementation for ICCV

Yongming Rao 90 Dec 31, 2022
SnapMix: Semantically Proportional Mixing for Augmenting Fine-grained Data (AAAI 2021)

SnapMix: Semantically Proportional Mixing for Augmenting Fine-grained Data (AAAI 2021) PyTorch implementation of SnapMix | paper Method Overview Cite

DavidHuang 126 Dec 30, 2022