SUMO - Slim U-Net trained on MODA
Implementation of the SUMO (Slim U-Net trained on MODA) model as described in:
TODO: add reference to paper once available
Installation Guide
On Linux with anaconda
or miniconda
installed, the project can be used by running the following commands to clone the repository, create a new environment and install the required dependencies:
git clone https://github.com/dslaborg/sumo.git
cd sumo
conda env create --file environment.yaml
conda activate sumo
Scripts - Quick Guide
Running and evaluating an experiment
The main model training and evaluation procedure is implemented in bin/train.py
and bin/eval.py
using the Pytorch Lightning framework. A chosen configuration used to train the model is called an experiment, and the evaluation is carried out using a configuration and the result folder of a training run.
train.py
Trains the model as specified in the corresponding configuration file, writes its log to the console and saves a log file and intermediate results for Tensorboard and model checkpoints to a result directory.
Arguments:
-e NAME, --experiment NAME
: name of experiment to run, for which aNAME.yaml
file has to exist in theconfig
directory; default isdefault
eval.py
Evaluates a trained model, either on the validation data or test data and reports the achieved metrics.
Arguments:
-e NAME, --experiment NAME
: name of configuration file, that should be used for evaluation, for which aNAME.yaml
file has to exist in theconfig
directory; usually equals the experiment used to train the model; default isdefault
-i PATH, --input PATH
: path containing the model that should be evaluated; the given input can either be a model checkpoint, which then will be used directly, or the output directory of atrain.py
execution, in which case the best model will be used fromPATH/models/
; if the configuration has cross validation enabled, the output directory is expected and the best model per fold will be obtained fromPATH/fold_*/models/
; no default value-t, --test
: if given, the test data is used instead of the validation data
Further example scripts
In addition to scripts used to create the figures in our manuscript (spindle_analysis.py
, spindle_analysis_correlations.py
and spindle_detection_examply.py
), the scripts
directory contains two scripts that demonstrate the usage of this project.
create_data_splits.py
Demonstrates the procedure used to split the data into test and non-test subjects and the subsequent creation of a hold-out validation set and (alternatively) cross validation folds.
Arguments:
-i PATH, --input PATH
: path containing the (necessary) input data, as produced by the MODA file MODA02_genEEGVectBlock.m; relative paths starting from thescripts
directory; default is../input/
-o PATH, --output PATH
: path in which the generated data splits should be stored in; relative paths starting from thescripts
directory; default is../output/datasets_{datatime}
-n NUMBER, --n_datasets NUMBER
: number of random split-candidates drawn/generated; default is25
-t FRACTION, --test FRACTION
: Proportion of data that is used as test data;0<=FRACTION<=1
; default is0.2
predict_plain_data.py
Demonstrates how to predict spindles with a trained SUMO model on arbitrary EEG data, which is expected as a dict with the keys representing the EEG channels and the values the corresponding data vector.
Arguments:
-d PATH, --data_path PATH
: path containing the input data, either in.pickle
or.npy
format, as a dict with the channel name as key and the EEG data as value; relative paths starting from thescripts
directory; no default value-m PATH, --model_path PATH
: path containing the model checkpoint, which should be used to predict spindles; relative paths starting from thescripts
directory; default is../output/final.ckpt
-g NUMBER, --gpus NUMBER
: number of GPUs to use, if0
is given, calculations are done using CPUs; default is0
-sr RATE, --sample_rate RATE
: sample rate of the provided data; default is100.0
Project Setup
The project is set up as follows:
bin/
: contains thetrain.py
andeval.py
scripts, which are used for model training and subsequent evaluation in experiments (as configured within theconfig
directory) using the Pytorch Lightning frameworkconfig/
: contains the configurations of the experiments, configuring how to train or evaluate the modeldefault.yaml
: provides a sensible default configurationfinal.yaml
: contains the configuration used to train the final model checkpoint (output/final.ckpt
)predict.yaml
: configuration that can be used to predict spindles on arbitrary data, e.g. by using the script atscripts/predict_plain_data.py
input/
: should contain the used input files, e.g. the EEG data and annotated spindles as produced by the MODA repository and transformed as demonstrated in the/scripts/create_data_splits.py
fileoutput/
: contains generated output by any experiment runs or scripts, e.g. the created figuresfinal.ckpt
: the final model checkpoint, on which the test data performance, as reported in the paper, was obtained
scripts/
: various scripts used to create the plots of our paper and to demonstrate the usage of this projecta7/
: python implementation of the A7 algorithm as described in:Karine Lacourse, Jacques Delfrate, Julien Beaudry, Paul E. Peppard and Simon C. Warby. "A sleep spindle detection algorithm that emulates human expert spindle scoring." Journal of Neuroscience Methods 316 (2019): 3-11.
create_data_splits.py
: demonstrates the procedure, how the data set splits were obtained, including the evaluation on the A7 algorithmpredict_plain_data.py
: demonstrates the prediction of spindles on arbitrary EEG data, using a trained model checkpointspindle_analysis.py
,spindle_analysis_correlations.py
,spindle_detection_example.py
: scripts used to create some of the figures used in our paper
sumo/
: the implementation of the SUMO model and used classes and functions, for more information see the docstrings
Configuration Parameters
The configuration of an experiment is implemented using yaml configuration files. These files must be placed within the config
directory and must match the name past as --experiment
to the eval.py
or train.py
script. The default.yaml
is always loaded as a set of default configuration parameters and parameters specified in an additional file overwrite the default values. Any parameters or groups of parameters that should be None
, have to be configured as either null
or Null
following the YAML definition.
The available parameters are as follows:
data
: configuration of the used input data; optional, can beNone
if spindle should be annotated on arbitrary EEG datadirectory
andfile_name
: the input file containing theSubject
objects (seescripts/create_data_splits.py
) is expected to be located at${directory}/${file_name}
, where relative paths are to be starting from the root project directory; the file should be a (pickled) dict with the name of a data set as key and the list of corresponding subjects as value; default isinput/subjects.pickle
split
: describing the keys of the data sets to be used, specifying eithertrain
andvalidation
, orcross_validation
, and optionallytest
cross_validation
: can be either an integer k>=2, in which the keysfold_0
, ...,fold_{k-1}
are expected to exist, or a list of keys
batch_size
: size of the used minbatches during training; default is12
preprocessing
: if z-scoring should be performed on the EEG data, default isTrue
experiment
: definition of the performed experiment; mandatorymodel
: definition of the model configuration; mandatoryn_classes
: number of output parameters; default is2
activation
: name of an activation function as defined intorch.nn
package; default isReLU
depth
: number of layers of the U excluding the last layer; default is2
channel_size
: number of filters of the convolutions in the first layer; default is16
pools
: list containing the size of pooling and upsampling operations; has to contain as many values as the value ofdepth
; default[4;4]
convolution_params
: parameters used by the Conv1d modulesmoving_avg_size
: width of the moving average filter; default is42
train
: configuration used in training the model; mandatoryn_epochs
: maximal number of epochs to be run before stopping training; default is800
early_stopping
: number of epochs without any improvement in theval_f1_mean
metric, after which training is stopped; default is300
optimizer
: configuration of an optimizer as defined intorch.optim
package; containsclass_name
(default isAdam
) and parameters, which are passed to the constructor of the used optimizer classlr_scheduler
: used learning rate scheduler; optional, default isNone
loss
: configuration of loss function as defined either insumo.loss
package (GeneralizedDiceLoss
) ortorch.nn
package; containsclass_name
(default isGeneralizedDiceLoss
) and parameters, which are passed to the constructor of the used loss class
validation
: configuration used in evaluating the model; mandatoryoverlap_threshold_step
: step size of the overlap thresholds used to calculate (validation) F1 scores