代码地址:
https://github.com/Sharpiless/yolov5-knowledge-distillation
教师模型:
python train.py --weights weights/yolov5m.pt \
--cfg models/yolov5m.yaml --data data/voc.yaml --epochs 50 \
--batch-size 8 --device 0 --hyp data/hyp.scratch.yaml
蒸馏训练:
python train.py --weights weights/yolov5s.pt \
--cfg models/yolov5s.yaml --data data/voc.yaml --epochs 50 \
--batch-size 8 --device 0 --hyp data/hyp.scratch.yaml \
--t_weights yolov5m.pt --distill
训练参数:
--weights:预训练模型
--t_weights:教师模型权重
--distill:使用知识蒸馏进行训练
--dist_loss:l2或者kl
--temperature:使用知识蒸馏时的温度
使用《Object detection at 200 Frames Per Second》中的损失
这篇文章分别对这几个损失函数做出改进,具体思路为只有当teacher network的objectness value高时,才学习bounding box坐标和class probabilities。
实验结果:
这里假设VOC2012中新增加的数据为无标签数据(2k张)。
教师模型 | 训练方法 | 蒸馏损失 | P | R | mAP50 |
---|---|---|---|---|---|
无 | 正常训练 | 不使用 | 0.7756 | 0.7115 | 0.7609 |
Yolov5l | output based | l2 | 0.7585 | 0.7198 | 0.7644 |
Yolov5l | output based | KL | 0.7417 | 0.7207 | 0.7536 |
Yolov5m | output based | l2 | 0.7682 | 0.7436 | 0.7976 |
Yolov5m | output based | KL | 0.7731 | 0.7313 | 0.7931 |
参数和细节正在完善,支持KL散度、L2 logits损失和Sigmoid蒸馏损失等
1. 正常训练:
2. L2蒸馏损失:
我的公众号:
关于作者
AI Studio:https://aistudio.baidu.com/aistudio/personalcenter/thirdview/67156