Sleep_Staging_Knowledge Distillation
This codebase implements knowledge distillation approach for ECG based sleep staging assisted by EEG based sleep staging model. Knowledge distillation is incorporated here by softmax distillation and another approach by Attention transfer based feature training. The combination of both is the proposed model.
The code implementation is done with Pytorch-lightning framework. Dependencies can be found in requirements.txt
RESEARCH
DATASET
Montreal Archive of Sleep Studies (MASS) - Complete 200 subject data used.
- SS1 and SS3 subsets follow AASM guidelines
- SS2, SS4, SS5 subsets follow R_K guidelines
KNOWLEDGE DISTILLATION FRAMEWORK
Knowledge distillation framework using minor modifications in U-Time as base model.
Improvement in bottleneck features from ECG_Base model to KD_model as a result of Knowledge distillation compared to EEG_base model features.
Case 1 : KD_model predicting correctly, ECG_Base predicting incorrectly
Case 2 : KD_model predicting incorrectly, ECG_Base predicting correctly
Run Training
Run train.py from 3-class or 4-class directories
To train baseline models
python train.py --model_type <"base model type"> --model_ckpt_name <"ckpt name">
To run Knowledge Distillation
- Feature Training
python train.py --model_type "feat_train" --model_ckpt_name <"ckpt name"> --eeg_baseline_path <"eeg base ckpt path">
- Feat_Temp (AT+SD+CL)
python train.py --model_type "Feat_Temp" --model_ckpt_name <"ckpt name"> --feat_path <"path to feature trained ckpt">
- Feat_WCE (AT+CL)
python train.py --model_type "feat_wce" --model_ckpt_name <"ckpt name"> --feat_path <"path to feature trained ckpt">
- KD-Temp (SD+CL)
python train.py --model_type "kd_temp" --model_ckpt_name <"ckpt name"> --eeg_baseline_path <"eeg base ckpt path">
Run Testing
Run test.py from 3-class or 4-class directories
To test from checkpoints
python test.py --model_type <"model type"> --test_ckpt <"Path to checkpoint>
Other arguments can be used for training and testing as per requirements
Reproducing experiments
Checkpoints to reproduce the test results can be found in this link
Directory Map
Dataset Spliting:
Splits Data in train-val-test for 4-class and 3-class cases (AASM and R_K both)
├─ Dataset_split
├── Data_split_3class_AllData30s_R_K.py
├── Data_split_3class_AllData_AASM.py
├── Data_split_AllData_30s_R_K.py
└── Data_split_All_Data_AASM.py
3 Class Classification:
Run train.py with neccessary arguments for training 3-class sleep staging
├── 3_class
│ ├── datasets
│ │ ├── __init__.py
│ │ └── mass.py
│ │
│ ├── models
│ │ ├── __init__.py
│ │ ├── ecg_base.py
│ │ ├── eeg_base.py
│ │ ├── FEAT_TEMP.py
│ │ ├── FEAT_TRAINING.py
│ │ ├── FEAT_WCE.py
│ │ └── KD_TEMP.py
│ │
│ ├── test.py
│ ├── train.py
│ └── utils
│ ├── __init__.py
│ ├── arg_utils.py
│ ├── callback_utils.py
│ ├── dataset_utils.py
│ └── model_utils.py
4 Class Classification:
Run train.py with neccessary arguments for training 4-class sleep staging
├── 4_class
│ ├── datasets
│ │ ├── __init__.py
│ │ └── mass.py
│ │
│ ├── models
│ │ ├── __init__.py
│ │ ├── ecg_base.py
│ │ ├── eeg_base.py
│ │ ├── FEAT_TEMP.py
│ │ ├── FEAT_TRAINING.py
│ │ ├── FEAT_WCE.py
│ │ └── KD_TEMP.py
│ │
│ ├── test.py
│ ├── train.py
│ └── utils
│ ├── __init__.py
│ ├── arg_utils.py
│ ├── callback_utils.py
│ ├── dataset_utils.py
│ └── model_utils.py