Pytorch implementation of set transformer

Overview

set_transformer

Official PyTorch implementation of the paper Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks .

Requirements

  • Python 3
  • torch >= 1.0
  • matplotlib
  • scipy
  • tqdm

Abstract

Many machine learning tasks such as multiple instance learning, 3D shape recognition, and few-shot image classification are defined on sets of instances. Since solutions to such problems do not depend on the order of elements of the set, models used to address them should be permutation invariant. We present an attention-based neural network module, the Set Transformer, specifically designed to model interactions among elements in the input set. The model consists of an encoder and a decoder, both of which rely on attention mechanisms. In an effort to reduce computational complexity, we introduce an attention scheme inspired by inducing point methods from sparse Gaussian process literature. It reduces the computation time of self-attention from quadratic to linear in the number of elements in the set. We show that our model is theoretically attractive and we evaluate it on a range of tasks, demonstrating the state-of-the-art performance compared to recent methods for set-structured data.

Experiments

This repository implements the maximum value regression (section 5.1), amortized clustering (section 5.3), and point cloud classification (section 5.5) experiments in the paper.

Maximum Value Regression

This experiment is reproduced in max_regression_demo.ipynb.

Amortized Clustering

To run the amortized clustering experiment with Set Transformer, run

python run.py --net=set_transformer

To run the same experiment with Deep Sets, run

python run.py --net=deepset

Point Cloud Classification

We used the same preprocessed ModelNet40 dataset used in the DeepSets paper. We cannot publicly share this file due to copyright and license issues. To run this code, you must obtain the preprocessed dataset "ModelNet40_cloud.h5". We recommend using multiple GPUs for this experiment; we used 8 Tesla P40s.

To run the point cloud classification experiment, run

python main_pointcloud.py --batch_size 256 --num_pts 100
python main_pointcloud.py --batch_size 256 --num_pts 1000
python main_pointcloud.py --batch_size 256 --num_pts 5000

The hyperparameters here were minimally tuned yet reproduced the results in the paper. It is likely that further tuning will get better results.

Reference

If you found the provided code useful, please consider citing our work.

@InProceedings{lee2019set,
    title={Set Transformer: A Framework for Attention-based Permutation-Invariant Neural Networks},
    author={Lee, Juho and Lee, Yoonho and Kim, Jungtaek and Kosiorek, Adam and Choi, Seungjin and Teh, Yee Whye},
    booktitle={Proceedings of the 36th International Conference on Machine Learning},
    pages={3744--3753},
    year={2019}
}
Comments
  • Question about model's input

    Question about model's input

    Hi juho-lee, I have many sets, each of which has a different size. I want to take some sets as a mini-batch for set-transformer model. But I find that every set in a mini-batch must have same size. Have you ever face this problem? How did you deal with it? Padding or other methods? thank you!

    opened by Qiu-dot 12
  • question about the network architecture for set transformer

    question about the network architecture for set transformer

    Hi, @yoonholee ,

    Thanks a lot for adding the code for the point cloud part. After looking into the network part, it shows that SAB modules are not included in decoder part? Is that the reason of increased time complexity when appending SAB modules to enhance the expressiveness of representations ? It seems that the classification accuracy will be increased by doing so. Had you performed the related experiments?

    THX!

    opened by amiltonwong 3
  • Why is LayerNorm default to False?

    Why is LayerNorm default to False?

    Not an issue, but a question: why is the default LayerNorm function set to False? In particular, for the point cloud example, the LayerNorm is not used.

    Can you comment on the importance of having the nested LayerNorm activated for the model? That is, in the paper there was not exposition on having LayerNorm activated versus not.

    Thanks!

    opened by mathDR 2
  • PMA implementation missing rFF?

    PMA implementation missing rFF?

    Dear Juho,

    First of all, thank you for the implementation! It has been very helpful to my understanding of the architecture.

    I ran into an alleged discrepancy between code and paper, and I was wondering if you could help clear this up. In particular, it seems to me that the PMA implementation is missing the row-wise feed-forward layer that is mentioned in the paper:

    PMA(S, Z) = MAB(S, rFF(Z))
    

    The PMA code:

    class PMA(nn.Module):
        def __init__(self, dim, num_heads, num_seeds, ln=False):
            super(PMA, self).__init__()
            self.S = nn.Parameter(torch.Tensor(1, num_seeds, dim))
            nn.init.xavier_uniform_(self.S)
            self.mab = MAB(dim, dim, dim, num_heads, ln=ln)
    
        def forward(self, X):
            return self.mab(self.S.repeat(X.size(0), 1, 1), X)
    

    To me this reads PMA(S, X) = MAB(S, X), rather than the MAB(S, rFF(X)) of the paper.

    Thanks!

    Tim

    opened by Timsey 2
  • Question about Deep Sets Implementation

    Question about Deep Sets Implementation

    Hi @juho-lee,

    First of all, thanks for making this code publicly available. It's very useful.

    One question, though. I am looking at your implementation of the Zaheer et al network ("Deep Sets.") In his paper, we have something like rho(sum (phi(x))), where we are adding over each element of the set (I believe you call this a set pooling method in your paper )

    In your DeepSet class, we have a succession of Linear -> ReLU -> Linear -> ReLU layers, that operate on the entire data set, and then are pooled at the end.

    Could you explain a little about why these are equivalent?

    opened by arnavs 2
  • LayerNorm

    LayerNorm

    Dear Juho, Thanks for making the code public! One quick question, if I read the code correctly, LayerNorm was never used in any of the three examples you opensourced here in this repo is that correct? If so, is it because they give bit inferior performances? And have you tried moving the LayerNorm layer inside the skip connections instead of before/after the skip connections like done in several more recent papers such that you have an connection directly from output to input? Thanks in advance and looking forward to your reply!

    opened by jingweiz 1
  • License

    License

    Hello,

    Just read your paper and was very happy to see that you've made this implementation available. Would you be willing to add a license to this repo (MIT, for instance), so that others can build on this code?

    opened by cfoster0 1
  • Question on dim_split in MAB

    Question on dim_split in MAB

    Hello,

    Would you please explain the necessity to use dim_split in MAB? For e.g. if I have a batch of 2x387x768 I see the A tensor has shape 24x387x387 because it is using Q_ instead of Q

    Would appreciate your response!

    Thank you! Sharmi 

    opened by BSharmi 0
  • Question about ISAB

    Question about ISAB

    Not sure if I understand the Induced Set Attention Block correctly.

    So basically SAB is a transformer without positional encoding (and dropout?). In the paper, you said that SAB is "too expensive for large sets". But set size here refers to the max sequence length in a transformer which is usually 512. Why not just use the SAB for SetTransformer? Is there any reason other than efficiency, to use ISAB for SetTransformer?

    opened by zhiqihuang 0
  • Inputs of the SetTransformer

    Inputs of the SetTransformer

    Hi,

    Could you please explain the meanings of the inputs of SetTransformer:

    dim_input, num_outputs, dim_output, num_inds=32, dim_hidden=128, num_heads=4, ln=False

    Thanks.

    opened by SRL94 1
  • A little puzzle about the implementation details.

    A little puzzle about the implementation details.

    Hi juho-lee! I have two little puzzles about your paper. In section 1-Introduction. You said "A model for set-input problems should satisfy two critical requirements. First, it should be permutation invariant the output of the model should not change under any permutation of the elements in the input set. Second, such a model should be able to process input sets of any size." But after reading the whole paper, I actually didn't know how you tackle with these two problems. For problem 1, I guess you may remove the position embedding from the initial Transformers? As for problem 2, I had totally no idea how you achieved it. Thank you!

    opened by Jack-mi 2
  • MAB Implementation diverges from Paper

    MAB Implementation diverges from Paper

    Dear Juho,

    is it possible that the implementation of the MAB diverges from the paper?

    In more detail: The paper states

    Multihead(Q,K,V;λ,ω)=concat(O_1,··· ,O_h)W_O
    H = LayerNorm(X + Multihead(X, Y, Y ; ω))
    MAB(X, Y ) = LayerNorm(H + rFF(H))
    

    but the code does

    A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
    O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)  # This is output of multihead
    O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
    O = O + F.relu(self.fc_o(O))
    O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
    
    • It seems that the matrix W_O is not being used in the code at all to mix the output of the different heads?

    • The skip connection Q_ + A.bmm(V_) also diverges from what's stated in the paper, given that Q_ is derived from Q which gets linearly transformed via Q = self.fc_q(Q) in the first line of forward() and is therefore no longer equal to the original query. (On second thought, this may be a necessary requirement, since the output of the MAB has different shape than the input shape. That means in this case, the paper is imprecise.)

    Thanks a lot and best wishes Jannik

    opened by jlko 5
  • 4-D equivalent?

    4-D equivalent?

    What if I have a set of matrices instead of a set of vectors? Is it possible to extend the Set Transformer framework to cover that scenario?

    I played around with it a little (including making some small tweaks) but got bogged down with the .bmm call in the MAB module:

    RuntimeError: Expected 3-dimensional tensor, but got 4-dimensional tensor for argument #1 'batch1' (while checking arguments for bmm)
    
    opened by zabzug-pfpt 3
Owner
Juho Lee
Juho Lee
VSR-Transformer - This paper proposes a new Transformer for video super-resolution (called VSR-Transformer).

VSR-Transformer By Jiezhang Cao, Yawei Li, Kai Zhang, Luc Van Gool This paper proposes a new Transformer for video super-resolution (called VSR-Transf

Jiezhang Cao 225 Nov 13, 2022
This is our ARTS test set, an enriched test set to probe Aspect Robustness of ABSA.

This is the repository for our 2020 paper "Tasty Burgers, Soggy Fries: Probing Aspect Robustness in Aspect-Based Sentiment Analysis". Data We provide

null 35 Nov 16, 2022
Open-Set Recognition: A Good Closed-Set Classifier is All You Need

Open-Set Recognition: A Good Closed-Set Classifier is All You Need Code for our paper: "Open-Set Recognition: A Good Closed-Set Classifier is All You

null 194 Jan 3, 2023
Script that receives an Image (original) and a set of images to be used as "pixels" in reconstruction of the Original image using the set of images as "pixels"

picinpics Script that receives an Image (original) and a set of images to be used as "pixels" in reconstruction of the Original image using the set of

RodrigoCMoraes 1 Oct 24, 2021
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Phil Wang 272 Dec 23, 2022
Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch

Phil Wang 12.6k Jan 9, 2023
A simple but complete full-attention transformer with a set of promising experimental features from various papers

x-transformers A concise but fully-featured transformer, complete with a set of promising experimental features from various papers. Install $ pip ins

Phil Wang 2.3k Jan 3, 2023
Rethinking Transformer-based Set Prediction for Object Detection

Rethinking Transformer-based Set Prediction for Object Detection Here are the code for the ICCV paper. The code is adapted from Detectron2 and AdelaiD

Zhiqing Sun 62 Dec 3, 2022
Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2)

ImageProcessingTransformer Third party Pytorch implement of Image Processing Transformer (Pre-Trained Image Processing Transformer arXiv:2012.00364v2)

null 61 Jan 1, 2023
Transformer - Transformer in PyTorch

Transformer 完成进度 Embeddings and PositionalEncoding with example. MultiHeadAttent

Tianyang Li 1 Jan 6, 2022
This is an official implementation for "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" on Object Detection and Instance Segmentation.

Swin Transformer for Object Detection This repo contains the supported code and configuration files to reproduce object detection results of Swin Tran

Swin Transformer 1.4k Dec 30, 2022
The implementation of "Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer"

Shuffle Transformer The implementation of "Shuffle Transformer: Rethinking Spatial Shuffle for Vision Transformer" Introduction Very recently, window-

null 87 Nov 29, 2022
Unofficial implementation of "Swin Transformer: Hierarchical Vision Transformer using Shifted Windows" (https://arxiv.org/abs/2103.14030)

Swin-Transformer-Tensorflow A direct translation of the official PyTorch implementation of "Swin Transformer: Hierarchical Vision Transformer using Sh

null 52 Dec 29, 2022
Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"

FLASH - Pytorch Implementation of the Transformer variant proposed in the paper Transformer Quality in Linear Time Install $ pip install FLASH-pytorch

Phil Wang 209 Dec 28, 2022
Official PyTorch implementation of "Adversarial Reciprocal Points Learning for Open Set Recognition"

Adversarial Reciprocal Points Learning for Open Set Recognition Official PyTorch implementation of "Adversarial Reciprocal Points Learning for Open Se

Guangyao Chen 78 Dec 28, 2022
Accurate 3D Face Reconstruction with Weakly-Supervised Learning: From Single Image to Image Set (CVPRW 2019). A PyTorch implementation.

Accurate 3D Face Reconstruction with Weakly-Supervised Learning: From Single Image to Image Set —— PyTorch implementation This is an unofficial offici

Sicheng Xu 833 Dec 28, 2022
Alex Pashevich 62 Dec 24, 2022
CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped

CSWin-Transformer This repo is the official implementation of "CSWin Transformer: A General Vision Transformer Backbone with Cross-Shaped Windows". Th

Microsoft 409 Jan 6, 2023
nnFormer: Interleaved Transformer for Volumetric Segmentation Code for paper "nnFormer: Interleaved Transformer for Volumetric Segmentation "

nnFormer: Interleaved Transformer for Volumetric Segmentation Code for paper "nnFormer: Interleaved Transformer for Volumetric Segmentation ". Please

jsguo 610 Dec 28, 2022