Sound-guided Semantic Image Manipulation - Official Pytorch Code (CVPR 2022)

Overview

πŸ”‰ Sound-guided Semantic Image Manipulation (CVPR2022)

Official Pytorch Implementation

Teaser image

Sound-guided Semantic Image Manipulation
IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR) 2022

Paper : https://arxiv.org/abs/2112.00007
Project Page: https://kuai-lab.github.io/cvpr2022sound/
Seung Hyun Lee, Wonseok Roh, Wonmin Byeon, Sang Ho Yoon, Chanyoung Kim, Jinkyu Kim*, and Sangpil Kim*

Abstract: The recent success of the generative model shows that leveraging the multi-modal embedding space can manipulate an image using text information. However, manipulating an image with other sources rather than text, such as sound, is not easy due to the dynamic characteristics of the sources. Especially, sound can convey vivid emotions and dynamic expressions of the real world. Here, we propose a framework that directly encodes sound into the multi-modal~(image-text) embedding space and manipulates an image from the space. Our audio encoder is trained to produce a latent representation from an audio input, which is forced to be aligned with image and text representations in the multi-modal embedding space. We use a direct latent optimization method based on aligned embeddings for sound-guided image manipulation. We also show that our method can mix different modalities, i.e., text and audio, which enrich the variety of the image modification. The experiments on zero-shot audio classification and semantic-level image classification show that our proposed model outperforms other text and sound-guided state-of-the-art methods.

πŸ’Ύ Installation

For all the methods described in the paper, is it required to have:

Specific requirements for each method are described in its section. To install CLIP please run the following commands:

conda install --yes -c pytorch pytorch=1.7.1 torchvision cudatoolkit=<CUDA_VERSION>
pip install ftfy regex tqdm gdown
pip install git+https://github.com/openai/CLIP.git

πŸ”¨ Method

Method image

1. CLIP-based Contrastive Latent Representation Learning.

Dataset Curation.

We create an audio-text pair dataset with the vggsound dataset. We also used the audioset dataset as the script below.

  1. Please download vggsound.csv from the link.
  2. Execute download.py to download the audio file of the vggsound dataset.
  3. Execute curate.py to preprocess the audio file (wav to mel-spectrogram).
cd soundclip
python3 download.py
python3 curate.py

Training.

python3 train.py

2. Sound-Guided Image Manipulation.

Direct Latent Code Optimization.

The code relies on the StyleCLIP pytorch implementation.

python3 optimization/run_optimization.py --lambda_similarity 0.002 --lambda_identity 0.0 --truncation 0.7 --lr 0.1 --audio_path "./audiosample/explosion.wav" --ckpt ./pretrained_models/landscape.pt --stylegan_size 256

β›³ Results

Zero-shot Audio Classification Accuracy.

Model Supervised Setting Zero-Shot ESC-50 UrbanSound 8K
ResNet50 βœ… - 66.8% 71.3%
Ours (Without Self-Supervised) - - 58.7% 63.3%
✨ Ours (Logistic Regression) - - 72.2% 66.8%
Wav2clip - βœ… 41.4% 40.4%
AudioCLIP - βœ… 69.4% 68.8%
Ours (Without Self-Supervised) - βœ… 49.4% 45.6%
✨ Ours - βœ… 57.8% 45.7%

Manipulation Results.

LSUN. LSUN image

FFHQ. FFHQ image

To see more diverse examples, please visit our project page!

Citation

@article{lee2021sound,
    title={Sound-Guided Semantic Image Manipulation},
    author={Lee, Seung Hyun and Roh, Wonseok and Byeon, Wonmin and Yoon, Sang Ho and Kim, Chan Young and Kim, Jinkyu and Kim, Sangpil},
    journal={arXiv preprint arXiv:2112.00007},
    year={2021}
}
Issues
  • add Gradio Demo for cvpr 2022 call for demos

    add Gradio Demo for cvpr 2022 call for demos

    Hi, would you be interested in adding sound-guided-semantic-image-manipulation to Hugging Face as a Gradio Web Demo for CVPR 2022 call for Demos? The Hub offers free hosting, and it would make your work more accessible and visible to the rest of the ML community. Models/datasets/spaces(web demos) can be added to a user account or organization similar to github.

    more info on CVPR call for demos: https://huggingface.co/CVPR

    and here are guides for adding web demo to the org

    How to add a Space: https://huggingface.co/blog/gradio-spaces

    Please let us know if you would be interested and if you have any questions, we can also help with the technical implementation.

    opened by AK391 0
  • About StyleGAN3

    About StyleGAN3

    The main code is borrowed from the link below Link : https://github.com/ouhenio/StyleGAN3-CLIP-notebooks

    StyleGAN3 + Our CLIP-based sound representation

    import sys
    
    import io
    import os, time, glob
    import pickle
    import shutil
    import numpy as np
    from PIL import Image
    import torch
    import torch.nn.functional as F
    import requests
    import torchvision.transforms as transforms
    import torchvision.transforms.functional as TF
    import clip
    import unicodedata
    import re
    from tqdm import tqdm
    from torchvision.transforms import Compose, Resize, ToTensor, Normalize
    from einops import rearrange
    from collections import OrderedDict
    
    import timm
    import librosa
    import cv2
    
    def make_transform(translate, angle):
        m = np.eye(3)
        s = np.sin(angle/360.0*np.pi*2)
        c = np.cos(angle/360.0*np.pi*2)
        m[0][0] = c
        m[0][1] = s
        m[0][2] = translate[0]
        m[1][0] = -s
        m[1][1] = c
        m[1][2] = translate[1]
        return m
        
    class AudioEncoder(torch.nn.Module):
        def __init__(self):
            super(AudioEncoder, self).__init__()
            self.conv = torch.nn.Conv2d(1, 3, (3, 3))
            self.feature_extractor = timm.create_model("resnet18", num_classes=512, pretrained=True)
    
        def forward(self, x):
            x = self.conv(x)
            x = self.feature_extractor(x)
            return x
    
    def copyStateDict(state_dict):
        if list(state_dict.keys())[0].startswith("module"):
            start_idx = 1
        else:
            start_idx = 0
        new_state_dict = OrderedDict()
        for k, v in state_dict.items():
            name = ".".join(k.split(".")[start_idx:])
            new_state_dict[name] = v
        return new_state_dict
    
    class CLIP(object):
      def __init__(self):
        clip_model = "ViT-B/32"
        self.model, _ = clip.load(clip_model)
        self.model = self.model.requires_grad_(False)
        self.normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],
                                              std=[0.26862954, 0.26130258, 0.27577711])
    
      @torch.no_grad()
      def embed_text(self, prompt):
          "Normalized clip text embedding."
          return norm1(self.model.encode_text(clip.tokenize(prompt).to(device)).float())
    
      def embed_cutout(self, image):
          "Normalized clip image embedding."
          # return norm1(self.model.encode_image(self.normalize(image)))
          return norm1(self.model.encode_image(image))
    
    tf = Compose([
      Resize(224),
      lambda x: torch.clamp((x+1)/2,min=0,max=1),
      ])
    
    def norm1(prompt):
        "Normalize to the unit sphere."
        return prompt / prompt.square().sum(dim=-1,keepdim=True).sqrt()
    
    def spherical_dist_loss(x, y):
        x = F.normalize(x, dim=-1)
        y = F.normalize(y, dim=-1)
        return (x - y).norm(dim=-1).div(2).arcsin().pow(2).mul(2)
    
    def prompts_dist_loss(x, targets, loss):
        if len(targets) == 1: # Keeps consitent results vs previous method for single objective guidance 
          return loss(x, targets[0])
        distances = [loss(x, target) for target in targets]
        return torch.stack(distances, dim=-1).sum(dim=-1)  
    
    class MakeCutouts(torch.nn.Module):
        def __init__(self, cut_size, cutn, cut_pow=1.):
            super().__init__()
            self.cut_size = cut_size
            self.cutn = cutn
            self.cut_pow = cut_pow
    
        def forward(self, input):
            sideY, sideX = input.shape[2:4]
            max_size = min(sideX, sideY)
            min_size = min(sideX, sideY, self.cut_size)
            cutouts = []
            for _ in range(self.cutn):
                size = int(torch.rand([])**self.cut_pow * (max_size - min_size) + min_size)
                offsetx = torch.randint(0, sideX - size + 1, ())
                offsety = torch.randint(0, sideY - size + 1, ())
                cutout = input[:, :, offsety:offsety + size, offsetx:offsetx + size]
                cutouts.append(F.adaptive_avg_pool2d(cutout, self.cut_size))
            return torch.cat(cutouts)
    
    make_cutouts = MakeCutouts(224, 32, 0.5)
    
    def embed_image(image):
      n = image.shape[0]
      cutouts = make_cutouts(image)
      embeds = clip_model.embed_cutout(cutouts)
      embeds = rearrange(embeds, '(cc n) c -> cc n c', n=n)
      return embeds
    
    def run(timestring):
      torch.manual_seed(seed)
    
      # Init
      # Sample 32 inits and choose the one closest to prompt
    
      with torch.no_grad():
        qs = []
        losses = []
        for _ in range(8):
          q = (G.mapping(torch.randn([4,G.mapping.z_dim], device=device), None, truncation_psi=0.7) - G.mapping.w_avg) / w_stds
          images = G.synthesis(q * w_stds + G.mapping.w_avg)
          embeds = embed_image(images.add(1).div(2))
          loss = prompts_dist_loss(embeds, targets, spherical_dist_loss).mean(0)
          i = torch.argmin(loss)
          qs.append(q[i])
          losses.append(loss[i])
        qs = torch.stack(qs)
        losses = torch.stack(losses)
        i = torch.argmin(losses)
        q = qs[i].unsqueeze(0).requires_grad_()
    
      w_init = (q * w_stds + G.mapping.w_avg).detach().clone()
      # Sampling loop
      q_ema = q
      opt = torch.optim.AdamW([q], lr=0.03, betas=(0.0,0.999))
      loop = tqdm(range(steps))
      for i in loop:
        opt.zero_grad()
        w = q * w_stds + G.mapping.w_avg
        image = G.synthesis(w , noise_mode='const')
        embed = embed_image(image.add(1).div(2))
        loss = 0.1 *  prompts_dist_loss(embed, targets, spherical_dist_loss).mean() + ((w - w_init) ** 2).mean()
        # loss = prompts_dist_loss(embed, targets, spherical_dist_loss).mean()
        loss.backward()
        opt.step()
        loop.set_postfix(loss=loss.item(), q_magnitude=q.std().item())
    
        q_ema = q_ema * 0.9 + q * 0.1
    
        final_code = q_ema * w_stds + G.mapping.w_avg
        final_code[:,6:,:] = w_init[:,6:,:]
        image = G.synthesis(final_code, noise_mode='const')
    
        if i % 10 == 9 or i % 10 == 0:
          # display(TF.to_pil_image(tf(image)[0]))
          print(f"Image {i}/{steps} | Current loss: {loss}")
          pil_image = TF.to_pil_image(image[0].add(1).div(2).clamp(0,1).cpu())
          os.makedirs(f'samples/{timestring}', exist_ok=True)
          pil_image.save(f'samples/{timestring}/{i:04}.jpg')
    
    
    device = torch.device('cuda:0')
    print('Using device:', device, file=sys.stderr)
    
    model_url = "./pretrained_models/stylegan3-r-afhqv2-512x512.pkl"
    
    with open(model_url, 'rb') as fp:
      G = pickle.load(fp)['G_ema'].to(device)
    
    zs = torch.randn([100000, G.mapping.z_dim], device=device)
    w_stds = G.mapping(zs, None).std(0)
    
    m = make_transform([0,0], 0)
    m = np.linalg.inv(m)
    G.synthesis.input.transform.copy_(torch.from_numpy(m))
    # audio_paths = "./audio/sweet-kitty-meow.wav"
    #audio_paths = "./audio/dog-sad.wav"
    audio_paths = "./audio/cartoon-voice-laugh.wav"
    steps = 200
    seed = 14 + 22
    #seed = 22
    
    audio_paths = [frase.strip() for frase in audio_paths.split("|") if frase]
    
    clip_model = CLIP()
    audio_encoder = AudioEncoder()
    audio_encoder.load_state_dict(copyStateDict(torch.load("./pretrained_models/resnet18.pth", map_location=device)))
    audio_encoder = audio_encoder.to(device)
    audio_encoder.eval()
    
    targets = []
    n_mels = 128
    time_length = 864
    resize_resolution = 512
    
    for audio_path in audio_paths:
        y, sr = librosa.load(audio_path, sr=44100)
        audio_inputs = librosa.feature.melspectrogram(y=y, sr=sr, n_mels=n_mels)
        audio_inputs = librosa.power_to_db(audio_inputs, ref=np.max) / 80.0 + 1
    
        zero = np.zeros((n_mels, time_length))
        h, w = audio_inputs.shape
        if w >= time_length:
            j = (w - time_length) // 2
            audio_inputs = audio_inputs[:,j:j+time_length]
        else:
            j = (time_length - w) // 2
            zero[:,:w] = audio_inputs[:,:w]
            audio_inputs = zero
        
        audio_inputs = cv2.resize(audio_inputs, (n_mels, resize_resolution))
        audio_inputs = np.array([audio_inputs])
        audio_inputs = torch.from_numpy(audio_inputs.reshape((1, 1, n_mels, resize_resolution))).float().to(device)
        with torch.no_grad():
            audio_embedding = audio_encoder(audio_inputs)
            audio_embedding = audio_embedding / audio_embedding.norm(dim=-1, keepdim=True)
        targets.append(audio_embedding)
    
    timestring = time.strftime('%Y%m%d%H%M%S')
    run(timestring)
    
    opened by lsh3163 0
Owner
CVLAB
CVLAB in Department of artificial intelligence, Korea University
CVLAB
[CVPR 2022] CoTTA Code for our CVPR 2022 paper Continual Test-Time Domain Adaptation

CoTTA Code for our CVPR 2022 paper Continual Test-Time Domain Adaptation Prerequisite Please create and activate the following conda envrionment. To r

Qin Wang 42 Jun 24, 2022
ManiSkill-Learn is a framework for training agents on SAPIEN Open-Source Manipulation Skill Challenge (ManiSkill Challenge), a large-scale learning-from-demonstrations benchmark for object manipulation.

ManiSkill-Learn ManiSkill-Learn is a framework for training agents on SAPIEN Open-Source Manipulation Skill Challenge, a large-scale learning-from-dem

Hao Su's Lab, UCSD 43 Jun 21, 2022
Official PyTorch implementation of "RMGN: A Regional Mask Guided Network for Parser-free Virtual Try-on" (IJCAI-ECAI 2022)

RMGN-VITON RMGN: A Regional Mask Guided Network for Parser-free Virtual Try-on In IJCAI-ECAI 2022(short oral). [Paper] [Supplementary Material] Abstra

null 18 Jun 22, 2022
[CVPR 2022] Back To Reality: Weak-supervised 3D Object Detection with Shape-guided Label Enhancement

Back To Reality: Weak-supervised 3D Object Detection with Shape-guided Label Enhancement Announcement ?? We have not tested the code yet. We will fini

Xiuwei Xu 6 Mar 30, 2022
The 7th edition of NTIRE: New Trends in Image Restoration and Enhancement workshop will be held on June 2022 in conjunction with CVPR 2022.

NTIRE 2022 - Image Inpainting Challenge Important dates 2022.02.01: Release of train data (input and output images) and validation data (only input) 2

AndrΓ©s Romero 25 Jun 22, 2022
Learning Pixel-level Semantic Affinity with Image-level Supervision for Weakly Supervised Semantic Segmentation, CVPR 2018

Learning Pixel-level Semantic Affinity with Image-level Supervision This code is deprecated. Please see https://github.com/jiwoon-ahn/irn instead. Int

Jiwoon Ahn 317 Jun 22, 2022
Official PyTorch implementation of the paper "Deep Constrained Least Squares for Blind Image Super-Resolution", CVPR 2022.

Deep Constrained Least Squares for Blind Image Super-Resolution [Paper] This is the official implementation of 'Deep Constrained Least Squares for Bli

MEGVII Research 70 Jun 23, 2022
Official implementation of "SinIR: Efficient General Image Manipulation with Single Image Reconstruction" (ICML 2021)

SinIR (Official Implementation) Requirements To install requirements: pip install -r requirements.txt We used Python 3.7.4 and f-strings which are in

null 46 Jun 19, 2022
[CVPR 2022] Official Pytorch code for OW-DETR: Open-world Detection Transformer

OW-DETR: Open-world Detection Transformer (CVPR 2022) [Paper] Akshita Gupta*, Sanath Narayan*, K J Joseph, Salman Khan, Fahad Shahbaz Khan, Mubarak Sh

Akshita Gupta 73 Jun 23, 2022
Provided is code that demonstrates the training and evaluation of the work presented in the paper: "On the Detection of Digital Face Manipulation" published in CVPR 2020.

FFD Source Code Provided is code that demonstrates the training and evaluation of the work presented in the paper: "On the Detection of Digital Face M

null 85 May 31, 2022
Official implementation of the paper 'Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution' in CVPR 2022

LDL Paper | Supplementary Material Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution Jie Liang*, Hu

null 102 Jun 24, 2022
Official implementation for "Style Transformer for Image Inversion and Editing" (CVPR 2022)

Style Transformer for Image Inversion and Editing (CVPR2022) https://arxiv.org/abs/2203.07932 Existing GAN inversion methods fail to provide latent co

Xueqi Hu 109 Jun 28, 2022
Imposter-detector-2022 - HackED 2022 Team 3IQ - 2022 Imposter Detector

HackED 2022 Team 3IQ - 2022 Imposter Detector By Aneeljyot Alagh, Curtis Kan, Jo

Joshua Ji 4 Jan 27, 2022
[CVPR 2022] Semi-Supervised Semantic Segmentation Using Unreliable Pseudo-Labels

Using Unreliable Pseudo Labels Official PyTorch implementation of Semi-Supervised Semantic Segmentation Using Unreliable Pseudo Labels, CVPR 2022. Ple

Haochen Wang 151 Jun 24, 2022
Contrastive learning of Class-agnostic Activation Map for Weakly Supervised Object Localization and Semantic Segmentation (CVPR 2022)

CCAM (Unsupervised) Code repository for our paper "CCAM: Contrastive learning of Class-agnostic Activation Map for Weakly Supervised Object Localizati

Computer Vision Insitute, SZU 46 Jun 20, 2022
Scribble-Supervised LiDAR Semantic Segmentation, CVPR 2022 (ORAL)

Scribble-Supervised LiDAR Semantic Segmentation Dataset and code release for the paper Scribble-Supervised LiDAR Semantic Segmentation, CVPR 2022 (ORA

null 52 Jun 1, 2022
Official Pytorch implementation of "Learning to Estimate Robust 3D Human Mesh from In-the-Wild Crowded Scenes", CVPR 2022

Learning to Estimate Robust 3D Human Mesh from In-the-Wild Crowded Scenes / 3DCrowdNet News ?? 3DCrowdNet achieves the state-of-the-art accuracy on 3D

Hongsuk Choi 74 Jun 27, 2022
Commonality in Natural Images Rescues GANs: Pretraining GANs with Generic and Privacy-free Synthetic Data - Official PyTorch Implementation (CVPR 2022)

Commonality in Natural Images Rescues GANs: Pretraining GANs with Generic and Privacy-free Synthetic Data (CVPR 2022) Potentials of primitive shapes f

null 29 Jun 27, 2022