Prototypical Networks for Few shot Learning in PyTorch

Overview

Prototypical Networks for Few shot Learning in PyTorch

Simple alternative Implementation of Prototypical Networks for Few Shot Learning (paper, code) in PyTorch.

Prototypical Networks

As shown in the reference paper Prototypical Networks are trained to embed samples features in a vectorial space, in particular, at each episode (iteration), a number of samples for a subset of classes are selected and sent through the model, for each subset of class c a number of samples' features (n_support) are used to guess the prototype (their barycentre coordinates in the vectorial space) for that class, so then the distances between the remaining n_query samples and their class barycentre can be minimized.

Prototypical Networks

T-SNE

After training, you can compute the t-SNE for the features generated by the model (not done in this repo, more infos about t-SNE here), this is a sample as shown in the paper.

Reference Paper t-SNE

Omniglot Dataset

Kudos to @ludc for his contribute: https://github.com/pytorch/vision/pull/46. We will use the official dataset when it will be added to torchvision if it doesn't imply big changes to the code.

Dataset splits

We implemented the Vynials splitting method as in [Matching Networks for One Shot Learning]. That sould be the same method used in the paper (in fact I download the split files from the "offical" repo). We then apply the same rotations there described. In this way we should be able to compare results obtained by running this code with results described in the reference paper.

Prototypical Batch Sampler

As described in its PyDoc, this class is used to generate the indexes of each batch for a prototypical training algorithm.

In particular, the object is instantiated by passing the list of the labels for the dataset, the sampler infers then the total number of classes and creates a set of indexes for each class ni the dataset. At each episode the sampler selects n_classes random classes and returns a number (n_support + n_query) of samples indexes for each one of the selected classes.

Prototypical Loss

Compute the loss as in the cited paper, mostly inspired by this code by one of its authors.

In prototypical_loss.py both loss function and loss class à la PyTorch are implemented.

The function takes in input the batch input from the model, samples' ground truths and the number n_suppport of samples to be used as support samples. Episode classes get infered from the target list, n_support samples get randomly extracted for each class, their class barycentres get computed, as well as the distances of each remaining samples' embedding from each class barycentre and the probability of each sample of belonging to each episode class get finmally computed; then the loss is then computed from the wrong predictions probabilities (for the query samples) as usual in classification problems.

Training

Please note that the training code is here just for demonstration purposes.

To train the Protonet on this task, cd into this repo's src root folder and execute:

$ python train.py

The script takes the following command line options:

  • dataset_root: the root directory where tha dataset is stored, default to '../dataset'

  • nepochs: number of epochs to train for, default to 100

  • learning_rate: learning rate for the model, default to 0.001

  • lr_scheduler_step: StepLR learning rate scheduler step, default to 20

  • lr_scheduler_gamma: StepLR learning rate scheduler gamma, default to 0.5

  • iterations: number of episodes per epoch. default to 100

  • classes_per_it_tr: number of random classes per episode for training. default to 60

  • num_support_tr: number of samples per class to use as support for training. default to 5

  • num_query_tr: nnumber of samples per class to use as query for training. default to 5

  • classes_per_it_val: number of random classes per episode for validation. default to 5

  • num_support_val: number of samples per class to use as support for validation. default to 5

  • num_query_val: number of samples per class to use as query for validation. default to 15

  • manual_seed: input for the manual seeds initializations, default to 7

  • cuda: enables cuda (store True)

Running the command without arguments will train the models with the default hyperparamters values (producing results shown above).

Performances

We are trying to reproduce the reference paper performaces, we'll update here our best results.

Model 1-shot (5-way Acc.) 5-shot (5-way Acc.) 1 -shot (20-way Acc.) 5-shot (20-way Acc.)
Reference Paper 98.8% 99.7% 96.0% 98.9%
This repo 98.5%** 99.6%* 95.1%° 98.6%°°

* achieved using default parameters (using --cuda option)

** achieved running python train.py --cuda -nsTr 1 -nsVa 1

° achieved running python train.py --cuda -nsTr 1 -nsVa 1 -cVa 20

°° achieved running python train.py --cuda -nsTr 5 -nsVa 5 -cVa 20

Helpful links

.bib citation

cite the paper as follows (copied-pasted it from arxiv for you):

@article{DBLP:journals/corr/SnellSZ17,
  author    = {Jake Snell and
               Kevin Swersky and
               Richard S. Zemel},
  title     = {Prototypical Networks for Few-shot Learning},
  journal   = {CoRR},
  volume    = {abs/1703.05175},
  year      = {2017},
  url       = {http://arxiv.org/abs/1703.05175},
  archivePrefix = {arXiv},
  eprint    = {1703.05175},
  timestamp = {Wed, 07 Jun 2017 14:41:38 +0200},
  biburl    = {http://dblp.org/rec/bib/journals/corr/SnellSZ17},
  bibsource = {dblp computer science bibliography, http://dblp.org}
}

License

This project is licensed under the MIT License

Copyright (c) 2018 Daniele E. Ciriello, Orobix Srl (www.orobix.com).

Comments
  • How do I make a prediction?

    How do I make a prediction?

    Thank you for your work!

    I've trained a model for a few epochs, and now I'd like to make predictions with it. I load it:

    model = ProtoNet().cuda()
    model.load_state_dict(torch.load('./output/best_model.pth'))
    

    I load 15 labeled data points, for a total of 3 labels:

    # x.size() -> (15)
    # y.size() -> (15, 64, 64)
    x, y = load_data()
    

    I load a single datapoint I want to predict

    to_predict = torch.Tensor(1, 64, 64)
    

    I now would like to few-shot train on 5 examples per class and then predict a class for my to_predict. How do I go about that?

    question 
    opened by ale316 7
  • dimension error

    dimension error

    Running python train.py --cuda -nsTr 1 -nsVa 1 gives me a runtime error in euclidean_dist ("dimension out of range (expected to be in range of [-1, 0], but got 1)")

    It seems the code is trying to compute a distance between query_samples (a matrix size 360x64) and prototypes (a vector length 60)- so yeah, I can see why it's producing this error. But why are query_samples and prototypes these sizes, to begin with? Shouldn't they have the same value in size(1)?

    bug 
    opened by annkennedy 6
  • Mini Imagenet Results

    Mini Imagenet Results

    opened by madiltalay 5
  • Questions about dataset

    Questions about dataset

    Hi I have a question about image loading in omniglot_dataset.py

    Why use 1 - array & why use transpose and reshape as below?

    https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch/blob/master/src/omniglot_dataset.py#L179-L180

    Thanks

    question 
    opened by liulu112601 5
  • Prediction new problem

    Prediction new problem

    Hi guys! I am at this point and I do not just understand it, I already have the trained model. For the test, what I understand that should be done is to calculate the embedding of my sample and see which centroid is closest to classify ... If I have trained with 5000 classes I do not understand at all as in the test phase it is necessary to pass a set of support and another of query.

    Taking the implementation of @ale316 , predict (support_x, support_y, query_x, query_y = None), we will pass support sets of the training set and querys of which we do not know its y (therefore we equate it to None) and ... as takes into account the 5000 classes if the support only includes 30 (to say some number)

    My idea is to create all the training embeddings and then generate the centroids -> After that generate the embeddings for the test samples and get the class with the distance euclidean to the centroids... I am correct? But you dont do that,

    Sorry for the spam here, but for posterity:

    I wrote a function that should return predictions given:

    • a tensor support_x of size (n_support, 1024)
    • a tensor support_y of size (n_support,)
    • a tensor query_x of size (n_query, 1024)
    def predict(support_x, support_y, query_x, query_y=None):
        support_x = support_x.to('cpu')
        support_y = support_y.to('cpu')
        query_x = query_x.to('cpu')
    
        classes = torch.unique(support_y)
        n_classes = len(classes)
        n_query = len(query_x)
    
        # get a list of tensors of support_y for each class
        support_idxs = list(map(lambda c: support_y.eq(c).nonzero().squeeze(1), classes))
    
        # take the mean of tensors for each class to create a centroid
        prototypes = torch.stack([support_x[idx_list].mean(0) for idx_list in support_idxs])
    
        # finds the euclidean distances between each query_x and each centroid
        dists = euclidean_dist(query_x, prototypes)
    
        # run it through softmax
        log_p_y = F.log_softmax(-dists, dim=1)
    
        # lists the idx (label) of the closest centroid for each query_x
        _, y_hat = log_p_y.max(1)
        labels = [classes[i] for i in y_hat.squeeze()]
    
        return labels
    

    Now, I still have some doubts:

    • Is this correct? Edit: yes it is
    • Why do we even softmax the distances, instead of just taking the min?

    Originally posted by @MarioProjects in https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch/issues/8#issuecomment-450552865

    opened by MarioProjects 2
  • Loss Backpropogation

    Loss Backpropogation

    During the learning stage, the loss isn't backpropagated to the model and I am obtaining the same accuracy and loss even after training for a huge number of epochs.

    opened by vatsalsaglani 1
  • Testing Error

    Testing Error

    When testing, the network calculates loss using the number of supports specified for the training regime. This appears to be an error, but perhaps I've missed something?

    _, acc = loss_fn(model_output, target=y, n_support=opt.num_support_tr)

    To me, this line should be

    _, acc = loss_fn(model_output, target=y, n_support=opt.num_support_val)

    opened by ralphatobe 1
  • Something wrong when i set num_query_val to 1.

    Something wrong when i set num_query_val to 1.

    Hello, thank you for the code provided. When I set num_query_val to 1, why val_acc and test_acc will not change, and it is always about num_query_val / classes_per_it_val, can you tell me what went wrong? Thank you! Looking forward to your reply.

    opened by P-DX 1
  • Question about batching function

    Question about batching function

    Hello,

    I have been trying to implement prototypical networks on a different dataset (audio), and I have been having some difficulty with training, my loss seems to be reducing very slowly.

    I had a question about your batching function. If we assume there are 32 query points in a batch, and 3-5 support points per class.

    Is each query point getting a label (1-32) and is this arbitrary? Or are datapoints given labels based on the whole training set? Put in another way, Does a given query point get the same label (necessarily) in different mini-batches. If you could give me an example that would be super helpful.

    Thanks, Gautam

    help wanted question 
    opened by gautamb85 1
  • Fix prototypical_loss bug #23

    Fix prototypical_loss bug #23

    target_inds.shape == (n_classes, n_query, 1) target_inds.squeeze().shape is supposed to be (n_classes, n_query) but changes to (n_classes, ) when n_query = 1 (one-shot) which leads to wrong acc_val. Use target_inds.squeeze(2) instead of target_inds.squeeze() can fix it.

    opened by rcy17 0
  • Sampling without replacement

    Sampling without replacement

    In the original paper in "Algorithm 1", they mention that each batch is sampled "without replacement":

    ... RANDOMSAMPLE(S, N) denotes a set of N elements chosen uniformly at random from set S, without replacement.

    where as your sampler class clearly samples with replacement as you even pass the number of iterations as an argument to the class constructor. may i ask why?

    opened by astrocyted 0
  • FileNotFoundError

    FileNotFoundError


    FileNotFoundError Traceback (most recent call last) in 252 253 if name == 'main': --> 254 main()

    in main() 208 init_seed(options) 209 --> 210 tr_dataloader = init_dataloader(options, 'train') 211 val_dataloader = init_dataloader(options, 'val') 212 # trainval_dataloader = init_dataloader(options, 'trainval')

    in init_dataloader(opt, mode) 47 48 def init_dataloader(opt, mode): ---> 49 dataset = init_dataset(opt, mode) 50 sampler = init_sampler(opt, dataset.y, mode) 51 dataloader = torch.utils.data.DataLoader(dataset, batch_sampler=sampler)

    in init_dataset(opt, mode) 23 24 def init_dataset(opt, mode): ---> 25 dataset = OmniglotDataset(mode=mode, root=opt.dataset_root) 26 n_classes = len(np.unique(dataset.y)) 27 if n_classes < opt.classes_per_it_tr or n_classes < opt.classes_per_it_val:

    E:\学习\jupyter\prototypical net\Prototypical-Networks-for-Few-shot-Learning-PyTorch-master\src\omniglot_dataset.py in init(self, mode, root, transform, target_transform, download) 53 raise RuntimeError( 54 'Dataset not found. You can use download=True to download it') ---> 55 self.classes = get_current_classes(os.path.join( 56 self.root, self.splits_folder, mode + '.txt')) 57 self.all_items = find_items(os.path.join(

    E:\学习\jupyter\prototypical net\Prototypical-Networks-for-Few-shot-Learning-PyTorch-master\src\omniglot_dataset.py in get_current_classes(fname) 159 160 def get_current_classes(fname): --> 161 with open(fname) as f: 162 classes = f.read().replace('/', os.sep).splitlines() 163 return classes

    FileNotFoundError: [Errno 2] No such file or directory: '..\dataset\splits\vinyals\train.txt'

    opened by zhougoodman 1
  • error of omniglot_dataset.py

    error of omniglot_dataset.py

    When i run the programe: how can i solve it?

    opt = get_parser().parse_args()
    mode = "train"
    OmniglotDataset(mode=mode, root=opt.dataset_root)
    

    == Downloading https://raw.githubusercontent.com/jakesnell/prototypical-networks/master/data/omniglot/splits/vinyals/test.txt Traceback (most recent call last):

    File "", line 1, in dataset = OmniglotDataset(mode=mode, root=opt.dataset_root)

    File "C:\Users\lenovo\Desktop\Prototypical-Networks-for-Few-shot-Learning-PyTorch-master\src\omniglot_dataset.py", line 50, in init self.download()

    File "C:\Users\lenovo\Desktop\Prototypical-Networks-for-Few-shot-Learning-PyTorch-master\src\omniglot_dataset.py", line 113, in download with open(file_path, 'wb') as f:

    OSError: [Errno 22] Invalid argument: '..\dataset\splits\vinyals\https://raw.githubusercontent.com/jakesnell/prototypical-networks/master/data/omniglot/splits/vinyals/test.txt'

    opened by bbcenglish 0
  • query in loading images

    query in loading images

    Hi.

    Thanks for the repo. I have a couple of queries in the load_img function.

    https://github.com/orobix/Prototypical-Networks-for-Few-shot-Learning-PyTorch/blob/5e18a5e5b369903092f683d434efb12c7c40a83c/src/omniglot_dataset.py#L166-L181

    1. Is it necessary to do x = 1.0 - torch.from_numpy(x) ? I understand this is to have 1s in the region of interest (the character) but does it really help?

    2. Why do you take the transpose (rotates the image again) at the end?

    Thanks!

    opened by Gateway2745 2
Owner
Orobix
Orobix
Few-NERD: Not Only a Few-shot NER Dataset

Few-NERD: Not Only a Few-shot NER Dataset This is the source code of the ACL-IJCNLP 2021 paper: Few-NERD: A Few-shot Named Entity Recognition Dataset.

THUNLP 319 Dec 30, 2022
Vanilla and Prototypical Networks with Random Weights for image classification on Omniglot and mini-ImageNet. Made with Python3.

vanilla-rw-protonets-project Vanilla Prototypical Networks and PNs with Random Weights for image classification on Omniglot and mini-ImageNet. Made wi

Giovani Candido 8 Aug 31, 2022
Semi-Supervised Graph Prototypical Networks for Hyperspectral Image Classification, IGARSS, 2021.

Semi-Supervised Graph Prototypical Networks for Hyperspectral Image Classification, IGARSS, 2021. Bobo Xi, Jiaojiao Li, Yunsong Li and Qian Du. Code f

Bobo Xi 7 Nov 3, 2022
Pytorch Implementation for CVPR2018 Paper: Learning to Compare: Relation Network for Few-Shot Learning

LearningToCompare Pytorch Implementation for Paper: Learning to Compare: Relation Network for Few-Shot Learning Howto download mini-imagenet and make

Jackie Loong 246 Dec 19, 2022
Prototypical Pseudo Label Denoising and Target Structure Learning for Domain Adaptive Semantic Segmentation (CVPR 2021)

Prototypical Pseudo Label Denoising and Target Structure Learning for Domain Adaptive Semantic Segmentation (CVPR 2021, official Pytorch implementatio

Microsoft 247 Dec 25, 2022
Proto-RL: Reinforcement Learning with Prototypical Representations

Proto-RL: Reinforcement Learning with Prototypical Representations This is a PyTorch implementation of Proto-RL from Reinforcement Learning with Proto

Denis Yarats 74 Dec 6, 2022
Codebase for "ProtoAttend: Attention-Based Prototypical Learning."

Codebase for "ProtoAttend: Attention-Based Prototypical Learning." Authors: Sercan O. Arik and Tomas Pfister Paper: Sercan O. Arik and Tomas Pfister,

47 2 May 17, 2022
Pytorch implementation of the paper "Optimization as a Model for Few-Shot Learning"

Optimization as a Model for Few-Shot Learning This repo provides a Pytorch implementation for the Optimization as a Model for Few-Shot Learning paper.

Albert Berenguel Centeno 238 Jan 4, 2023
mmfewshot is an open source few shot learning toolbox based on PyTorch

OpenMMLab FewShot Learning Toolbox and Benchmark

OpenMMLab 514 Dec 28, 2022
Cross-Image Region Mining with Region Prototypical Network for Weakly Supervised Segmentation

Cross-Image Region Mining with Region Prototypical Network for Weakly Supervised Segmentation The code of: Cross-Image Region Mining with Region Proto

LiuWeide 16 Nov 26, 2022
Prototypical python implementation of the trust-region algorithm presented in Sequential Linearization Method for Bound-Constrained Mathematical Programs with Complementarity Constraints by Larson, Leyffer, Kirches, and Manns.

Prototypical python implementation of the trust-region algorithm presented in Sequential Linearization Method for Bound-Constrained Mathematical Programs with Complementarity Constraints by Larson, Leyffer, Kirches, and Manns.

null 3 Dec 2, 2022
Official PyTorch implementation of MX-Font (Multiple Heads are Better than One: Few-shot Font Generation with Multiple Localized Experts)

Introduction Pytorch implementation of Multiple Heads are Better than One: Few-shot Font Generation with Multiple Localized Expert. | paper Song Park1

Clova AI Research 97 Dec 23, 2022
Official PyTorch Implementation of Hypercorrelation Squeeze for Few-Shot Segmentation, arXiv 2021

Hypercorrelation Squeeze for Few-Shot Segmentation This is the implementation of the paper "Hypercorrelation Squeeze for Few-Shot Segmentation" by Juh

Juhong Min 165 Dec 28, 2022
Implementation of Cross Transformer for spatially-aware few-shot transfer, in Pytorch

Cross Transformers - Pytorch (wip) Implementation of Cross Transformer for spatially-aware few-shot transfer, in Pytorch Install $ pip install cross-t

Phil Wang 40 Dec 22, 2022
Pytorch implementation of few-shot semantic image synthesis

Few-shot Semantic Image Synthesis Using StyleGAN Prior Our method can synthesize photorealistic images from dense or sparse semantic annotations using

null 40 Sep 26, 2022
(ICCV'21) Official PyTorch implementation of Relational Embedding for Few-Shot Classification

Relational Embedding for Few-Shot Classification (ICCV 2021) Dahyun Kang, Heeseung Kwon, Juhong Min, Minsu Cho [paper], [project hompage] We propose t

Dahyun Kang 82 Dec 24, 2022
PyTorch implementation of D2C: Diffuison-Decoding Models for Few-shot Conditional Generation.

D2C: Diffuison-Decoding Models for Few-shot Conditional Generation Project | Paper PyTorch implementation of D2C: Diffuison-Decoding Models for Few-sh

Jiaming Song 90 Dec 27, 2022
Implementation of 🦩 Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch

?? Flamingo - Pytorch Implementation of Flamingo, state-of-the-art few-shot visual question answering attention net, in Pytorch. It will include the p

Phil Wang 630 Dec 28, 2022
The Pytorch code of "Joint Distribution Matters: Deep Brownian Distance Covariance for Few-Shot Classification", CVPR 2022 (Oral).

DeepBDC for few-shot learning        Introduction In this repo, we provide the implementation of the following paper: "Joint Distribution Matters: Dee

FeiLong 116 Dec 19, 2022