BiSeNet based on pytorch

Overview

BiSeNet

BiSeNet based on pytorch 0.4.1 and python 3.6

Dataset

Download CamVid dataset from Google Drive or Baidu Yun(6xw4).

Pretrained model

Download best_dice_loss_miou_0.655.pth in Google Drive or in Baidu Yun(6y3e) and put it in ./checkpoints

Demo

python demo.py

Result

Original GT Predict

Train

python train.py

Use tensorboard to see the real-time loss and accuracy

loss on train

pixel precision on val

miou on val

Test

python test.py

Result

class Bicyclist Building Car Pole Fence Pedestrian Road Sidewalk SignSymbol Sky Tree miou
iou 0.61 0.80 0.86 0.35 0.37 0.59 0.88 0.81 0.28 0.91 0.73 0.655

This time I train the model with dice loss and get better result than cross entropy loss. I did not use lots special training strategy, you can get much better result than this repo if using task-specific strategy.
This repo is mainly for proving the effeciveness of the model.
I also tried some simplified version of bisenet but it seems does not preform very well in CamVid dataset.

Speed

Method 640×320 1280×720 1920×1080
Paper 129.4 47.9 23
This Repo 126.8 53.7 23.6

This shows the speed comparison between paper and my implementation.

  1. The number in first row means input image resolution.
  2. The number in second and third row means FPS.
  3. The result is based on resnet-18.

Future work

  • Finish real-time segmentation with camera or pre-load video

Reference

Comments
  • loss function

    loss function

    The loss function in the original paper is composed of two parts. But you only use the output of feature fusion model to calculate the loss. And the loss they use is cross entropy, here you use is binary cross entropy. Is there any reason to these changes? Thanks!

    opened by WellYoungIOE 15
  • I think this version is not very stable.

    I think this version is not very stable.

    When I first trained, the training data(MIOU) was normal. but second, miou has been fixed at around 0.12, I thought for a long time but did't solve this problem. So, I want to ask if you have encountered this problem.

    opened by YuGuii 7
  • AttributeError: 'BiSeNet' object has no attribute 'module'

    AttributeError: 'BiSeNet' object has no attribute 'module'

    load model from ./checkpoints/epoch_295.pth ... Traceback (most recent call last): File "c:\Users\Administrator\Desktop\BiSeNet-master\BiSeNet-master\demo.py", line 80, in main(params) File "c:\Users\Administrator\Desktop\BiSeNet-master\BiSeNet-master\demo.py", line 60, in main model.module.load_state_dict(torch.load(args.checkpoint_path)) File "C:\Program Files\Python\lib\site-packages\torch\nn\modules\module.py", line 518, in getattr type(self).name, name)) AttributeError: 'BiSeNet' object has no attribute 'module' PS C:\Users\Administrator\Desktop\BiSeNet-master\BiSeNet-master> Where do I put the dataset ?

    opened by Eileen2014 6
  • train wrong

    train wrong

    when i train my data , it happend as followed:

    os@os-l3:/disk3t-2/zym/BiSeNet-PyTorch$ python train.py epoch 0, lr 0.001000: 0%| | 0/4963 [00:00<?, ?it/s]Traceback (most recent call last): File "train.py", line 157, in main(params) File "train.py", line 141, in main train(args, model, optimizer, dataloader_train, dataloader_val, csv_path) File "train.py", line 56, in train output = model(data) File "/home/os/.local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 477, in call result = self.forward(*input, **kwargs) File "/home/os/.local/lib/python2.7/site-packages/torch/nn/parallel/data_parallel.py", line 121, in forward return self.module(*inputs[0], **kwargs[0]) File "/home/os/.local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 477, in call result = self.forward(*input, **kwargs) File "/disk3t-2/zym/BiSeNet-PyTorch/model/build_BiSeNet.py", line 97, in forward cx1 = self.attention_refinement_module1(cx1) File "/home/os/.local/lib/python2.7/site-packages/torch/nn/modules/module.py", line 477, in call result = self.forward(*input, **kwargs) File "/disk3t-2/zym/BiSeNet-PyTorch/model/build_BiSeNet.py", line 40, in forward assert self.in_channels == x.size(1), 'in_channels and out_channels should all be {}'.format(x.size(1)) AssertionError: in_channels and out_channels should all be 256

    opened by zhangyunming 4
  • Wrong label for Seq05VD_f02610_L.png

    Wrong label for Seq05VD_f02610_L.png

    Hi, I convert label from RGB mode to 'P' mode based on class_dict.csv, and found there are some points/coord with wrong label, aka its color is out of color defined in class_dict.csv. Here is how I convert RGB mode to 'P' mode and check label:

    from PIL import Image
    from utils import get_label_info
    
    labelinfo = get_label_info(csvfile)
    palette = []
    for key in labelinfo:
        for item in labelinfo[key]:
            palette.append(item)
    
    ref_image = Image.new(mode='P', size=(1, 1))
    ref_image.putpalette(palette)
    
    img_rgb = Image.open('rgb_mode.png')
    img_p = img_rgb.quantize(palette=ref_image)
    
    img_p_wrong = img_p[img > 31]
    img_rgb_wrong = img_rgb[img > 31]
    

    I'm not sure how to deal with this image, currently I simply remove it out of my training/val/testing. After convert RGB mode to 'P' mode, I don't need to use one hot encoding for label anymore. It seems that the one hot encoding & decoding slow down training speed.

    And, the code you calculating accuracy seems weird, https://github.com/ooooverflow/BiSeNet/blob/master/utils.py#L103, pred & label are in shape of 1xCxHxW, where C is the channels, which is 3 in this case. pred[:, :, h, w] == label[:, :, h, w] means one pixel prediction right, not three.

    opened by hubutui 3
  • This Code only supports pytorch==0.4.1

    This Code only supports pytorch==0.4.1

    Fisrt I met the same error with https://github.com/ooooverflow/BiSeNet/issues/1 using pytorch==0.4.0, and I met other error using pytorch==1.0.0 When I change the version into 0.4.1, it solved. But when I start training, it can work well with resnet101. When I change the build_contextpath into "resnet18", I met the same error with https://github.com/ooooverflow/BiSeNet/issues/3

    opened by CPFLAME 2
  • train val accuracy is not as high as mentioned. Plus res101 accuracy curve is not stable

    train val accuracy is not as high as mentioned. Plus res101 accuracy curve is not stable

    i train with context_path: resnet101 and resnet 18. first question: both the validation accuracy hardly reach 0.9, Mostly stop as 0.88-0.89. second question: while the resnet101 training, the validation accuracy will fluctuate a lot after epoch 100, and will drop a lot after about epoch 180+. could you please share your training parameter like lr,batchsize,GPU num,crop_height,crop_width and some detailed trick?

    opened by 1093842024 2
  •  Problems in running demo.py to process a image (Missing key(s))

    Problems in running demo.py to process a image (Missing key(s))

    RuntimeError: Error(s) in loading state_dict for BiSeNet: Missing key(s) in state_dict: "supervision1.weight", "supervision1.bias", "supervision2.weight", "supervision2.bias".

    opened by VguanwenV 1
  • Poor performance

    Poor performance

    Hey, I git clone the repo, and download CamVid, change the path, and train for 300 epochs, but got poor performance, 0.182, not 94.1 or 93.2 reported in README.md. Any suggestion?

    Here is the val curve. image

    opened by hubutui 1
  • 训练时val过程很慢

    训练时val过程很慢

    我在训练过程中发现val过程很慢,比训练一个epoch都慢,请问有没有什么优化的方法? 由于TypeError: can't convert CUDA tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.错误,我将train.py 中第32和38行改为了predict = predict.data.cpu().numpy()和label = label.data.cpu().numpy(),这是不是有影响?

    opened by ZTZ-99 0
  • modifications

    modifications

    1. support resnet18 and resnet101 both solve #3 #9
    2. remove some errors on pytorch 1.0
    3. support more optimizers, like sgd, adam(more smooth accuracy during training)
    opened by seanXYZ 0
  • On the problem of Spatial Path

    On the problem of Spatial Path

    In 3.1 Spatial Path, there is a sentence "Based on this observation, we propose a Spatial Path to preserve the spatial size of the original input image and encode affluent spatial information. ". But in 3.1 Spatial Path, this path extracts the output feature maps that is 1/8 of the original image. What does it mean to preserve the spatial size of the original input image? Why does the paper say to preserve the spatial size of the original input image? I'm looking forward to your reply. I'm very confused about this question. Thank you very much!

    opened by yiyi-today 0
  • Right Dice Loss?

    Right Dice Loss?

    class DiceLoss(nn.Module): def __init__(self): super().__init__() self.epsilon = 1e-5

    def forward(self, output, target):
        # print(output.shape)
        # print(target.shape)
    
        assert output.size() == target.size(), "'input' and 'target' must have the same shape"
        # 在classes上做softmax
        output = F.softmax(output, dim=1)
        # 打平tensor
        output = flatten(output) # [num_classes,B*H*W]
        target = flatten(target) # [num_classes,B*H*W]
        # intersect = (output * target).sum(-1).sum() + self.epsilon
        # denominator = ((output + target).sum(-1)).sum() + self.epsilon
    
        intersect = (output * target).sum(-1)
        denominator = (output + target).sum(-1)
        # dice --(0-0.5)
        dice = intersect / denominator
        dice = torch.mean(dice)
        # 1-dice (0.5,1)???
        return 1 - dice
        # return 1 - 2. * intersect / denominator
    

    double the intersection over union?

    opened by hitsz-zuoqi 0
  • FileNotFoundError: [Errno 2] No such file or directory: '/PI_Blackfriars_Sys_1_4/Room_34_SetTempHeat.csv'

    FileNotFoundError: [Errno 2] No such file or directory: '/PI_Blackfriars_Sys_1_4/Room_34_SetTempHeat.csv'


    FileNotFoundError Traceback (most recent call last) in 5 #df 6 for site_name in df['SiteName'].unique(): ----> 7 df[df['SiteName'] == site_name].to_csv('{}.csv'.format(site_name))

    ~\AppData\Local\Continuum\anaconda3\lib\site-packages\pandas\core\generic.py in to_csv(self, path_or_buf, sep, na_rep, float_format, columns, header, index, index_label, mode, encoding, compression, quoting, quotechar, line_terminator, chunksize, tupleize_cols, date_format, doublequote, escapechar, decimal) 3018 doublequote=doublequote, 3019 escapechar=escapechar, decimal=decimal) -> 3020 formatter.save() 3021 3022 if path_or_buf is None:

    ~\AppData\Local\Continuum\anaconda3\lib\site-packages\pandas\io\formats\csvs.py in save(self) 155 f, handles = _get_handle(self.path_or_buf, self.mode, 156 encoding=self.encoding, --> 157 compression=self.compression) 158 close = True 159

    ~\AppData\Local\Continuum\anaconda3\lib\site-packages\pandas\io\common.py in _get_handle(path_or_buf, mode, encoding, compression, memory_map, is_text) 422 elif encoding: 423 # Python 3 and encoding --> 424 f = open(path_or_buf, mode, encoding=encoding, newline="") 425 elif is_text: 426 # Python 3 and no explicit encoding

    FileNotFoundError: [Errno 2] No such file or directory: '/PI_Blackfriars_Sys_1_4/Room_34_SetTempHeat.csv'

    can you please tell me why this is happening and how to solve this problem

    opened by atulmishra835 0
  • I think there is a trouble code

    I think there is a trouble code

    Please see the function that come from utils.py: def one_hot_it_v11_dice(label, label_info): semantic_map = [] void = np.zeros(label.shape[:2]) for index, info in enumerate(label_info): color = label_info[info][:3] class_11 = label_info[info][3] if class_11 == 1: equality = np.equal(label, color) class_map = np.all(equality, axis=-1) semantic_map.append(class_map) else: equality = np.equal(label, color) class_map = np.all(equality, axis=-1) void[class_map] = 1 semantic_map.append(void) semantic_map = np.stack(semantic_map, axis=-1).astype(np.float) return semantic_map The variable "semantci_map" is a python list, but in the function ,the list only have twice append operations so that in compute loss pahse the error of output and target have different shape happend, because output shape is [batch, num class,w,h], but target shape is [batch ,2,w,h]. I can't guarantee I'm absolutely right.

    opened by mshmoon 6
Owner
null
RETRO-pytorch - Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch

RETRO - Pytorch (wip) Implementation of RETRO, Deepmind's Retrieval based Attent

Phil Wang 556 Jan 4, 2023
Deep Image Search is an AI-based image search engine that includes deep transfor learning features Extraction and tree-based vectorized search.

Deep Image Search - AI-Based Image Search Engine Deep Image Search is an AI-based image search engine that includes deep transfer learning features Ex

null 139 Jan 1, 2023
Alex Pashevich 62 Dec 24, 2022
DIT is a DTLS MitM proxy implemented in Python 3. It can intercept, manipulate and suppress datagrams between two DTLS endpoints and supports psk-based and certificate-based authentication schemes (RSA + ECC).

DIT - DTLS Interception Tool DIT is a MitM proxy tool to intercept DTLS traffic. It can intercept, manipulate and/or suppress DTLS datagrams between t

null 52 Nov 30, 2022
Calculates carbon footprint based on fuel mix and discharge profile at the utility selected. Can create graphs and tabular output for fuel mix based on input file of series of power drawn over a period of time.

carbon-footprint-calculator Conda distribution ~/anaconda3/bin/conda install anaconda-client conda-build ~/anaconda3/bin/conda config --set anaconda_u

Seattle university Renewable energy research 7 Sep 26, 2022
A lightweight Python-based 3D network multi-agent simulator. Uses a cell-based congestion model. Calculates risk, loudness and battery capacities of the agents. Suitable for 3D network optimization tasks.

AMAZ3DSim AMAZ3DSim is a lightweight python-based 3D network multi-agent simulator. It uses a cell-based congestion model. It calculates risk, battery

Daniel Hirsch 13 Nov 4, 2022
Fast and customizable reconnaissance workflow tool based on simple YAML based DSL.

Fast and customizable reconnaissance workflow tool based on simple YAML based DSL, with support of notifications and distributed workload of that work

Américo Júnior 3 Mar 11, 2022
PyTorchCV: A PyTorch-Based Framework for Deep Learning in Computer Vision.

PyTorchCV: A PyTorch-Based Framework for Deep Learning in Computer Vision @misc{CV2018, author = {Donny You ([email protected])}, howpubl

Donny You 40 Sep 14, 2022
Official PyTorch implementation for paper Context Matters: Graph-based Self-supervised Representation Learning for Medical Images

Context Matters: Graph-based Self-supervised Representation Learning for Medical Images Official PyTorch implementation for paper Context Matters: Gra

null 49 Nov 23, 2022
PIKA: a lightweight speech processing toolkit based on Pytorch and (Py)Kaldi

PIKA: a lightweight speech processing toolkit based on Pytorch and (Py)Kaldi PIKA is a lightweight speech processing toolkit based on Pytorch and (Py)

null 336 Nov 25, 2022
PyTorch code for the paper: FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning

FeatMatch: Feature-Based Augmentation for Semi-Supervised Learning This is the PyTorch implementation of our paper: FeatMatch: Feature-Based Augmentat

null 43 Nov 19, 2022
This is an implementation of PIFuhd based on Pytorch

Open-PIFuhd This is a unofficial implementation of PIFuhd PIFuHD: Multi-Level Pixel-Aligned Implicit Function forHigh-Resolution 3D Human Digitization

Lingteng Qiu 235 Dec 19, 2022
A pytorch reprelication of the model-based reinforcement learning algorithm MBPO

Overview This is a re-implementation of the model-based RL algorithm MBPO in pytorch as described in the following paper: When to Trust Your Model: Mo

Xingyu Lin 93 Jan 5, 2023
Official PyTorch implementation for Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers, a novel method to visualize any Transformer-based network. Including examples for DETR, VQA.

PyTorch Implementation of Generic Attention-model Explainability for Interpreting Bi-Modal and Encoder-Decoder Transformers 1 Using Colab Please notic

Hila Chefer 489 Jan 7, 2023
MMDetection3D is an open source object detection toolbox based on PyTorch

MMDetection3D is an open source object detection toolbox based on PyTorch, towards the next-generation platform for general 3D detection. It is a part of the OpenMMLab project developed by MMLab.

OpenMMLab 3.2k Jan 5, 2023
Devkit for 3D -- Some utils for 3D object detection based on Numpy and Pytorch

D3D Devkit for 3D: Some utils for 3D object detection and tracking based on Numpy and Pytorch Please consider siting my work if you find this library

Jacob Zhong 27 Jul 7, 2022
A PyTorch-Based Framework for Deep Learning in Computer Vision

TorchCV: A PyTorch-Based Framework for Deep Learning in Computer Vision @misc{you2019torchcv, author = {Ansheng You and Xiangtai Li and Zhen Zhu a

Donny You 2.2k Jan 9, 2023
A semantic segmentation toolbox based on PyTorch

Introduction vedaseg is an open source semantic segmentation toolbox based on PyTorch. Features Modular Design We decompose the semantic segmentation

null 407 Dec 15, 2022