[arxiv]
CDTrans: Cross-domain Transformer for Unsupervised Domain AdaptationThis is the official repository for CDTrans: Cross-domain Transformer for Unsupervised Domain Adaptation
Introduction
Unsupervised domain adaptation (UDA) aims to transfer knowledge learned from a labeled source domain to a different unlabeled target domain. Most existing UDA methods focus on learning domain-invariant feature representation, either from the domain level or category level, using convolution neural networks (CNNs)-based frameworks. With the success of Transformer in various tasks, we find that the cross-attention in Transformer is robust to the noisy input pairs for better feature alignment, thus in this paper Transformer is adopted for the challenging UDA task. Specifically, to generate accurate input pairs, we design a two-way center-aware labeling algorithm to produce pseudo labels for target samples. Along with the pseudo labels, a weight-sharing triple-branch transformer framework is proposed to apply self-attention and cross-attention for source/target feature learning and source-target domain alignment, respectively. Such design explicitly enforces the framework to learn discriminative domain-specific and domain-invariant representations simultaneously. The proposed method is dubbed CDTrans (cross-domain transformer), and it provides one of the first attempts to solve UDA tasks with a pure transformer solution. Extensive experiments show that our proposed method achieves the best performance on all public UDA datasets including Office-Home, Office-31, VisDA-2017, and DomainNet.
Results
Table 1 [UDA results on Office-31]
Methods | Avg. | A->D | A->W | D->A | D->W | W->A | W->D |
Baseline(DeiT-S) | 86.7 | 87.6 | 86.9 | 74.9 | 97.7 | 73.5 | 99.6 |
model | model | model | |||||
CDTrans(DeiT-S) | 90.4 | 94.6 | 93.5 | 78.4 | 98.2 | 78 | 99.6 |
model | model | model | model | model | model | ||
Baseline(DeiT-B) | 88.8 | 90.8 | 90.4 | 76.8 | 98.2 | 76.4 | 100 |
model | model | model | |||||
CDTrans(DeiT-B) | 92.6 | 97 | 96.7 | 81.1 | 99 | 81.9 | 100 |
model | model | model | model | model | model |
Table 2 [UDA results on Office-Home]
Methods | Avg. | Ar->Cl | Ar->Pr | Ar->Re | Cl->Ar | Cl->Pr | Cl->Re | Pr->Ar | Pr->Cl | Pr->Re | Re->Ar | Re->Cl | Re->Pr |
Baseline(DeiT-S) | 69.8 | 55.6 | 73 | 79.4 | 70.6 | 72.9 | 76.3 | 67.5 | 51 | 81 | 74.5 | 53.2 | 82.7 |
model | model | model | model | ||||||||||
CDTrans(DeiT-S) | 74.7 | 60.6 | 79.5 | 82.4 | 75.6 | 81.0 | 82.3 | 72.5 | 56.7 | 84.4 | 77.0 | 59.1 | 85.5 |
model | model | model | model | model | model | model | model | model | model | model | model | ||
Baseline(DeiT-B) | 74.8 | 61.8 | 79.5 | 84.3 | 75.4 | 78.8 | 81.2 | 72.8 | 55.7 | 84.4 | 78.3 | 59.3 | 86 |
model | model | model | model | ||||||||||
CDTrans(DeiT-B) | 80.5 | 68.8 | 85 | 86.9 | 81.5 | 87.1 | 87.3 | 79.6 | 63.3 | 88.2 | 82 | 66 | 90.6 |
model | model | model | model | model | model | model | model | model | model | model | model |
Table 3 [UDA results on VisDA-2017]
Methods | Per-class | plane | bcycl | bus | car | horse | knife | mcycl | person | plant | sktbrd | train | truck |
Baseline(DeiT-B) | 67.3 (model) | 98.1 | 48.1 | 84.6 | 65.2 | 76.3 | 59.4 | 94.5 | 11.8 | 89.5 | 52.2 | 94.5 | 34.1 |
CDTrans(DeiT-B) | 88.4 (model) | 97.7 | 86.39 | 86.87 | 83.33 | 97.76 | 97.16 | 95.93 | 84.08 | 97.93 | 83.47 | 94.59 | 55.3 |
Table 4 [UDA results on DomainNet]
Base-S | clp | info | pnt | qdr | rel | skt | Avg. | CDTrans-S | clp | info | pnt | qdr | rel | skt | Avg. |
clp | - | 21.2 | 44.2 | 15.3 | 59.9 | 46.0 | 37.3 | clp | - | 25.3 | 52.5 | 23.2 | 68.3 | 53.2 | 44.5 |
model | model | model | model | model | model | model | |||||||||
info | 36.8 | - | 39.4 | 5.4 | 52.1 | 32.6 | 33.3 | info | 47.6 | - | 48.3 | 9.9 | 62.8 | 41.1 | 41.9 |
model | model | model | model | model | model | model | |||||||||
pnt | 47.1 | 21.7 | - | 5.7 | 60.2 | 39.9 | 34.9 | pnt | 55.4 | 24.5 | - | 11.7 | 67.4 | 48.0 | 41.4 |
model | model | model | model | model | model | model | |||||||||
qdr | 25.0 | 3.3 | 10.4 | - | 18.8 | 14.0 | 14.3 | qdr | 36.6 | 5.3 | 19.3 | - | 33.8 | 22.7 | 23.5 |
model | model | model | model | model | model | model | |||||||||
rel | 54.8 | 23.9 | 52.6 | 7.4 | - | 40.1 | 35.8 | rel | 61.5 | 28.1 | 56.8 | 12.8 | - | 47.2 | 41.3 |
model | model | model | model | model | model | model | |||||||||
skt | 55.6 | 18.6 | 42.7 | 14.9 | 55.7 | - | 37.5 | skt | 64.3 | 26.1 | 53.2 | 23.9 | 66.2 | - | 46.7 |
model | model | model | model | model | model | model | |||||||||
Avg. | 43.9 | 17.7 | 37.9 | 9.7 | 49.3 | 34.5 | 32.2 | Avg. | 53.08 | 21.86 | 46.02 | 16.3 | 59.7 | 42.44 | 39.9 |
Base-B | clp | info | pnt | qdr | rel | skt | Avg. | CDTrans-B | clp | info | pnt | qdr | rel | skt | Avg. |
clp | - | 24.2 | 48.9 | 15.5 | 63.9 | 50.7 | 40.6 | clp | - | 29.4 | 57.2 | 26.0 | 72.6 | 58.1 | 48.7 |
model | model | model | model | model | model | model | |||||||||
info | 43.5 | - | 44.9 | 6.5 | 58.8 | 37.6 | 38.3 | info | 57.0 | - | 54.4 | 12.8 | 69.5 | 48.4 | 48.4 |
model | model | model | model | model | model | model | |||||||||
pnt | 52.8 | 23.3 | - | 6.6 | 64.6 | 44.5 | 38.4 | pnt | 62.9 | 27.4 | - | 15.8 | 72.1 | 53.9 | 46.4 |
model | model | model | model | model | model | model | |||||||||
qdr | 31.8 | 6.1 | 15.6 | - | 23.4 | 18.9 | 19.2 | qdr | 44.6 | 8.9 | 29.0 | - | 42.6 | 28.5 | 30.7 |
model | model | model | model | model | model | model | |||||||||
rel | 58.9 | 26.3 | 56.7 | 9.1 | - | 45.0 | 39.2 | rel | 66.2 | 31.0 | 61.5 | 16.2 | - | 52.9 | 45.6 |
model | model | model | model | model | model | model | |||||||||
skt | 60.0 | 21.1 | 48.4 | 16.6 | 61.7 | - | 41.6 | skt | 69.0 | 29.6 | 59.0 | 27.2 | 72.5 | - | 51.5 |
model | model | model | model | model | model | model | |||||||||
Avg. | 49.4 | 20.2 | 42.9 | 10.9 | 54.5 | 39.3 | 36.2 | Avg. | 59.9 | 25.3 | 52.2 | 19.6 | 65.9 | 48.4 | 45.2 |
Requirements
Installation
pip install -r requirements.txt
(Python version is the 3.7 and the GPU is the V100 with cuda 10.1, cudatoolkit 10.1)
Prepare Datasets
Download the UDA datasets Office-31, Office-Home, VisDA-2017, DomainNet
Then unzip them and rename them under the directory like follow: (Note that each dataset floader needs to make sure that it contains the txt file that contain the path and lable of the picture, which is already in data/the_dataset of this project.)
data
├── OfficeHomeDataset
│ │── class_name
│ │ └── images
│ └── *.txt
├── domainnet
│ │── class_name
│ │ └── images
│ └── *.txt
├── office31
│ │── class_name
│ │ └── images
│ └── *.txt
├── visda
│ │── train
│ │ │── class_name
│ │ │ └── images
│ │ └── *.txt
│ └── validation
│ │── class_name
│ │ └── images
│ └── *.txt
Prepare DeiT-trained Models
For fair comparison in the pre-training data set, we use the DeiT parameter init our model based on ViT. You need to download the ImageNet pretrained transformer model : DeiT-Small, DeiT-Base and move them to the ./data/pretrainModel
directory.
Training
We utilize 1 GPU for pre-training and 2 GPUs for UDA, each with 16G of memory.
Scripts.
Command input paradigm
bash scripts/[pretrain/uda]/[office31/officehome/visda/domainnet]/run_*.sh [deit_base/deit_small]
For example
DeiT-Base scripts
# Office-31 Source: Amazon -> Target: Dslr, Webcam
bash scripts/pretrain/office31/run_office_amazon.sh deit_base
bash scripts/uda/office31/run_office_amazon.sh deit_base
#Office-Home Source: Art -> Target: Clipart, Product, Real_World
bash scripts/pretrain/officehome/run_officehome_Ar.sh deit_base
bash scripts/uda/officehome/run_officehome_Ar.sh deit_base
# VisDA-2017 Source: train -> Target: validation
bash scripts/pretrain/visda/run_visda.sh deit_base
bash scripts/uda/visda/run_visda.sh deit_base
# DomainNet Source: Clipart -> Target: painting, quickdraw, real, sketch, infograph
bash scripts/pretrain/domainnet/run_domainnet_clp.sh deit_base
bash scripts/uda/domainnet/run_domainnet_clp.sh deit_base
DeiT-Small scripts Replace deit_base with deit_small to run DeiT-Small results. An example of training on office-31 is as follows:
# Office-31 Source: Amazon -> Target: Dslr, Webcam
bash scripts/pretrain/office31/run_office_amazon.sh deit_small
bash scripts/uda/office31/run_office_amazon.sh deit_small
Evaluation
# For example VisDA-2017
python test.py --config_file 'configs/uda.yml' MODEL.DEVICE_ID "('0')" TEST.WEIGHT "('../logs/uda/vit_base/visda/transformer_best_model.pth')" DATASETS.NAMES 'VisDA' DATASETS.NAMES2 'VisDA' OUTPUT_DIR '../logs/uda/vit_base/visda/' DATASETS.ROOT_TRAIN_DIR './data/visda/train/train_image_list.txt' DATASETS.ROOT_TRAIN_DIR2 './data/visda/train/train_image_list.txt' DATASETS.ROOT_TEST_DIR './data/visda/validation/valid_image_list.txt'
Acknowledgement
Codebase from TransReID