I was testing AUC margin loss on NIH Chest X-rays dataset in a multi-label binary classification scenario. I wasn't able to improve my validation ROC AUC score at all.
I noticed there is huge difference between train and val loss. Train loss starts from 0.1 and decreases to 0.00x in couple epochs (1.2k iterations with batch size 64 per epoch), but val loss doesn't even move. I tried both training from ImageNet weights and from my own best checkpoint. When I start training from ImageNet weights, val loss doesn't decrease and val auc oscillates between 0.49 and 0.51. When I start training from my own best checkpoint, val loss doesn't decrease again and val auc changes on the scale of 1e-6. Basically, we can say that there is no learning.
My imbalance ratios are
[0.10288084, 0.02472351, 0.11868534, 0.17722083, 0.05124866,
0.05639493, 0.01206743, 0.04725294, 0.04162504, 0.02054049,
0.02244024, 0.01503746, 0.03019087, 0.00202462]
which are calculated by simply doing n_positive_samples / (n_positive_samples + n_negative_samples)
. I use AUCM_MultiLabel
loss with imbalance ratios above.
I tried PESG
optimizer with different learning rates 1e-1 to 1e-5 but the final result was always same. I use gamma=500, margin=1.0 and weight_decay=1e-5. Finally, I update learning rate and regularizer after first epoch.
The only thing that comes up to my mind is using different heads in models so I checked libauc.models
but they are same as well. My model outputs (N, 14) shaped sigmoided logits.
Basically, I did everything except the last one in Useful Tips section.
- Your data should have binary labels 0,1 and 1 is the minority class and 0 is the majority class
- Compute the imbalance_ratio from your train set and pass it to AUCMLoss(imratio=xxx)
- Adopt a proper initial learning rate, e.g., lr=[0.1, 0.05] usually works better
- Choose libauc.optimizers.PESG to optimize AUCMLoss(imratio=xxx)
- Use optimizer.update_regularizer(decay_factor=10) to update learning rate and regularizer in stagewise
- Add activation layer, e.g., torch.sigmoid(logits), before passing model outputs to loss function
- Reshape both variables preds and targets to (N, 1) before calling loss function
I don't think the final point is valid for multi label scenario because AUCM_MultiLabel already iterates over the second dimension of predictions and labels, computes loss for every label and accumulates it to total_loss. At this point I can't figure out what I'm missing here.
I'm also providing a portion of my training logs.
train_loss: 0.081627 - lr: 0.00011867: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1227/1227 [08:49<00:00, 2.32it/s]
val_loss: 0.552354: 91%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | 80/88 [00:25<00:02, 3.71it/s]
val_loss: 0.550620: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:41<00:00, 2.11it/s]
Epoch 1 - Training Loss: 0.081627 - Validation Loss: 0.550620 - Validation Micro Average ROC AUC Score: 0.893272 - Macro Average ROC AUC Score: 0.855293
Saving model to ../models/efficientnet/efficientnetv2_s_auc.pt (validation loss decreased from inf to 0.550620)
Reducing learning rate to 0.00001 @ T=1227!
Updating regularizer @ T=1227!
train_loss: 0.069679 - lr: 0.00092689: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1227/1227 [08:54<00:00, 2.29it/s]
val_loss: 0.409439: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:28<00:00, 3.12it/s]
Epoch 2 - Training Loss: 0.069679 - Validation Loss: 0.409439 - Validation Micro Average ROC AUC Score: 0.892139 - Macro Average ROC AUC Score: 0.855510
Saving model to ../models/efficientnet/efficientnetv2_s_auc.pt (validation loss decreased from 0.550620 to 0.409439)
Reducing learning rate to 0.00009 @ T=1227!
Updating regularizer @ T=1227!
train_loss: 0.061587 - lr: 0.00025882: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1227/1227 [08:52<00:00, 2.30it/s]
val_loss: 0.359555: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:27<00:00, 3.14it/s]
Epoch 3 - Training Loss: 0.061587 - Validation Loss: 0.359555 - Validation Micro Average ROC AUC Score: 0.891677 - Macro Average ROC AUC Score: 0.855467
Saving model to ../models/efficientnet/efficientnetv2_s_auc.pt (validation loss decreased from 0.409439 to 0.359555)
Reducing learning rate to 0.00003 @ T=1227!
Updating regularizer @ T=1227!
train_loss: 0.053611 - lr: 0.00073131: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1227/1227 [08:51<00:00, 2.31it/s]
val_loss: 0.343951: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:27<00:00, 3.15it/s]
Epoch 4 - Training Loss: 0.053611 - Validation Loss: 0.343951 - Validation Micro Average ROC AUC Score: 0.890916 - Macro Average ROC AUC Score: 0.855372
Saving model to ../models/efficientnet/efficientnetv2_s_auc.pt (validation loss decreased from 0.359555 to 0.343951)
Reducing learning rate to 0.00007 @ T=1227!
Updating regularizer @ T=1227!
train_loss: 0.048404 - lr: 0.00049360: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1227/1227 [08:51<00:00, 2.31it/s]
val_loss: 0.351462: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:28<00:00, 3.13it/s]
Epoch 5 - Training Loss: 0.048404 - Validation Loss: 0.351462 - Validation Micro Average ROC AUC Score: 0.890097 - Macro Average ROC AUC Score: 0.855150
Reducing learning rate to 0.00005 @ T=1227!
Updating regularizer @ T=1227!
train_loss: 0.042902 - lr: 0.00047681: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1227/1227 [08:51<00:00, 2.31it/s]
val_loss: 0.351917: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:27<00:00, 3.16it/s]
Epoch 6 - Training Loss: 0.042902 - Validation Loss: 0.351917 - Validation Micro Average ROC AUC Score: 0.889476 - Macro Average ROC AUC Score: 0.855386
Reducing learning rate to 0.00005 @ T=1227!
Updating regularizer @ T=1227!
train_loss: 0.039794 - lr: 0.00074670: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1227/1227 [08:49<00:00, 2.32it/s]
val_loss: 0.359203: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:27<00:00, 3.16it/s]
Epoch 7 - Training Loss: 0.039794 - Validation Loss: 0.359203 - Validation Micro Average ROC AUC Score: 0.889033 - Macro Average ROC AUC Score: 0.855199
Reducing learning rate to 0.00007 @ T=1227!
Updating regularizer @ T=1227!
train_loss: 0.035798 - lr: 0.00024610: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1227/1227 [08:52<00:00, 2.31it/s]
val_loss: 0.363825: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:27<00:00, 3.15it/s]
Epoch 8 - Training Loss: 0.035798 - Validation Loss: 0.363825 - Validation Micro Average ROC AUC Score: 0.887879 - Macro Average ROC AUC Score: 0.855122
Reducing learning rate to 0.00002 @ T=1227!
Updating regularizer @ T=1227!
train_loss: 0.033877 - lr: 0.00093589: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1227/1227 [08:53<00:00, 2.30it/s]
val_loss: 0.368633: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:28<00:00, 3.12it/s]
Epoch 9 - Training Loss: 0.033877 - Validation Loss: 0.368633 - Validation Micro Average ROC AUC Score: 0.887842 - Macro Average ROC AUC Score: 0.854925
Reducing learning rate to 0.00009 @ T=1227!
Updating regularizer @ T=1227!
train_loss: 0.031171 - lr: 0.00011414: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1227/1227 [08:50<00:00, 2.31it/s]
val_loss: 0.373346: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:27<00:00, 3.15it/s]
Epoch 10 - Training Loss: 0.031171 - Validation Loss: 0.373346 - Validation Micro Average ROC AUC Score: 0.887144 - Macro Average ROC AUC Score: 0.855042
Reducing learning rate to 0.00001 @ T=1227!
Updating regularizer @ T=1227!
train_loss: 0.034546 - lr: 0.00099968: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1227/1227 [08:50<00:00, 2.31it/s]
val_loss: 0.368587: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:27<00:00, 3.15it/s]
Epoch 11 - Training Loss: 0.034546 - Validation Loss: 0.368587 - Validation Micro Average ROC AUC Score: 0.887355 - Macro Average ROC AUC Score: 0.855000
Reducing learning rate to 0.00010 @ T=1227!
Updating regularizer @ T=1227!
train_loss: 0.033058 - lr: 0.00012381: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1227/1227 [08:49<00:00, 2.32it/s]
val_loss: 0.374046: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:28<00:00, 3.10it/s]
Epoch 12 - Training Loss: 0.033058 - Validation Loss: 0.374046 - Validation Micro Average ROC AUC Score: 0.886223 - Macro Average ROC AUC Score: 0.854960
Reducing learning rate to 0.00001 @ T=1227!
Updating regularizer @ T=1227!
train_loss: 0.030131 - lr: 0.00091735: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1227/1227 [08:50<00:00, 2.31it/s]
val_loss: 0.379541: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:27<00:00, 3.18it/s]
Epoch 13 - Training Loss: 0.030131 - Validation Loss: 0.379541 - Validation Micro Average ROC AUC Score: 0.886065 - Macro Average ROC AUC Score: 0.854729
Reducing learning rate to 0.00009 @ T=1227!
Updating regularizer @ T=1227!
train_loss: 0.027916 - lr: 0.00027196: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1227/1227 [08:52<00:00, 2.30it/s]
val_loss: 0.385755: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:27<00:00, 3.16it/s]
Epoch 14 - Training Loss: 0.027916 - Validation Loss: 0.385755 - Validation Micro Average ROC AUC Score: 0.883915 - Macro Average ROC AUC Score: 0.854616
Reducing learning rate to 0.00003 @ T=1227!
Updating regularizer @ T=1227!
train_loss: 0.025726 - lr: 0.00071566: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1227/1227 [08:53<00:00, 2.30it/s]
val_loss: 0.387343: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:27<00:00, 3.17it/s]
Epoch 15 - Training Loss: 0.025726 - Validation Loss: 0.387343 - Validation Micro Average ROC AUC Score: 0.884937 - Macro Average ROC AUC Score: 0.854766
Reducing learning rate to 0.00007 @ T=1227!
Updating regularizer @ T=1227!
train_loss: 0.024255 - lr: 0.00051047: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1227/1227 [08:55<00:00, 2.29it/s]
val_loss: 0.393512: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:27<00:00, 3.14it/s]
Epoch 16 - Training Loss: 0.024255 - Validation Loss: 0.393512 - Validation Micro Average ROC AUC Score: 0.883381 - Macro Average ROC AUC Score: 0.854501
Reducing learning rate to 0.00005 @ T=1227!
Updating regularizer @ T=1227!
train_loss: 0.022581 - lr: 0.00046013: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1227/1227 [08:54<00:00, 2.30it/s]
val_loss: 0.400079: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:27<00:00, 3.14it/s]
Epoch 17 - Training Loss: 0.022581 - Validation Loss: 0.400079 - Validation Micro Average ROC AUC Score: 0.883271 - Macro Average ROC AUC Score: 0.854286
Reducing learning rate to 0.00005 @ T=1227!
Updating regularizer @ T=1227!
train_loss: 0.021331 - lr: 0.00076182: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1227/1227 [08:52<00:00, 2.31it/s]
val_loss: 0.396512: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:27<00:00, 3.15it/s]
Epoch 18 - Training Loss: 0.021331 - Validation Loss: 0.396512 - Validation Micro Average ROC AUC Score: 0.882822 - Macro Average ROC AUC Score: 0.854248
Reducing learning rate to 0.00008 @ T=1227!
Updating regularizer @ T=1227!
train_loss: 0.020128 - lr: 0.00023381: 100%|βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 1227/1227 [08:51<00:00, 2.31it/s]
val_loss: 0.401782: 100%|ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ| 88/88 [00:28<00:00, 3.13it/s]
Epoch 19 - Training Loss: 0.020128 - Validation Loss: 0.401782 - Validation Micro Average ROC AUC Score: 0.882145 - Macro Average ROC AUC Score: 0.854304
Early stopping (validation loss didn't increase for 15 epochs/steps)
Best validation loss is 0.343951