PyTorch META-DATASET (Few-shot classification benchmark)

Overview

PyTorch META-DATASET (Few-shot classification benchmark)

This repo contains a PyTorch implementation of meta-dataset and a unified implementation of some few-shot methods. This repo may be useful to you if you:

  • want some pre-trained ImageNet models in PyTorch for META-DATASET;
  • want to benchmark your method on META-DATASET (but do not want to mix your PyTorch code with the original TensorFlow implementation);
  • are looking for a codebase to visualize few-shot episodes.

Benefits over original code:

  1. This repo can be properly seeded, allowing to repeat the same random series of episodes if needed;
  2. Data shuffling is performed without using a buffer, hence reducing the memory consumption;
  3. Better results can be obtained using this repo thanks to an enhanced way of resizing images. More details in the paper.

Note that this code also includes the original implementation for comparison (using the PyTorch workaround proposed by the authors). If you wish to use the original implementation, set the option loader_version: 'tf' in base.yaml (by default set to pytorch).

Yet to do:

  1. Add more methods
  2. Test for the multi-source setting

Table of contents

1. Setting up

Please carefully follow the instructions below to get started.

1.1 Requirements

The present code was developped and tested in Python 3.8. The list of requirements is provided in requirements.txt:

pip install -r requirements.txt

1.2 Data

To download the META-DATASET, please follow the details instructions provided at meta-dataset to obtain the .tfrecords converted data. Once done, make sure all converted dataset are in a single folder, and execute the following script to produce index files:

bash scripts/make_records/make_index_files.sh <path_to_converted_data>

This may take a few minutes. Once all this is done, set the path variable in config/base.yaml to your data folder.

1.3 Download pre-trained models

We provide trained Resnet-18 and WRN-2810 models on the training split of ILSVRC_2012 at checkpoints. All non-episodic baselines use the same checkpoint, stored in the standard folder. The results (averaged over 600 episodes) obtained with the provided Resnet-18 are summarized below:

Inductive methods Architecture ILSVRC Omniglot Aircraft Birds Textures Quick Draw Fungi VGG Flower Traffic Signs MSCOCO Mean
Finetune Resnet-18 59.8 60.5 63.5 80.6 80.9 61.5 45.2 91.1 55.1 41.8 64.0
ProtoNet Resnet-18 48.2 46.7 44.6 53.8 70.3 45.1 38.5 82.4 42.2 38.0 51.0
SimpleShot Resnet-18 60.0 54.2 55.9 78.6 77.8 57.4 49.2 90.3 49.6 44.2 61.7
Transductive methods Architecture ILSVRC Omniglot Aircraft Birds Textures Quick Draw Fungi VGG Flower Traffic Signs MSCOCO Mean
BD-CSPN Resnet-18 60.5 54.4 55.2 80.9 77.9 57.3 50.0 91.7 47.8 43.9 62.0
TIM-GD Resnet-18 63.6 65.6 66.4 85.6 84.7 65.8 57.5 95.6 65.2 50.9 70.1

See Sect. 1.4 and 1.5 to reproduce these results.

1.4 Train models from scratch (optional)

In order to train you model from scratch, execute scripts/train.sh script:

bash scripts/train.sh <method> <architecture> <dataset>

method is to be chosen among all method specific config files in config/, architecture in ['resnet18', 'wideres2810'] and dataset among all datasets (as named by the META-DATASET converted folders). Note that the hierarchy of arguments passed to src/train.py and src/eval.py is the following: base_config < method_config < opts arguments.

Mutiprocessing : This code supports distributed training. To leverage this feature, set the gpus option accordingly (for instance gpus: [0, 1, 2, 3]).

1.5 Test your models

Once trained (or once pre-trained models downloaded), you can evaluate your model on the test split of each dataset by running:

bash scripts/test.sh <method> <architecture> <base_dataset> <test_dataset>

Results will be saved in results/ / where corresponds to a unique hash number of the config (you can only get the same result folder iff all hyperparameters are the same).

2. Visualization of results

2.1 Training metrics

During training, training loss and validation accuracy are recorded and saved as .npy files in the checkpoint folder. Then, you can use the src/plot.py to plot these metrics (even during training).

Example 1: Plot the metrics of the standard (=non episodic) resnet-18 on ImageNet:

python src/plot.py --folder checkpoints/ilsvrc_2012/ilsvrc_2012/resnet18/standard/

Example 2: Plot the metrics of all Resnet-18 trained on ImageNet

python src/plot.py --folder checkpoints/ilsvrc_2012/ilsvrc_2012/resnet18/

2.2 Inference metrics

For methods that perform test-time optimization (for instance MAML, TIM, Finetune, ...), method specific metrics are plotted in real-time (versus test iterations) and averaged over test epidodes, which can allow you to track unexpected behavior easily. Such metrics are implemented in src/metrics/, and the choice of which metric to plot is specificied through the eval_metrics option in the method .yaml config file. An example with TIM method is provided below.

2.3 Visualization of episodes

By setting the option visu: True at inference, you can visualize samples of episodes. An example of such visualization is given below:

The samples will be saved in results/. All relevant optons can be found in the base.yaml file, in the EVAL-VISU section.

3. Incorporate your own method

This code was designed to allow easy incorporation of new methods.

Step 1: Add your method .py file to src/methods/ by following the template provided in src/methods/method.py.

Step 2: Add import in src/methods/__init__.py

Step 3: Add your method .yaml config file including the required options episodic_training and method (name of the class corresponding to your method). Also make sure that if your method performs test-time optimization, you also properly set the option iter that specifies the number of optimization steps performed at inference (this argument is also used to plot the inference metrics, see section 2.2).

4. Contributions

Contributions are more than welcome. In particular, if you want to add methods/pre-trained models, do make a pull-request.

5. Citation

If you find this repo useful for your research, please consider citing the following papers:

@misc{boudiaf2021mutualinformation,
      title={Mutual-Information Based Few-Shot Classification}, 
      author={Malik Boudiaf and Ziko Imtiaz Masud and Jérôme Rony and Jose Dolz and Ismail Ben Ayed and Pablo Piantanida},
      year={2021},
      eprint={2106.12252},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Additionally, do not hesitate to file issues if you encounter problems, or reach out directly to Malik Boudiaf ([email protected]).

6. Acknowledgments

I thank the authors of meta-dataset for releasing their code and the author of open-source TFRecord reader for open sourcing an awesome Pytorch-compatible TFRecordReader ! Also big thanks to @hkervadec for his thorough code review !

Comments
  • shuffle buffer issue?

    shuffle buffer issue?

    I suspect that your reader code is affected by the meta-dataset shuffle buffer issue 54. I did a full run with your reader and the results were mostly consistent with what I would get with get using the official meta-dataset reader except for traffic signs (and a couple of other datasets) where the results were more optimistic that if the data is not shuffled. In a quick look through your code, it seems that the shuffle buffer mechanism is not used.

    opened by jfb54 8
  • Buggy code in example.py?

    Buggy code in example.py?

    example.py has the following lines of code:

            use_bilevel_ontology_list = [False]*len(datasets)
            # Enable ontology aware sampling for Omniglot and ImageNet.
            if 'omniglot' in datasets:
                use_bilevel_ontology_list[datasets.index('omniglot')] = True
            if 'imagenet' in datasets:
                use_bilevel_ontology_list[datasets.index('imagenet')] = True
    
            use_bilevel_ontology_list = use_bilevel_ontology_list
            use_dag_ontology_list = [False]*len(datasets)
    

    shouldn't it be:

            use_bilevel_ontology_list = [False]*len(datasets)
            use_dag_ontology_list = [False]*len(datasets)
    
            # Enable ontology aware sampling for Omniglot and ImageNet.
            if 'omniglot' in datasets:
                use_bilevel_ontology_list[datasets.index('omniglot')] = True
            if 'ilsvrc_2012' in datasets:
                use_dag_ontology_list[datasets.index('imagenet')] = True
    

    as the 'imagenet' dataset is actually called 'ilsvrc_2012' and it should use the DAG ontology.

    opened by jfb54 4
  • Feature Request: Ideally the episodes generated would be repeatable for a specified seed.

    Feature Request: Ideally the episodes generated would be repeatable for a specified seed.

    The official meta-dataset reader is not deterministic and repeatable, which is frustrating as test runs cannot be directly compared. If your reader could be deterministic (given a seed), that would be a huge win.

    opened by jfb54 2
  • Meta-batch size hard coded to 1

    Meta-batch size hard coded to 1

    Hi, thank you for your implementation.

    The meta-training dataloader seems to have a batch size 1 hard coded in it. I would like to train MAML on this and the default meta batch size there is 4. So I would like to know if there is any particular reason as to why the meta batch size is hard coded to 1.

    Thank you!

    opened by sudarshan1994 2
  • Unexpected behavior from min_examples_in_class

    Unexpected behavior from min_examples_in_class

    The official meta-dataset does not have the parameter min_examples_in_class exposed (as far as I know). I set it to 1 as I wanted any class to have at least one example and I get the error message when I turn on use_bilevel_hierarchy for Omniglot (which is standard for Meta-Dataset).

    "use_bilevel_hierarchy" is incompatible with "min_examples_in_class"
    

    I don't understand why this restriction is required.

    opened by jfb54 2
  • squeeze() needs to be added to support, support_labels, query, query_labels

    squeeze() needs to be added to support, support_labels, query, query_labels

    When support, query, support_labels, query_labels are returned from the DataLoader, the 1st dimension is 1 in size which is redundant and will usually not work properly when fed into a network. A squeeze(x, dim=0) will fix this.

    opened by jfb54 2
  • Normalize should map image tensors in the range -1 to 1 to be compatible with Meta-Dataset

    Normalize should map image tensors in the range -1 to 1 to be compatible with Meta-Dataset

    Instead of the ImageNet friendly normalize transform:

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    

    use the following:

    normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
    

    as Meta-Dataset guarantees that all images will be normalized in the range -1 to 1.

    opened by jfb54 2
  • DataConfig and EpisodeConfig constructors should directly accept arguments as opposed to encapsulated in an argparse.Namespace

    DataConfig and EpisodeConfig constructors should directly accept arguments as opposed to encapsulated in an argparse.Namespace

    First of all, thanks for writing this library. I have been using the official MetaDataset TensorFlow reader in conjunction with PyTorch (which works, but is very resource hungry), so I thought I would try out your library. It now works, but I needed to work around several issues. I'll file a series of issues for these.

    The first issue I hit was the fact that the DataConfig and EpisodeConfig constructors took their arguments via an argparse.Namespace. This won't work for a real meta-dataset application as you typically need to set up several DataLoaders (say for train, validate and test) and you usually need to specify different parameters for each (e.g. max_support_set_size). My workaround was to modify your code and make constructors with a conventional set of arguments. Would be great if you could make the change!

    opened by jfb54 2
  • no bash file

    no bash file

    Hi, thanks for such a nice contribution I am just a bit confused that in your data preparation instruction, it says once you obtain the converted data, you can run: bash scripts/make_records/make_index_files.sh <path_to_converted_data>

    However, I cannot find /make_records/make_index_files.sh file in this repo

    opened by RongKaiWeskerMA 1
  • Too many unexpected Errors.

    Too many unexpected Errors.

    I tried to run this repository to reproduce the results. However, there are so many errors. For example, when I set gpu: [1] and ran, the code was ran on gpu=0. At the same time, when running forward_call, there is an error regarding asynchronously report...

    opened by HongduanTian 0
  • Sampling from episodic loader gives error -

    Sampling from episodic loader gives error - "Key image doesn't exist (select from [])!"

    When sampling from the episodic loader, all usually goes fine until I get the following error:

    Traceback (most recent call last):
      File "/home/patrick/pytorch-meta-dataset/pytorch_meta_dataset/pipeline.py", line 145, in get_next
        sample_dic = next(self.class_datasets[class_id])
    TypeError: 'TFRecordDataset' object is not an iterator
    During handling of the above exception, another exception occurred:
    Traceback (most recent call last):
      File "/home/patrick/pytorch-meta-dataset/pytorch_meta_dataset/pipeline.py", line 219, in get_next
        dataset = next(self.dataset_list[source_id])
      File "/home/patrick/pytorch-meta-dataset/pytorch_meta_dataset/pipeline.py", line 121, in __iter__
        sample_dic = self.get_next(class_id)
      File "/home/patrick/pytorch-meta-dataset/pytorch_meta_dataset/pipeline.py", line 148, in get_next
        sample_dic = next(self.class_datasets[class_id])
      File "/home/patrick/pytorch-meta-dataset/pytorch_meta_dataset/utils.py", line 23, in cycle_
        yield next(iterator)
      File "/home/patrick/pytorch-meta-dataset/pytorch_meta_dataset/tfrecord/reader.py", line 222, in example_loader
        feature_dic = extract_feature_dict(example.features, description, typename_mapping)
      File "/home/patrick/pytorch-meta-dataset/pytorch_meta_dataset/tfrecord/reader.py", line 162, in extract_feature_dict
        raise KeyError(f"Key {key} doesn't exist (select from {all_keys})!")
    KeyError: "Key image doesn't exist (select from [])!"
    During handling of the above exception, another exception occurred:
    Traceback (most recent call last):
      File "/home/patrick/pytorch-meta-dataset/pytorch_meta_dataset/pipeline.py", line 201, in __iter__
        next_e = self.get_next(rand_source)
      File "/home/patrick/pytorch-meta-dataset/pytorch_meta_dataset/pipeline.py", line 222, in get_next
        dataset = next(self.dataset_list[source_id])
    StopIteration
    

    Just for info - I used an older version of your repo (https://github.com/mboudiaf/pytorch-meta-dataset/tree/c6d6922003380342ab2e3509425d96307aa925c5). I am sampling from the episodic loader. I use

    episodic_dataset = pipeline.make_episode_pipeline(dataset_spec_list=all_dataset_specs,
                                                          split=split,
                                                          data_config=data_config,
                                                          episode_descr_config=episod_config)
    episodic_loader = DataLoader(dataset=episodic_dataset,
                                     batch_size=meta_batch_size,
                                     num_workers=data_config.num_workers,
                                     worker_init_fn=seeded_worker_fn)
    #Sample a batch of size [B, N*K, C, H, W] from episodic loader via next(iter(episodic_loader))
    #where B = meta_batch_size, N*K = n_ways*k_shots, C = channels, H = height of image, W = width of image
    

    Do you know what may be causing the KeyError: "Key image doesn't exist (select from [])!" StopIteration error? For the above error, I am setting 5-way 15-shots for train and 5-way 5-shot for test/validation, and meta_batch_size 2 for train and 4 for test/val.

    Thanks a lot in advance!

    opened by patricks-lab 0
  • Training the fine-tuned base line with standard supervised learning with union/concatenation of labels

    Training the fine-tuned base line with standard supervised learning with union/concatenation of labels

    Hi @mboudiaf, I wanted to train the fine-tuned baseline from meta-data set (MDS) i.e. concatenate/union all the data sets and all the labels and then train in normal supervised learning. Is the right way to do this this:

    https://github.com/mboudiaf/pytorch-meta-dataset/blob/c6d6922003380342ab2e3509425d96307aa925c5/example.py#L173

    I am mainly asking because there needs to be some sort of relabling that takes into account all the data set labels and wanted to know how that was done.

    Thank you!

    opened by brando90 0
  • Learn2Learn support?

    Learn2Learn support?

    Is there learn2learn support for this https://github.com/learnables/learn2learn/issues/286

    I am happy to help add it. What would be a good starting point?

    opened by brando90 3
  • How to run the code correctly?

    How to run the code correctly?

    Hi, Thank you for your outstanding work!I met the following problems when trying to use your code:When I've processed all the data,I just used the command :bash scripts/train.sh protonet resnet18 ilsvrc_2012, but got the error:“use_hierarchy" is incompatible with "num_ways". So I tried to set the num_ways to -1 to avoid this error, but strange things still happend: RuntimError:stack expects each tensor to be equal size, but got [50, 3, 84, 84] at entry 0 and [30, 3, 84, 84] at entry 1.This is the first question which puzzled me.What's more, How do I get a pre-trained model on ImageNet by training from scratch?There does not seem to be an option to obtain a pre-trained model in the method options.I am looking forward to your reply. Thank you!

    opened by Fei-Long121 2
  • ResNet structure

    ResNet structure

    Hi,

    The original tensorflow implementation uses the standard structure for the first convolution layer, i.e., 7x7 kernel size, stride 2, padding 3 and a 3x3 max pooling layer after that (link) while in your implementation this layer is used with 3x3 kernel size and without max pooling (link). In this way the resulted feature map is way larger and costs more memory. I also notice that in the PAMI version of TIM the authors claim that the pytorch version of baselines are much better than the original version. I wonder if the performance boost comes from this modification. The 'larger' version of resnet seems not so practical for meta-dataset, since it will lead to OOM when being trained with the ProtoNet or other episodic methods. I don't know if I have any misunderstanding about the code.

    Thanks.

    opened by loadder 5
Owner
Malik Boudiaf
Malik Boudiaf
The implementation of PEMP in paper "Prior-Enhanced Few-Shot Segmentation with Meta-Prototypes"

Prior-Enhanced network with Meta-Prototypes (PEMP) This is the PyTorch implementation of PEMP. Overview of PEMP Meta-Prototypes & Adaptive Prototypes

Jianwei ZHANG 8 Oct 14, 2021
Code and data of the ACL 2021 paper: Few-Shot Text Ranking with Meta Adapted Synthetic Weak Supervision

MetaAdaptRank This repository provides the implementation of meta-learning to reweight synthetic weak supervision data described in the paper Few-Shot

THUNLP 5 Jun 16, 2022
Few-shot NLP benchmark for unified, rigorous eval

FLEX FLEX is a benchmark and framework for unified, rigorous few-shot NLP evaluation. FLEX enables: First-class NLP support Support for meta-training

AI2 84 Oct 29, 2022
Code for T-Few from "Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning"

T-Few This repository contains the official code for the paper: "Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learni

null 203 Nov 28, 2022
(ICCV'21) Official PyTorch implementation of Relational Embedding for Few-Shot Classification

Relational Embedding for Few-Shot Classification (ICCV 2021) Dahyun Kang, Heeseung Kwon, Juhong Min, Minsu Cho [paper], [project hompage] We propose t

Dahyun Kang 80 Nov 13, 2022
The Pytorch code of "Joint Distribution Matters: Deep Brownian Distance Covariance for Few-Shot Classification", CVPR 2022 (Oral).

DeepBDC for few-shot learning        Introduction In this repo, we provide the implementation of the following paper: "Joint Distribution Matters: Dee

FeiLong 107 Nov 21, 2022
EMNLP 2021 Adapting Language Models for Zero-shot Learning by Meta-tuning on Dataset and Prompt Collections

Adapting Language Models for Zero-shot Learning by Meta-tuning on Dataset and Prompt Collections Ruiqi Zhong, Kristy Lee*, Zheng Zhang*, Dan Klein EMN

Ruiqi Zhong 42 Nov 3, 2022
Ready-to-use code and tutorial notebooks to boost your way into few-shot image classification.

Easy Few-Shot Learning Ready-to-use code and tutorial notebooks to boost your way into few-shot image classification. This repository is made for you

Sicara 384 Nov 26, 2022
Library of various Few-Shot Learning frameworks for text classification

FewShotText This repository contains code for the paper A Neural Few-Shot Text Classification Reality Check Environment setup # Create environment pyt

Thomas Dopierre 45 Nov 18, 2022
Spatial Contrastive Learning for Few-Shot Classification (SCL)

This repo contains the official implementation of Spatial Contrastive Learning for Few-Shot Classification (SCL), which presents of a novel contrastive learning method applied to few-shot image classification in order to learn more general purpose embeddings, and facilitate the test-time adaptation to novel visual categories.

Yassine 32 Oct 24, 2022
An original implementation of "Noisy Channel Language Model Prompting for Few-Shot Text Classification"

Channel LM Prompting (and beyond) This includes an original implementation of Sewon Min, Mike Lewis, Hannaneh Hajishirzi, Luke Zettlemoyer. "Noisy Cha

Sewon Min 86 Nov 18, 2022
TransPrompt - Towards an Automatic Transferable Prompting Framework for Few-shot Text Classification

TransPrompt This code is implement for our EMNLP 2021's paper 《TransPrompt:Towards an Automatic Transferable Prompting Framework for Few-shot Text Cla

WangJianing 21 Nov 2, 2022
vit for few-shot classification

Few-Shot ViT Requirements PyTorch (>= 1.9) TorchVision timm (latest) einops tqdm numpy scikit-learn scipy argparse tensorboardx Pretrained Checkpoints

Martin Dong 23 Nov 20, 2022
This repository contains the code for using the H3DS dataset introduced in H3D-Net: Few-Shot High-Fidelity 3D Head Reconstruction

H3DS Dataset This repository contains the code for using the H3DS dataset introduced in H3D-Net: Few-Shot High-Fidelity 3D Head Reconstruction Access

Crisalix 70 Nov 2, 2022
N-Omniglot is a large neuromorphic few-shot learning dataset

N-Omniglot [Paper] || [Dataset] N-Omniglot is a large neuromorphic few-shot learning dataset. It reconstructs strokes of Omniglot as videos and uses D

null 10 Oct 13, 2022
Implementation of "Meta-rPPG: Remote Heart Rate Estimation Using a Transductive Meta-Learner"

Meta-rPPG: Remote Heart Rate Estimation Using a Transductive Meta-Learner This repository is the official implementation of Meta-rPPG: Remote Heart Ra

Eugene Lee 134 Nov 22, 2022
A benchmark dataset for mesh multi-label-classification based on cube engravings introduced in MeshCNN

Double Cube Engravings This script creates a dataset for multi-label mesh clasification, with an intentionally difficult setup for point cloud classif

Yotam Erel 1 Nov 30, 2021
LoveDA: A Remote Sensing Land-Cover Dataset for Domain Adaptive Semantic Segmentation (NeurIPS2021 Benchmark and Dataset Track)

LoveDA: A Remote Sensing Land-Cover Dataset for Domain Adaptive Semantic Segmentation by Junjue Wang, Zhuo Zheng, Ailong Ma, Xiaoyan Lu, and Yanfei Zh

Kingdrone 168 Nov 14, 2022