Official PyTorch implementation of "Meta-Learning with Task-Adaptive Loss Function for Few-Shot Learning" (ICCV2021 Oral)

Overview

MeTAL - Meta-Learning with Task-Adaptive Loss Function for Few-Shot Learning (ICCV2021 Oral)

Sungyong Baik, Janghoon Choi, Heewon Kim, Dohee Cho, Jaesik Min, Kyoung Mu Lee

Official PyTorch implementation of Meta-Learning with Task-Adaptive Loss Function for Few-Shot Learning (ICCV2021 Oral)

The code is based off the public code of MAML++, where their reimplementation of MAML is used as the baseline. The code also includes the implementation of ALFA.

[Paper-arXiv] [Video]

Requirements

Ubuntu 18.04

  • Anaconda3
  • Python==3.7.10
  • PyTorch==1.4
  • numpy==1.19.2

To install requirements, first download Anaconda3 and then run the following:

conda create -n metal python=3.7.10
conda activate metal
bash install.sh

Datasets

For miniIamgenet, the dataset can be downloaded from the link provided from MAML++ public code. make a directory named 'datasets' and place the downloaded miniImagnet under the 'datasets' directory.

Training

To train a model, run the following command in experiments_scripts directory

bash MeTAL.sh $GPU_ID

Evaluation

After training is finished, evaluation is performed automatically. To run an evaluation manually, run the same command

bash MeTAL.sh $GPU_ID

Results

Model Backbone 1-shot Accuracy 5-shot Accuracy
MAML 4-CONV 49.64 ± 0.31% 64.99 ± 0.27%
MeTAL 4-CONV 52.63 ± 0.37% 70.52 ± 0.29%
ALFA+MAML 4-CONV 50.58 ± 0.51% 69.12 ± 0.47%
ALFA+MeTAL 4-CONV 57.75 ± 0.38% 74.10 ± 0.43%
MAML ResNet12 58.60 ± 0.42% 69.54 ± 0.38%
MeTAL ResNet12 59.64 ± 0.38% 76.20 ± 0.19%
ALFA+MAML ResNet12 59.74 ± 0.49% 77.96 ± 0.41%
ALFA+MeTAL ResNet12 66.61 ± 0.28% 81.43 ± 0.29%

Reference

@InProceedings{baik2021meta,
 title={Meta-Learning with Task-Adaptive Loss Function for Few-Shot Learning},
 author={Sungyong Baik, Janghoon Choi, Heewon Kim, Dohee Cho, Jaesik Min, Kyoung Mu Lee}
 booktitle = {International Conference on Computer Vision (ICCV)}, 
 year={2021}
}
You might also like...
Official PyTorch Implementation of Convolutional Hough Matching Networks, CVPR 2021 (oral)
Official PyTorch Implementation of Convolutional Hough Matching Networks, CVPR 2021 (oral)

Convolutional Hough Matching Networks This is the implementation of the paper "Convolutional Hough Matching Network" by J. Min and M. Cho. Implemented

Official Pytorch Implementation of 'Learning Action Completeness from Points for Weakly-supervised Temporal Action Localization' (ICCV-21 Oral)
Official Pytorch Implementation of 'Learning Action Completeness from Points for Weakly-supervised Temporal Action Localization' (ICCV-21 Oral)

Learning-Action-Completeness-from-Points Official Pytorch Implementation of 'Learning Action Completeness from Points for Weakly-supervised Temporal A

Official pytorch implementation of "Feature Stylization and Domain-aware Contrastive Loss for Domain Generalization" ACMMM 2021 (Oral)

Feature Stylization and Domain-aware Contrastive Loss for Domain Generalization This is an official implementation of "Feature Stylization and Domain-

This is the pytorch implementation for the paper: Generalizable Mixed-Precision Quantization via Attribution Rank Preservation, which is accepted to ICCV2021.

GMPQ: Generalizable Mixed-Precision Quantization via Attribution Rank Preservation This is the pytorch implementation for the paper: Generalizable Mix

PyTorch implementation of our ICCV2021 paper: StructDepth: Leveraging the structural regularities for self-supervised indoor depth estimation
PyTorch implementation of our ICCV2021 paper: StructDepth: Leveraging the structural regularities for self-supervised indoor depth estimation

StructDepth PyTorch implementation of our ICCV2021 paper: StructDepth: Leveraging the structural regularities for self-supervised indoor depth estimat

A PyTorch implementation of
A PyTorch implementation of "From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network" (ICCV2021)

From Two to One: A New Scene Text Recognizer with Visual Language Modeling Network The official code of VisionLAN (ICCV2021). VisionLAN successfully a

Official code for
Official code for "Simpler is Better: Few-shot Semantic Segmentation with Classifier Weight Transformer. ICCV2021".

Simpler is Better: Few-shot Semantic Segmentation with Classifier Weight Transformer. ICCV2021. Introduction We proposed a novel model training paradi

Official code of ICCV2021 paper
Official code of ICCV2021 paper "Residual Attention: A Simple but Effective Method for Multi-Label Recognition"

CSRA This is the official code of ICCV 2021 paper: Residual Attention: A Simple But Effective Method for Multi-Label Recoginition Demo, Train and Vali

Official code for ICCV2021 paper
Official code for ICCV2021 paper "M3D-VTON: A Monocular-to-3D Virtual Try-on Network"

M3D-VTON: A Monocular-to-3D Virtual Try-On Network Official code for ICCV2021 paper "M3D-VTON: A Monocular-to-3D Virtual Try-on Network" Paper | Suppl

Comments
  • naming confuse from code to paper, and outer-loop procedure.

    naming confuse from code to paper, and outer-loop procedure.

    MeTAL is an inspiring effort for meta learning !!

    I want to understand some details in your work:

    1. Is the generated_alpha_params and generated_beta_params represent the \gamma and \beta for affine transformation in https://github.com/baiksung/MeTAL/blob/master/inner_loop_optimizers.py#L137
    2. Can I understand your update for meta-learner as:
    • In inner-loop, the base-model will be trained with adaptive loss generated by \gamma and \beta from meta-learner;
    • In the outer-loop, the query set will be used to check the performance of the base-model from the inner-loop, and these losses will be used to update the meta-learner to generate more generalization loss.
    1. Suppose my comprehension of your work is proper. Could you answer one of my puzzles: when use the same task dataset for the inner and outer loop update, is it a cheat for meta-learner to generate robust and generalized parameters for the inner-loop to perform well? Just like giving the validation dataset for the model to train.

    Thanks for your work again!! I am looking forward to your reply!

    opened by MagicDevilZhang 2
  • the training process got stuck

    the training process got stuck

    when i try to training the model, i prepared like the README says, but i don't konw why the whole process can't keep going after the print of layer_dict

    batch_size 2 <class 'int'> image_height 84 <class 'int'> image_width 84 <class 'int'> image_channels 3 <class 'int'> reset_stored_filepaths False <class 'bool'> reverse_channels False <class 'bool'> num_of_gpus 1 <class 'int'> indexes_of_folders_indicating_class [-3, -2] <class 'list'> train_val_test_split [0.64, 0.16, 0.2] <class 'list'> samples_per_iter 1 <class 'int'> labels_as_int False <class 'bool'> seed 104 <class 'int'> gpu_to_use 4 <class 'int'> num_dataprovider_workers 4 <class 'int'> max_models_to_save 5 <class 'int'> dataset_name mini_imagenet_full_size <class 'str'> dataset_path datasets/datasets/mini_imagenet_full_size dataset_path datasets/mini_imagenet_full_size <class 'str'> reset_stored_paths False <class 'bool'> experiment_name MeTAL <class 'str'> architecture_name None <class 'NoneType'> continue_from_epoch latest <class 'str'> dropout_rate_value 0.0 <class 'float'> num_target_samples 15 <class 'int'> second_order True <class 'bool'> total_epochs 100 <class 'int'> total_iter_per_epoch 500 <class 'int'> min_learning_rate 0.001 <class 'float'> meta_learning_rate 0.001 <class 'float'> meta_opt_bn False <class 'bool'> task_learning_rate 0.1 <class 'float'> norm_layer batch_norm <class 'str'> max_pooling True <class 'bool'> per_step_bn_statistics False <class 'bool'> num_classes_per_set 5 <class 'int'> cnn_num_blocks 4 <class 'int'> number_of_training_steps_per_iter 5 <class 'int'> number_of_evaluation_steps_per_iter 5 <class 'int'> cnn_num_filters 48 <class 'int'> cnn_blocks_per_stage 1 <class 'int'> num_samples_per_class 5 <class 'int'> name_of_args_json_file experiment_config/MeTAL.json <class 'str'> backbone 4-CONV <class 'str'> attenuate False <class 'bool'> alfa False <class 'bool'> random_init False <class 'bool'> meta_loss True <class 'bool'> train_seed 0 <class 'int'> val_seed 0 <class 'int'> sets_are_pre_split True <class 'bool'> evaluate_on_test_set_only False <class 'bool'> num_evaluation_tasks 600 <class 'int'> multi_step_loss_num_epochs 15 <class 'int'> minimum_per_task_contribution 0.01 <class 'float'> learnable_per_layer_per_step_inner_loop_learning_rate False <class 'bool'> enable_inner_loop_optimizable_bn_params False <class 'bool'> evalute_on_test_set_only False <class 'bool'> learnable_batch_norm_momentum False <class 'bool'> load_into_memory False <class 'bool'> init_inner_loop_learning_rate 0.01 <class 'float'> init_inner_loop_weight_decay 0.0005 <class 'float'> learnable_bn_gamma True <class 'bool'> learnable_bn_beta True <class 'bool'> total_epochs_before_pause 101 <class 'int'> first_order_to_second_order_epoch -1 <class 'int'> weight_decay 0.0 <class 'float'> num_stages 4 <class 'int'> conv_padding True <class 'bool'> use_multi_step_loss_optimization False <class 'bool'> use GPU 0 GPU ID 0 Using max pooling torch.Size([2, 48, 84, 84]) torch.Size([2, 48, 42, 42]) torch.Size([2, 48, 21, 21]) torch.Size([2, 48, 10, 10]) VGGNetwork build torch.Size([2, 5]) meta network params layer_dict.conv0.conv.weight torch.Size([48, 3, 3, 3]) layer_dict.conv0.conv.bias torch.Size([48]) layer_dict.conv0.norm_layer.running_mean torch.Size([48]) layer_dict.conv0.norm_layer.running_var torch.Size([48]) layer_dict.conv0.norm_layer.bias torch.Size([48]) layer_dict.conv0.norm_layer.weight torch.Size([48]) layer_dict.conv1.conv.weight torch.Size([48, 48, 3, 3]) layer_dict.conv1.conv.bias torch.Size([48]) layer_dict.conv1.norm_layer.running_mean torch.Size([48]) layer_dict.conv1.norm_layer.running_var torch.Size([48]) layer_dict.conv1.norm_layer.bias torch.Size([48]) layer_dict.conv1.norm_layer.weight torch.Size([48]) layer_dict.conv2.conv.weight torch.Size([48, 48, 3, 3]) layer_dict.conv2.conv.bias torch.Size([48]) layer_dict.conv2.norm_layer.running_mean torch.Size([48]) layer_dict.conv2.norm_layer.running_var torch.Size([48]) layer_dict.conv2.norm_layer.bias torch.Size([48]) layer_dict.conv2.norm_layer.weight torch.Size([48]) layer_dict.conv3.conv.weight torch.Size([48, 48, 3, 3]) layer_dict.conv3.conv.bias torch.Size([48]) layer_dict.conv3.norm_layer.running_mean torch.Size([48]) layer_dict.conv3.norm_layer.running_var torch.Size([48]) layer_dict.conv3.norm_layer.bias torch.Size([48]) layer_dict.conv3.norm_layer.weight torch.Size([48]) layer_dict.linear.weights torch.Size([5, 1200]) layer_dict.linear.bias torch.Size([5])

    after show this the training process doesn't keep going but still hold this states without break down @baiksung would you mind tell me what happened and how to solve it? tks a lot

    opened by Ren-Zixin 0
Owner
Sungyong Baik
Ph.D. Student in CVLab, SNU
Sungyong Baik
TOOD: Task-aligned One-stage Object Detection, ICCV2021 Oral

One-stage object detection is commonly implemented by optimizing two sub-tasks: object classification and localization, using heads with two parallel branches, which might lead to a certain level of spatial misalignment in predictions between the two tasks.

null 264 Jan 9, 2023
ICCV2021 Oral SA-ConvONet: Sign-Agnostic Optimization of Convolutional Occupancy Networks

Sign-Agnostic Convolutional Occupancy Networks Paper | Supplementary | Video | Teaser Video | Project Page This repository contains the implementation

null 63 Nov 18, 2022
ICCV2021 Oral SA-ConvONet: Sign-Agnostic Optimization of Convolutional Occupancy Networks

Sign-Agnostic Convolutional Occupancy Networks Paper | Supplementary | Video | Teaser Video | Project Page This repository contains the implementation

null 64 Jan 5, 2023
Official PyTorch Implementation of Rank & Sort Loss [ICCV2021]

Rank & Sort Loss for Object Detection and Instance Segmentation The official implementation of Rank & Sort Loss. Our implementation is based on mmdete

Kemal Oksuz 229 Dec 20, 2022
Official PyTorch code for Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021)

Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021) This repository is the official P

Jingyun Liang 159 Dec 30, 2022
Official PyTorch code for Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021)

Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021) This repository is the official P

Jingyun Liang 159 Dec 30, 2022
Official PyTorch code for Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet, ICCV2021)

Mutual Affine Network for Spatially Variant Kernel Estimation in Blind Image Super-Resolution (MANet, ICCV2021) This repository is the official PyTorc

Jingyun Liang 139 Dec 29, 2022
This is an official implementation of the paper "Distance-aware Quantization", accepted to ICCV2021.

PyTorch implementation of DAQ This is an official implementation of the paper "Distance-aware Quantization", accepted to ICCV2021. For more informatio

CV Lab @ Yonsei University 36 Nov 4, 2022
Official pytorch implementation of paper "Inception Convolution with Efficient Dilation Search" (CVPR 2021 Oral).

IC-Conv This repository is an official implementation of the paper Inception Convolution with Efficient Dilation Search. Getting Started Download Imag

Jie Liu 111 Dec 31, 2022
Official PyTorch implementation of RobustNet (CVPR 2021 Oral)

RobustNet (CVPR 2021 Oral): Official Project Webpage Codes and pretrained models will be released soon. This repository provides the official PyTorch

Sungha Choi 173 Dec 21, 2022