deep-table implements various state-of-the-art deep learning and self-supervised learning algorithms for tabular data using PyTorch.

Overview

deep-table

deep-table implements various state-of-the-art deep learning and self-supervised learning algorithms for tabular data using PyTorch.

Design

Architecture

As shown below, each pretraining/fine-tuning model is decomposed into two modules: Encoder and Head.

Encoder

Encoder has Embedding and Backbone.

  • Embedding makes continuous/categorical features tokenized or simply normalized.
  • Backbone processes the tokenized features.

Pretraining/Fine-tuning Head

Pretraining/Fine-tuning Head uses Encoder module for training.

Implemented Methods

Available Modules

Encoder - Embedding

  • FeatureEmbedding
  • TabTransformerEmbedding

Encoder - Backbone

  • MLPBackbone
  • FTTransformerBackbone
  • SAINTBackbone

Model - Head

  • MLPHeadModel

Model - Pretraining

  • DenoisingPretrainModel
  • SAINTPretrainModel
  • TabTransformerPretrainModel
  • VIMEPretrainModel

How To Use

Step 0. Install

python setup.py install

# Installation with pip
pip install -e .

Step 1. Define config.json

You have to define three configs at least.

  1. encoder
  2. model
  3. trainer

Minimum configurations are as follows:

from omegaconf import OmegaConf

encoder_config = OmegaConf.create({
    "embedding": {
        "name": "FeatureEmbedding",
    },
    "backbone": {
        "name": "FTTransformerBackbone",
    }
})

model_config = OmegaConf.create({
    "name": "MLPHeadModel"
})

trainer_config = OmegaConf.create({
    "max_epochs": 1,
})

Other parameters can be changed also by config.json if you want.

Step 2. Define Datamodule

from deep_table.data.data_module import TabularDatamodule


datamodule = TabularDatamodule(
    train=train_df,
    validation=val_df,
    test=test_df,
    task="binary",
    dim_out=1,
    categorical_cols=["education", "occupation", ...],
    continuous_cols=["age", "hours-per-week", ...],
    target=["income"],
    num_categories=110,
)

Step 3. Run Training

>> {'accuracy': array([0.8553...]), 'AUC': array([0.9111...]), 'F1 score': array([0.9077...]), 'cross_entropy': array([0.3093...])} ">
from deep_table.estimators.base import Estimator
from deep_table.utils import get_scores


estimator = Estimator(
    encoder_config,      # Encoder architecture
    model_config,        # model settings (learning rate, scheduler...)
    trainer_config,      # training settings (epoch, gpu...)
)

estimator.fit(datamodule)
predict = estimator.predict(datamodule.dataloader(split="test"))
get_scores(predict, target, task="binary")
>>> {'accuracy': array([0.8553...]),
     'AUC': array([0.9111...]),
     'F1 score': array([0.9077...]),
     'cross_entropy': array([0.3093...])}

If you want to train a model with pretraining, write as follows:

from deep_table.estimators.base import Estimator
from deep_table.utils import get_scores


pretrain_model_config = OmegaConf.create({
    "name": "SAINTPretrainModel"
})

pretrain_model = Estimator(encoder_config, pretrain_model_config, trainer_config)
pretrain_model.fit(datamodule)

estimator = Estimator(encoder_config, model_config, trainer_config)
estimator.fit(datamodule, from_pretrained=pretrain_model)

See notebooks/train_adult.ipynb for more details.

Custom Datasets

You can use your own datasets.

  1. Prepare datasets and create DataFrame
  2. Preprocess DataFrame
  3. Create your own datamodules using TabularDatamodule

Example code is shown below.

import pandas as pd

import os,sys; sys.path.append(os.path.abspath(".."))
from deep_table.data.data_module import TabularDatamodule
from deep_table.preprocess import CategoryPreprocessor


# 0. Prepare datasets and create DataFrame
iris = pd.read_csv('https://raw.githubusercontent.com/mwaskom/seaborn-data/master/iris.csv')

# 1. Preprocessing pd.DataFrame
category_preprocesser = CategoryPreprocessor(categorical_columns=["species"], use_unk=False)
iris = category_preprocesser.fit_transform(iris)

# 2. TabularDatamodule
datamodule = TabularDatamodule(
    train=iris.iloc[:20],
    val=iris.iloc[20:40],
    test=iris.iloc[40:],
    task="multiclass",
    dim_out=3,
    categorical_columns=[],
    continuous_columns=["sepal_length", "sepal_width", "petal_length", "petal_width"],
    target=["species"],
    num_categories=0,
)

See notebooks/custom_dataset.ipynb for the full training example.

Custom Models

You can also use your Embedding/Backbone/Model. Set arguments as shown below.

estimator = Estimator(
    encoder_config, model_config, trainer_config,
    custom_embedding=YourEmbedding, custom_backbone=YourBackbone, custom_model=YourModel
)

If custom models are set, the attributes name in corresponding configs will be overwritten.

See notebooks/custom_model.ipynb for more details.

You might also like...
State of the Art Neural Networks for Deep Learning

pyradox This python library helps you with implementing various state of the art neural networks in a totally customizable fashion using Tensorflow 2

🤗 Transformers: State-of-the-art Natural Language Processing for Pytorch, TensorFlow, and JAX.
🤗 Transformers: State-of-the-art Natural Language Processing for Pytorch, TensorFlow, and JAX.

English | 简体中文 | 繁體中文 State-of-the-art Natural Language Processing for Jax, PyTorch and TensorFlow 🤗 Transformers provides thousands of pretrained mo

Deep Text Search is an AI-powered multilingual text search and recommendation engine with state-of-the-art transformer-based multilingual text embedding (50+ languages).
Deep Text Search is an AI-powered multilingual text search and recommendation engine with state-of-the-art transformer-based multilingual text embedding (50+ languages).

Deep Text Search - AI Based Text Search & Recommendation System Deep Text Search is an AI-powered multilingual text search and recommendation engine w

LaneDet is an open source lane detection toolbox based on PyTorch that aims to pull together a wide variety of state-of-the-art lane detection models
LaneDet is an open source lane detection toolbox based on PyTorch that aims to pull together a wide variety of state-of-the-art lane detection models

LaneDet is an open source lane detection toolbox based on PyTorch that aims to pull together a wide variety of state-of-the-art lane detection models. Developers can reproduce these SOTA methods and build their own methods.

Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch
Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch

NÜWA - Pytorch (wip) Implementation of NÜWA, state of the art attention network for text to video synthesis, in Pytorch. This repository will be popul

Implementation of 🦩 Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch
Implementation of 🦩 Flamingo, state-of-the-art few-shot visual question answering attention net out of Deepmind, in Pytorch

🦩 Flamingo - Pytorch Implementation of Flamingo, state-of-the-art few-shot visual question answering attention net, in Pytorch. It will include the p

TorchMultimodal is a PyTorch library for training state-of-the-art multimodal multi-task models at scale.

TorchMultimodal (Alpha Release) Introduction TorchMultimodal is a PyTorch library for training state-of-the-art multimodal multi-task models at scale.

Implementation of ETSformer, state of the art time-series Transformer, in Pytorch
Implementation of ETSformer, state of the art time-series Transformer, in Pytorch

ETSformer - Pytorch Implementation of ETSformer, state of the art time-series Transformer, in Pytorch Install $ pip install etsformer-pytorch Usage im

Model search is a framework that implements AutoML algorithms for model architecture search at scale
Model search is a framework that implements AutoML algorithms for model architecture search at scale

Model search (MS) is a framework that implements AutoML algorithms for model architecture search at scale. It aims to help researchers speed up their exploration process for finding the right model architecture for their classification problems (i.e., DNNs with different types of layers).

Owner
null
This is an official implementation for the WTW Dataset in "Parsing Table Structures in the Wild " on table detection and table structure recognition.

WTW-Dataset This is an official implementation for the WTW Dataset in "Parsing Table Structures in the Wild " on ICCV 2021. Here, you can download the

null 109 Dec 29, 2022
State-of-the-art data augmentation search algorithms in PyTorch

MuarAugment Description MuarAugment is a package providing the easiest way to a state-of-the-art data augmentation pipeline. How to use You can instal

null 43 Dec 12, 2022
The official implementation of the paper, "SubTab: Subsetting Features of Tabular Data for Self-Supervised Representation Learning"

SubTab: Author: Talip Ucar ([email protected]) The official implementation of the paper, SubTab: Subsetting Features of Tabular Data for Self-Supervis

AstraZeneca 98 Dec 29, 2022
The Self-Supervised Learner can be used to train a classifier with fewer labeled examples needed using self-supervised learning.

Published by SpaceML • About SpaceML • Quick Colab Example Self-Supervised Learner The Self-Supervised Learner can be used to train a classifier with

SpaceML 92 Nov 30, 2022
tsai is an open-source deep learning package built on top of Pytorch & fastai focused on state-of-the-art techniques for time series classification, regression and forecasting.

Time series Timeseries Deep Learning Pytorch fastai - State-of-the-art Deep Learning with Time Series and Sequences in Pytorch / fastai

timeseriesAI 2.8k Jan 8, 2023
A complete, self-contained example for training ImageNet at state-of-the-art speed with FFCV

ffcv ImageNet Training A minimal, single-file PyTorch ImageNet training script designed for hackability. Run train_imagenet.py to get... ...high accur

FFCV 92 Dec 31, 2022
😇A pyTorch implementation of the DeepMoji model: state-of-the-art deep learning model for analyzing sentiment, emotion, sarcasm etc

------ Update September 2018 ------ It's been a year since TorchMoji and DeepMoji were released. We're trying to understand how it's being used such t

Hugging Face 865 Dec 24, 2022
Deepparse is a state-of-the-art library for parsing multinational street addresses using deep learning

Here is deepparse. Deepparse is a state-of-the-art library for parsing multinational street addresses using deep learning. Use deepparse to Use the pr

GRAAL/GRAIL 192 Dec 20, 2022