A new mini-batch framework for optimal transport in deep generative models, deep domain adaptation, approximate Bayesian computation, color transfer, and gradient flow.

Overview

BoMb-OT

Python3 implementation of the papers On Transportation of Mini-batches: A Hierarchical Approach and Improving Mini-batch Optimal Transport via Partial Transportation.

Please CITE our papers whenever this repository is used to help produce published results or incorporated into other software.

@article{nguyen2021transportation,
      title={On Transportation of Mini-batches: A Hierarchical Approach}, 
      author={Khai Nguyen and Dang Nguyen and Quoc Nguyen and Tung Pham and Hung Bui and Dinh Phung and Trung Le and Nhat Ho},
      journal={arXiv preprint arXiv:2102.05912},
      year={2021},
}
@article{nguyen2021improving,
      title={Improving Mini-batch Optimal Transport via Partial Transportation}, 
      author={Khai Nguyen and Dang Nguyen and Tung Pham and Nhat Ho},
      journal={arXiv preprint arXiv:2108.09645},
      year={2021},
}

This implementation is made by Khai Nguyen and Dang Nguyen. README is on updating process.

Requirement

  • python 3.6
  • pytorch 1.7.1
  • torchvision
  • numpy
  • tqdm
  • geomloss
  • POT
  • matplotlib
  • cvxpy

What is included?

The scalable implementation of the batch of mini-batches scheme and the conventional averaging scheme of mini-batch transportation types: optimal transport (OT), partial optimal transport (POT), unbalanced optimal transport (UOT), sliced optimal transport for:

  • Deep Generative Models
  • Deep Domain Adaptation
  • Approximate Bayesian Computation
  • Color Transfer
  • Gradient Flow

Deep Adaptation on digits datasets (DeepDA/digits)

Code organization

cfg.py : this file contains arguments for training.

methods.py : this file implements the training process of the deep DA.

models.py : this file contains the architecture of the genertor and the classifier.

train_digits.py: running file for deep DA.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of mini-batch deep DA method (jdot, jumbot, jpmbot)

--source_ds : source dataset

--target_ds : target dataset

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--eta1 : weight of embedding loss

--eta2 : weight of transportation loss

--k : number of mini-batches

--mbsize : mini-batch size

--n_epochs : number of running epochs

--test_interval : interval of two continuous test phase

--lr : initial learning rate

--data_dir : path to dataset

--reg : OT regularization coefficient for Sinkhorn algorithm

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Change the number of mini-batches $k$

bash sh/exp_mOT_change_k.sh
bash sh/exp_BoMbOT_change_k.sh

Change the mini-batch size $m$

bash sh/exp_mOT_change_m.sh
bash sh/exp_BoMbOT_change_m.sh

Deep Adaptation on Office-Home and VisDA datasets (DeepDA/office)

Code organization

data_list.py : this file contains functions to create dataset.

evaluate.py : this file is used to evaluate model trained on VisDA dataset.

lr_schedule.py : this file implements the learning rate scheduler.

network.py : this file contains the architecture of the genertor and the classifier.

pre_process.py : this file implements preprocessing techniques.

train.py : this file implements the training process for both datasets.

Terminologies

--net : architecture type of the generator

--dset : name of the dataset

--test_interval : interval of two continuous test phase

--s_dset_path : path to source dataset

--stratify_source : use stratify sampling

--s_dset_path : path to target dataset

--batch_size : training batch size

--stop_step : number of iterations

--ot_type : type of OT loss (balanced, unbalanced, partial)

--eta1 : weight of embedding loss ($\alpha$ in equation 10)

--eta2 : weight of transportation loss ($\lambda_t$ in equation 10)

--epsilon : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on Office-Home

bash sh/train_home.sh

Train on VisDA

bash sh/train_visda.sh

Deep Generative model (DeepGM)

Code organization

Celeba_generator.py, Cifar_generator.py : these files contain the architecture of the generator on CelebA and CIFAR10 datasets, and include some self-function to compute losses of corresponding baselines.

experiments.py : this file contains some functions for generating images.

fid_score.py: this file is used to compute the FID score.

gen_images.py: read saved models to produce 10000 images to calculate FID.

inception.py: this file contains the architecture of Inception Net V3.

main_celeba.py, main_cifar.py : running files on the corresponding datasets.

utils.py : this file contains implementation of utility functions.

Terminologies

--method : type of OT loss (OT, UOT, POT, sliced)

--reg : OT regularization coefficient for Sinkhorn algorithm

--tau : marginal penalization coefficient in UOT

--mass : fraction of masses in POT

--k : number of mini-batches

--m : mini-batch size

--epochs : number of epochs at k = 1. The actual running epochs is calculated by multiplying this value by the value of k.

--lr : initial learning rate

--latent-size : latent size of the generator

--datadir : path to dataset

--L : number of projections when using slicing approach

--bomb : Using Batch of Mini-batches

--ebomb : Using entropic Batch of Mini-batches

--breg : OT regularization coefficient for entropic Batch of Mini-batches

Train on CIFAR10

CUDA_VISIBLE_DEVICES=0 python main_cifar.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 100 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Train on CELEBA

CUDA_VISIBLE_DEVICES=0 python main_celeba.py --method POT --reg 0 --tau 1 \
    --mass 0.7 --k 2 --m 200 --epochs 100 --lr 5e-4 --latent-size 32 --datadir ./data

Gradient Flow (GradientFlow)

python main.py

Color Transfer (Color Transfer)

python main.py  --m=100 --T=10000 --source images/s1.bmp --target images/t1.bmp --cluster

Terminologies

--k : number of mini-batches

--m : the size of mini-batches

--T : the number of steps

--cluster: K mean clustering to compress images

--palette: show color palette

--source: Path to the source image

Acknowledgment

The structure of DeepDA is largely based on JUMBOT and ALDA. The structure of ABC is largely based on SlicedABC. We are very grateful for their open sources.

You might also like...
This is a batch script created to WEB-DL.
This is a batch script created to WEB-DL.

widevine-L3-WEB-DL-Script This is a batch script created to WEB-DL. Works well with .mpd files , for m3u8 please use n_m3u8 program (not included in t

Batch Python Program Verify

Batch Python Program Verify About As a TA(teaching assistant) of Programming Class, it is very annoying to test students' homework assignments one by

A tool to determine optimal projects for Gridcoin crunchers. Maximize your magnitude!

FindTheMag FindTheMag helps optimize your BOINC client for Gridcoin mining. You can group BOINC projects into two groups: "preferred" projects and "mi

Batch generate asset browser previews
Batch generate asset browser previews

When dealing with hundreds of library files it becomes tedious to mark their contents as assets. Using python to automate the process is a perfect fit

A python package that computes an optimal motion plan for approaching a red light

redlight_approach redlight_approach is a Python package that computes an optimal motion plan during traffic light approach. RLA_demo.mov Given the par

easy_sbatch - Batch submitting Slurm jobs with script templates

easy_sbatch - Batch submitting Slurm jobs with script templates

Supply Chain will be a SAAS platfom to provide e-logistic facilites with most optimal

Shipp It Welcome To Supply Chain App [ Shipp It ] In "Shipp It" we are creating a full solution[web+app] for a entire supply chain from receiving orde

Batch obfuscator based on the obfuscation method used by the trick bot launcher

Batch obfuscator based on the obfuscation method used by the trick bot launcher

A python package for batch import of resume attachments to be parsed in HrFlow.

HrFlow Importer Description A python package for batch import of resume attachments to be parsed in HrFlow. hrflow-importer is an open-source project

Owner
Khai Ba Nguyen
I am currently an AI Resident at VinAI Research, Vietnam.
Khai Ba Nguyen
Convert three types of color in your clipboard and paste it to the color property (gamma correct)

ColorPaster [Blender Addon] Convert three types of color in your clipboard and paste it to the color property (gamma correct) How to Use Hover your mo

null 13 Oct 31, 2022
TB Set color display - Add-on for Blender to set multiple objects and material Display Color at once.

TB_Set_color_display Add-on for Blender with operations to transfer name between object, data, materials and action names Set groups of object's or ma

null 1 Jun 1, 2022
Code for Crowd counting via unsupervised cross-domain feature adaptation.

CDFA-pytorch Code for Unsupervised crowd counting via cross-domain feature adaptation. Pre-trained models Google Drive Baidu Cloud : t4qc Environment

Guanchen Ding 6 Dec 11, 2022
Python bindings for `ign-msgs` and `ign-transport`

Python Ignition This project aims to provide Python bindings for ignition-msgs and ignition-transport. It is a work in progress... C++ and Python libr

Rhys Mainwaring 3 Nov 8, 2022
Antchain-MPC is a library of MPC (Multi-Parties Computation)

Antchain-MPC Antchain-MPC is a library of MPC (Multi-Parties Computation). It include Morse-STF: A tool for machine learning using MPC. Others: Commin

Alipay 37 Nov 22, 2022
A script that will warn you, by opening a new browser tab, when there are new content in your favourite websites.

web check A script that will warn you, by opening a new browser tab, when there are new content in your favourite websites. What it does The script wi

Jaime Álvarez 52 Mar 15, 2022
Mini-calculadora escrita como exemplo para uma palestra relâmpago sobre `git bisect`

Calculadora Mini-calculadora criada para uma palestra relâmpado sobre git bisect. Tem até uma colinha! Exemplo de uso Modo interativo $ python -m calc

Eduardo Cuducos 3 Dec 14, 2021
My collection of mini-projects in various languages

Mini-Projects My collection of mini-projects in various languages About: This repository consists of a number of small projects. Most of these "mini-p

Siddhant Attavar 1 Jul 11, 2022
creates a batch file that uses adb to auto-install apks into the Windows Subsystem for Android and registers it as the default application to open apks.

wsa-apktool creates a batch file that uses adb to auto-install apks into the Windows Subsystem for Android and registers it as the default application

Aditya Vikram 3 Apr 5, 2022
We want to check several batch of web URLs (1~100 K) and find the phishing website/URL among them.

We want to check several batch of web URLs (1~100 K) and find the phishing website/URL among them. This module is designed to do the URL/web attestation by using the API from NUS-Phishperida-Project.

null 3 Dec 28, 2022