Robust fine-tuning of zero-shot models

Overview

Robust fine-tuning of zero-shot models

This repository contains code for the paper Robust fine-tuning of zero-shot models by Mitchell Wortsman*, Gabriel Ilharco*, Jong Wook Kim, Mike Li, Simon Kornblith, Rebecca Roelofs, Raphael Gontijo-Lopes, Hannaneh Hajishirzi, Ali Farhadi, Hongseok Namkoong, Ludwig Schmidt.

Abstract

Large pre-trained models such as CLIP offer consistent accuracy across a range of data distributions when performing zero-shot inference (i.e., without fine-tuning on a specific dataset). Although existing fine-tuning approaches substantially improve accuracy in-distribution, they also reduce out-of-distribution robustness. We address this tension by introducing a simple and effective method for improving robustness: ensembling the weights of the zero-shot and fine-tuned models. Compared to standard fine-tuning, the resulting weight-space ensembles provide large accuracy improvements out-of-distribution, while matching or improving in-distribution accuracy. On ImageNet and five derived distribution shifts, weight-space ensembles improve out-of-distribution accuracy by 2 to 10 percentage points while increasing in-distribution accuracy by nearly 1 percentage point relative to standard fine-tuning. These improvements come at no additional computational cost during fine-tuning or inference.

Summary figure

figure1

Compared to standard fine-tuning, weight-space ensembles for fine-tuning (WiSE-FT) improve out-of-distribution (OOD) accuracy without decreasing in-distribution (ID) performance. Top left: Zero-shot CLIP models exhibit high effective robustness and moderate in-distribution accuracy, while standard fine-tuning (end-to-end or with a linear classifier) attains higher ID accuracy and less effective robustness. Top right: Our method linearly interpolates between the zero-shot and fine-tuned models with a mixing coefficient alpha in [0,1]. Bottom: On five distribution shifts derived from ImageNet (ImageNetV2, ImageNet-R, ImageNet Sketch, ObjectNet, and ImageNet-A), WiSE-FT improves average OOD accuracy by 8.7 percentage points (pp) when fine-tuning end-to-end (+2.1 pp when fine-tuning a linear classifier) while maintaining ID accuracy.

Code

Overview

WiSE-FT can be implemented in a few lines of code in addition to standard fine-tuning, as shown below. See src/wise_ft.py for more details.

# Load models
zeroshot = ImageClassifier.load(zeroshot_checkpoint)
finetuned = ImageClassifier.load(finetuned_checkpoint)
theta_0 = zeroshot.state_dict()
theta_1 = finetuned.state_dict()

# make sure checkpoints are compatible
assert set(theta_0.keys()) == set(theta_1.keys())

# interpolate between checkpoints with mixing coefficient alpha
theta = {
    key: (1-alpha) * theta_0[key] + alpha * theta_1[key]
    for key in theta_0.keys()
}

# update the model acccording to the new weights
finetuned.load_state_dict(theta)

# evaluate
evaluate(finetuned, args)

Install dependencies

conda env create
conda activate wiseft

Add directory to PYTHONPATH:

cd wise-ft
export PYTHONPATH="$PYTHONPATH:$PWD"

Download data

When necessary, please refer to datasets.md for instructions on how to download datasets.

Run WiSE-FT

Sample command when zeroshot and fine-tuned models are available:

python src/wise_ft.py   \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --load=models/zeroshot.pt,models/finetuned.pt  \
    --results-db=results.jsonl  \
    --save=models/wiseft  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

Sample command for running WiSE-FT from scratch using ViT-B/32:

python src/wise_ft.py   \
    --train-dataset=ImageNet  \
    --epochs=10  \
    --lr=0.00003  \
    --batch-size=512  \
    --cache-dir=cache  \
    --model=ViT-B/32  \
    --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --template=openai_imagenet_template  \
    --results-db=results.jsonl  \
    --save=models/wiseft/ViTB32  \
    --data-location=~/data \
    --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0

Note: the flag --freeze-encoder controls whether only a linear classifier is fine-tuned, or if all weights are fine-tuned (end-to-end).

Plotting results

Sample command for generating a scatter plot:

python src/scatter_plot.py  \
    --eval-datasets=ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
    --results-db=results.jsonl  \
    --save plots

We show samples of expected behavior below when running the commands above using ViT-B/16 (models can be downloaded here):

ImageNet-Sketch         ImageNet-A

ImageNet-R         ImageNetV2

ObjectNet

Citing

If you found this repository useful, please consider citing:

@article{wortsman2021robust,
  title={Robust fine-tuning of zero-shot models},
  author={Wortsman, Mitchell and Ilharco, Gabriel and Kim, Jong Wook and Li, Mike and Kornblith, Simon and Roelofs, Rebecca and Gontijo-Lopes, Raphael and Hajishirzi, Hannaneh and Farhadi, Ali and Namkoong, Hongseok and Schmidt, Ludwig},
  journal={arXiv preprint arXiv:2109.01903},
  note={\url{https://arxiv.org/abs/2109.01903}},
  year={2021}
}
Comments
  • Model Interpolation of Models of Different Size (#layers, hidden_size, intermediate_size, attention_head)

    Model Interpolation of Models of Different Size (#layers, hidden_size, intermediate_size, attention_head)

    I am trying to add two model checkpoints [ViT-B16 transformer] of different no. of layers, hidden size, attention heads and intermediate size using the same model interpolation way shown in the paper. Could you provide/suggest a minimal sample code or pseudo code for this as the way shown in the paper just works for the same models?

    opened by sanyalsunny111 3
  • Replicating few-shot results

    Replicating few-shot results

    In Table 7 of the paper, there are results showing Wise-FT with a linear classifier and the ViT/B-16 backbone can get 73% accuracy on a 16-shot imagenet dataset. It was mentioned that the learning rate was 10e-5 and it was trained for 10 epochs, but even with this information, I still cannot replicate the result shown in the paper. I was wondering if I could be provided with an exact command, or additional hyperparameters (e.g. batch size, number of warmup steps, etc.) so that this result can be replicated?

    opened by samuelyu2002 3
  • Added first pass of Fisher computation and merging

    Added first pass of Fisher computation and merging

    Fisher computation for the fine-tuned model works as well as merging. However, I run into errors when computing the Fisher for the zero-shot model.

    Here is the error. The input to the classification head has shape [1, 3, 224, 224] for the zero-shot model.

    Traceback (most recent call last):                                                                                                                              
      File "src/models/fisher.py", line 135, in <module>
        compute_fisher(args)
      File "src/models/fisher.py", line 105, in compute_fisher
        logits = utils.get_logits(inputs, model)
      File "/home/owner/Desktop/projects/wise-ft/src/models/utils.py", line 76, in get_logits
        return classifier(inputs)
      File "/home/owner/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
        result = self.forward(*input, **kwargs)
      File "/home/owner/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/nn/parallel/data_parallel.py", line 159, in forward
        return self.module(*inputs[0], **kwargs[0])
      File "/home/owner/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
        result = self.forward(*input, **kwargs)
      File "/home/owner/Desktop/projects/wise-ft/src/models/modeling.py", line 77, in forward
        outputs = self.classification_head(inputs)
      File "/home/owner/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
        result = self.forward(*input, **kwargs)
      File "/home/owner/Desktop/projects/wise-ft/src/models/modeling.py", line 52, in forward
        return super().forward(inputs)
      File "/home/owner/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/nn/modules/linear.py", line 93, in forward
        return F.linear(input, self.weight, self.bias)
      File "/home/owner/miniconda3/envs/wiseft/lib/python3.6/site-packages/torch/nn/functional.py", line 1692, in linear
        output = input.matmul(weight.t())
    RuntimeError: mat1 dim 1 must match mat2 dim 0
    
    
    opened by mmatena 3
  • Question about Table. 2 in the paper

    Question about Table. 2 in the paper

    image

    How can we get the results in above figure. Do we need to design text prompts for each task and use them to init the classification head?

    I try to add the classification head with random init weights, but get poor results for WiSE-FT.

    opened by vtddggg 2
  • zero-shot model

    zero-shot model

    Hi,

    I would like to use the WiSE-FT method to other tasks or pretrained models (e.g., bert, gpt). In this context, the so-called zero-shot model is actually the orignial model without fine-tuning, right? and the zero-model parameters actually means the directly-loaded pretrained parameters?

    Thank you!

    opened by LindgeW 2
  • Does fine-tune only tweak image encoder?

    Does fine-tune only tweak image encoder?

    First of all, thanks for sharing the codebase. I briefly went through the codes and it seems like you only fine-tune the image encoder part, is that right? If yes, I'm curious have you tried tweaking both image and text encoders?

    opened by yushuinanrong 2
  • Finetuning configs for more models

    Finetuning configs for more models

    Hi, dear authors. In this code you have provided an example for finetuning ViT-B/32:

    python src/wise_ft.py   \
        --train-dataset=ImageNet  \
        --epochs=10  \
        --lr=0.00003  \
        --batch-size=512  \
        --cache-dir=cache  \
        --model=ViT-B/32  \
        --eval-datasets=ImageNet,ImageNetV2,ImageNetR,ImageNetA,ImageNetSketch  \
        --template=openai_imagenet_template  \
        --results-db=results.jsonl  \
        --save=models/wiseft/ViTB32  \
        --data-location=~/data \
        --alpha 0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0
    

    By runing it, I can get the final WISE-FT results at \alpha=0.5 below:

    ImageNet Top-1 accuracy: 0.7554
    ImageNetR Top-1 accuracy: 0.7145
    ImageNetA Top-1 accuracy: 0.3452
    ImageNetSketch Top-1 accuracy: 0.4696
    
    • Is the result correctly aligned with your results? Since I cannot find official results for ViT-B/32 in paper, I just want to ensure that I run the code correctly.
    • What hyper-parameter config for other models, such as ViT-L, ViT-B, etc?
    opened by vtddggg 1
  • Poor performance on ResNet.

    Poor performance on ResNet.

    Although good performace obtained by fine tuning ViT model, I found the poor performance on the ResNet models. Thus, How to fine tune the CLIP model by using pre-trained ResNet models? Thanks.

    opened by jingzhengli 3
Owner
null
Code for T-Few from "Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning"

T-Few This repository contains the official code for the paper: "Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learni

null 220 Dec 31, 2022
This repository is the official implementation of Unleashing the Power of Contrastive Self-Supervised Visual Models via Contrast-Regularized Fine-Tuning (NeurIPS21).

Core-tuning This repository is the official implementation of ``Unleashing the Power of Contrastive Self-Supervised Visual Models via Contrast-Regular

vanint 18 Dec 17, 2022
Black-Box-Tuning - Black-Box Tuning for Language-Model-as-a-Service

Black-Box-Tuning Source code for paper "Black-Box Tuning for Language-Model-as-a

Tianxiang Sun 149 Jan 4, 2023
The source code for Generating Training Data with Language Models: Towards Zero-Shot Language Understanding.

SuperGen The source code for Generating Training Data with Language Models: Towards Zero-Shot Language Understanding. Requirements Before running, you

Yu Meng 38 Dec 12, 2022
Code for ACL2021 paper Consistency Regularization for Cross-Lingual Fine-Tuning.

xTune Code for ACL2021 paper Consistency Regularization for Cross-Lingual Fine-Tuning. Environment DockerFile: dancingsoul/pytorch:xTune Install the f

Bo Zheng 42 Dec 9, 2022
Jihye Back 520 Jan 4, 2023
Official codebase for Legged Robots that Keep on Learning: Fine-Tuning Locomotion Policies in the Real World

Legged Robots that Keep on Learning Official codebase for Legged Robots that Keep on Learning: Fine-Tuning Locomotion Policies in the Real World, whic

Laura Smith 70 Dec 7, 2022
Fine-tuning StyleGAN2 for Cartoon Face Generation

Cartoon-StyleGAN ?? : Fine-tuning StyleGAN2 for Cartoon Face Generation Abstract Recent studies have shown remarkable success in the unsupervised imag

Jihye Back 520 Jan 4, 2023
Example Of Fine-Tuning BERT For Named-Entity Recognition Task And Preparing For Cloud Deployment Using Flask, React, And Docker

Example Of Fine-Tuning BERT For Named-Entity Recognition Task And Preparing For Cloud Deployment Using Flask, React, And Docker This repository contai

Nikita 12 Dec 14, 2022
Implementation of the paper "Fine-Tuning Transformers: Vocabulary Transfer"

Transformer-vocabulary-transfer Implementation of the paper "Fine-Tuning Transfo

LEYA 13 Nov 30, 2022
Ensemble Knowledge Guided Sub-network Search and Fine-tuning for Filter Pruning

Ensemble Knowledge Guided Sub-network Search and Fine-tuning for Filter Pruning This repository is official Tensorflow implementation of paper: Ensemb

Seunghyun Lee 12 Oct 18, 2022
code for the ICLR'22 paper: On Robust Prefix-Tuning for Text Classification

On Robust Prefix-Tuning for Text Classification Prefix-tuning has drawed much attention as it is a parameter-efficient and modular alternative to adap

Zonghan Yang 12 Nov 30, 2022
Codes for "Template-free Prompt Tuning for Few-shot NER".

EntLM The source codes for EntLM. Dependencies: Cuda 10.1, python 3.6.5 To install the required packages by following commands: $ pip3 install -r requ

null 77 Dec 27, 2022
Official PyTorch implementation of N-ImageNet: Towards Robust, Fine-Grained Object Recognition with Event Cameras (ICCV 2021)

N-ImageNet: Towards Robust, Fine-Grained Object Recognition with Event Cameras Official PyTorch implementation of N-ImageNet: Towards Robust, Fine-Gra

null 32 Dec 26, 2022
PyTorch implementation of NeurIPS 2021 paper: "CoFiNet: Reliable Coarse-to-fine Correspondences for Robust Point Cloud Registration"

PyTorch implementation of NeurIPS 2021 paper: "CoFiNet: Reliable Coarse-to-fine Correspondences for Robust Point Cloud Registration"

null 76 Jan 3, 2023
PyTorch implementation of Weak-shot Fine-grained Classification via Similarity Transfer

SimTrans-Weak-Shot-Classification This repository contains the official PyTorch implementation of the following paper: Weak-shot Fine-grained Classifi

BCMI 60 Dec 2, 2022
[CVPR 2021] Released code for Counterfactual Zero-Shot and Open-Set Visual Recognition

Counterfactual Zero-Shot and Open-Set Visual Recognition This project provides implementations for our CVPR 2021 paper Counterfactual Zero-S

null 144 Dec 24, 2022
SC-GlowTTS: an Efficient Zero-Shot Multi-Speaker Text-To-Speech Model

SC-GlowTTS: an Efficient Zero-Shot Multi-Speaker Text-To-Speech Model Edresson Casanova, Christopher Shulby, Eren Gölge, Nicolas Michael Müller, Frede

Edresson Casanova 92 Dec 9, 2022
code for CVPR paper Zero-shot Instance Segmentation

Code for CVPR2021 paper Zero-shot Instance Segmentation Code requirements python: python3.7 nvidia GPU pytorch1.1.0 GCC >=5.4 NCCL 2 the other python

zhengye 86 Dec 13, 2022