Segmentation models with pretrained backbones. PyTorch.

Overview

logo
Python library with Neural Networks for Image
Segmentation based on PyTorch.

PyPI version Build Status Documentation Status
Downloads Generic badge

The main features of this library are:

  • High level API (just two lines to create a neural network)
  • 9 models architectures for binary and multi class segmentation (including legendary Unet)
  • 113 available encoders
  • All encoders have pre-trained weights for faster and better convergence

πŸ“š Project Documentation πŸ“š

Visit Read The Docs Project Page or read following README to know more about Segmentation Models Pytorch (SMP for short) library

πŸ“‹ Table of content

  1. Quick start
  2. Examples
  3. Models
    1. Architectures
    2. Encoders
    3. Timm Encoders
  4. Models API
    1. Input channels
    2. Auxiliary classification output
    3. Depth
  5. Installation
  6. Competitions won with the library
  7. Contributing
  8. Citing
  9. License

⏳ Quick start

1. Create your first Segmentation model with SMP

Segmentation model is just a PyTorch nn.Module, which can be created as easy as:

import segmentation_models_pytorch as smp

model = smp.Unet(
    encoder_name="resnet34",        # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights="imagenet",     # use `imagenet` pre-trained weights for encoder initialization
    in_channels=1,                  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=3,                      # model output channels (number of classes in your dataset)
)
  • see table with available model architectures
  • see table with available encoders and their corresponding weights

2. Configure data preprocessing

All encoders have pretrained weights. Preparing your data the same way as during weights pre-training may give your better results (higher metric score and faster convergence). But it is relevant only for 1-2-3-channels images and not necessary in case you train the whole model, not only decoder.

from segmentation_models_pytorch.encoders import get_preprocessing_fn

preprocess_input = get_preprocessing_fn('resnet18', pretrained='imagenet')

Congratulations! You are done! Now you can train your model with your favorite framework!

πŸ’‘ Examples

  • Training model for cars segmentation on CamVid dataset here.
  • Training SMP model with Catalyst (high-level framework for PyTorch), TTAch (TTA library for PyTorch) and Albumentations (fast image augmentation library) - here Open In Colab
  • Training SMP model with Pytorch-Lightning framework - here (clothes binary segmentation by @teranus).

πŸ“¦ Models

Architectures

Encoders

The following is a list of supported encoders in the SMP. Select the appropriate family of encoders and click to expand the table and select a specific encoder and its pre-trained weights (encoder_name and encoder_weights parameters).

ResNet
Encoder Weights Params, M
resnet18 imagenet / ssl / swsl 11M
resnet34 imagenet 21M
resnet50 imagenet / ssl / swsl 23M
resnet101 imagenet 42M
resnet152 imagenet 58M
ResNeXt
Encoder Weights Params, M
resnext50_32x4d imagenet / ssl / swsl 22M
resnext101_32x4d ssl / swsl 42M
resnext101_32x8d imagenet / instagram / ssl / swsl 86M
resnext101_32x16d instagram / ssl / swsl 191M
resnext101_32x32d instagram 466M
resnext101_32x48d instagram 826M
ResNeSt
Encoder Weights Params, M
timm-resnest14d imagenet 8M
timm-resnest26d imagenet 15M
timm-resnest50d imagenet 25M
timm-resnest101e imagenet 46M
timm-resnest200e imagenet 68M
timm-resnest269e imagenet 108M
timm-resnest50d_4s2x40d imagenet 28M
timm-resnest50d_1s4x24d imagenet 23M
Res2Ne(X)t
Encoder Weights Params, M
timm-res2net50_26w_4s imagenet 23M
timm-res2net101_26w_4s imagenet 43M
timm-res2net50_26w_6s imagenet 35M
timm-res2net50_26w_8s imagenet 46M
timm-res2net50_48w_2s imagenet 23M
timm-res2net50_14w_8s imagenet 23M
timm-res2next50 imagenet 22M
RegNet(x/y)
Encoder Weights Params, M
timm-regnetx_002 imagenet 2M
timm-regnetx_004 imagenet 4M
timm-regnetx_006 imagenet 5M
timm-regnetx_008 imagenet 6M
timm-regnetx_016 imagenet 8M
timm-regnetx_032 imagenet 14M
timm-regnetx_040 imagenet 20M
timm-regnetx_064 imagenet 24M
timm-regnetx_080 imagenet 37M
timm-regnetx_120 imagenet 43M
timm-regnetx_160 imagenet 52M
timm-regnetx_320 imagenet 105M
timm-regnety_002 imagenet 2M
timm-regnety_004 imagenet 3M
timm-regnety_006 imagenet 5M
timm-regnety_008 imagenet 5M
timm-regnety_016 imagenet 10M
timm-regnety_032 imagenet 17M
timm-regnety_040 imagenet 19M
timm-regnety_064 imagenet 29M
timm-regnety_080 imagenet 37M
timm-regnety_120 imagenet 49M
timm-regnety_160 imagenet 80M
timm-regnety_320 imagenet 141M
GERNet
Encoder Weights Params, M
timm-gernet_s imagenet 6M
timm-gernet_m imagenet 18M
timm-gernet_l imagenet 28M
SE-Net
Encoder Weights Params, M
senet154 imagenet 113M
se_resnet50 imagenet 26M
se_resnet101 imagenet 47M
se_resnet152 imagenet 64M
se_resnext50_32x4d imagenet 25M
se_resnext101_32x4d imagenet 46M
SK-ResNe(X)t
Encoder Weights Params, M
timm-skresnet18 imagenet 11M
timm-skresnet34 imagenet 21M
timm-skresnext50_32x4d imagenet 25M
DenseNet
Encoder Weights Params, M
densenet121 imagenet 6M
densenet169 imagenet 12M
densenet201 imagenet 18M
densenet161 imagenet 26M
Inception
Encoder Weights Params, M
inceptionresnetv2 imagenet / imagenet+background 54M
inceptionv4 imagenet / imagenet+background 41M
xception imagenet 22M
EfficientNet
Encoder Weights Params, M
efficientnet-b0 imagenet 4M
efficientnet-b1 imagenet 6M
efficientnet-b2 imagenet 7M
efficientnet-b3 imagenet 10M
efficientnet-b4 imagenet 17M
efficientnet-b5 imagenet 28M
efficientnet-b6 imagenet 40M
efficientnet-b7 imagenet 63M
timm-efficientnet-b0 imagenet / advprop / noisy-student 4M
timm-efficientnet-b1 imagenet / advprop / noisy-student 6M
timm-efficientnet-b2 imagenet / advprop / noisy-student 7M
timm-efficientnet-b3 imagenet / advprop / noisy-student 10M
timm-efficientnet-b4 imagenet / advprop / noisy-student 17M
timm-efficientnet-b5 imagenet / advprop / noisy-student 28M
timm-efficientnet-b6 imagenet / advprop / noisy-student 40M
timm-efficientnet-b7 imagenet / advprop / noisy-student 63M
timm-efficientnet-b8 imagenet / advprop 84M
timm-efficientnet-l2 noisy-student 474M
timm-efficientnet-lite0 imagenet 4M
timm-efficientnet-lite1 imagenet 5M
timm-efficientnet-lite2 imagenet 6M
timm-efficientnet-lite3 imagenet 8M
timm-efficientnet-lite4 imagenet 13M
MobileNet
Encoder Weights Params, M
mobilenet_v2 imagenet 2M
timm-mobilenetv3_large_075 imagenet 1.78M
timm-mobilenetv3_large_100 imagenet 2.97M
timm-mobilenetv3_large_minimal_100 imagenet 1.41M
timm-mobilenetv3_small_075 imagenet 0.57M
timm-mobilenetv3_small_100 imagenet 0.93M
timm-mobilenetv3_small_minimal_100 imagenet 0.43M
DPN
Encoder Weights Params, M
dpn68 imagenet 11M
dpn68b imagenet+5k 11M
dpn92 imagenet+5k 34M
dpn98 imagenet 58M
dpn107 imagenet+5k 84M
dpn131 imagenet 76M
VGG
Encoder Weights Params, M
vgg11 imagenet 9M
vgg11_bn imagenet 9M
vgg13 imagenet 9M
vgg13_bn imagenet 9M
vgg16 imagenet 14M
vgg16_bn imagenet 14M
vgg19 imagenet 20M
vgg19_bn imagenet 20M

* ssl, swsl - semi-supervised and weakly-supervised learning on ImageNet (repo).

Timm Encoders

docs

Pytorch Image Models (a.k.a. timm) has a lot of pretrained models and interface which allows using these models as encoders in smp, however, not all models are supported

  • transformer models do not have features_only functionality implemented
  • some models do not have appropriate strides

Total number of supported encoders: 467

πŸ” Models API

  • model.encoder - pretrained backbone to extract features of different spatial resolution
  • model.decoder - depends on models architecture (Unet/Linknet/PSPNet/FPN)
  • model.segmentation_head - last block to produce required number of mask channels (include also optional upsampling and activation)
  • model.classification_head - optional block which create classification head on top of encoder
  • model.forward(x) - sequentially pass x through model`s encoder, decoder and segmentation head (and classification head if specified)
Input channels

Input channels parameter allows you to create models, which process tensors with arbitrary number of channels. If you use pretrained weights from imagenet - weights of first convolution will be reused. For 1-channel case it would be a sum of weights of first convolution layer, otherwise channels would be populated with weights like new_weight[:, i] = pretrained_weight[:, i % 3] and than scaled with new_weight * 3 / new_in_channels.

model = smp.FPN('resnet34', in_channels=1)
mask = model(torch.ones([1, 1, 64, 64]))
Auxiliary classification output

All models support aux_params parameters, which is default set to None. If aux_params = None then classification auxiliary output is not created, else model produce not only mask, but also label output with shape NC. Classification head consists of GlobalPooling->Dropout(optional)->Linear->Activation(optional) layers, which can be configured by aux_params as follows:

aux_params=dict(
    pooling='avg',             # one of 'avg', 'max'
    dropout=0.5,               # dropout ratio, default is None
    activation='sigmoid',      # activation function, default is None
    classes=4,                 # define number of output labels
)
model = smp.Unet('resnet34', classes=4, aux_params=aux_params)
mask, label = model(x)
Depth

Depth parameter specify a number of downsampling operations in encoder, so you can make your model lighter if specify smaller depth.

model = smp.Unet('resnet34', encoder_depth=4)

πŸ›  Installation

PyPI version:

$ pip install segmentation-models-pytorch

Latest version from source:

$ pip install git+https://github.com/qubvel/segmentation_models.pytorch

πŸ† Competitions won with the library

Segmentation Models package is widely used in the image segmentation competitions. Here you can find competitions, names of the winners and links to their solutions.

🀝 Contributing

Run test
$ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev pytest -p no:cacheprovider
Generate table
$ docker build -f docker/Dockerfile.dev -t smp:dev . && docker run --rm smp:dev python misc/generate_table.py

πŸ“ Citing

@misc{Yakubovskiy:2019,
  Author = {Pavel Yakubovskiy},
  Title = {Segmentation Models Pytorch},
  Year = {2020},
  Publisher = {GitHub},
  Journal = {GitHub repository},
  Howpublished = {\url{https://github.com/qubvel/segmentation_models.pytorch}}
}

πŸ›‘οΈ License

Project is distributed under MIT License

Comments
  • RuntimeError: Error(s) in loading state_dict for Unet

    RuntimeError: Error(s) in loading state_dict for Unet

    i get the error below when i try to use my unet with se_resnext50 models trained weight file for ensembling :

    RuntimeError: Error(s) in loading state_dict for Unet: Missing key(s) in state_dict: "decoder.blocks.0.conv1.0.weight", "decoder.blocks.0.conv1.1.weight", "decoder.blocks.0.conv1.1.bias", "decoder.blocks.0.conv1.1.running_mean", "decoder.blocks.0.conv1.1.running_var", "decoder.blocks.0.conv2.0.weight", "decoder.blocks.0.conv2.1.weight", "decoder.blocks.0.conv2.1.bias", "decoder.blocks.0.conv2.1.running_mean", "decoder.blocks.0.conv2.1.running_var", "decoder.blocks.1.conv1.0.weight", "decoder.blocks.1.conv1.1.weight", "decoder.blocks.1.conv1.1.bias", "decoder.blocks.1.conv1.1.running_mean", "decoder.blocks.1.conv1.1.running_var", "decoder.blocks.1.conv2.0.weight", "decoder.blocks.1.conv2.1.weight", "decoder.blocks.1.conv2.1.bias", "decoder.blocks.1.conv2.1.running_mean", "decoder.blocks.1.conv2.1.running_var", "decoder.blocks.2.conv1.0.weight", "decoder.blocks.2.conv1.1.weight", "decoder.blocks.2.conv1.1.bias", "decoder.blocks.2.conv1.1.running_mean", "decoder.blocks.2.conv1.1.running_var", "decoder.blocks.2.conv2.0.weight", "decoder.blocks.2.conv2.1.weight", "decoder.blocks.2.conv2.1.bias", "decoder.blocks.2.conv2.1.running_mean", "decoder.blocks.2.conv2.1.running_var", "decoder.blocks.3.conv1.0.weight", "decoder.blocks.3.conv1.1.weight", "decoder.blocks.3.conv1.1.bias", "decoder.blocks.3.conv1.1.running_mean", "decoder.blocks.3.conv1.1.running_var", "decoder.blocks.3.conv2.0.weight", "decoder.blocks.3.conv2.1.weight", "decoder.blocks.3.conv2.1.bias", "decoder.blocks.3.conv2.1.running_mean", "decoder.blocks.3.conv2.1.running_var", "decoder.blocks.4.conv1.0.weight", "decoder.blocks.4.conv1.1.weight", "decoder.blocks.4.conv1.1.bias", "decoder.blocks.4.conv1.1.running_mean", "decoder.blocks.4.conv1.1.running_var", "decoder.blocks.4.conv2.0.weight", "decoder.blocks.4.conv2.1.weight", "decoder.blocks.4.conv2.1.bias", "decoder.blocks.4.conv2.1.running_mean", "decoder.blocks.4.conv2.1.running_var", "segmentation_head.0.weight", "segmentation_head.0.bias". Unexpected key(s) in state_dict: "decoder.layer1.block.0.block.0.weight", "decoder.layer1.block.0.block.1.weight", "decoder.layer1.block.0.block.1.bias", "decoder.layer1.block.0.block.1.running_mean", "decoder.layer1.block.0.block.1.running_var", "decoder.layer1.block.0.block.1.num_batches_tracked", "decoder.layer1.block.1.block.0.weight", "decoder.layer1.block.1.block.1.weight", "decoder.layer1.block.1.block.1.bias", "decoder.layer1.block.1.block.1.running_mean", "decoder.layer1.block.1.block.1.running_var", "decoder.layer1.block.1.block.1.num_batches_tracked", "decoder.layer2.block.0.block.0.weight", "decoder.layer2.block.0.block.1.weight", "decoder.layer2.block.0.block.1.bias", "decoder.layer2.block.0.block.1.running_mean", "decoder.layer2.block.0.block.1.running_var", "decoder.layer2.block.0.block.1.num_batches_tracked", "decoder.layer2.block.1.block.0.weight", "decoder.layer2.block.1.block.1.weight", "decoder.layer2.block.1.block.1.bias", "decoder.layer2.block.1.block.1.running_mean", "decoder.layer2.block.1.block.1.running_var", "decoder.layer2.block.1.block.1.num_batches_tracked", "decoder.layer3.block.0.block.0.weight", "decoder.layer3.block.0.block.1.weight", "decoder.layer3.block.0.block.1.bias", "decoder.layer3.block.0.block.1.running_mean", "decoder.layer3.block.0.block.1.running_var", "decoder.layer3.block.0.block.1.num_batches_tracked", "decoder.layer3.block.1.block.0.weight", "decoder.layer3.block.1.block.1.weight", "decoder.layer3.block.1.block.1.bias", "decoder.layer3.block.1.block.1.running_mean", "decoder.layer3.block.1.block.1.running_var", "decoder.layer3.block.1.block.1.num_batches_tracked", "decoder.layer4.block.0.block.0.weight", "decoder.layer4.block.0.block.1.weight", "decoder.layer4.block.0.block.1.bias", "decoder.layer4.block.0.block.1.running_mean", "decoder.layer4.block.0.block.1.running_var", "decoder.layer4.block.0.block.1.num_batches_tracked", "decoder.layer4.block.1.block.0.weight", "decoder.layer4.block.1.block.1.weight", "decoder.layer4.block.1.block.1.bias", "decoder.layer4.block.1.block.1.running_mean", "decoder.layer4.block.1.block.1.running_var", "decoder.layer4.block.1.block.1.num_batches_tracked", "decoder.layer5.block.0.block.0.weight", "decoder.layer5.block.0.block.1.weight", "decoder.layer5.block.0.block.1.bias", "decoder.layer5.block.0.block.1.running_mean", "decoder.layer5.block.0.block.1.running_var", "decoder.layer5.block.0.block.1.num_batches_tracked", "decoder.layer5.block.1.block.0.weight", "decoder.layer5.block.1.block.1.weight", "decoder.layer5.block.1.block.1.bias", "decoder.layer5.block.1.block.1.running_mean", "decoder.layer5.block.1.block.1.running_var", "decoder.layer5.block.1.block.1.num_batches_tracked", "decoder.final_conv.weight", "decoder.final_conv.bias".

    Stale 
    opened by mobassir94 28
  • cannot import name 'container_abcs' from 'torch._six'

    cannot import name 'container_abcs' from 'torch._six'

    Encountering this from today: 19-Jun-2021

    ImportError Traceback (most recent call last) in () 4 get_ipython().system('pip install -U segmentation-models-pytorch') 5 ----> 6 import segmentation_models_pytorch as smp 7 8

    11 frames /usr/local/lib/python3.7/dist-packages/timm/models/layers/helpers.py in () 4 """ 5 from itertools import repeat ----> 6 from torch._six import container_abcs 7 8

    ImportError: cannot import name 'container_abcs' from 'torch._six' (/usr/local/lib/python3.7/dist-packages/torch/_six.py)

    opened by lifeischaotic 27
  • How to use metrics for multi-class binary mask target and multi-class multi-channel output?

    How to use metrics for multi-class binary mask target and multi-class multi-channel output?

    I saw in the documentation that the metrics for multilabel prediction require (batch, num_class, height, width). But then, I have a multi-class mask of one channel as target where each pixel are labeled by the class.

    How do I use this for that scenario?

    Also, this seems to be computing it per batch. How do I do it per epoch?

    import segmentation_models_pytorch as smp
    
    # lets assume we have multilabel prediction for 3 classes
    output = torch.rand([10, 3, 256, 256])
    target = torch.rand([10, 3, 256, 256]).round().long()
    
    # first compute statistics for true positives, false positives, false negative and
    # true negative "pixels"
    tp, fp, fn, tn = smp.metrics.get_stats(output, target, mode='multilabel', threshold=0.5)
    
    # then compute metrics with required reduction (see metric docs)
    iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="micro")
    f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro")
    f2_score = smp.metrics.fbeta_score(tp, fp, fn, tn, beta=2, reduction="micro")
    accuracy = smp.metrics.accuracy(tp, fp, fn, tn, reduction="macro")
    recall = smp.metrics.recall(tp, fp, fn, tn, reduction="micro-imagewise")
    
    Stale 
    opened by sarmientoj24 21
  • Add Apple's MobileOne encoder

    Add Apple's MobileOne encoder

    Hello,

    I added support for Apple's MobileOne encoder.

    Paper: Link

    There were very few changes I had to make to their official github repo: Link

    It works with all decoders and has impressive inference time for images with 256x256:

    | Encoder-Decoder| Inference time in vanilla torch| | ------------- | ------------- | | mobileone_s1_pspnet_256 | 0.0313718318939209 | mobileone_s0_pan_256 | 0.03421592712402344 mobileone_s2_pspnet_256 | 0.036206960678100586 mobileone_s3_pspnet_256 | 0.04711484909057617 mobileone_s1_pan_256 | 0.05329489707946777 mobileone_s0_linknet_256 | 0.05789995193481445 mobileone_s0_deeplabv3plus_256 | 0.058853864669799805 mobileone_s0_fpn_256 | 0.07664108276367188 mobileone_s4_pspnet_256 | 0.0768282413482666 mobileone_s1_deeplabv3plus_256 | 0.07886672019958496 mobileone_s2_pan_256 | 0.07946181297302246 mobileone_s3_pan_256 | 0.09101414680480957 mobileone_s1_fpn_256 | 0.09615683555603027 mobileone_s1_linknet_256 | 0.09956574440002441 mobileone_s2_fpn_256 | 0.11291790008544922 mobileone_s0_unet_256 | 0.11676502227783203 mobileone_s2_linknet_256 | 0.12518310546875 mobileone_s3_deeplabv3plus_256 | 0.12642478942871094 mobileone_s2_deeplabv3plus_256 | 0.1289658546447754 mobileone_s3_fpn_256 | 0.1370537281036377 mobileone_s4_pan_256 | 0.14015984535217285 mobileone_s1_unet_256 | 0.15249204635620117 mobileone_s3_linknet_256 | 0.15824413299560547 mobileone_s4_deeplabv3plus_256 | 0.16476082801818848 mobileone_s0_manet_256 | 0.17203474044799805 mobileone_s2_unet_256 | 0.17334604263305664 mobileone_s4_fpn_256 | 0.182358980178833 mobileone_s3_unet_256 | 0.20330286026000977 mobileone_s4_linknet_256 | 0.21462082862854004 mobileone_s0_deeplabv3_256 | 0.22992897033691406 mobileone_s4_unet_256 | 0.24337363243103027 mobileone_s0_unetplusplus_256 | 0.29451799392700195 mobileone_s1_deeplabv3_256 | 0.31217503547668457 mobileone_s1_manet_256 | 0.3140380382537842 mobileone_s1_unetplusplus_256 | 0.5090749263763428 mobileone_s2_deeplabv3_256 | 0.5372707843780518 mobileone_s3_deeplabv3_256 | 0.5489542484283447 mobileone_s2_unetplusplus_256 | 0.5728631019592285 mobileone_s4_deeplabv3_256 | 0.638185977935791 mobileone_s2_manet_256 | 0.6446411609649658 mobileone_s3_manet_256 | 0.6838269233703613 mobileone_s3_unetplusplus_256 | 0.6991360187530518 mobileone_s4_manet_256 | 0.748121976852417 mobileone_s4_unetplusplus_256 | 0.9898359775543213

    opened by kevinpl07 19
  • Class weights for Losses

    Class weights for Losses

    Hi, love using this library.

    I have encountered problem, that my datasets are very imbalanced, they have multiple classes, but classes take less than 2% of the image space, they are mainly small objects, the rest is background and it seems that Unet fails to predict accurately.

    Using your segmentation_models for Tensorflow library I was able to use class weights for losses and it increased model prediction accuracy.

    Is it possible to use class weights on this library? Might there be any code snippet?

    Best Regards, Augustas

    opened by augasur 19
  • Feature: support `timm` features_only functionality

    Feature: support `timm` features_only functionality

    I've noticed more and more timm backbones being added here, which is great, but a lot of the effort is currently duplicating some features of timm, ie tracking channel numbers, modifying the networks, etc.

    timm has a features_only arg in the model factory that will return a model setup as a backbone to produce pyramid features. It has a .features_info attribute you can query to understand what the channels of each output, the approx reduction factor is, etc.

    I've adapted the unet and deeplab impl here in the past to use this successfully, although it was quick hack and train work, nothing to serve as a clean example.

    If this was supported, any timm model (vit excluded right now) can be used as a backbone in generic fashion, just by model name string passed to creation fn, possibly a small config mapping of model types to index specificiations (some models have slightly different out_indices alignment to strides if they happen be a stride 64 model, or don't have a stride=2 feature, etc). All tap points are the latest possible point for a given feature map stride. Some, but not all of the timm backbones also support an output_stride= arg that will dilate the blocks appropriately for 8, 16 network strides.

    Some references:

    • https://rwightman.github.io/pytorch-image-models/feature_extraction/#multi-scale-feature-maps-feature-pyramid
    • https://github.com/rwightman/efficientdet-pytorch/blob/92bb66fd0cf91d0e23fe8b10cba97e2f0bb9884f/effdet/efficientdet.py#L554-L569

    For most of the models, the featuers are extracted by flattening part of the backbone model via wrapper. A few models where the feature taps are embedded deep within the model use hooks, which causes some issues with torchscript but that will likely be fixed soon in PyTorch.

    opened by rwightman 18
  • How to modify the sample code for multiple classifications. I have modified it according to the readme file, but the result after training is a single classification, and the masks of other categories are empty.

    How to modify the sample code for multiple classifications. I have modified it according to the readme file, but the result after training is a single classification, and the masks of other categories are empty.

    Thank you very much for any help, your code is so cool! Hi, I am using the segmentation code from the example. I use my own data set to perfectly segment individual categories. But when I tried to split images of multiple categories, I still just split one category. The image I split includes three categories and a background. What I have done is ACTIVATION = 'softmax2d' I changed the category to 4 (including background) The current result is that the output of the training output is 4 masks, but only one class and one background class are included, and the other two classes are empty. Thank you again!

    Stale 
    opened by siyangbing 17
  • update diceloss

    update diceloss

    In the master branch, DiceLoss is implemented in such a way that the loss computed is along all class masks while it should be a mean of each diceloss for each class. And so, multiclass segmentation does not work well. This update should correct this problem

    Stale 
    opened by julienguegan 13
  • Conversion PyTorch => ONNX => TensorRT

    Conversion PyTorch => ONNX => TensorRT

    Hi,

    I'm trying to convert a segmentation model (ENCODER = efficientnet-b2, DECODER = FPN) to ONNX and afterwards to TensorRT (TRT). Converting to ONNX seems to work but I can't get the conversion to TRT right. I tried the 'torch2trt' library but couldn't succeed...

    Does anyone has experience with this?

    Running: efficientnet-pytorch==0.6.3 onnx==1.7.0 segmentation-models-pytorch==0.1.0 torch==1.6.0 torch2trt==0.1.0 torchvision==0.7.0

    Thanks in advance,

    Michiel

    Stale 
    opened by michieljanssen97 13
  • Size mismatch occurs in UNet model at 5th stage

    Size mismatch occurs in UNet model at 5th stage

    I used the SMP library to create a UNet model with the following configurations: model = smp.Unet(encoder_name='resnet50', encoder_weights='imagenet', in_channels=3, classes=30)

    However, I have also tried with other encoders (including the default resnet34) and the error seems to appear for every encoder that I choose. I am training it on a custom dataset of which the dimensions of the images are: w=320, h=192

    My code runs fine until one of the final steps in the decoder block. The error traces back to smp/unet/decoder.py. When I'm running a training epoch, the error occurs in def forward(self, x, skip=None) of decoder.py

    def forward(self, x, skip=None):
        x = F.interpolate(x, scale_factor=2, mode="nearest")
        if skip is not None:
                x = torch.cat([x, skip], dim=1)
                x = self.attention1(x)
        x = self.conv1(x)
        x = self.conv2(x)
        x = self.attention2(x)
        return x
    

    For the first steps, everything runs fine and the dimensions of 'x' match with 'skip'. Below you can find a list of the dimensions of both x and skip as I go through the decoder:

    STEP 1
    x.shape
    Out[1]: torch.Size([1, 2048, 14, 20])
    skip.shape
    Out[2]: torch.Size([1, 1024, 14, 20])
    STEP 2
    x.shape
    Out[3]: torch.Size([1, 256, 28, 40])
    skip.shape
    Out[4]: torch.Size([1, 512, 28, 40])
    STEP 3
    x.shape
    Out[5]: torch.Size([1, 128, 56, 80])
    skip.shape
    Out[6]: torch.Size([1, 256, 55, 80])
    STEP 4
    x.shape
    Out[7]: torch.Size([1, 128, 56, 80])
    skip.shape
    Out[8]: torch.Size([1, 256, 55, 80])
    STEP 5
    x.shape
    Out[9]: torch.Size([1, 3, 192, 320])
    skip.shape
    Out[10]: torch.Size([1, 256, 55, 80])
    

    Around step 3, a mismatch between the tensors starts occurring which causes the error. This error traceback can be seen in the indented block below. What I find weird about this, is that I have used the exact same codebase with a different dataset that only consisted of 6 classes and in that case there was no issue. I am also unsure where this is happening as I cannot seem to find the root cause.

    Traceback

    (most recent call last): File "/Users/fc/Desktop/ct/segmentation_code/main.py", line 141, in trainer.train() File "/Users/fc/Desktop/ct/segmentation_code/ops/trainer.py", line 44, in train self.train_logs = self.train_epoch.run(self.trainloader) File "/Users/fc/.local/lib/python3.8/site-packages/segmentation_models_pytorch/utils/train.py", line 47, in run loss, y_pred = self.batch_update(x, y) File "/Users/fc/.local/lib/python3.8/site-packages/segmentation_models_pytorch/utils/train.py", line 87, in batch_update prediction = self.model.forward(x) File "/Users/fc/.local/lib/python3.8/site-packages/segmentation_models_pytorch/base/model.py", line 16, in forward decoder_output = self.decoder(*features) File "/Users/fc/miniconda3/envs/ct/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/Users/fc/.local/lib/python3.8/site-packages/segmentation_models_pytorch/unet/decoder.py", line 119, in forward x = decoder_block(x, skip) File "/Users/fc/miniconda3/envs/ct/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl return forward_call(*input, **kwargs) File "/Users/fc/.local/lib/python3.8/site-packages/segmentation_models_pytorch/unet/decoder.py", line 38, in forward x = torch.cat([x, skip], dim=1) RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 56 but got size 55 for tensor number 1 in the list.

    Stale 
    opened by Fritskee 12
  • Unet decoder upsampling

    Unet decoder upsampling

    Hi

    I am using a Unet model with the encoder set to 'resnet34' and the pretrained weights are imagenet.

    When I look at the model I do not see where the upsampling is occuring. The convolutions in the encoder side are occuring (although the downsampling is seemingly occuring after the intended layer e.g downsampling from layer 1 to layer 2 only occurs after layer 2), however I do not see where the upsampling takes place in the decoder side.

    There is also the case where I do not see the centre block convolutions occuring.

    Can I please be explained where the upsampling occurs?

    My model for reference:

    resnet34 Unet model.txt

    opened by DamienLopez1 12
  • Recommended way to load pretrained weights for encoder from checkpoint file.

    Recommended way to load pretrained weights for encoder from checkpoint file.

    I have a pretrained model checkpoint that I would like to use as the encoder weights for a segmentation model and then train this segmentation model on a new task. It looks likethe only options for encoder_weights argument are strings to certain pretrained weights within the smp library that are listed in the table. Is there a workaround to for example load some other pretrained resnet50 backbone in form of a checkpoint file as the encoder weight to an smp model?

    opened by nilsleh 0
  • MixVisionTransformer in combination with PAN fails with

    MixVisionTransformer in combination with PAN fails with "encoder does not support dilated mode"

    import segmentation_models_pytorch as smp
    
    smp.PAN(encoder_name="mit_b0")
    

    raises the exception:

    ValueError: MixVisionTransformer encoder does not support dilated mode
    

    Since the default PAN uses dilation, this config is uncompatible atm?

    If we use a configuration of PAN that does not use dilation the error, of course, does not apper:

    smp.PAN(encoder_name="mit_b0", encoder_output_stride=32)
    

    I did not test yet though if output strides of 32 still deliver comparable results. My guess would be that the default stride of 16 should encode a lot more of information that might be beneficial for better performence.

    Is there any way to get it to work with dilation?

    opened by Daniel451 2
  • How to compute metrics for each class in multi class segmentation

    How to compute metrics for each class in multi class segmentation

    I would compute the metrics individually for each class so I would like to have in output a (1xC) vector where C is the number of classes, I was trying like this but it throws me an error:

    output = torch.rand([10, 3, 256, 256])
    target = torch.rand([10, 1, 256, 256]).round().long()
    
    # first compute statistics for true positives, false positives, false negative and
    # true negative "pixels"
    tp, fp, fn, tn = smp.metrics.get_stats(output, target, mode='multi class', num_classes = 3)
    
    # then compute metrics with required reduction (see metric docs)
    iou_score = smp.metrics.iou_score(tp, fp, fn, tn, reduction="macro-imagewise")
    f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="macro-imagewise")
    false_negatives = smp.metrics.false_negative_rate(tp, fp, fn, tn, reduction=None)
    recall = smp.metrics.recall(tp, fp, fn, tn, reduction=None)
    

    The error:

    ValueError: For ``multiclass`` mode ``target`` should be one of the integer types, got torch.float32.
    
    opened by santurini 1
  • Softmax activation function throws deprecation warning

    Softmax activation function throws deprecation warning

    When defining a smp model in __init__() as:

    self.base = smp.Unet(encoder_name='resnet50', pretrained='imagenet', 
                                      in_channels=3, classes=7,
                                      activation='softmax') 
    

    This will throw the following warning upon initialisation:

    ~/anaconda3/envs/geo/lib/python3.7/site-packages/segmentation_models_pytorch/base/modules.py:116: UserWarning: Implicit dimension choice for softmax has been deprecated. Change the call to include dim=X as an argument. return self.activation(x)

    As the softmax is defined by passing 'softmax' as an arg, I'm not sure where/how to include the dim as the warning suggests? Many thanks

    See also this closed (but without stating resolution) issue: Originally posted by @vdplasthijs in https://github.com/qubvel/segmentation_models.pytorch/issues/169#issuecomment-1334066128

    opened by vdplasthijs 0
  • AttributeError: module 'segmentation_models_pytorch' has no attribute 'utils'

    AttributeError: module 'segmentation_models_pytorch' has no attribute 'utils'

    I was following the cars example but I keep getting this error. I installed the module with pip in both suggested ways but I keep getting this error even if I checked the presence of utils in the repo. What am I doing wrong?

    This is the code:

    !pip install git+https://github.com/qubvel/segmentation_models.pytorch
    # !pip install -q segmentation-models-pytorch
    
    train_epoch = smp.utils.train.TrainEpoch(
        model, 
        loss=loss, 
        metrics=metrics, 
        optimizer=optimizer,
        device=DEVICE,
        verbose=True,
    )
    
    opened by santurini 2
Releases(v0.3.1)
  • v0.3.1(Nov 30, 2022)

  • v0.3.0(Jul 29, 2022)

    Updates

    • Added smp.metrics module with different metrics based on confusion matrix, see docs
    • Added new notebook with training example using pytorch-lightning Open In Colab
    • Improved handling of incorrect input image size error (checking image size is 2^n)
    • Codebase refactoring and style checks (black, flake8)
    • Minor typo fixes and bug fixes

    Breaking changes

    • utils module is going to be deprecated, if you still need it import it manually from segmentation_models_pytorch import utils

    Thanks a lot for all contributors!

    Source code(tar.gz)
    Source code(zip)
  • v0.2.1(Nov 18, 2021)

  • v0.2.0(Jul 5, 2021)

    Updates

    • New architecture: MANet (#310)
    • New encoders from timm: mobilenetv3 (#355) and gernet (#344)
    • New loss functions in smp.losses module (smp.utils.losses would be deprecated in future versions)
    • New pretrained weight initialization for first convolution if in_channels > 3
    • Updated timm version (0.4.12)
    • Bug fixes and docs improvement

    Thanks to @azkalot1 @JulienMaille @originlake @Kupchanski @loopdigga96 @zurk @nmerty @ludics @Vozf @markson14 and others!

    Source code(tar.gz)
    Source code(zip)
  • v0.1.3(Dec 13, 2020)

    Updates

    • New architecture Unet++ (#279)
    • New encoders RegNet, ResNest, SK-Net, Res2Net (#286)
    • Updated timm version (0.3.2)
    • Improved docstrings and typehints for models
    • Project documentation on https://smp.readthedocs.io

    Thanks to @azkalot1 for the new encoders and architecture!

    Source code(tar.gz)
    Source code(zip)
  • v0.1.2(Sep 28, 2020)

  • v0.1.1(Sep 26, 2020)

    Updates

    • New decoders DeepLabV3, DeepLabV3+, PAN
    • New backbones (encoders) timm-efficientnet*
    • New pretrained weights (ssl, wsl) for resnets
    • New pretrained weights (advprop) for efficientnets

    And some small fixes.

    Thanks @IlyaDobrynin @gavrin-s @lizmisha @suitre77 @thisisiron @phamquiluan and all other contributers!

    Source code(tar.gz)
    Source code(zip)
  • V0.1.0(Dec 9, 2019)

    Updates

    1. New backbones (mobilenet, efficientnet, inception)
    2. depth and in_channels options for all models
    3. Auxiliary classification output

    Note!

    Model architectures have been changed, use previous versions for weights compatibility!

    Source code(tar.gz)
    Source code(zip)
  • v0.0.3(Sep 28, 2019)

  • v0.0.2(Sep 19, 2019)

Owner
Pavel Yakubovskiy
Pavel Yakubovskiy
Facebook Research 605 Jan 2, 2023
Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pytorch Lightning 1.4k Jan 1, 2023
(ImageNet pretrained models) The official pytorch implemention of the TPAMI paper "Res2Net: A New Multi-scale Backbone Architecture"

Res2Net The official pytorch implemention of the paper "Res2Net: A New Multi-scale Backbone Architecture" Our paper is accepted by IEEE Transactions o

Res2Net Applications 928 Dec 29, 2022
Pretrained Pytorch face detection (MTCNN) and recognition (InceptionResnet) models

Face Recognition Using Pytorch Python 3.7 3.6 3.5 Status This is a repository for Inception Resnet (V1) models in pytorch, pretrained on VGGFace2 and

Tim Esler 3.3k Jan 4, 2023
Official PyTorch implementation and pretrained models of the paper Self-Supervised Classification Network

Self-Classifier: Self-Supervised Classification Network Official PyTorch implementation and pretrained models of the paper Self-Supervised Classificat

Elad Amrani 24 Dec 21, 2022
Base pretrained models and datasets in pytorch (MNIST, SVHN, CIFAR10, CIFAR100, STL10, AlexNet, VGG16, VGG19, ResNet, Inception, SqueezeNet)

This is a playground for pytorch beginners, which contains predefined models on popular dataset. Currently we support mnist, svhn cifar10, cifar100 st

Aaron Chen 2.4k Dec 28, 2022
Implementation of Squeezenet in pytorch, pretrained models on Cifar 10 data to come

Pytorch Squeeznet Pytorch implementation of Squeezenet model as described in https://arxiv.org/abs/1602.07360 on cifar-10 Data. The definition of Sque

gaurav pathak 86 Oct 28, 2022
Repository providing a wide range of self-supervised pretrained models for computer vision tasks.

Hierarchical Pretraining: Research Repository This is a research repository for reproducing the results from the project "Self-supervised pretraining

Colorado Reed 53 Nov 9, 2022
Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.

Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.

Matthias Wright 169 Dec 26, 2022
This project provides an unsupervised framework for mining and tagging quality phrases on text corpora with pretrained language models (KDD'21).

UCPhrase: Unsupervised Context-aware Quality Phrase Tagging To appear on KDD'21...[pdf] This project provides an unsupervised framework for mining and

Xiaotao Gu 146 Dec 22, 2022
Using pretrained language models for biomedical knowledge graph completion.

LMs for biomedical KG completion This repository contains code to run the experiments described in: Scientific Language Models for Biomedical Knowledg

Rahul Nadkarni 41 Nov 30, 2022
Measuring and Improving Consistency in Pretrained Language Models

ParaRel ?? This repository contains the code and data for the paper: Measuring and Improving Consistency in Pretrained Language Models as well as the

Yanai Elazar 26 Dec 2, 2022
Reference implementation of code generation projects from Facebook AI Research. General toolkit to apply machine learning to code, from dataset creation to model training and evaluation. Comes with pretrained models.

This repository is a toolkit to do machine learning for programming languages. It implements tokenization, dataset preprocessing, model training and m

Facebook Research 408 Jan 1, 2023
A library for finding knowledge neurons in pretrained transformer models.

knowledge-neurons An open source repository replicating the 2021 paper Knowledge Neurons in Pretrained Transformers by Dai et al., and extending the t

EleutherAI 96 Dec 21, 2022
VisualGPT: Data-efficient Adaptation of Pretrained Language Models for Image Captioning

VisualGPT Our Paper VisualGPT: Data-efficient Adaptation of Pretrained Language Models for Image Captioning Main Architecture of Our VisualGPT Downloa

Vision CAIR Research Group, KAUST 140 Dec 28, 2022
YOLOv5 πŸš€ is a family of object detection architectures and models pretrained on the COCO dataset

YOLOv5 ?? is a family of object detection architectures and models pretrained on the COCO dataset, and represents Ultralytics open-source research int

ι˜Ώζ‰ 73 Dec 16, 2022
Music Source Separation; Train & Eval & Inference piplines and pretrained models we used for 2021 ISMIR MDX Challenge.

Music Source Separation with Channel-wise Subband Phase Aware ResUnet (CWS-PResUNet) Introduction This repo contains the pretrained Music Source Separ

Lau 100 Dec 25, 2022
The PASS dataset: pretrained models and how to get the data - PASS: Pictures without humAns for Self-Supervised Pretraining

The PASS dataset: pretrained models and how to get the data - PASS: Pictures without humAns for Self-Supervised Pretraining

Yuki M. Asano 249 Dec 22, 2022
LWCC: A LightWeight Crowd Counting library for Python that includes several pretrained state-of-the-art models.

LWCC: A LightWeight Crowd Counting library for Python LWCC is a lightweight crowd counting framework for Python. It wraps four state-of-the-art models

Matija TerΕ‘ek 39 Dec 28, 2022