HairCLIP: Design Your Hair by Text and Reference Image

Overview

Overview

This repository hosts the official PyTorch implementation of the paper: "HairCLIP: Design Your Hair by Text and Reference Image".

Our single framework supports hairstyle and hair color editing individually or jointly, and conditional inputs can come from either image or text domain.

Tianyi Wei1, Dongdong Chen2, Wenbo Zhou1, Jing Liao3, Zhentao Tan1, Lu Yuan2, Weiming Zhang1, Nenghai Yu1
1University of Science and Technology of China, 2Microsoft Cloud AI, 3City University of Hong Kong

Abstract

Hair editing is an interesting and challenging problem in computer vision and graphics. Many existing methods require well-drawn sketches or masks as conditional inputs for editing, however these interactions are neither straightforward nor efficient. In order to free users from the tedious interaction process, this paper proposes a new hair editing interaction mode, which enables manipulating hair attributes individually or jointly based on the texts or reference images provided by users. For this purpose, we encode the image and text conditions in a shared embedding space and propose a unified hair editing framework by leveraging the powerful image text representation capability of the Contrastive Language-Image Pre-Training (CLIP) model. With the carefully designed network structures and loss functions, our framework can perform high-quality hair editing in a disentangled manner. Extensive experiments demonstrate the superiority of our approach in terms of manipulation accuracy, visual realism of editing results, and irrelevant attribute preservation.

Comparison

Comparison to Text-Driven Image Manipulation Methods

Comparison to Hair Transfer Methods

Application

Hair Interpolation

Generalization Ability to Unseen Descriptions

Cross-Modal Conditional Inputs

To Do

  • Release testing code
  • Release pretrained model
  • Release training code
Comments
  • Except hair coloe change only, but hair style of some results are change

    Except hair coloe change only, but hair style of some results are change

    I want to change hair color on FFHQ data, however hairstyle of some of results are change. Did I do wrong? The following is my command

    python scripts/inference.py
    --exp_dir=./experiment
    --checkpoint_path=../pretrained_models/hairclip.pt
    --latents_test_path=./latents.pt
    --editing_type=color
    --input_type=text
    --color_description=red

    00001-0000-red hair

    opened by kasim0226 7
  • Demo Play ?

    Demo Play ?

    Hi. 🤗 This is an awesome work. 👍 Thanks for all of you, the contributors. 🌹 I am wondering if you could tell me if you have any plan to make one demo public on huggingface/spaces, etc. 🤔 ?

    opened by ZenMoore 3
  • Can I use my own image test?

    Can I use my own image test?

    Hello, can I use my own image for the resend test, I found that the input was test_face.pt (test data set ?) file, and I did not find the input image content in the code, The only thing that feels like an input image is w(w=torch.Size([1, 18, 512])), But it's not the size of a picture

    opened by liuzhuangyuan 2
  • The generated image is quite different from the reference image

    The generated image is quite different from the reference image

    I tested the effect and found that the hair style of the generated image is quite different from that of the reference image. Here is my test script. The reference image is selected from CelebAMask-HQ dataset. Is there a problem in my test process?

    python scripts/inference.py \ --exp_dir=../outputs/0321/ \ --checkpoint_path=../pretrained_models/hairclip.pt \ --latents_test_path=../pretrained_models/test_faces.pt \ --editing_type=both \ --input_type=image_image \ --color_ref_img_test_path=../input/16 \ --hairstyle_ref_img_test_path=../input/16 --num_of_ref_img 1

    image
    opened by 1273545169 2
  • about pretrained unet infer

    about pretrained unet infer

    mask_512 = (torch.unsqueeze(torch.max(labels_predict, 1)[1], 1)==13).float() 1.why hair equal 13, bg not equal 13? 2.unet infer results that have 19 channels, what did they means?

    opened by eeric 2
  • About the training details.

    About the training details.

    Thank you for your great project!

    In this paper, you said “We train and evaluate our hair mapper on the CelebA-HQ dataset. Since we use e4e [43] as our inversion encoder, we follow its division of the training set and test set.” However, I found that e4e used the FFHQ dataset for training and the CelebA-HQ test dataset for evaluation. Hence, I feel confused. My question is that how to split the training and test datasets on the CelebA-HQ dataset?

    opened by bb12346 2
  • About modulation module

    About modulation module

    Hi, Great work! But I have a question about the modulation module of mapper network. I assume the dimension of x and e should be 1x1xC. If so, what is the mean and std of x? channel-wise average? And how about the output dimensions of fr(e) and fb(e)?

    Thanks.

    opened by janchen0611 2
  • 用两张图片测试的时候报错

    用两张图片测试的时候报错

    输入命令: E:\Linux\XSpace\papers\HairCLIP\mapper>python scripts/inference.py --exp_dir=E:\Linux\XSpace\pap ers\HairCLIP\data\exp --checkpoint_path=F:\Dataset\CelebA\Data\hairclip.pt --latents_test_path=F:\Dataset\CelebA\Data\test_faces.pt --editin g_type=color --input_type=image --hairstyle_description="hairstyle_list.txt" --color_ref_img_test_path=E:\Linux\XSpace\papers\HairCLIP\data
    ref

    在 latent_mappers.py 中的 x = clip_model.encode_image(masked_generated_renormed) 报错了,错误信息如下:

    *** RuntimeError: The following operation failed in the TorchScript interpreter. Traceback of TorchScript, serialized code (most recent call last): File "code/torch/multimodal/model/multimodal_transformer/___torch_mangle_9591.py", line 19, in encode_image _0 = self.visual input = torch.to(image, torch.device("cuda:0"), 5, False, False, None) return (_0).forward(input, ) ~~~~~~~~~~~ <--- HERE def encode_text(self: torch.multimodal.model.multimodal_transformer.___torch_mangle_9591.Multimodal, input: Tensor) -> Tensor: File "code/torch/multimodal/model/multimodal_transformer.py", line 34, in forward x2 = torch.add(x1, torch.to(_4, 5, False, False, None), alpha=1) x3 = torch.permute((_3).forward(x2, ), [1, 0, 2]) x4 = torch.permute((_2).forward(x3, ), [1, 0, 2]) ~~~~~~~~~~~ <--- HERE _15 = torch.slice(x4, 0, 0, 9223372036854775807, 1) x5 = torch.slice(torch.select(_15, 1, 0), 1, 0, 9223372036854775807, 1) File "code/torch/multimodal/model/multimodal_transformer/___torch_mangle_9477.py", line 8, in forward def forward(self: torch.multimodal.model.multimodal_transformer.___torch_mangle_9477.Transformer, x: Tensor) -> Tensor: return (self.resblocks).forward(x, ) ~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE def forward1(self: torch.multimodal.model.multimodal_transformer.___torch_mangle_9477.Transformer, x: Tensor) -> Tensor: File "code/torch/torch/nn/modules/container/___torch_mangle_9476.py", line 29, in forward _8 = getattr(self, "3") _9 = getattr(self, "2") _10 = (getattr(self, "1")).forward((getattr(self, "0")).forward(x, ), ) ~~~~~~~~~~~~~~~~~~~~~~~~~~~ <--- HERE _11 = (_7).forward((_8).forward((_9).forward(_10, ), ), ) _12 = (_4).forward((_5).forward((_6).forward(_11, ), ), ) File "code/torch/multimodal/model/multimodal_transformer/___torch_mangle_9376.py", line 13, in forward _0 = self.mlp _1 = self.ln_2 _2 = (self.attn).forward((self.ln_1).forward(x, ), ) ~~~~~~~~~~~~~~~~~~ <--- HERE x0 = torch.add(x, _2, alpha=1) x1 = torch.add(x0, (_0).forward((_1).forward(x0, ), ), alpha=1) File "code/torch/torch/nn/modules/activation/___torch_mangle_9369.py", line 38, in forward _16 = [-1, int(torch.mul(bsz, CONSTANTS.c0)), _8] v0 = torch.transpose(torch.view(_15, _16), 0, 1) attn_output_weights = torch.bmm(q2, torch.transpose(k0, 1, 2)) ~~~~~~~~~ <--- HERE input = torch.softmax(attn_output_weights, -1, None) attn_output_weights0 = torch.dropout(input, 0., True)

    Traceback of TorchScript, original code (most recent call last): /opt/conda/lib/python3.7/site-packages/torch/nn/functional.py(4294): multi_head_attention_forward /opt/conda/lib/python3.7/site-packages/torch/nn/modules/activation.py(985): forward /opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py(709): _slow_forward /opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py(725): _call_impl /root/workspace/multimodal-pytorch/multimodal/model/multimodal_transformer.py(45): attention /root/workspace/multimodal-pytorch/multimodal/model/multimodal_transformer.py(48): forward /opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py(709): _slow_forward /opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py(725): _call_impl /opt/conda/lib/python3.7/site-packages/torch/nn/modules/container.py(117): forward /opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py(709): _slow_forward /opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py(725): _call_impl /root/workspace/multimodal-pytorch/multimodal/model/multimodal_transformer.py(63): forward /opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py(709): _slow_forward /opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py(725): _call_impl /root/workspace/multimodal-pytorch/multimodal/model/multimodal_transformer.py(93): forward /opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py(709): _slow_forward /opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py(725): _call_impl /root/workspace/multimodal-pytorch/multimodal/model/multimodal_transformer.py(221): visual_forward /opt/conda/lib/python3.7/site-packages/torch/jit/_trace.py(940): trace_module (36): export_torchscript_models (3): /opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py(3418): run_code /opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py(3338): run_ast_nodes /opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py(3147): run_cell_async /opt/conda/lib/python3.7/site-packages/IPython/core/async_helpers.py(68): _pseudo_sync_runner /opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py(2923): _run_cell /opt/conda/lib/python3.7/site-packages/IPython/core/interactiveshell.py(2878): run_cell /opt/conda/lib/python3.7/site-packages/IPython/terminal/interactiveshell.py(555): interact /opt/conda/lib/python3.7/site-packages/IPython/terminal/interactiveshell.py(564): mainloop /opt/conda/lib/python3.7/site-packages/IPython/terminal/ipapp.py(356): start /opt/conda/lib/python3.7/site-packages/traitlets/config/application.py(845): launch_instance /opt/conda/lib/python3.7/site-packages/IPython/init.py(126): start_ipython /opt/conda/bin/ipython(8): RuntimeError: cublas runtime error : unknown error at C:/cb/pytorch_1000000000000/work/aten/src/THC/THCBlas.cu:225 (Pdb) img_tensor.shape torch.Size([1, 3, 1024, 1024])

    请问是输入的tensor大小不对吗

    opened by hello-lx 1
  • Is is normal speed?

    Is is normal speed?

    微信截图_20220412182041 Hello, I want to ask if the speed of run the inferrence.py for testing is normal. This is my executive code: cd mapper python scripts/inference.py
    --exp_dir=/home/ps/HairCLIP/mapper/path/to/experiment
    --checkpoint_path=/home/ps/HairCLIP/pretrained_models/hairclip.pt
    --latents_test_path=/home/ps/HairCLIP/mapper/path/to/test_faces.pt
    --editing_type=hairstyle
    --input_type=text
    --hairstyle_description="/home/ps/HairCLIP/mapper/hairstyle_list.txt" \

    opened by sunhaha123 1
  • Hairstyles can only show so much

    Hairstyles can only show so much

    First of all , thanks for you excellent work! There are many hairstyles in hairstyle.txt, but actually I found only a few styles in result images after trying all styles. More or less repeat the following images.

    • cornrows cut hairstyle image

    • crew cut hairstyle image (the points on left glasses in right image is mouse)

    the following is my command:

    python scripts/inference.py 
    --exp_dir=../result/test_1/
    --checkpoint_path=../pretrained_models/hairclip.pt
    --latents_test_path=../inference_data/test_1/latent.pt
    --editing_type=hairstyle
    --input_type=text
    --hairstyle_description="hairstyle_list.txt"
    

    What's the problem? Should I train with my own dataset?

    I list some hairstyles which have the same effect:

      1. the same as cornrows: crown braid hairstyle, dreadlocks hairstyle, finger waves hairstyle, french braid hairstyle and so on.
      1. the same as crew cut hairstyle: caesar cut hairstyle, dido flip hairstyle, extensions hairstyle, fade hairstyle, fauxhawk hairstyle, frosted tops hairstyle ,full crown hairstyle, harvard clip hairstyle, high and tigh hairstyle, hime cut hairstyle, hi-top fade hairstyle and so son.
    opened by ZziTaiLeo 1
  • Hosting HairCLIP model

    Hosting HairCLIP model

    Hi!

    First off, thank you for your work!

    I'm trying to create a Colab Notebook to play with your model, but since the weights and stuff are hosted inside google drive, the download limits seems to restrict me from simply downloading it with gdown or wget.

    Could I download it and move it to another hosting service (i.e archive.org) to avoid this issue? Of course, I would add all the references to all the authors and parties involved.

    Again, thanks for your work!

    opened by ouhenio 1
  • F and C

    F and C

    Hello, boss. I noticed that the neural network structure diagram may be incorrectly drawn in the paper. F should be fine, meaning high-level semantic information; C should be coarse, meaning low-level semantic information.

    opened by 123456klk1 0
  • question of split database(train.pt and test.pt)

    question of split database(train.pt and test.pt)

    @wty-ustc Thank you for the amazing work! I try to split the CelebA-HQ by official list_eval_partition.txt. Eventually, I got 24183/2993/2824 images for training/validation/testing split. but i found the len of train.pt is 24176 ...so... I'm very confused about what data you're used?

    opened by ssxxx1a 0
  • About Video Hair Editing

    About Video Hair Editing

    Thanks you for you great works! Do you think video hair editing based on HairCLIP is achievable? I have a little try, but the region of hairstyle still hard to control. Consistency in hair styles is quite difficult to maintain. Can you give me some insights about video-hairstyle-editing?

    opened by ZziTaiLeo 0
Owner
Ph.D Student @ University of Science and Technology of China
null
This repository contains numerical implementation for the paper Intertemporal Pricing under Reference Effects: Integrating Reference Effects and Consumer Heterogeneity.

This repository contains numerical implementation for the paper Intertemporal Pricing under Reference Effects: Integrating Reference Effects and Consumer Heterogeneity.

Hansheng Jiang 6 Nov 18, 2022
Image morphing without reference points by applying warp maps and optimizing over them.

Differentiable Morphing Image morphing without reference points by applying warp maps and optimizing over them. Differentiable Morphing is machine lea

Alex K 380 Dec 19, 2022
MASA-SR: Matching Acceleration and Spatial Adaptation for Reference-Based Image Super-Resolution (CVPR2021)

MASA-SR Official PyTorch implementation of our CVPR2021 paper MASA-SR: Matching Acceleration and Spatial Adaptation for Reference-Based Image Super-Re

DV Lab 126 Dec 20, 2022
Pip-package for trajectory benchmarking from "Be your own Benchmark: No-Reference Trajectory Metric on Registered Point Clouds", ECMR'21

Map Metrics for Trajectory Quality Map metrics toolkit provides a set of metrics to quantitatively evaluate trajectory quality via estimating consiste

Mobile Robotics Lab. at Skoltech 31 Oct 28, 2022
No-reference Image Quality Assessment(NIQA) Algorithms (BRISQUE, NIQE, PIQE, RankIQA, MetaIQA)

No-Reference Image Quality Assessment Algorithms No-reference Image Quality Assessment(NIQA) is a task of evaluating an image without a reference imag

Dae-Young Song 26 Jan 4, 2023
Code for reproducing our analysis in the paper titled: Image Cropping on Twitter: Fairness Metrics, their Limitations, and the Importance of Representation, Design, and Agency

Image Crop Analysis This is a repo for the code used for reproducing our Image Crop Analysis paper as shared on our blog post. If you plan to use this

Twitter Research 239 Jan 2, 2023
FuseDream: Training-Free Text-to-Image Generationwith Improved CLIP+GAN Space OptimizationFuseDream: Training-Free Text-to-Image Generationwith Improved CLIP+GAN Space Optimization

FuseDream This repo contains code for our paper (paper link): FuseDream: Training-Free Text-to-Image Generation with Improved CLIP+GAN Space Optimizat

XCL 191 Dec 31, 2022
A 1.3B text-to-image generation model trained on 14 million image-text pairs

minDALL-E on Conceptual Captions minDALL-E, named after minGPT, is a 1.3B text-to-image generation model trained on 14 million image-text pairs for no

Kakao Brain 604 Dec 14, 2022
An open source Jetson Nano baseboard and tools to design your own.

My Jetson Nano Baseboard This basic baseboard gives the user the foundation and the flexibility to design their own baseboard for the Jetson Nano. It

NVIDIA AI IOT 57 Dec 29, 2022
Deep Text Search is an AI-powered multilingual text search and recommendation engine with state-of-the-art transformer-based multilingual text embedding (50+ languages).

Deep Text Search - AI Based Text Search & Recommendation System Deep Text Search is an AI-powered multilingual text search and recommendation engine w

null 19 Sep 29, 2022
TAP: Text-Aware Pre-training for Text-VQA and Text-Caption, CVPR 2021 (Oral)

TAP: Text-Aware Pre-training TAP: Text-Aware Pre-training for Text-VQA and Text-Caption by Zhengyuan Yang, Yijuan Lu, Jianfeng Wang, Xi Yin, Dinei Flo

Microsoft 61 Nov 14, 2022
Pytorch re-implementation of Paper: SwinTextSpotter: Scene Text Spotting via Better Synergy between Text Detection and Text Recognition (CVPR 2022)

SwinTextSpotter This is the pytorch implementation of Paper: SwinTextSpotter: Scene Text Spotting via Better Synergy between Text Detection and Text R

mxin262 183 Jan 3, 2023
Fast, modular reference implementation of Instance Segmentation and Object Detection algorithms in PyTorch.

Faster R-CNN and Mask R-CNN in PyTorch 1.0 maskrcnn-benchmark has been deprecated. Please see detectron2, which includes implementations for all model

Facebook Research 9k Jan 4, 2023
Reference implementation of code generation projects from Facebook AI Research. General toolkit to apply machine learning to code, from dataset creation to model training and evaluation. Comes with pretrained models.

This repository is a toolkit to do machine learning for programming languages. It implements tokenization, dataset preprocessing, model training and m

Facebook Research 408 Jan 1, 2023
YouRefIt: Embodied Reference Understanding with Language and Gesture

YouRefIt: Embodied Reference Understanding with Language and Gesture YouRefIt: Embodied Reference Understanding with Language and Gesture by Yixin Che

null 16 Jul 11, 2022
Wanli Li and Tieyun Qian: Exploit a Multi-head Reference Graph for Semi-supervised Relation Extraction, IJCNN 2021

MRefG Wanli Li and Tieyun Qian: "Exploit a Multi-head Reference Graph for Semi-supervised Relation Extraction", IJCNN 2021 1. Requirements To reproduc

万理 5 Jul 26, 2022