Official pytorch implementation of paper "Image-to-image Translation via Hierarchical Style Disentanglement".

Overview

License CC BY-NC-SA 4.0

HiSD: Image-to-image Translation via Hierarchical Style Disentanglement

Official pytorch implementation of paper "Image-to-image Translation via Hierarchical Style Disentanglement".

fig

HiSD is the SOTA image-to-image translation method for both Scalability for multiple labels and Controllable Diversity with impressive disentanglement.

The styles to manipolate each tag in our method can be not only generated by random noise but also extracted from images!

Also, the styles can be smoothly interpolated like:

reference

All tranlsations are producted be a unified HiSD model and trained end-to-end.

Easy Use (for Both Jupyter Notebook and Python Script)

Download the pretrained checkpoint in Baidu Drive (Password:ihxf) or Google Drive. Then put it into the root of this repo.

Open "easy_use.ipynb" and you can manipolate the facial attributes by yourself!

If you haven't installed Jupyter, use "easy_use.py".

The script will translate "examples/input_0.jpg" to be with bangs generated by a random noise and glasses extracted from "examples/reference_glasses_0.jpg"

Quick Start

Clone this repo:

git clone https://github.com/imlixinyang/HiSD.git
cd HiSD/

Install the dependencies: (Anaconda is recommended.)

conda create -n HiSD python=3.6.6
conda activate HiSD
conda install -y pytorch=1.0.1 torchvision=0.2.2  cudatoolkit=10.1 -c pytorch
pip install pillow tqdm tensorboardx pyyaml

Download the dataset.

We recommend you to download CelebA-HQ from CelebAMask-HQ. Anyway you shound get the dataset folder like:

celeba_or_celebahq
 - img_dir
   - img0
   - img1
   - ...
 - train_label.txt

Preprocess the dataset.

In our paper, we use fisrt 3000 as test set and remaining 27000 for training. Carefully check the fisrt few (always two) lines in the label file which is not like the others.

python proprecessors/celeba-hq.py --img_path $your_image_path --label_path $your_label_path --target_path datasets --start 3002 --end 30002

Then you will get several ".txt" files in the "datasets/", each of them consists of lines of the absolute path of image and its tag-irrelevant conditions (Age and Gender by default).

Almost all custom datasets can be converted into special cases of HiSD. We provide a script for custom datasets. You need to organize the folder like:

your_training_set
 - Tag0
   - attribute0
     - img0
     - img1
     - ...
   - attribute1
     - ...
 - Tag1
 - ...

For example, the AFHQ (one tag and three attributes, remember to split the training and test set first):

AFHQ_training
  - Category
    - cat
      - img0
      - img1
      - ...
    - dog
      - ...
    - wild
      - ...

You can Run

python proprecessors/custom.py --imgs $your_training_set --target_path datasets/custom.txt

For other datasets, please code the preprocessor by yourself.

Here, we provide some links for you to download other available datasets:

Dataset in Bold means we have tested the generalization of HiSD for this dataset.

Train.

Following "configs/celeba-hq.yaml" to make the config file fit your machine and dataset.

For a single 1080Ti and CelebA-HQ, you can directly run:

python core/train.py --config configs/celeba-hq.yaml --gpus 0

The samples and checkpoints are in the "outputs/" dir. For Celeba-hq dataset, the samples during first 200k iterations will be like: (tag 'Glasses' to attribute 'with')

training

Test.

Modify the 'steps' dict in the first few lines in 'core/test.py' and run:

python core/test.py --config configs/celeba-hq.yaml --checkpoint $your_checkpoint --input_path $your_input_path --output_path results

$your_input_path can be either a image file or a folder of images. Default 'steps' make every image to be with bangs and glasses using random latent-guided styles.

Evaluation metrics.

We use FID for quantitative comparison. For more details, please refer to the paper.

License

Licensed under the CC BY-NC-SA 4.0 (Attribution-NonCommercial-ShareAlike 4.0 International)

The code is released for academic research use only. For other use, please contact me at [email protected].

Citation

If our paper helps your research, please cite it in your publications:

@misc{li2021imagetoimage,
      title={Image-to-image Translation via Hierarchical Style Disentanglement}, 
      author={Xinyang Li and Shengchuan Zhang and Jie Hu and Liujuan Cao and Xiaopeng Hong and Xudong Mao and Feiyue Huang and Yongjian Wu and Rongrong Ji},
      year={2021},
      eprint={2103.01456},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

I try my best to make the code easy to understand or further modified because I feel very lucky to start with the clear and readily comprehensible code of MUNIT when I'm a beginner.

If you have any problem, please feel free to contact me at [email protected] or raise an issue.

Related Work

Comments
  • Quick Start-Test issue

    Quick Start-Test issue

    I have completed the first few steps of Quick Start,Download the datasetd and Preprocess the dataset. Train.py step has not been done because of the GPU problem then I start to try test.py. as your suggestion,$your_input_path can be either a image file or a folder of images. then I try python core/test.py --config configs/celeba-hq.yaml --checkpoint configs/checkpoint_256_celeba-hq.pt--input_path CelebAMask-HQ/CelebAMask-HQ/CelebA-HQ-img --output_path result test.py: error: unrecognized arguments: CelebAMask-HQ/CelebAMask-HQ/CelebA-HQ-img or python core/test.py --config configs/celeba-hq.yaml --checkpoint configs/checkpoint_256_celeba-hq.pt--input_path CelebAMask-HQ/CelebAMask-HQ/CelebA-HQ-img/0.jpg --output_path result test.py: error: unrecognized arguments: CelebAMask-HQ/CelebAMask-HQ/CelebA-HQ-img/0.jpg I don't know if it's because I haven't yet modified the 'steps' dict in the first few lines in 'core/test.py' .If it is for this reason, can you tell me how to modify the 'steps' dict?As a junior who has just studied deep-learning for one or two months, it's really a bit difficult for me,thanks a lot.

    opened by HeX-2000 7
  • A few questions about training tricks

    A few questions about training tricks

    Sorry to bother you, I'm doing some work on your model and would like to ask a few questions about training skills

    1. As for the number of training times, I combine the loss function with the process graph generated in the training process. When the loss function declines to be stable and the generated graph looks ok, it means that the training is done. Do you have any additional comments?

    2. Regarding the problem of model collapse, can you give some examples in this task?

    3. Do you have any skills in setting loss function coefficients?

    Looking forward to your reply~

    opened by zhushuqi2333 6
  • 作者您好,想请问一下关于discriminator的问题

    作者您好,想请问一下关于discriminator的问题

    ` class Dis(nn.Module): def init(self, hyperparameters): super().init() self.tags = hyperparameters['tags'] channels = hyperparameters['discriminators']['channels'] #[64, 128, 256, 512, 1024, 2048] self.conv = nn.Sequential( nn.Conv2d(hyperparameters['input_dim'], channels[0], 1, 1, 0), *[DownBlock(channels[i], channels[i + 1]) for i in range(len(channels) - 1)], nn.AdaptiveAvgPool2d(1), ) self.fcs = nn.ModuleList([nn.Sequential( nn.Conv2d(channels[-1] + #2048 # ALI part which is not shown in the original submission but help disentangle the extracted style. #ALI部分未在原始提交中显示,但有助于解耦提取到的 style。 hyperparameters['style_dim'] + #256 # Tag-irrelevant part. Sec.3.4 self.tags[i]['tag_irrelevant_conditions_dim'], #2 2 2 # One for translated, one for cycle. Eq.4 len(self.tags[i]['attributes'] * 2), 1, 1, 0), #4 4 6 ) for i in range(len(self.tags))]) #这里的i控制的是三个tag里面的哪个

    def forward(self, x, s, y, i):
        f = self.conv(x)
        fsy = torch.cat([f, tile_like(s, f), tile_like(y, f)], 1)
        #按照第一维度,也就是列维度,叠加起来,也就是横着串起来
        return self.fcs[i](fsy).view(f.size(0), 2, -1) `
    

    作者你好,关于判别器我有几个不太懂的点,还希望您可以教教我

    1. 对于判别器是怎么不去改变两个无关标签我不是很理解这其中的过程

    2. 判别器的forward那边最后的.view(f.size(0), 2, -1),第一维是batch_size,第二维我不懂是什么,为啥是2,第三维是控制的属性吗,这边看不太懂

    3. 关于计算生成器的对抗损失这边,为什么真实图片取的[:,0]和[:,1]的平均之和,而两张fake图片分别取的[:,0]和[:,1]的平均?这边不太理解。代码如下: ` def calc_gen_loss_real(self, x, s, y, i, j):# loss = 0 out = self.forward(x, s, y, i)[:, :, j]#选到那个属性 #比如是[8, 2, 2], 截取[:,:,1] 就变成了[8, 2]了 loss += out[:, 0].mean() loss += out[:, 1].mean() return loss

      def calc_gen_loss_fake_trg(self, x, s, y, i, j): out = self.forward(x, s, y, i)[:, :, j] loss = - out[:, 0].mean() return loss

      def calc_gen_loss_fake_cyc(self, x, s, y, i, j): out = self.forward(x, s, y, i)[:, :, j] loss = - out[:, 1].mean() return loss ` 希望您可以解答我的疑惑,谢谢作者!

    opened by zhushuqi2333 6
  • 您好,请教一下论文中实验部分的一些问题,望指点

    您好,请教一下论文中实验部分的一些问题,望指点

    作者您好: 十分感谢您之前对我提出的问题的耐心回复,我现在正在您的框架的基础上进展一些工作,现在我需要做一些对比实验,我看到您的论文中使用了SDIT进行了reference-guided实验,我也看了这篇论文,发现是用的一个随机向量z,放入鉴别器中进行映射,然后用鉴别器去约束这个style。 我现在也想做reference-guided实验,是否是,将原图x 放入鉴别器,获得style,再将生成器所需的内容放入生成器生成图片呢,还望指点一下,感谢!

    opened by zhushuqi2333 5
  • A little problem about the batchsize

    A little problem about the batchsize

    Hi, your work is so impressive that I want to reproduce your experiment. But I got a problem about the config of the batchsize, I use a single GTX 1080Ti GPU as you did,but the batchsize of 8 seems too big for it, I can just set batchsize to 5 or 4. Can you tell me the solution about this problem,thanks a lot

    opened by xiaolongcheng 5
  • Multi-tag task

    Multi-tag task

    tags:

    name: Bangs
    tag_irrelevant_conditions_dim: 2
    attributes: 
      -
        name: 'with'
        filename: datasets/Bangs_with.txt
      -
        name: 'without'
        filename: datasets/Bangs_without.txt
    
    • name: Eyeglasses tag_irrelevant_conditions_dim: 2 attributes:

      name: 'with'
      filename: datasets/Eyeglasses_with.txt
      
      • name: 'without' filename: datasets/Eyeglasses_without.txt

    if I want to make experiment on Multi-tag task, how to replace Multi-tag task on celeba-hq.yaml setting?

    opened by jackskr666 4
  • The codes may have some typos?

    The codes may have some typos?

    Hi, thanks for sharing the codes for the awesome work. I am wondering whether there is a typo in trainer.py L66. loss_gen_adv = self.dis.calc_gen_loss_real(x, s, y, i, j) + \ self.dis.calc_gen_loss_fake_trg(x_trg, s_trg.detach(), y, i, j_trg) + \ self.dis.calc_gen_loss_fake_cyc(x_cyc, s.detach(), y, i, j) When updating the generator, why use the real image for calculating the loss?

    opened by HelenMao 4
  • About celeba-hq.yaml

    About celeba-hq.yaml

    Thanks for your great works. I have some questions about the training configuration. There are differences between celeba-hq.yaml and celeba-hq_256.yaml in configs folder, such as the normalization in different channels. What's the reason here? If I try to train the model based on the size of 512 or 1024, how should I set? It is appreciated if I can receive your reple

    opened by Hpjhpjhs 3
  • Release checkpoint for Celeb-HQ dataset

    Release checkpoint for Celeb-HQ dataset

    Thank you for your open source. This code is very helpful to my knowledge of face attribute transfer. You have opened the checkpoint of 256 size, could you open the checkpoint of 1024 size for learning and use. Thank you very much.

    opened by huang-1030 3
  • About the design of non-translation

    About the design of non-translation

    In your paper, you have designed non-translation , self-translation, and cycle-translation. In non-translation generator, it just encode the input, e = encoder(input), and then decode e, output = decoder(e). Certainly you have already achieve great results with HiSD designed like this. I am a beginner of GAN, and I am just wondering why don't you set stylecode as 0, and use a translator in your non-translation, which means, e = encoder(input), e_trg = translator(e, 0), output = decoder(e_trg), I have seen other works did like this. Could it make any difference?

    opened by xiaolongcheng 2
  • question about paper

    question about paper

    Hi, thanks for your beautiful work, I want to konw the reason for the design about the m、f of the translator , Is there any reference work , or you design this by experiment? And as your paper mentioned "The attention mask in our translator is both spatial- wise and channel-wise." can you explain specifically ?

    opened by diaodeyi 2
  • 请教关于生成结果图像质量下降等问题

    请教关于生成结果图像质量下降等问题

    作者您好,作为一名初学者,您的工作让我受益良多。有几个问题想请教一下:

    1. 在复现论文的过程中,我发现在对Young个glasses这两个属性分别进行转换,训练到11万次左右的时候出现了如下情况(在对其他组合训练的时候则不会产生):请问这是由于什么原因导致的?
    image image
    1. 基于问题1,不同的标签组合是否会对训练结果产生影响?
    2. 是否训练的次数越多效果就一定越好呢?还是说最好的结果出现在哪里要通过生成的结果直观查看
    3. 在做reference-guided时是否可以通过标签对参考图像进行约束?

    期待您的回复,感谢感谢!祝您科研顺利

    opened by irine1210 3
  • Data imbalance

    Data imbalance

    Dear Author, I appreciate your work, but i have some questions about the data imbalance. The face number with glasses is only 1377 as the numbe without glasses is 25622. I want to known the effects of the data imbalance. Thanks~

    opened by ryanqiutu 1
  • Another size of image

    Another size of image

    Thank you for your research ! But now, I try to generate a new "checkpoint_512_celeba-hq.pt" for 512x512 size image. However, it still something wrong through test.py phase.

    The error message shows that : RuntimeError: cuDNN error: CUDNN_STATUS_EXECUTION_FAILED

    How can I do now ? Do you have the checkpoint_512_celeba-hq.pt file ?

    opened by brian411227 5
Owner
null
Official pytorch implementation of paper "Inception Convolution with Efficient Dilation Search" (CVPR 2021 Oral).

IC-Conv This repository is an official implementation of the paper Inception Convolution with Efficient Dilation Search. Getting Started Download Imag

Jie Liu 111 Dec 31, 2022
Official implementation of our paper "LLA: Loss-aware Label Assignment for Dense Pedestrian Detection" in Pytorch.

LLA: Loss-aware Label Assignment for Dense Pedestrian Detection This project provides an implementation for "LLA: Loss-aware Label Assignment for Dens

null 35 Dec 6, 2022
Official implementation of our CVPR2021 paper "OTA: Optimal Transport Assignment for Object Detection" in Pytorch.

OTA: Optimal Transport Assignment for Object Detection This project provides an implementation for our CVPR2021 paper "OTA: Optimal Transport Assignme

null 217 Jan 3, 2023
This is the official PyTorch implementation of the paper "TransFG: A Transformer Architecture for Fine-grained Recognition" (Ju He, Jie-Neng Chen, Shuai Liu, Adam Kortylewski, Cheng Yang, Yutong Bai, Changhu Wang, Alan Yuille).

TransFG: A Transformer Architecture for Fine-grained Recognition Official PyTorch code for the paper: TransFG: A Transformer Architecture for Fine-gra

Ju He 307 Jan 3, 2023
[PyTorch] Official implementation of CVPR2021 paper "PointDSC: Robust Point Cloud Registration using Deep Spatial Consistency". https://arxiv.org/abs/2103.05465

PointDSC repository PyTorch implementation of PointDSC for CVPR'2021 paper "PointDSC: Robust Point Cloud Registration using Deep Spatial Consistency",

null 153 Dec 14, 2022
The official pytorch implementation of our paper "Is Space-Time Attention All You Need for Video Understanding?"

TimeSformer This is an official pytorch implementation of Is Space-Time Attention All You Need for Video Understanding?. In this repository, we provid

Facebook Research 1k Dec 31, 2022
Official Pytorch Implementation of: "ImageNet-21K Pretraining for the Masses"(2021) paper

ImageNet-21K Pretraining for the Masses Paper | Pretrained models Official PyTorch Implementation Tal Ridnik, Emanuel Ben-Baruch, Asaf Noy, Lihi Zelni

null 574 Jan 2, 2023
The official PyTorch implementation of the paper: *Xili Dai, Xiaojun Yuan, Haigang Gong, Yi Ma. "Fully Convolutional Line Parsing." *.

F-Clip — Fully Convolutional Line Parsing This repository contains the official PyTorch implementation of the paper: *Xili Dai, Xiaojun Yuan, Haigang

Xili Dai 115 Dec 28, 2022
The repository offers the official implementation of our paper in PyTorch.

Cloth Interactive Transformer (CIT) Cloth Interactive Transformer for Virtual Try-On Bin Ren1, Hao Tang1, Fanyang Meng2, Runwei Ding3, Ling Shao4, Phi

Bingoren 49 Dec 1, 2022
The official PyTorch implementation of recent paper - SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training

This repository is the official PyTorch implementation of SAINT. Find the paper on arxiv SAINT: Improved Neural Networks for Tabular Data via Row Atte

Gowthami Somepalli 284 Dec 21, 2022
Official PyTorch implementation and pretrained models of the paper Self-Supervised Classification Network

Self-Classifier: Self-Supervised Classification Network Official PyTorch implementation and pretrained models of the paper Self-Supervised Classificat

Elad Amrani 24 Dec 21, 2022
Official Pytorch implementation of paper "Reverse Engineering of Generative Models: Inferring Model Hyperparameters from Generated Images"

Reverse_Engineering_GMs Official Pytorch implementation of paper "Reverse Engineering of Generative Models: Inferring Model Hyperparameters from Gener

null 100 Dec 18, 2022
Official Pytorch Implementation of: "Semantic Diversity Learning for Zero-Shot Multi-label Classification"(2021) paper

Semantic Diversity Learning for Zero-Shot Multi-label Classification Paper Official PyTorch Implementation Avi Ben-Cohen, Nadav Zamir, Emanuel Ben Bar

null 28 Aug 29, 2022
Official PyTorch implementation of the preprint paper "Stylized Neural Painting", accepted to CVPR 2021.

Official PyTorch implementation of the preprint paper "Stylized Neural Painting", accepted to CVPR 2021.

Zhengxia Zou 1.5k Dec 28, 2022
Official Pytorch implementation of ICLR 2018 paper Deep Learning for Physical Processes: Integrating Prior Scientific Knowledge.

Deep Learning for Physical Processes: Integrating Prior Scientific Knowledge: Official Pytorch implementation of ICLR 2018 paper Deep Learning for Phy

emmanuel 47 Nov 6, 2022
This is the official pytorch implementation for the paper: Instance Similarity Learning for Unsupervised Feature Representation.

ISL This is the official pytorch implementation for the paper: Instance Similarity Learning for Unsupervised Feature Representation, which is accepted

null 19 May 4, 2022
Official PyTorch implementation of the paper "Recycling Discriminator: Towards Opinion-Unaware Image Quality Assessment Using Wasserstein GAN", accepted to ACM MM 2021 BNI Track.

RecycleD Official PyTorch implementation of the paper "Recycling Discriminator: Towards Opinion-Unaware Image Quality Assessment Using Wasserstein GAN

Yunan Zhu 23 Nov 5, 2022
Official PyTorch implementation of the paper: Improving Graph Neural Network Expressivity via Subgraph Isomorphism Counting.

Improving Graph Neural Network Expressivity via Subgraph Isomorphism Counting Official PyTorch implementation of the paper: Improving Graph Neural Net

Giorgos Bouritsas 58 Dec 31, 2022
The (Official) PyTorch Implementation of the paper "Deep Extraction of Manga Structural Lines"

MangaLineExtraction_PyTorch The (Official) PyTorch Implementation of the paper "Deep Extraction of Manga Structural Lines" Usage model_torch.py [sourc

Miaomiao Li 82 Jan 2, 2023