PyTorch implementation of Federated Learning with Non-IID Data, and federated learning algorithms, including FedAvg, FedProx.

Overview

Federated Learning with Non-IID Data

This is an implementation of the following paper:

Yue Zhao, Meng Li, Liangzhen Lai, Naveen Suda, Damon Civin, Vikas Chandra. Federated Learning with Non-IID Data
arXiv:1806.00582.

Paper

TL;DR: Previous federated optization algorithms (such as FedAvg and FedProx) converge to stationary points of a mismatched objective function due to heterogeneity in data distribution. In this paper, the authors propose a data-sharing strategy to improve training on non-IID data by creating a small subset of data which is globally shared between all the edge devices.

Abstract: Federated learning enables resource-constrained edge compute devices, such as mobile phones and IoT devices, to learn a shared model for prediction, while keeping the training data local. This decentralized approach to train models provides privacy, security, regulatory and economic benefits. In this work, we focus on the statistical challenge of federated learning when local data is non-IID. We first show that the accuracy of federated learning reduces significantly, by up to ~55% for neural networks trained for highly skewed non-IID data, where each client device trains only on a single class of data. We further show that this accuracy reduction can be explained by the weight divergence, which can be quantified by the earth mover’s distance (EMD) between the distribution over classes on each device and the population distribution. As a solution, we propose a strategy to improve training on non-IID data by creating a small subset of data which is globally shared between all the edge devices. Experiments show that accuracy can be increased by ~30% for the CIFAR-10 dataset with only 5% globally shared data.

Requirements

The implementation runs on:

  • Python 3.8
  • PyTorch 1.6.0
  • CUDA 10.1
  • cuDNN 7.6.5

Federated Learning Algorithms

Currently, this repository supports the following federated learning algorithms:

Launch Experiments

An example launch script is shown below.

python main.py 
    --all_clients \
    --fed fedavg \
    --gpu 0 \
    --seed 1 \
    --sampling noniid \
    --sys_homo \
    --num_channels 3 \
    --dataset cifar

Explanations of arguments:

  • fed: federated optimization algorithm
  • mu: parameter for fedprox
  • sampling: sampling method
  • alpha: random portion of global dataset
  • dataset: name of dataset
  • rounds: total number of communication rounds
  • sys_homo: no system heterogeneity

Acknowledgements

Referred http://doi.org/10.5281/zenodo.4321561

You might also like...
pytorch implementation of
pytorch implementation of "Contrastive Multiview Coding", "Momentum Contrast for Unsupervised Visual Representation Learning", and "Unsupervised Feature Learning via Non-Parametric Instance-level Discrimination"

Unofficial implementation: MoCo: Momentum Contrast for Unsupervised Visual Representation Learning (Paper) InsDis: Unsupervised Feature Learning via N

Vertical Federated Principal Component Analysis and Its Kernel Extension on Feature-wise Distributed Data based on Pytorch Framework
Vertical Federated Principal Component Analysis and Its Kernel Extension on Feature-wise Distributed Data based on Pytorch Framework

VFedPCA+VFedAKPCA This is the official source code for the Paper: Vertical Federated Principal Component Analysis and Its Kernel Extension on Feature-

PyTorch implementation of SCAFFOLD (Stochastic Controlled Averaging for Federated Learning, ICML 2020).

Scaffold-Federated-Learning PyTorch implementation of SCAFFOLD (Stochastic Controlled Averaging for Federated Learning, ICML 2020). Environment numpy=

Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.
Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.

PyTorch Implementation of Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers 1 Using Colab Please notic

pytorch implementation of openpose including Hand and Body Pose Estimation.
pytorch implementation of openpose including Hand and Body Pose Estimation.

pytorch-openpose pytorch implementation of openpose including Body and Hand Pose Estimation, and the pytorch model is directly converted from openpose

Official pytorch implementation for Learning to Listen: Modeling Non-Deterministic Dyadic Facial Motion (CVPR 2022)
Official pytorch implementation for Learning to Listen: Modeling Non-Deterministic Dyadic Facial Motion (CVPR 2022)

Learning to Listen: Modeling Non-Deterministic Dyadic Facial Motion This repository contains a pytorch implementation of "Learning to Listen: Modeling

deep-table implements various state-of-the-art deep learning and self-supervised learning algorithms for tabular data using PyTorch.
deep-table implements various state-of-the-art deep learning and self-supervised learning algorithms for tabular data using PyTorch.

deep-table implements various state-of-the-art deep learning and self-supervised learning algorithms for tabular data using PyTorch.

PyTorch implementation of Advantage async actor-critic Algorithms (A3C) in PyTorch
PyTorch implementation of Advantage async actor-critic Algorithms (A3C) in PyTorch

Advantage async actor-critic Algorithms (A3C) in PyTorch @inproceedings{mnih2016asynchronous, title={Asynchronous methods for deep reinforcement lea

Many Class Activation Map methods implemented in Pytorch for CNNs and Vision Transformers. Including Grad-CAM, Grad-CAM++, Score-CAM, Ablation-CAM and XGrad-CAM
Many Class Activation Map methods implemented in Pytorch for CNNs and Vision Transformers. Including Grad-CAM, Grad-CAM++, Score-CAM, Ablation-CAM and XGrad-CAM

Class Activation Map methods implemented in Pytorch pip install grad-cam ⭐ Tested on many Common CNN Networks and Vision Transformers. ⭐ Includes smoo

Comments
  • fix update.py

    fix update.py

    net.zero_grad()
    log_probs = net(images)
    loss = self.loss_func(log_probs, labels)
    
    # FedProx: https://arxiv.org/abs/1812.06127
    if self.args.fed == 'fedprox':
        if iter > 0: 
            for w, w_t in zip(local_net.parameters(), net.parameters()):
                loss += self.args.mu / 2. * torch.pow(torch.norm(w.data - w_t.data), 2)
                ### w_t.grad.data += self.args.mu * (w_t.data - w.data)
    

    Net.zero_grad kept net.grad zero. That line of code makes no sense.

    w_t.grad.data += self.args.mu * (w_t.data - w.data)
    
    opened by 13015517713 0
  • Can AlexNet be applied to the CIFAR-10 dataset by using the code you published? Thank you!

    Can AlexNet be applied to the CIFAR-10 dataset by using the code you published? Thank you!

    I have learned a lot after reading your paper and code. Thank you for sharing. I noticed that there was AlexNet network in the net.py file, but after I changed the code a little bit to use AlexNet, I had some problems, such as the test accuracy was not as good as CNN or even very low, etc. Can AlexNet be applied to the CIFAR-10 dataset by using the code you published? Thank you!

    opened by realcly 1
Owner
Youngjoon Lee
AI Research Scientist
Youngjoon Lee
An unofficial PyTorch implementation of a federated learning algorithm, FedAvg.

Federated Averaging (FedAvg) in PyTorch An unofficial implementation of FederatedAveraging (or FedAvg) algorithm proposed in the paper Communication-E

Seok-Ju Hahn 123 Jan 6, 2023
TianyuQi 10 Dec 11, 2022
Everything you want about DP-Based Federated Learning, including Papers and Code. (Mechanism: Laplace or Gaussian, Dataset: femnist, shakespeare, mnist, cifar-10 and fashion-mnist. )

Differential Privacy (DP) Based Federated Learning (FL) Everything about DP-based FL you need is here. (所有你需要的DP-based FL的信息都在这里) Code Tip: the code o

wenzhu 83 Dec 24, 2022
Pytorch implementations of popular off-policy multi-agent reinforcement learning algorithms, including QMix, VDN, MADDPG, and MATD3.

Off-Policy Multi-Agent Reinforcement Learning (MARL) Algorithms This repository contains implementations of various off-policy multi-agent reinforceme

null 183 Dec 28, 2022
Independent and minimal implementations of some reinforcement learning algorithms using PyTorch (including PPO, A3C, A2C, ...).

PyTorch RL Minimal Implementations There are implementations of some reinforcement learning algorithms, whose characteristics are as follow: Less pack

Gemini Light 4 Dec 31, 2022
A non-linear, non-parametric Machine Learning method capable of modeling complex datasets

Fast Symbolic Regression Symbolic Regression is a non-linear, non-parametric Machine Learning method capable of modeling complex data sets. fastsr aim

VAMSHI CHOWDARY 3 Jun 22, 2022
A Robust Non-IoU Alternative to Non-Maxima Suppression in Object Detection

Confluence: A Robust Non-IoU Alternative to Non-Maxima Suppression in Object Detection 1. 介绍 用以替代 NMS,在所有 bbox 中挑选出最优的集合。 NMS 仅考虑了 bbox 的得分,然后根据 IOU 来

null 44 Sep 15, 2022
FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.

FedJAX: Federated learning with JAX What is FedJAX? FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX priori

Google 208 Dec 14, 2022
Scripts of Machine Learning Algorithms from Scratch. Implementations of machine learning models and algorithms using nothing but NumPy with a focus on accessibility. Aims to cover everything from basic to advance.

Algo-ScriptML Python implementations of some of the fundamental Machine Learning models and algorithms from scratch. The goal of this project is not t

Algo Phantoms 81 Nov 26, 2022
Pull sensitive data from users on windows including discord tokens and chrome data.

⭐ For a ?? Pegasus Pull sensitive data from users on windows including discord tokens and chrome data. Features ?? Discord tokens ?? Geolocation data

Addi 44 Dec 31, 2022