The official PyTorch implementation of recent paper - SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training

Overview

This repository is the official PyTorch implementation of SAINT. Find the paper on arxiv

SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training

Overview

Requirements

We recommend using anaconda or miniconda for python. Our code has been tested with python=3.8 on linux.

To create a new environment with conda

conda create -n saint_env python=3.8
conda activate saint_env

We recommend installing the latest pytorch, torchvision, einops, pandas, wget, sklearn packages.

You can install them using

conda install pytorch torchvision -c pytorch
conda install -c conda-forge einops 
conda install -c conda-forge pandas 
conda install -c conda-forge python-wget 
conda install -c anaconda scikit-learn 

Make sure the following requirements are met

  • torch>=1.8.1
  • torchvision>=0.9.1

Optional

We used wandb to update our logs. But it is optional.

conda install -c conda-forge wandb 

Training & Evaluation

In each of our experiments, we use a single Nvidia GeForce RTX 2080Ti GPU.

First download the processed datasets from this link into the folder ./data

To train the model(s) in the paper, run this command:

python train.py  --dataset <dataset_name> --attentiontype <attention_type> 

Pretraining is useful when there are few training data samples. Sample code looks like this

python train.py  --dataset <dataset_name> --attentiontype <attention_type> --pretrain --pt_tasks <pretraining_task_touse> --pt_aug <augmentations_on_data_touse> --ssl_avail_y <Number_of_labeled_samples>

Train all 16 datasets by running bash files. train.sh for supervised learning and train_pt.sh for pretraining and semi-supervised learning

bash train.sh
bash train_pt.sh

Arguments

  • --dataset : Dataset name. We support only the 16 datasets discussed in the paper. Supported datasets are ['1995_income','bank_marketing','qsar_bio','online_shoppers','blastchar','htru2','shrutime','spambase','philippine','mnist','arcene','volkert','creditcard','arrhythmia','forest','kdd99']
  • --embedding_size : Size of the feature embeddings
  • --transformer_depth : Depth of the model. Number of stages.
  • --attention_heads : Number of attention heads in each Attention layer.
  • --cont_embeddings : Style of embedding continuous data.
  • --attentiontype : Variant of SAINT. 'col' refers to SAINT-s variant, 'row' is SAINT-i, and 'colrow' refers to SAINT.
  • --pretrain : To enable pretraining
  • --pt_tasks : Losses we want to use for pretraining. Multiple arguments can be passed.
  • --pt_aug : Types of data augmentations used in pretraining. Multiple arguments are allowed. We support only mixup and CutMix right now.
  • --ssl_avail_y : Number of labeled samples used in semi-supervised experiments. Default is 0, which means all samples are labeled and is supervised case.
  • --pt_projhead_style : Projection head style used in contrastive pipeline.
  • --nce_temp : Temperature used in contrastive loss function.
  • --active_log : To update the logs onto wandb. This is optional

Evaluation

We choose the best model by evaluating the model on validation dataset. The AUROC(for binary classification datasets) and Accuracy (for multiclass classification datasets) of the best model on test datasets is printed after training is completed. If wandb is enabled, they are logged to 'test_auroc_bestep', 'test_accuracy_bestep' variables.

Acknowledgements

We would like to thank the following public repo from which we borrowed various utilites.

License

This repository is released under the Apache 2.0 license as found in the LICENSE file.

Cite us

@article{somepalli2021saint,
  title={SAINT: Improved Neural Networks for Tabular Data via Row Attention and Contrastive Pre-Training},
  author={Somepalli, Gowthami and Goldblum, Micah and Schwarzschild, Avi and Bruss, C Bayan and Goldstein, Tom},
  journal={arXiv preprint arXiv:2106.01342},
  year={2021}
}

Comments
  • Regression Task

    Regression Task

    Thanks for awesome work.

    I'm using a tabular dataset for a regression task. I would like to predict the last column (float values) in the picture below.

    image

    I'm not sure how should I setup network and esp these two parameters:

    categories = tuple(cat_dims),
    num_continuous = len(con_idxs) 
    

    For now I'm using

    con_idxs = [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21]
    

    If I change the last column values to int using train[target] = train[target].astype(int) and use the following as cat dims it starts training but I want to predict floating values.

    cat_dims = np.append(np.array(cat_dims),np.array([50])).astype(int)
    

    If I dont convert target to int it throws following error:

     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
    RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.DoubleTensor instead (while checking arguments for embedding)
    ``
    
    opened by asadabbas09 4
  • Benchmark results difference

    Benchmark results difference

    Hi, I was trying to reproduce the benchmark (xgboost and lightgbm) results but i can't get the same showed in your paper.

    I used this to split the dataset in train, valid and test: https://github.com/somepago/saint/blob/e0ee763b0c23136ae03fa7cff71ff3b2ce4ba647/data.py#L190

    I used early stop on validation and collect test performance as final results and rerun the experiment on 5 different seed (0, 1, ..., 5) as you do for Saint model.

    I used standard parameter for xgboost and lightgbm with some regularization.

    I used the dataset you provide in the following link: https://drive.google.com/file/d/1mJtWP9mRP0a10d1rT6b3ksYkp4XOpM0r/view?usp=sharing

    The results i get are:

    | Model\Dataset| Bank | Blastchar | arrhytmia | Arcene | Forest | Shoppers | Income | Volkert | | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | | lightgbm | 93.46 | 83.71 | 93.18 | 85.25| 99.79 | 93.23 | 92.03 | 71.46 | | xgboost | 93.41 | 83.67 | 93.13 | 87.66 | 99.71 | 92.62 | 92.36 | 70.32 |

    My experiment show clear improvement of the benchmark result as showed below:

    | Model\Dataset| Bank | Blastchar | arrhytmia | Arcene | Forest | Shoppers | Income | Volkert | | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | ------------- | | lightgbm | +0.069 | +0.54 | +4.45 | +4.2| +6.5 | +0.03 | -0.54 | +3.55 | | xgboost | +0.45 | +1.89 | +11.14 | +6.25 | +4.38 | +0.11 | 0.05 | +1.37 |

    Can you share the code used to calculate benchmark results?

    I used also quite standard parameter to train xgboost and lightgbm:

    xgboost: - max_depth=8, - learning_rate=0.01, - tree_method = 'hist', - subsample=0.75, - colsample_bytree=0.75, - reg_alpha= 0.5, - reg_lambda= 0.5,

    lightgbm: - learning_rate= 0.01, - max_depth= -1, - num_leaves= 2**8, - lambda_l1= 0.5, - lambda_l2= 0.5, - feature_fraction= 0.75, - bagging_fraction= 0.75, - bagging_freq = 1,

    I think i can improve these results by tuning these parameter more.

    opened by DavideStenner 3
  • Handling missing values

    Handling missing values

    Hello thanks for the code and awesome works.

    I read your paper impressively and have a question about the code.

    In the paper, p.14 Data preprocessing, it is written to "Each feature (or columns) has a different missing value token to account for missing data.".

    However I found that the code just fill missing values with an average value for continuous features.

    I wonder the token embedding works only for categorical data.

    It was very exciting to read the paper and I hope to apply the algorithm to my dataset soon!

    opened by deepneuralnetworks 2
  • reps = model_saint.transformer(x_categ_enc, x_cont_enc)==nan

    reps = model_saint.transformer(x_categ_enc, x_cont_enc)==nan

    I applied this great model to regression, but the value is nan in the model.transformer part.

    class RowColTransformer(nn.Module):
    ~~~~~~~~~
        def forward(self, x, x_cont=None, mask = None):
            if x_cont is not None:
                x = torch.cat((x,x_cont),dim=1)
            _, n, _ = x.shape
            print("TRANFOERMR")
            if self.style == 'colrow':
                for attn1, ff1, attn2, ff2 in self.layers: 
                    x = attn1(x)##here x==nan
    

    Did this happen during implementation? If anyone has used it for their own data, please let me know.

    these are hyper params

    model_saint = SAINT(
        categories = tuple(cat_dims.values()),#len(cat_dims)==2
        num_continuous = len(numerical_features)+1,         
        dim =128,                           
        dim_out = 1,                       
        depth = 6,                       
        heads = 8,                         
        attn_dropout = 0.1,             
        ff_dropout = 0.1,                  
        mlp_hidden_mults = (4, 2),       
        continuous_mean_std = None, 
        cont_embeddings = "MLP",
        attentiontype = 'col',
        final_mlp_style = 'sep',
        y_dim = 1
        )
    
    optim:AdamW(model_saint.parameters(), lr=1e-3,weight_decay=5e-5)
    BATCH_size=256
    
    opened by abebe9849 2
  • Question: only continuous variables (no category)

    Question: only continuous variables (no category)

    Is it possible to use SAINT for the tabular data, which contains only continuous variables, without categorical?

    We need to pass to SAINT model two parameters: x_categ and x_cont Do I need to pass some torch.empy tensor as x_categ? What to pass as "categories" parameter to the SAINT model? Empty tuple?

    opened by AaronBlare 1
  • Problem about dataset id of Openml

    Problem about dataset id of Openml

    Thank you for sharing your great work! When I want to evaluate your result on all datasets that are listed on your paper, eg, Bank, Blastchar, Arrhythmia, ..., I had a problem about your code in data_openml.py. The id_dataset id that you listed in the file (1487,44, ...) did not match with datasets you list on paper (bank, blastchar, ...). id: 1487 when I use api of opennl.datasets.get_dataset(1487), I got ozone-level-8hr dataset. Might you give me some suggestions to evaluate your result on datasets you listed on paper.

    Many thanks!

    opened by LeCongThuong 1
  • issue about data

    issue about data

    Hi, I find in your paper “Results are averaged over 5 trials and 14 binary classification datasets.” However, there is “'binary [1487,44,1590,42178,1111,31,42733,1494,1017,4134]” in your code. Could you provide other datasets?

    opened by zhenye234 1
  • Including target data while pretraining

    Including target data while pretraining

    Hello thanks for the awesome works and the codes.

    I've applied your code to some datasets and had some questions.

    While pretraining, line 323 in data.py performs concatenating category features and target features. This concatenated categorical data pass the embedding layer and the results used as an input data of the transformer. I couldn't find the code that separating the target data before passing the concatenated data into the transformer.

    It is okay to include target data while pretraining the model?

    Thank you.

    opened by deepneuralnetworks 1
  • Module is absent when importing

    Module is absent when importing

    Hello!

    There is an import in pretraining.py file:

    from baselines.data_openml import data_prep_openml,task_dset_ids,DataSetCatCon on line 4

    However, in the repo there is no folder baselines, and thus there is an error, when I attempt to apply pretraining in train file.

    Thanks!

    opened by wallykop 1
  • Attention plotting code

    Attention plotting code

    Is there any progress on the releasement of the plotting code?

    Thanks in advance!


    Hi Fabien, I will release the attention plotting code in the next version. I am busy with another project rn, I am targeting the end of November for this. If you need it urgently let me know.

    Originally posted by @somepago in https://github.com/somepago/saint/issues/8#issuecomment-945202273

    opened by wallykop 0
  • Environment: packages not found

    Environment: packages not found

    Thank you so much for sharing this impressive work. I failed in creating the environment. Anything I could do to fix this error? My error detail is listed below:

    K:\library\saint>conda env create -f saint_environment.yml
    Collecting package metadata (repodata.json): done
    Solving environment: failed
    
    ResolvePackageNotFound:
      - gmp==6.2.1=h58526e2_0
      - certifi==2021.5.30=py38h578d9bd_0
      - lame==3.100=h7f98852_1001
      - promise==2.3=py38h578d9bd_3
      - jupyter_core==4.7.1=py38h578d9bd_0
      - libglib==2.68.3=h3e27bee_0
      - setuptools==49.6.0=py38h578d9bd_3
      - ffmpeg==4.3=hf484d3e_0
      - markupsafe==2.0.1=py38h497a2fe_0
      - libprotobuf==3.17.2=h780b84a_0
      - libgomp==9.3.0=h2828fa1_19
      - protobuf==3.17.2=py38h709712a_0
      - yaml==0.2.5=h516909a_0
      - gst-plugins-base==1.14.0=hbbd80ab_1
      - freetype==2.10.4=h0708190_1
      - pcre==8.45=h9c3ff4c_0
      - tornado==6.1=py38h497a2fe_1
      - _openmp_mutex==4.5=1_gnu
      - debugpy==1.3.0=py38h709712a_0
      - xgboost==1.4.0=py38h578d9bd_0
      - expat==2.4.1=h9c3ff4c_0
      - kiwisolver==1.3.1=py38h1fd1430_1
      - pyzmq==22.1.0=py38h2035c66_0
      - glib==2.68.3=h9c3ff4c_0
      - tk==8.6.10=h21135ba_1
      - pysocks==1.7.1=py38h578d9bd_3
      - websocket-client==0.57.0=py38h578d9bd_4
      - ipython==7.25.0=py38hd0cf306_1
      - numpy-base==1.20.2=py38hfae3a4d_0
      - libffi==3.3=h58526e2_2
      - nbconvert==6.1.0=py38h578d9bd_0
      - libuuid==2.32.1=h7f98852_1000
      - numpy==1.20.2=py38h2d18471_0
      - mkl_random==1.2.2=py38h1abd341_0
      - pthread-stubs==0.4=h36c2ea0_1001
      - libpng==1.6.37=h21135ba_2
      - mkl_fft==1.3.0=py38h42c9631_2
      - chardet==4.0.0=py38h578d9bd_1
      - readline==8.1=h46c0cb4_0
      - psutil==5.8.0=py38h497a2fe_1
      - shortuuid==1.0.1=py38h578d9bd_4
      - gstreamer==1.14.0=h28cd5cc_2
      - ld_impl_linux-64==2.35.1=hea4e1c9_2
      - libgcc-ng==9.3.0=h2828fa1_19
      - xorg-libxau==1.0.9=h7f98852_0
      - mistune==0.8.4=py38h497a2fe_1004
      - pytorch==1.8.1=py3.8_cuda11.1_cudnn8.0.5_0
      - libunistring==0.9.10=h14c3975_0
      - fontconfig==2.13.1=hba837de_1005
      - importlib-metadata==4.6.1=py38h578d9bd_0
      - glib-tools==2.68.3=h9c3ff4c_0
      - libuv==1.41.0=h7f98852_0
      - click==8.0.1=py38h578d9bd_0
      - xorg-libxdmcp==1.1.3=h7f98852_0
      - mkl-service==2.4.0=py38h497a2fe_0
      - watchdog==0.10.4=py38h578d9bd_0
      - pillow==8.2.0=py38he98fc37_0
      - py-xgboost==1.4.0=py38h578d9bd_0
      - qt==5.9.7=h5867ecd_1
      - libidn2==2.3.1=h7f98852_0
      - brotlipy==0.7.0=py38h497a2fe_1001
      - libwebp-base==1.2.0=h7f98852_2
      - cryptography==3.4.7=py38ha5dfef3_0
      - gettext==0.19.8.1=h0b5b191_1005
      - scikit-learn==0.23.2=py38h0573a6f_0
      - libxcb==1.13=h7f98852_1003
      - argon2-cffi==20.1.0=py38h497a2fe_2
      - sqlite==3.35.5=h74cdb3f_0
      - nettle==3.6=he412f7d_0
      - openssl==1.1.1k=h7f98852_0
      - matplotlib==3.4.2=py38h578d9bd_0
      - anyio==3.2.1=py38h578d9bd_0
      - jedi==0.18.0=py38h578d9bd_2
      - libxml2==2.9.12=h03d6c58_0
      - sniffio==1.2.0=py38h578d9bd_1
      - xz==5.2.5=h516909a_1
      - wget==1.20.1=h22169c7_0
      - mkl==2021.2.0=h06a4308_296
      - libiconv==1.16=h516909a_0
      - jpeg==9b=h024ee3a_2
      - ca-certificates==2021.5.30=ha878542_0
      - gnutls==3.6.13=h85f3911_1
      - matplotlib-base==3.4.2=py38hcc49a3a_0
      - libgfortran-ng==7.3.0=hdf63c60_0
      - lcms2==2.12=h3be6417_0
      - icu==58.2=hf484d3e_1000
      - libxgboost==1.4.0=h9c3ff4c_0
      - pandoc==2.14.0.3=h7f98852_0
      - libsodium==1.0.18=h36c2ea0_1
      - dbus==1.13.18=hb2f20db_0
      - pandas==1.2.4=py38h1abd341_0
      - pyyaml==5.4.1=py38h497a2fe_0
      - zstd==1.4.9=ha95c52a_0
      - cudatoolkit==11.1.1=h6406543_8
      - python==3.8.10=h49503c6_1_cpython
      - _libgcc_mutex==0.1=conda_forge
      - zeromq==4.3.4=h9c3ff4c_0
      - pyrsistent==0.17.3=py38h497a2fe_2
      - cffi==1.14.5=py38ha65f79e_0
      - openh264==2.1.1=h780b84a_0
      - libtiff==4.2.0=h85742a9_0
      - lz4-c==1.9.3=h9c3ff4c_0
      - scipy==1.6.2=py38had2a1c9_1
      - ipykernel==6.0.2=py38hd0cf306_0
      - ninja==1.10.2=h4bd325d_0
      - pyqt==5.9.2=py38h05f1152_4
      - intel-openmp==2021.2.0=h06a4308_610
      - sip==4.19.13=py38he6710b0_0
      - zlib==1.2.11=h516909a_1010
      - bzip2==1.0.8=h7f98852_4
      - ncurses==6.2=h58526e2_4
      - libstdcxx-ng==9.3.0=h6de172a_19
      - terminado==0.10.1=py38h578d9bd_0
    
    opened by IvoryLu 1
  • Plotting attention for explainability

    Plotting attention for explainability

    Hello Gowthami,

    Thank you for this project. It shows uplift in performance for my use-case over xgboost. It will be of great help to get the attention plotting code (both self attention and inter-sample attention)for the SAINT implementation as shown by you in the paper for SAINT.

    opened by isamgul 0
  • fillna continuous data

    fillna continuous data

    Hello, I'm a beginner interested in Tabular Learning. Your superb paper, SAINT, impresses me a lot. But I've had some problems learning your code.

    For https://github.com/somepago/saint/blob/e288e84c77a54cfd2ffb55a53678fb7cbbb16630/old_version/data.py#L233 or https://github.com/somepago/saint/blob/e288e84c77a54cfd2ffb55a53678fb7cbbb16630/data_openml.py#L89 a) Why is train.loc[train_indices, col] rather than train.loc[:, col]? Vaild data and test data may also be nan. b) Why is train.fillna rather than train[col].fillna? It may fillnan for other columns.

    I think the correct expression should be train[col].fillna(train.loc[:, col].mean(), inplace=True).

    I'm not sure whether I am correct. I would appreciate it if you can reply. Thank you very much!

    opened by Hahahah3 1
Owner
Gowthami Somepalli
Gowthami Somepalli
Official code for "Stereo Waterdrop Removal with Row-wise Dilated Attention (IROS2021)"

Stereo-Waterdrop-Removal-with-Row-wise-Dilated-Attention This repository includes official codes for "Stereo Waterdrop Removal with Row-wise Dilated A

null 29 Oct 1, 2022
Implementation of TransGanFormer, an all-attention GAN that combines the finding from the recent GanFormer and TransGan paper

TransGanFormer (wip) Implementation of TransGanFormer, an all-attention GAN that combines the finding from the recent GansFormer and TransGan paper. I

Phil Wang 146 Dec 6, 2022
Implementation of TabTransformer, attention network for tabular data, in Pytorch

Tab Transformer Implementation of Tab Transformer, attention network for tabular data, in Pytorch. This simple architecture came within a hair's bread

Phil Wang 420 Jan 5, 2023
Saeed Lotfi 28 Dec 12, 2022
VIMPAC: Video Pre-Training via Masked Token Prediction and Contrastive Learning

This is a release of our VIMPAC paper to illustrate the implementations. The pretrained checkpoints and scripts will be soon open-sourced in HuggingFace transformers.

Hao Tan 74 Dec 3, 2022
BigDetection: A Large-scale Benchmark for Improved Object Detector Pre-training

BigDetection: A Large-scale Benchmark for Improved Object Detector Pre-training By Likun Cai, Zhi Zhang, Yi Zhu, Li Zhang, Mu Li, Xiangyang Xue. This

null 290 Dec 29, 2022
An implementation of the AlphaZero algorithm for Gomoku (also called Gobang or Five in a Row)

AlphaZero-Gomoku This is an implementation of the AlphaZero algorithm for playing the simple board game Gomoku (also called Gobang or Five in a Row) f

Junxiao Song 2.8k Dec 26, 2022
The official implementation of the paper, "SubTab: Subsetting Features of Tabular Data for Self-Supervised Representation Learning"

SubTab: Author: Talip Ucar ([email protected]) The official implementation of the paper, SubTab: Subsetting Features of Tabular Data for Self-Supervis

AstraZeneca 98 Dec 29, 2022
A pytorch implementation of Paper "Improved Training of Wasserstein GANs"

WGAN-GP An pytorch implementation of Paper "Improved Training of Wasserstein GANs". Prerequisites Python, NumPy, SciPy, Matplotlib A recent NVIDIA GPU

Marvin Cao 1.4k Dec 14, 2022
An implementation demo of the ICLR 2021 paper Neural Attention Distillation: Erasing Backdoor Triggers from Deep Neural Networks in PyTorch.

Neural Attention Distillation This is an implementation demo of the ICLR 2021 paper Neural Attention Distillation: Erasing Backdoor Triggers from Deep

Yige-Li 84 Jan 4, 2023
Source code and dataset for ACL2021 paper: "ERICA: Improving Entity and Relation Understanding for Pre-trained Language Models via Contrastive Learning".

ERICA Source code and dataset for ACL2021 paper: "ERICA: Improving Entity and Relation Understanding for Pre-trained Language Models via Contrastive L

THUNLP 75 Nov 2, 2022
Code of our paper "Contrastive Object-level Pre-training with Spatial Noise Curriculum Learning"

CCOP Code of our paper Contrastive Object-level Pre-training with Spatial Noise Curriculum Learning Requirement Install OpenSelfSup Install Detectron2

Chenhongyi Yang 21 Dec 13, 2022
Supervision Exists Everywhere: A Data Efficient Contrastive Language-Image Pre-training Paradigm

DeCLIP Supervision Exists Everywhere: A Data Efficient Contrastive Language-Image Pre-training Paradigm. Our paper is available in arxiv Updates ** Ou

Sense-GVT 470 Dec 30, 2022
CLIP (Contrastive Language–Image Pre-training) trained on Indonesian data

CLIP-Indonesian CLIP (Radford et al., 2021) is a multimodal model that can connect images and text by training a vision encoder and a text encoder joi

Galuh 17 Mar 10, 2022
[CVPR 2022] Official code for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved Neural Network Calibration"

MDCA Calibration This is the official PyTorch implementation for the paper: "A Stitch in Time Saves Nine: A Train-Time Regularizing Loss for Improved

MDCA Calibration 21 Dec 22, 2022
PyTorch implementation for OCT-GAN Neural ODE-based Conditional Tabular GANs (WWW 2021)

OCT-GAN: Neural ODE-based Conditional Tabular GANs (OCT-GAN) Code for reproducing the experiments in the paper: Jayoung Kim*, Jinsung Jeon*, Jaehoon L

BigDyL 7 Dec 27, 2022
PyTorch implementation of CloudWalk's recent work DenseBody

densebody_pytorch PyTorch implementation of CloudWalk's recent paper DenseBody. Note: For most recent updates, please check out the dev branch. Update

Lingbo Yang 401 Nov 19, 2022
Official PyTorch Implementation of Embedding Transfer with Label Relaxation for Improved Metric Learning, CVPR 2021

Embedding Transfer with Label Relaxation for Improved Metric Learning Official PyTorch implementation of CVPR 2021 paper Embedding Transfer with Label

Sungyeon Kim 37 Dec 6, 2022
Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding

Relational Self-Attention: What's Missing in Attention for Video Understanding This repository is the official implementation of "Relational Self-Atte

mandos 43 Dec 7, 2022