News
- [2021-9-6]:
v0.0.0
has been released.
Introduction
GeDML is an easy-to-use generalized deep metric learning library, which contains:
- State-of-the-art DML algorithms: We contrain 18+ losses functions and 6+ sampling strategies, and divide these algorithms into three categories (i.e., collectors, selectors, and losses).
- Bridge bewteen DML and SSL: We attempt to bridge the gap between deep metric learning and self-supervised learning through specially designed modules, such as
collectors
. - Auxiliary modules to assist in building: We also encapsulates the upper interface for users to start programs quickly and separates the codes and configs for managing hyper-parameters conveniently.
Installation
Pip
pip install gedml
Framework
This project is modular in design. The pipeline diagram is as follows:
Code structure
Method
Collectors
method | description |
---|---|
BaseCollector | Base class |
DefaultCollector | Do nothing |
ProxyCollector | Maintain a set of proxies |
MoCoCollector | paper: Momentum Contrast for Unsupervised Visual Representation Learning |
SimSiamCollector | paper: Exploring Simple Siamese Representation Learning |
HDMLCollector | paper: Hardness-Aware Deep Metric Learning |
DAMLCollector | paper: Deep Adversarial Metric Learning |
DVMLCollector | paper: Deep Variational Metric Learning |
Losses
classifier-based
method | description |
---|---|
CrossEntropyLoss | Cross entropy loss for unsupervised methods |
LargeMarginSoftmaxLoss | paper: Large-Margin Softmax Loss for Convolutional Neural Networks |
ArcFaceLoss | paper: ArcFace: Additive Angular Margin Loss for Deep Face Recognition |
CosFaceLoss | paper: CosFace: Large Margin Cosine Loss for Deep Face Recognition |
pair-based
method | description |
---|---|
ContrastiveLoss | paper: Learning a Similarity Metric Discriminatively, with Application to Face Verification |
MarginLoss | paper: Sampling Matters in Deep Embedding Learning |
TripletLoss | paper: Learning local feature descriptors with triplets and shallow convolutional neural networks |
AngularLoss | paper: Deep Metric Learning with Angular Loss |
CircleLoss | paper: Circle Loss: A Unified Perspective of Pair Similarity Optimization |
FastAPLoss | paper: Deep Metric Learning to Rank |
LiftedStructureLoss | paper: Deep Metric Learning via Lifted Structured Feature Embedding |
MultiSimilarityLoss | paper: Multi-Similarity Loss With General Pair Weighting for Deep Metric Learning |
NPairLoss | paper: Improved Deep Metric Learning with Multi-class N-pair Loss Objective |
SignalToNoiseRatioLoss | paper: Signal-To-Noise Ratio: A Robust Distance Metric for Deep Metric Learning |
PosPairLoss | paper: Exploring Simple Siamese Representation Learning |
proxy-based
method | description |
---|---|
ProxyLoss | paper: No Fuss Distance Metric Learning Using Proxies |
ProxyAnchorLoss | paper: Proxy Anchor Loss for Deep Metric Learning |
SoftTripleLoss | paper: SoftTriple Loss: Deep Metric Learning Without Triplet Sampling |
Selectors
method | description |
---|---|
BaseSelector | Base class |
DefaultSelector | Do nothing |
DenseTripletSelector | Select all triples |
DensePairSelector | Select all pairs |
Quickstart
Please set the environment variable WORKSPACE
first to indicate where to manage your project.
Initialization
Use ConfigHandler
to create all objects.
config_handler = ConfigHandler()
config_handler.get_params_dict()
objects_dict = config_handler.create_all()
Start
Use manager
to automatically call trainer
and tester
.
manager = utils.get_default(objects_dict, "managers")
manager.run()
Directly use trainer
and tester
.
trainer = utils.get_default(objects_dict, "trainers")
tester = utils.get_default(objects_dict, "testers")
recorder = utils.get_default(objects_dict, "recorders")
# start to train
utils.func_params_mediator(
[objects_dict],
trainer.__call__
)
# start to test
metrics = utils.func_params_mediator(
[
{"recorders": recorder},
objects_dict,
],
tester.__call__
)
Document
For more information, please refer to:
Some specific guidances:
Configs
We will continually update the optimal parameters of different configs in TsinghuaCloud
Code Reference
- KevinMusgrave / pytorch-metric-learning
- KevinMusgrave / powerful-benchmarker
- Confusezius / Deep-Metric-Learning-Baselines
- facebookresearch / moco
- PatrickHua / SimSiam
- ujjwaltiwari / Deep_Variational_Metric_Learning
- idstcv / SoftTriple
- wzzheng / HDML
- google-research / simclr
- kunhe / FastAP-metric-learning
- wy1iu / LargeMargin_Softmax_Loss
- tjddus9597 / Proxy-Anchor-CVPR2020
- facebookresearch / deit
TODO:
- assert parameters
- distributed methods and Non-distributed methods!!!
- write github action to automate unit-test, package publish and docs building.
- add cross-validation splits protocol.