[CVPR'21] FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space

Overview

FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space

by Quande Liu, Cheng Chen, Jing Qin, Qi Dou, Pheng-Ann Heng.

Introduction

This repository is for our CVPR 2021 paper 'FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space'.

Usage

  1. Start with a demo for continuous frequency space interpolation among federated clicnets:
    python freq_space_interpolation_demo.py

  1. Prepare the dataset, and then extract the amplitude spectrum of samples in each local client with the function in dataset/prepare_dataset.py:

  2. Organize the data (saved sa npy) and amplitude spectrum of local clients as following structure:

      ├── dataset
         ├── client1
            ├── data_npy
                ├── sample1.npy, sample2.npy, xxxx
            ├── freq_amp_npy
                ├── amp_sample1.npy, amp_sample2.npy, xxxx
         ├── clientxxx
         ├── clientxxx
    
  3. Train the federated learning model with ELCFS:

    python train_ELCFS.py

Citation

If this repository is useful for your research, please consider citing:

@article{liu2021feddg,
  title={FedDG: Federated Domain Generalization on Medical Image Segmentation via Episodic Learning in Continuous Frequency Space},
  author={Liu, Quande and Chen, Cheng and Qin, Jing and Dou, Qi and Heng, Pheng-Ann},
  journal={The IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
  year={2021}
}

Acknowledgement

Some of the code is adapted from SAML and FDA. The datasets used in this paper are downloaded from Prostate and Fundus.

Comments
  • Question about interpolated amplitude spectrum

    Question about interpolated amplitude spectrum

    Dear Quande, congratulations on your paper being accepted! I have a question about your paper. In the paper, you mentioned that you calculate the Interpolated Amplitude Spectrum by formular (2). From my perspecitve, I think the formular (2) means adding the central region of amplitude spectrum(AS) from the distribution bank multipling λ to the non-central region of AS from local image multipling (1-λ), which means the new AS's central region only contains the results calculated by AS from the distribution bank. However, in "freq_space_interpolation_demo.py" line 37, you just calculated the central region of new AS by adding the central region of original AS from local image to the central region of AS from the distribution bank(I left out λ for brevity), which I suppose does not match formular (2). Is this a mistake or did I misunderstand it?

    opened by zzzqzhou 4
  • Request about Release the pre-processed Dataset

    Request about Release the pre-processed Dataset

    Dear Author, I met some problems at step 2 "Prepare the dataset", as I do not know what it exactly is in the .npy file (e.g. converted from raw image? or after some pre-processing? or something else), and there is no reference. Therefore, could you please release the re-organized dataset that we could directly applied in training? I mean, the "/Dataset" repository with clientxxx and proper placed .npy files in it. Dataset for either task would help a lot. Thank you very much for your attention. Also congratulations!

    opened by xxliang99 3
  • Question about “test_ELCFS.py”.

    Question about “test_ELCFS.py”.

    Dear author, In code file “test_ELCFS.py” line 63, you used "test_net.train()" in test stage. However, as we know, we typically use "model.eval()" operation during the test stage. Is there any reason for you to do so?

    opened by zzzqzhou 2
  • Questions about data augmentation in fundus dataset.

    Questions about data augmentation in fundus dataset.

    Hi, In your released code, I cannot see the data augmentation operations in train_ELCFS.py which are introduced in your paper. We can only see you use the ToTensor() operation, which is weird.

    opened by BurningFr 1
  • It seems that some documents are absent?

    It seems that some documents are absent?

    Dear author, Thanks for your great work and released the project. when I am rebuilding the FedDG, it seems that some documents are missing?

    • https://github.com/liuquande/FedDG-ELCFS/blob/d59321cc72f09571a27777579d66b98c193631e0/train_ELCFS.py#L61
    • https://github.com/liuquande/FedDG-ELCFS/blob/d59321cc72f09571a27777579d66b98c193631e0/train_ELCFS.py#L153

    We sincerely hope to get your help. It would be great if you could answer our doubts or share related documents.

    opened by lichen14 1
  • The mask for some data in the prostate dataset is empty

    The mask for some data in the prostate dataset is empty

    Hello, I would like to know that some masks in the data set of prostate are empty, which leads to nan appearing at the beginning of training. How to solve this situation

    opened by yunpengt 1
  • Prostate dataset and dataloader

    Prostate dataset and dataloader

    Hi Quande,

    Thanks for sharing this exciting research work! Here I want to learn more details about the implementation of the prostate task. It seems the main scripts (including train_ELCFS.py, test_ELCFS.py and freq_space_interpolation_demo.py) are tailored for the Fundus dataset. Could it be possible for me to find the dataloader script of the prostate dataset? Or, Fundus dataset and prostate dataset share the same dataloader script and demand the same dataset preprocessing and folder structure? Another concern is that, could it be possible for me to reproduce the prostate experiment with this repository?

    Thanks very much for your time and attention!

    opened by franciszchen 1
  • Obtained Mismatching Reproducing Results

    Obtained Mismatching Reproducing Results

    Dear Quande,

    Thank you for your time, and I modified the data-preparation step as you explained in previous issue #8 . However, sorry to report that, I still could not obtain the approximate Dice scores as is reported in the paper.

    I tried several times under different data pre-processing methods. Below is the description of my modifying.

    1. At first, The only python files that I edited myself is the prepare_dataset.py.

      • I resized the input images and labels (except domain 2) to 384*384. After that, the image data was converted to numpy array. Their amplitudes were extracted using the provided extract_amp_spectrum function in prepare_dataset.py.

      • The label mask was split into 2 masks, Disc and Cup. Pixels with gray level 255 are set 0 as background, gray level 128 and 0 are set 1 in each Disc and Cup masks as foreground. After all the steps above, the final numpy file stored is organized as [image_np, disc_mask_np, cup_mask_np], with size [384, 384, 5]. The image data are kept the original 3 channels, and the 2 masks are binary arrays.

      • For Domain 2, at the beginning, I cropped each image from Domain 2 into 2 individual samples that share the same label mask. The label mask I used is cropped from the left half side of the raw mask. Therefore, numbers of samples of each domain become 101, 318, 400, 400, respectively.

    To prove that I have correctly modified the prepare_dataset.py, here are some examples of Disc's and Cup's background and contour, that are extracted as the original code did.

    image

    1. Noticing that it is explained in paper that the best performance is reached when interpolation ratio is set random number between [0,1], but in the provided code it was set 1 (meaning all information transferred), I modified it to random number in function low_freq_mutate_np of fundus_dataloader.py. Below is the code:

    Original:

    def low_freq_mutate_np(amp_src, amp_trg, L=0.1):
        .......
        #ratio = random.randint(1, 10) / 10
    
        a_src[:, h1:h2, w1:w2] = a_trg[:, h1:h2, w1:w2]
        #a_src[:, h1:h2, w1:w2] = a_src[:, h1:h2, w1:w2] * ratio + a_trg[:, h1:h2, w1:w2] * (1 - ratio)
        ......
    

    Modified:

    def low_freq_mutate_np(amp_src, amp_trg, L=0.1):
        ......
        ratio = random.randint(1, 10) / 10
    
        # a_src[:, h1:h2, w1:w2] = a_trg[:, h1:h2, w1:w2]
        a_src[:, h1:h2, w1:w2] = a_src[:, h1:h2, w1:w2] * ratio + a_trg[:, h1:h2, w1:w2] * (1 - ratio)
        ......
    
    1. With unsatisfactory results produced, I further modified fundus_dataloader.py, added the random rotation and flipping step. I noticed that you have mentioned in issue #6 that the data augmentation steps are performed offline. I supposed that local kept data still remains unchanged during the training process, so I added the steps to getitem in fundus_dataloader.py. I noticed and made use of a provided function in the same python file, function RandomRotFlip(). Below is the code I added.

    Original code:

    def __getitem__(self, idx):
            raw_file = self.image_list[idx]
    
            mask_patches = []
            raw_inp = np.load(raw_file)
            image_patch = raw_inp[..., 0:3] 
            mask_patch = raw_inp[..., 3:]
            image_patches = image_patch.copy()
          
            # image_patches = 
            # print (image_patch.dtype)
            # print (mask_patch.dtype)
            disc_contour, disc_bg, cup_contour, cup_bg = _get_coutour_sample(mask_patch)
            # print ('raw', np.min(image_patch), np.max(image_patch))
            for tar_freq_domain in np.random.choice(self.freq_site_index, 2):
                   ......
    
    

    Modified code:

    def __init__(self, unseen_site_idx, client_idx=None, freq_site_idx=None, split='train', transform=None):
            self.unseen_site_idx = unseen_site_idx
            self.client_idx = client_idx    
            ......
    def __getitem__(self, idx):
            raw_file = self.image_list[idx]
    
            mask_patches = []
    
            raw_inp = np.load(raw_file)
            image_patch = raw_inp[..., 0:3]
            mask_patch = raw_inp[..., 3:]
    
            # image_patches = 
            # print (image_patch.dtype)
            # print (mask_patch.dtype)
    
            if self.client_idx != self.unseen_site_idx:
                sample = {"image": image_patch, "label": mask_patch}
                preprocessor = RandomRotFlip()
                sample = preprocessor(sample)
                mask_patch = sample["label"]
                image_patch = sample["image"]
    
            image_patches = image_patch.copy()
    
            disc_contour, disc_bg, cup_contour, cup_bg = _get_coutour_sample(mask_patch)
            # print ('raw', np.min(image_patch), np.max(image_patch))
            for tar_freq_domain in np.random.choice(self.freq_site_index, 2):
                     .......
    

    To prove that all the modifying steps are correct, below reveals the raw and transformed images. image

    • The image above is took under the [0,1] random-chosen interpolation ratio scheme.

    The Dice score of Discs after training for 100 epoches are reported in the following table. "1, 2, 3" refers to the modifying steps above, along with their corresponding combinations.

    image

    I only trained once for each situation, but obviously the performance I reproduced is much worse than that was reported in paper, about which I got quite confused. As the data preparation steps are absent in the provided code, I am sincerely seeking for your help in checking my preparation steps.

    Another question is, I did not observe any difference in treating data from domain 2 comparing to other domains, from the code provided. If 2 views are generated from a single image in domain 2, as is explained in issue #8 , the aggregating weight should be depending on [101, 318, 400, 400], but in line 63 of train_ELCFS.py the number is still [101, 159, 400, 400]. If the 2 views are stored in a single .npy file, then bugs would be raised in code below, from line 37~45 in fundus_dataloader.py because of changes in numpy array size.

    def __getitem__(self, idx):
            raw_file = self.image_list[idx]
    
            mask_patches = []
    
            raw_inp = np.load(raw_file)
            image_patch = raw_inp[..., 0:3]
            mask_patch = raw_inp[..., 3:]
    

    Thank you for your time. I was greatly inspired by your contribution, and I sincerely hope that a satisfactory result could be reproduced. I feel sorry if I have made any mistake, and please do not hesitate to point it out.

    Thank you very much!

    Best, Vivian

    opened by xxliang99 5
  • The train/val split of Prostate Datasets

    The train/val split of Prostate Datasets

    Hi, I cannot find the train/val split strategy or file list in your paper and your released code both on FedDG and SAML, could you please share the dataset split strategy or file list?

    Thanks a lot!

    opened by BurningFr 1
  • Question about data pre-processing

    Question about data pre-processing "center crop"

    Dear Quande,

    I read the previous issues and knew that data pre-processing had been finished before converted and saved as .npy files. I noticed that in the paper it is stated "For pre-processing, we center crop a 800 × 800 disc region for these data uniformly, then resize the cropped region to 384×384 as network input". Does it mean that only the 800800 region around the disc are saved to be the input, or the central 800800 region of the whole image? I only found the resize procedure in the provided prepare_dataset.py, so I am confused about the center-crop.

    Further more, in images from domain 2 of Fundus, there are 2 ROIs in one single image and are placed on each sides of the images. If center-crop is performed on each disc region, it would generate 2 ROIs for each image; If it is performed on the whole image, then the cropped image would not contain any ROI. Does it mean that 2 input samples would be generated and saved from one single raw image in domain 2?

    I would be very grateful if any clarification on "center crop" could be made. @liuquande

    Thank you for your time! I was deeply inspired by your great contribution.

    opened by xxliang99 2
Owner
Quande Liu
Medical Image Analysis, Model Robustness & Generalizability, Federated Learning
Quande Liu
Official repository for CVPR21 paper "Deep Stable Learning for Out-Of-Distribution Generalization".

StableNet StableNet is a deep stable learning method for out-of-distribution generalization. This is the official repo for CVPR21 paper "Deep Stable L

null 120 Dec 28, 2022
The code release of paper 'Domain Generalization for Medical Imaging Classification with Linear-Dependency Regularization' NIPS 2020.

Domain Generalization for Medical Imaging Classification with Linear Dependency Regularization The code release of paper 'Domain Generalization for Me

Yufei Wang 56 Dec 28, 2022
PyTorch implementation of Federated Learning with Non-IID Data, and federated learning algorithms, including FedAvg, FedProx.

Federated Learning with Non-IID Data This is an implementation of the following paper: Yue Zhao, Meng Li, Liangzhen Lai, Naveen Suda, Damon Civin, Vik

Youngjoon Lee 48 Dec 29, 2022
TianyuQi 10 Dec 11, 2022
Pytorch Code for "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation"

Medical-Transformer Pytorch Code for the paper "Medical Transformer: Gated Axial-Attention for Medical Image Segmentation" About this repo: This repo

Jeya Maria Jose 615 Dec 25, 2022
[CVPR21] LightTrack: Finding Lightweight Neural Network for Object Tracking via One-Shot Architecture Search

LightTrack: Finding Lightweight Neural Networks for Object Tracking via One-Shot Architecture Search The official implementation of the paper LightTra

Multimedia Research 290 Dec 24, 2022
Continuum Learning with GEM: Gradient Episodic Memory

Gradient Episodic Memory for Continual Learning Source code for the paper: @inproceedings{GradientEpisodicMemory, title={Gradient Episodic Memory

Facebook Research 360 Dec 27, 2022
Official pytorch implementation of "Feature Stylization and Domain-aware Contrastive Loss for Domain Generalization" ACMMM 2021 (Oral)

Feature Stylization and Domain-aware Contrastive Loss for Domain Generalization This is an official implementation of "Feature Stylization and Domain-

null 22 Sep 22, 2022
Implementation for "Domain-Specific Bias Filtering for Single Labeled Domain Generalization"

DSBF Introduction This repository contains the implementation code for paper: Domain-Specific Bias Filtering for Single Labeled Domain Generalization

ScottYuan 7 Jan 5, 2023
Alex Pashevich 62 Dec 24, 2022
FuseDream: Training-Free Text-to-Image Generationwith Improved CLIP+GAN Space OptimizationFuseDream: Training-Free Text-to-Image Generationwith Improved CLIP+GAN Space Optimization

FuseDream This repo contains code for our paper (paper link): FuseDream: Training-Free Text-to-Image Generation with Improved CLIP+GAN Space Optimizat

XCL 191 Dec 31, 2022
MLOps will help you to understand how to build a Continuous Integration and Continuous Delivery pipeline for an ML/AI project.

page_type languages products description sample python azure azure-machine-learning-service azure-devops Code which demonstrates how to set up and ope

null 1 Nov 1, 2021
Code for HLA-Face: Joint High-Low Adaptation for Low Light Face Detection (CVPR21)

HLA-Face: Joint High-Low Adaptation for Low Light Face Detection The official PyTorch implementation for HLA-Face: Joint High-Low Adaptation for Low L

Wenjing Wang 77 Dec 8, 2022
Repository relating to the CVPR21 paper TimeLens: Event-based Video Frame Interpolation

TimeLens: Event-based Video Frame Interpolation This repository is about the High Speed Event and RGB (HS-ERGB) dataset, used in the 2021 CVPR paper T

Robotics and Perception Group 544 Dec 19, 2022
Released code for Objects are Different: Flexible Monocular 3D Object Detection, CVPR21

MonoFlex Released code for Objects are Different: Flexible Monocular 3D Object Detection, CVPR21. Work in progress. Installation This repo is tested w

Yunpeng 169 Dec 6, 2022
Multi-atlas segmentation (MAS) is a promising framework for medical image segmentation

Multi-atlas segmentation (MAS) is a promising framework for medical image segmentation. Generally, MAS methods register multiple atlases, i.e., medical images with corresponding labels, to a target image;

NanYoMy 13 Oct 9, 2022
Code for: Gradient-based Hierarchical Clustering using Continuous Representations of Trees in Hyperbolic Space. Nicholas Monath, Manzil Zaheer, Daniel Silva, Andrew McCallum, Amr Ahmed. KDD 2019.

gHHC Code for: Gradient-based Hierarchical Clustering using Continuous Representations of Trees in Hyperbolic Space. Nicholas Monath, Manzil Zaheer, D

Nicholas Monath 35 Nov 16, 2022
A clean and robust Pytorch implementation of PPO on continuous action space.

PPO-Continuous-Pytorch I found the current implementation of PPO on continuous action space is whether somewhat complicated or not stable. And this is

XinJingHao 56 Dec 16, 2022