Domain Generalization with MixStyle, ICLR'21.

Overview

MixStyle

This repo contains the code of our ICLR'21 paper, "Domain Generalization with MixStyle".

The OpenReview link is https://openreview.net/forum?id=6xHJ37MVxxp.

########## Updates ############

12-04-2021: A variable self._activated is added to MixStyle to better control the computational flow. To deactivate MixStyle without modifying the model code, one can do

def deactivate_mixstyle(m):
    if type(m) == MixStyle:
        m.set_activation_status(False)

model.apply(deactivate_mixstyle)

Similarly, to activate MixStyle, one can do

def activate_mixstyle(m):
    if type(m) == MixStyle:
        m.set_activation_status(True)

model.apply(activate_mixstyle)

Note that MixStyle has been included in Dassl.pytorch. See the code for details.

05-03-2021: You might also be interested in our recently released survey on domain generalization at https://arxiv.org/abs/2103.02503, which summarizes the ten-year development in domain generalization, with coverage on the history, datasets, related problems, methodologies, potential directions, and so on.

##############################

A brief introduction: The key idea of MixStyle is to probablistically mix instance-level feature statistics of training samples across source domains. MixStyle improves model robustness to domain shift by implicitly synthesizing new domains at the feature level for regularizing the training of convolutional neural networks. This idea is largely inspired by neural style transfer which has shown that feature statistics are closely related to image style and therefore arbitrary image style transfer can be achieved by switching the feature statistics between a content and a style image.

MixStyle is very easy to implement. Below we show the PyTorch code of MixStyle.

import random
import torch
import torch.nn as nn


class MixStyle(nn.Module):
    """MixStyle.

    Reference:
      Zhou et al. Domain Generalization with MixStyle. ICLR 2021.
    """

    def __init__(self, p=0.5, alpha=0.1, eps=1e-6):
        """
        Args:
          p (float): probability of using MixStyle.
          alpha (float): parameter of the Beta distribution.
          eps (float): scaling parameter to avoid numerical issues.
        """
        super().__init__()
        self.p = p
        self.beta = torch.distributions.Beta(alpha, alpha)
        self.eps = eps
        self.alpha = alpha

        self._activated = True

    def __repr__(self):
        return f'MixStyle(p={self.p}, alpha={self.alpha}, eps={self.eps})'

    def set_activation_status(self, status=True):
        self._activated = status

    def forward(self, x):
        if not self.training or not self._activated:
            return x

        if random.random() > self.p:
            return x

        B = x.size(0)

        mu = x.mean(dim=[2, 3], keepdim=True)
        var = x.var(dim=[2, 3], keepdim=True)
        sig = (var + self.eps).sqrt()
        mu, sig = mu.detach(), sig.detach()
        x_normed = (x-mu) / sig

        lmda = self.beta.sample((B, 1, 1, 1))
        lmda = lmda.to(x.device)

        perm = torch.randperm(B)
        mu2, sig2 = mu[perm], sig[perm]
        mu_mix = mu*lmda + mu2 * (1-lmda)
        sig_mix = sig*lmda + sig2 * (1-lmda)

        return x_normed*sig_mix + mu_mix

How to apply MixStyle to your CNN models? Say you are using ResNet as the CNN architecture, and want to apply MixStyle after the 1st and 2nd residual blocks, you can first instantiate the MixStyle module using

self.mixstyle = MixStyle(p=0.5, alpha=0.1)

during network construction (in __init__()), and then apply MixStyle in the forward pass like

def forward(self, x):
    x = self.conv1(x) # 1st convolution layer
    x = self.res1(x) # 1st residual block
    x = self.mixstyle(x)
    x = self.res2(x) # 2nd residual block
    x = self.mixstyle(x)
    x = self.res3(x) # 3rd residual block
    x = self.res4(x) # 4th residual block
    ...

In our paper, we have demonstrated the effectiveness of MixStyle on three tasks: image classification, person re-identification, and reinforcement learning. The source code for reproducing all experiments can be found in mixstyle-release/imcls, mixstyle-release/reid, and mixstyle-release/rl, respectively.

Takeaways on applying MixStyle to your tasks:

  • Applying MixStyle to multiple lower layers is generally better
  • Do not apply MixStyle to the last layer that is the closest to the prediction layer
  • Different tasks might favor different combinations

For more analytical studies, please read our paper at https://openreview.net/forum?id=6xHJ37MVxxp.

To cite MixStyle in your publications, please use the following bibtex entry

@inproceedings{zhou2021mixstyle,
  title={Domain Generalization with MixStyle},
  author={Zhou, Kaiyang and Yang, Yongxin and Qiao, Yu and Xiang, Tao},
  booktitle={ICLR},
  year={2021}
}
Comments
  • I can't run this code

    I can't run this code

    I'm following the readme, and I get an error: FileNotFoundError: [Errno 2] No such file or directory:'/pub/data/cuiby/yjworkspace/Dassl.pytorch/data/office_home_dg/art/train' [Errno 2] No such file or directory: '/pub/data/cuiby/yjworkspace/Dassl.pytorch/data/office_home_dg/clipart/train'

    I can't find the instructions for that anywhere

    opened by PPJ001 9
  • Evaluate on PACS

    Evaluate on PACS

    Hi @KaiyangZhou , Thanks for sharing the code. I have the following questions:

    1. How many times did you repeat on PACS?
    2. Would you mind sharing the standard deviation of your PACS performances?
    3. How did you select the model to report your results?

    Thanks!

    opened by GA-17a 3
  • May a little bug in ssdg1.sh/ssdg2.sh

    May a little bug in ssdg1.sh/ssdg2.sh

    Excuse, when I run this:

    CUDA_VISIBLE_DEVICES=0 bash ssdg2.sh ssdg_pacs resnet18_ms_l123

    I get this: ssdg2.sh: line 62: syntax error: unexpected end of file

    at the beginning, I thought it may be fileformat mistake that shell was written under windows but used under linux.

    However I find there is "\" in line 59, while there isn't in dg.sh. After deleting "\", code works.

    I'm not familiar with Shell, so I don't know if it's a mistake.

    Looking forwark to your reply and I will close it soon.

    opened by judgingalready 1
  • Issue on the PACS dataset

    Issue on the PACS dataset

    Hi, i found the link of PACS (http://www.eecs.qmul.ac.uk/~dl307/project_iccv2017) is unavailable now, can you share a link to download the PACS? Thanks.

    opened by ericxian1997 1
  • Need clarification

    Need clarification

    I was looking resnet with mixstyle implementation and I found resent https://github.com/KaiyangZhou/Dassl.pytorch/blob/master/dassl/modeling/backbone/resnet.py Also, I found https://github.com/KaiyangZhou/mixstyle-release/blob/master/reid/models/resnet_ms.py and https://github.com/KaiyangZhou/mixstyle-release/blob/master/reid/models/resnet_ms2.py

    I was not able to differentiate between all these resnet implementations. All are having mixstyle layers. Could you briefly tell me what are differences between them? Thank you for clarification

    opened by noaman1989 0
  • Accuracy of Sketch

    Accuracy of Sketch

    ① In the PACS data set, the effect described in the paper cannot be achieved when the target domain is Sketch. Is the epoch set wrong? ② As mentioned in the paper, it is found that when the target is Sketch, the effect is much worse than other methods. When the target domain is other, it works better than the other methods. What causes this? It feels interesting

    opened by PPJ001 1
  • Great work! But I met some issues

    Great work! But I met some issues

    Great works! It can get amazing results in image classification generalization. But when I run the reid code, I met this problem:

    Traceback (most recent call last): File "main.py", line 214, in main() File "main.py", line 139, in main cfg.merge_from_list(args.opts) File "/opt/conda/lib/python3.8/site-packages/yacs/config.py", line 223, in merge_from_list _assert_with_logging( File "/opt/conda/lib/python3.8/site-packages/yacs/config.py", line 545, in _assert_with_logging assert cond, msg AssertionError: Override list has odd length: ['osnet_x1_0_ms23_a0d1', 'data.save_dir', 'output/osnet_x1_0_ms23_a0d1/market2duke']; it must be a list of pairs Traceback (most recent call last): File "main.py", line 214, in main() File "main.py", line 139, in main cfg.merge_from_list(args.opts) File "/opt/conda/lib/python3.8/site-packages/yacs/config.py", line 223, in merge_from_list _assert_with_logging( File "/opt/conda/lib/python3.8/site-packages/yacs/config.py", line 545, in _assert_with_logging assert cond, msg AssertionError: Override list has odd length: ['osnet_x1_0_ms23_a0d1', 'data.save_dir', 'output/osnet_x1_0_ms23_a0d1/duke2market']; it must be a list of pairs

    And another problem is I don't find how to organize the data and where to download them, should I just put Market1501 and Duke in the same folder as "data_dir"?

    opened by a791702141 2
  • Target Detection

    Target Detection

    Thank the author for sharing the research results. If I want to apply it to the target detection task, do I only need to insert it into the shallow features?

    opened by 1320414730 0
Owner
Kaiyang
Researcher in computer vision and machine learning :)
Kaiyang
Official pytorch implementation of "Feature Stylization and Domain-aware Contrastive Loss for Domain Generalization" ACMMM 2021 (Oral)

Feature Stylization and Domain-aware Contrastive Loss for Domain Generalization This is an official implementation of "Feature Stylization and Domain-

null 22 Sep 22, 2022
Implementation for "Domain-Specific Bias Filtering for Single Labeled Domain Generalization"

DSBF Introduction This repository contains the implementation code for paper: Domain-Specific Bias Filtering for Single Labeled Domain Generalization

ScottYuan 7 Jan 5, 2023
[CVPR'21] FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space

FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space by Quande Liu, Cheng Chen, Ji

Quande Liu 178 Jan 6, 2023
The code release of paper 'Domain Generalization for Medical Imaging Classification with Linear-Dependency Regularization' NIPS 2020.

Domain Generalization for Medical Imaging Classification with Linear Dependency Regularization The code release of paper 'Domain Generalization for Me

Yufei Wang 56 Dec 28, 2022
Benchmarks for semi-supervised domain generalization.

Semi-Supervised Domain Generalization This code is the official implementation of the following paper: Semi-Supervised Domain Generalization with Stoc

Kaiyang 49 Dec 10, 2022
Official implementation of paper Gradient Matching for Domain Generalization

Gradient Matching for Domain Generalisation This is the official PyTorch implementation of Gradient Matching for Domain Generalisation. In our paper,

null 94 Dec 23, 2022
The official repository for paper ''Domain Generalization for Vision-based Driving Trajectory Generation'' submitted to ICRA 2022

DG-TrajGen The official repository for paper ''Domain Generalization for Vision-based Driving Trajectory Generation'' submitted to ICRA 2022. Our Meth

Wang 25 Sep 26, 2022
This repo includes our code for evaluating and improving transferability in domain generalization (NeurIPS 2021)

Transferability for domain generalization This repo is for evaluating and improving transferability in domain generalization (NeurIPS 2021), based on

gordon 9 Nov 29, 2022
This codebase is the official implementation of Test-Time Classifier Adjustment Module for Model-Agnostic Domain Generalization (NeurIPS2021, Spotlight)

Test-Time Classifier Adjustment Module for Model-Agnostic Domain Generalization This codebase is the official implementation of Test-Time Classifier A

null 47 Dec 28, 2022
CSAC - Collaborative Semantic Aggregation and Calibration for Separated Domain Generalization

CSAC Introduction This repository contains the implementation code for paper: Co

ScottYuan 5 Jul 22, 2022
Code for CVPR2021 "Visualizing Adapted Knowledge in Domain Transfer". Visualization for domain adaptation. #explainable-ai

Visualizing Adapted Knowledge in Domain Transfer @inproceedings{hou2021visualizing, title={Visualizing Adapted Knowledge in Domain Transfer}, auth

Yunzhong Hou 80 Dec 25, 2022
[CVPR2021] Domain Consensus Clustering for Universal Domain Adaptation

[CVPR2021] Domain Consensus Clustering for Universal Domain Adaptation [Paper] Prerequisites To install requirements: pip install -r requirements.txt

Guangrui Li 84 Dec 26, 2022
Variational Attention: Propagating Domain-Specific Knowledge for Multi-Domain Learning in Crowd Counting (ICCV, 2021)

DKPNet ICCV 2021 Variational Attention: Propagating Domain-Specific Knowledge for Multi-Domain Learning in Crowd Counting Baseline of DKPNet is availa

null 19 Oct 14, 2022
CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation

CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation [arxiv] This is the official repository for CDTrans: Cross-domain Transformer for

null 238 Dec 22, 2022
CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation

[ICCV2021] TransReID: Transformer-based Object Re-Identification [pdf] The official repository for TransReID: Transformer-based Object Re-Identificati

DamoCV 569 Dec 30, 2022
A Pytorch Implementation of [Source data‐free domain adaptation of object detector through domain

A Pytorch Implementation of Source data‐free domain adaptation of object detector through domain‐specific perturbation Please follow Faster R-CNN and

null 1 Dec 25, 2021
Library for machine learning stacking generalization.

stacked_generalization Implemented machine learning *stacking technic[1]* as handy library in Python. Feature weighted linear stacking is also availab

null 114 Jul 19, 2022
A PyTorch implementation of Sharpness-Aware Minimization for Efficiently Improving Generalization

sam.pytorch A PyTorch implementation of Sharpness-Aware Minimization for Efficiently Improving Generalization ( Foret+2020) Paper, Official implementa

Ryuichiro Hataya 102 Dec 28, 2022