TransGAN: Two Transformers Can Make One Strong GAN

Overview

TransGAN: Two Transformers Can Make One Strong GAN

Code used for TransGAN: Two Transformers Can Make One Strong GAN.

Main Pipeline

Main Pipeline

Visual Results

Visual Results

prepare fid statistic file

mkdir fid_stat

Download the pre-calculated statistics (Google Drive) to ./fid_stat.

Environment

pip install -r requirements.txt

Notice: Pytorch version has to be <=1.3.0 !

Training

Coming soon

Testing

Firstly download the checkpoint from (Google Drive) to ./pretrained_weight

# cifar-10
sh exps/cifar10_test.sh

# stl-10
sh exps/stl10_test.sh

Acknowledgement

FID code and CIFAR-10 statistics file from https://github.com/bioinf-jku/TTUR (official). Codebase from AutoGAN

Comments
  • About the patchsize in the Generator

    About the patchsize in the Generator

    Dear author: Thanks for your job. In your model_search cifa and 256 size code ,I do not find where the patchsize use in the Generator but only in the Discriminator. Can you tell me where it is in your code? Thank you very much.

    opened by destructive-observer 16
  • Training fails

    Training fails

    I've tried using the function train in functions.py, and training seems to fail:

      File "functional.py", line 86, in adam
        exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
    RuntimeError: output with shape [768] doesn't match the broadcast shape [3, 48, 1, 768]
    ``
    opened by anj1 8
  • UnboundLocalError: local variable 'gen_net' referenced before assignment

    UnboundLocalError: local variable 'gen_net' referenced before assignment

    excuse me, sorry to disturb you. When I run the train_derived.py, the terminal appeared the "UnboundLocalError: local variable 'gen_net' referenced before assignment". The sentences of if , elif and else are used too much, I think. How can I fix it. Thanks.

    opened by Mr-AlanLiu 6
  • Evolution of the generations during training

    Evolution of the generations during training

    Hello, I was wondering if you had available anywhere some samples of the outputs of the GAN a different stages of training: like before training, after N epochs, etc... To get an idea of how the model is reaching its goal, and what to expect while training to see if I'm way out of the path. I have been trying to make something very similar with limited success, and always have to fall back to some convolutional layers (like putting a Conv2D after every attention layer) to get any relevant results...

    Thanks!

    opened by pabloppp 6
  • Question about GAN training.

    Question about GAN training.

    Hi, thanks for the work. Follow the training code in functions.py, it seems that you did not freeze D when training G. When running dis_optimizer.step(), the gradient from G training will also be used to update D's parameters. So I wonder whether if I missed something, or it was a bug here. Thanks a lot!

    opened by ForeverFancy 5
  • Linear Unflatten layer

    Linear Unflatten layer

    As I understand, your paper wants to completely remove convolution layers. But in code, for the linear unflatten layer (to obtain RGB image), I see you use conv2d. Why do you use this? Is there anyway that we can get RGB image without conv2d?

    opened by nthcode 5
  • About the role of function “get_attn_mask”.

    About the role of function “get_attn_mask”.

    Thanks for your work!

    When I'm looking at the code of model/Celeba64_TransGAN.py , I notice that the function “get_attn_mask” seems to play a role in the training process.Can you point out the specific role of this function?

    Thank you~~

    opened by NNNNAI 5
  • RuntimeError: The size of tensor a (5) must match the size of tensor b (4097) at non-singleton dimension 1

    RuntimeError: The size of tensor a (5) must match the size of tensor b (4097) at non-singleton dimension 1

    Before iteration , i meet follow error,what should i do image

    and the cfg.py

    import argparse

    def str2bool(v): if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: raise argparse.ArgumentTypeError('Boolean value expected.')

    def parse_args(): parser = argparse.ArgumentParser() parser.add_argument('--world-size', default=-1, type=int, help='number of nodes for distributed training') parser.add_argument('--rank', default=-1, type=int, help='node rank for distributed training') parser.add_argument('--loca_rank', default=-1, type=int, help='node rank for distributed training') parser.add_argument('--dist-url', default='tcp://224.66.41.62:23456', type=str, help='url used to set up distributed training') parser.add_argument('--dist-backend', default='nccl', type=str, help='distributed backend') parser.add_argument('--seed', default=12345, type=int, help='seed for initializing training. ') parser.add_argument('--gpu', default=0, type=int, help='GPU id to use.') parser.add_argument('--multiprocessing-distributed', action='store_true', help='Use multi-processing distributed training to launch ' 'N processes per node, which has N GPUs. This is the ' 'fastest way to use PyTorch for either single node or ' 'multi node data parallel training') parser.add_argument( '--max_epoch', type=int, default=200, help='number of epochs of training') parser.add_argument( '--max_iter', type=int, default=10, help='set the max iteration number') parser.add_argument( '-gen_bs', '--gen_batch_size', type=int, default=4, help='size of the batches') parser.add_argument( '-dis_bs', '--dis_batch_size', type=int, default=4, help='size of the batches') parser.add_argument( '--g_lr', type=float, default=0.0002, help='adam: gen learning rate') parser.add_argument( '--wd', type=float, default=0, help='adamw: gen weight decay') parser.add_argument( '--d_lr', type=float, default=0.0002, help='adam: disc learning rate') parser.add_argument( '--ctrl_lr', type=float, default=3.5e-4, help='adam: ctrl learning rate') parser.add_argument( '--lr_decay', action='store_true', help='learning rate decay or not') parser.add_argument( '--beta1', type=float, default=0.0, help='adam: decay of first order momentum of gradient') parser.add_argument( '--beta2', type=float, default=0.9, help='adam: decay of first order momentum of gradient') parser.add_argument( '--num_workers', type=int, default=0, help='number of cpu threads to use during batch generation') parser.add_argument( '--latent_dim', type=int, default=128, help='dimensionality of the latent space') parser.add_argument( '--img_size', type=int, default=256, help='size of each image dimension') parser.add_argument( '--channels', type=int, default=3, help='number of image channels') parser.add_argument( '--n_critic', type=int, default=1, help='number of training steps for discriminator per iter') parser.add_argument( '--val_freq', type=int, default=20, help='interval between each validation') parser.add_argument( '--print_freq', type=int, default=100, help='interval between each verbose') parser.add_argument( '--load_path', type=str, help='The reload model path') parser.add_argument( '--exp_name', type=str, default='Test', help='The name of exp') parser.add_argument( '--d_spectral_norm', type=str2bool, default=False, help='add spectral_norm on discriminator?') parser.add_argument( '--g_spectral_norm', type=str2bool, default=False, help='add spectral_norm on generator?') parser.add_argument( '--dataset', type=str, default='cifar10', help='dataset type') parser.add_argument( '--data_path', type=str, default='./data', help='The path of data set') parser.add_argument('--init_type', type=str, default='normal', choices=['normal', 'orth', 'xavier_uniform', 'false'], help='The init type') parser.add_argument('--gf_dim', type=int, default=64, help='The base channel num of gen') parser.add_argument('--df_dim', type=int, default=64, help='The base channel num of disc') parser.add_argument( '--gen_model', type=str, default='ViT_custom_rp', help='path of gen model') parser.add_argument( '--dis_model', type=str, default='ViT_custom_rp', help='path of dis model') parser.add_argument( '--controller', type=str, default='controller', help='path of controller') parser.add_argument('--eval_batch_size', type=int, default=100) parser.add_argument('--num_eval_imgs', type=int, default=50000) parser.add_argument( '--bottom_width', type=int, default=4, help="the base resolution of the GAN") parser.add_argument('--random_seed', type=int, default=12345)

    # search
    parser.add_argument('--shared_epoch', type=int, default=15,
                        help='the number of epoch to train the shared gan at each search iteration')
    parser.add_argument('--grow_step1', type=int, default=25,
                        help='which iteration to grow the image size from 8 to 16')
    parser.add_argument('--grow_step2', type=int, default=55,
                        help='which iteration to grow the image size from 16 to 32')
    parser.add_argument('--max_search_iter', type=int, default=90,
                        help='max search iterations of this algorithm')
    parser.add_argument('--ctrl_step', type=int, default=30,
                        help='number of steps to train the controller at each search iteration')
    parser.add_argument('--ctrl_sample_batch', type=int, default=1,
                        help='sample size of controller of each step')
    parser.add_argument('--hid_size', type=int, default=100,
                        help='the size of hidden vector')
    parser.add_argument('--baseline_decay', type=float, default=0.9,
                        help='baseline decay rate in RL')
    parser.add_argument('--rl_num_eval_img', type=int, default=5000,
                        help='number of images to be sampled in order to get the reward')
    parser.add_argument('--num_candidate', type=int, default=10,
                        help='number of candidate architectures to be sampled')
    parser.add_argument('--topk', type=int, default=5,
                        help='preserve topk models architectures after each stage' )
    parser.add_argument('--entropy_coeff', type=float, default=1e-3,
                        help='to encourage the exploration')
    parser.add_argument('--dynamic_reset_threshold', type=float, default=1e-3,
                        help='var threshold')
    parser.add_argument('--dynamic_reset_window', type=int, default=500,
                        help='the window size')
    parser.add_argument('--arch', nargs='+', type=int,
                        help='the vector of a discovered architecture')
    parser.add_argument('--optimizer', type=str, default="adam",
                        help='optimizer')
    parser.add_argument('--loss', type=str, default="hinge",
                        help='loss function')
    parser.add_argument('--n_classes', type=int, default=0,
                        help='classes')
    parser.add_argument('--phi', type=float, default=1,
                        help='wgan-gp phi')
    parser.add_argument('--grow_steps', nargs='+', type=int,default=[50,100,150],
                        help='the vector of a discovered architecture')
    parser.add_argument('--D_downsample', type=str, default="avg",
                        help='downsampling type')
    parser.add_argument('--fade_in', type=float, default=1,
                        help='fade in step')
    parser.add_argument('--d_depth', type=int, default=7,
                        help='Discriminator Depth')
    parser.add_argument('--g_depth', type=str, default="5,4,2",
                        help='Generator Depth')
    parser.add_argument('--g_norm', type=str, default="ln",
                        help='Generator Normalization')
    parser.add_argument('--d_norm', type=str, default="ln",
                        help='Discriminator Normalization')
    parser.add_argument('--g_act', type=str, default="gelu",
                        help='Generator activation Layer')
    parser.add_argument('--d_act', type=str, default="gelu",
                        help='Discriminator activation layer')
    parser.add_argument('--patch_size', type=int, default=4,
                        help='Discriminator Depth')
    parser.add_argument('--fid_stat', type=str, default="./fid_stat/fid_camera.npz",
                        help='Discriminator Depth')
    parser.add_argument('--diff_aug', type=str, default="None",
                        help='differentiable augmentation type')
    parser.add_argument('--accumulated_times', type=int, default=1,
                        help='gradient accumulation')
    parser.add_argument('--g_accumulated_times', type=int, default=1,
                        help='gradient accumulation')
    parser.add_argument('--num_landmarks', type=int, default=64,
                        help='number of landmarks')
    parser.add_argument('--d_heads', type=int, default=4,
                        help='number of heads')
    parser.add_argument('--dropout', type=float, default=0.,
                        help='dropout ratio')
    parser.add_argument('--ema', type=float, default=0.995,
                        help='ema')
    parser.add_argument('--ema_warmup', type=float, default=0.,
                        help='ema warm up')
    parser.add_argument('--ema_kimg', type=int, default=500,
                        help='ema thousand images')
    parser.add_argument('--latent_norm',action='store_true',
        help='latent vector normalization')
    parser.add_argument('--ministd',action='store_true',
        help='mini batch std')
    parser.add_argument('--g_mlp', type=int, default=4,
                        help='generator mlp ratio')
    parser.add_argument('--d_mlp', type=int, default=4,
                        help='discriminator mlp ratio')
    parser.add_argument('--g_window_size', type=int, default=8,
                        help='generator mlp ratio')
    parser.add_argument('--d_window_size', type=int, default=8,
                        help='discriminator mlp ratio')
    parser.add_argument('--show', action='store_true',
                    help='show')
    
    opt = parser.parse_args()
    
    return opt
    
    opened by sanersbug 4
  • model mismatch

    model mismatch

    Hi,thank u for your great work.But I got some problems when testing your provided model from https://drive.google.com/file/d/1Td9baoNua6jNtVvsnJczW1u4QWa1w9sl/view?usp=sharing. It seemed that the network of trained model is different from that in model_search.

    opened by yallien 4
  • Generated Images Have Some Blocking Artifact

    Generated Images Have Some Blocking Artifact

    Due to the patch-wise generation of TransGAN, I found some blocking artifacts in your generations. I think the authors had already known these phenomena. Are there any tricks to eliminate these artifacts? image

    opened by wsxtyrdd 4
  • About the Patch Splitting Image.

    About the Patch Splitting Image.

    Can you explain more about using Conv2d as a patch-splitting image? (in the Discriminator) Since the image will change when applying the Conv and it didn't split the image into parts of a grid as the Vision-Transformer do. And I saw some other implement uses Unfold to split the image.

            self.fRGB_1 = nn.Conv2d(3, embed_dim//4, kernel_size=6, stride=4, padding=1)
            self.fRGB_2 = nn.Conv2d(3, embed_dim//4, kernel_size=10, stride=8, padding=1)
            self.fRGB_3 = nn.Conv2d(3, embed_dim//2, kernel_size=18, stride=16, padding=1)
    
    opened by GajuuzZ 3
  • about test

    about test

    Now I have experimented with a framework in which the generator is transgan and the discriminator is autogan, but it doesn't seem to converge? Epoch is 320, and the experimental result FID is 130. What tricks did you use in the experiment?Thank you

    opened by maoshen5 9
  • Could you give us the CelebA pretrained model, please?

    Could you give us the CelebA pretrained model, please?

    Could you provide 256x256 CelebA pretrained model? Thank you so much. By the way, how long did it take you to train a 256 * 256 celeba model. How many GPUs and what model are used?

    opened by zzr525 1
  • train

    train

    image UnboundLocalError: local variable 'gen_net' referenced before assignment 请问 全局变量和局部变量冲突了? is the global variable conflict with the local variable... 微信截图_20220422202005

    opened by chelseaalex 1
  • test error

    test error

    I train this model and download cifar_checkpoint,but: => calculate fid score 0%| | 0/3125 [00:02<?, ?it/s] Traceback (most recent call last): File "/root/miniconda3/lib/python3.8/site-packages/tensorflow/python/client/session.py", line 1377, in _do_call return fn(*args) File "/root/miniconda3/lib/python3.8/site-packages/tensorflow/python/client/session.py", line 1360, in _run_fn return self._call_tf_sessionrun(options, feed_dict, fetch_list, File "/root/miniconda3/lib/python3.8/site-packages/tensorflow/python/client/session.py", line 1453, in _call_tf_sessionrun return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict, tensorflow.python.framework.errors_impl.UnimplementedError: 2 root error(s) found. (0) UNIMPLEMENTED: DNN library is not found. [[{{node FID_Inception_Net/conv/Conv2D}}]] [[FID_Inception_Net/pool_3/_3]] (1) UNIMPLEMENTED: DNN library is not found. [[{{node FID_Inception_Net/conv/Conv2D}}]] 0 successful operations. 0 derived errors ignored. I try to reduce the batchsize,but it still no work

    opened by maoshen5 1
  • One error

    One error

    RuntimeError: CUDA out of memory. Tried to allocate 1024.00 MiB (GPU 0; 7.94 GiB total capacity; 7.03 GiB already allocated; 144.81 MiB free; 7.12 GiB reserved in total by PyTorch)

    I meet this question after i print "python exps/cifar_train" in Terminal ,and it apears after "path:logs/cifar_train_2022_03_22_19_29_36 0%| |0/1563 [00:00<?, it/s]" . I know this means the CUDA is out of memory, but i only run this one program ,and the image has not loaded. Did author or someone also meet this question and how did you deal with it?

    opened by wudiduojimone 7
Owner
VITA
Visual Informatics Group @ University of Texas at Austin
VITA
This repo holds code for TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation

TransUNet This repo holds code for TransUNet: Transformers Make Strong Encoders for Medical Image Segmentation Usage

null 1.4k Jan 4, 2023
Implementation of 'lightweight' GAN, proposed in ICLR 2021, in Pytorch. High resolution image generations that can be trained within a day or two

512x512 flowers after 12 hours of training, 1 gpu 256x256 flowers after 12 hours of training, 1 gpu Pizza 'Lightweight' GAN Implementation of 'lightwe

Phil Wang 1.5k Jan 2, 2023
This framework implements the data poisoning method found in the paper Adversarial Examples Make Strong Poisons

Adversarial poison generation and evaluation. This framework implements the data poisoning method found in the paper Adversarial Examples Make Strong

null 31 Nov 1, 2022
Code for Two-stage Identifier: "Locate and Label: A Two-stage Identifier for Nested Named Entity Recognition"

Code for Two-stage Identifier: "Locate and Label: A Two-stage Identifier for Nested Named Entity Recognition", accepted at ACL 2021. For details of the model and experiments, please see our paper.

tricktreat 87 Dec 16, 2022
FuseDream: Training-Free Text-to-Image Generationwith Improved CLIP+GAN Space OptimizationFuseDream: Training-Free Text-to-Image Generationwith Improved CLIP+GAN Space Optimization

FuseDream This repo contains code for our paper (paper link): FuseDream: Training-Free Text-to-Image Generation with Improved CLIP+GAN Space Optimizat

XCL 191 Dec 31, 2022
DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time

DR-GAN: Automatic Radial Distortion Rectification Using Conditional GAN in Real-Time Introduction This is official implementation for DR-GAN (IEEE TCS

Kang Liao 18 Dec 23, 2022
Code for our NeurIPS 2021 paper Mining the Benefits of Two-stage and One-stage HOI Detection

CDN Code for our NeurIPS 2021 paper "Mining the Benefits of Two-stage and One-stage HOI Detection". Contributed by Aixi Zhang*, Yue Liao*, Si Liu, Mia

null 71 Dec 14, 2022
Code for Mining the Benefits of Two-stage and One-stage HOI Detection

Status: Archive (code is provided as-is, no updates expected) PPO-EWMA [Paper] This is code for training agents using PPO-EWMA and PPG-EWMA, introduce

OpenAI 33 Dec 15, 2022
A PyTorch implementation of "From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network" (ICCV2021)

From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network The official code of VisionLAN (ICCV2021). VisionLAN successfully a

null 81 Dec 12, 2022
[Preprint] "Chasing Sparsity in Vision Transformers: An End-to-End Exploration" by Tianlong Chen, Yu Cheng, Zhe Gan, Lu Yuan, Lei Zhang, Zhangyang Wang

Chasing Sparsity in Vision Transformers: An End-to-End Exploration Codes for [Preprint] Chasing Sparsity in Vision Transformers: An End-to-End Explora

VITA 64 Dec 8, 2022
Make a Turtlebot3 follow a figure 8 trajectory and create a robot arm and make it follow a trajectory

HW2 - ME 495 Overview Part 1: Makes the robot move in a figure 8 shape. The robot starts moving when launched on a real turtlebot3 and can be paused a

Devesh Bhura 0 Oct 21, 2022
DIT is a DTLS MitM proxy implemented in Python 3. It can intercept, manipulate and suppress datagrams between two DTLS endpoints and supports psk-based and certificate-based authentication schemes (RSA + ECC).

DIT - DTLS Interception Tool DIT is a MitM proxy tool to intercept DTLS traffic. It can intercept, manipulate and/or suppress DTLS datagrams between t

null 52 Nov 30, 2022
In this project, two programs can help you take full agvantage of time on the model training with a remote server

In this project, two programs can help you take full agvantage of time on the model training with a remote server, which can push notification to your phone about the information during model training, like the model indices and unexpected interrupts. Then you can do something in time for your work.

GrayLee 8 Dec 27, 2022
Transfer style api - An API to use with Tranfer Style App, where you can use two image and transfer the style

Transfer Style API It's an API to use with Tranfer Style App, where you can use

Brian Alejandro 1 Feb 13, 2022
Tensors and Dynamic neural networks in Python with strong GPU acceleration

PyTorch is a Python package that provides two high-level features: Tensor computation (like NumPy) with strong GPU acceleration Deep neural networks b

null 61.4k Jan 4, 2023
Tensors and Dynamic neural networks in Python with strong GPU acceleration

PyTorch is a Python package that provides two high-level features: Tensor computation (like NumPy) with strong GPU acceleration Deep neural networks b

null 46.1k Feb 13, 2021
FrankMocap: A Strong and Easy-to-use Single View 3D Hand+Body Pose Estimator

FrankMocap pursues an easy-to-use single view 3D motion capture system developed by Facebook AI Research (FAIR). FrankMocap provides state-of-the-art 3D pose estimation outputs for body, hand, and body+hands in a single system. The core objective of FrankMocap is to democratize the 3D human pose estimation technology, enabling anyone (researchers, engineers, developers, artists, and others) can easily obtain 3D motion capture outputs from videos and images.

Facebook Research 1.9k Jan 7, 2023
This repo is developed for Strong Baseline For Vehicle Re-Identification in Track 2 Ai-City-2021 Challenges

A STRONG BASELINE FOR VEHICLE RE-IDENTIFICATION This paper is accepted to the IEEE Conference on Computer Vision and Pattern Recognition Workshop(CVPR

Cybercore Co. Ltd 78 Dec 29, 2022