This repository contains an implementation of the Permutohedral Attention Module in Pytorch

Overview

Permutohedral_attention_module

This repository contains an implementation of the Permutohedral Attention Module (http://arxiv.org/abs/1907.00641) in Pytorch. We first used the Niftynet CRF as RNN implementation as model (https://niftynet.readthedocs.io/en/dev/_modules/niftynet/layer/crf.html#CRFAsRNNLayer) for the code.

This repository contains two versions of a HashTable in Pytorch, one in plain Pytorch (used in http://arxiv.org/abs/1907.00641) and one with a custom CUDA kernel that needs to be compiled and binded to Pytorch (this is the latest version that should be used now). In addition to those features, the repository also contains an implementation of the CRF-as-RNN widely used for segmentation regularization especially in medical imaging.

The repository also contains all the files to reproduce the experimental results presented in the "Permutohedral Attention Module for Efficient Non-Local Neural Networks" paper. In case of any issue to reproduce the results, miss-understanding or mistake you might find, please do not hesitate to contact us at: [email protected].

Comments
  • Try the Permutohedral latice

    Try the Permutohedral latice

    Hi, First, very nice work! Second, I tried to you code of the permutohedral lattice as a simple filter and I think that there might be a problem because when I run this code on a 2d RGB image:

    from permuthohedral_lattice import PermutohedralLattice
    
    
    img = np.asarray(Image.open("small_input.bmp"))
    
    indices = np.reshape(np.indices(img.shape[:2]), (2, -1))[None, :]
    rgb = np.reshape(img, (3, -1))[None, :]
    
    
    pl = PermutohedralLattice.apply
    
    out = pl(torch.from_numpy(indices/5.0).cuda().float(),
             torch.from_numpy(rgb/0.125).cuda().float())
    
    output = out.squeeze().cpu().numpy()
    output = np.reshape(output, img.shape)
    result = Image.fromarray((output/output.max() *255).astype(np.uint8))
    result.save('out.bmp')
    

    I get this image

    out

    I see two problems with this image : duplication of the image and horizontal black stripes. Do you know what might be causing this? Thanks

    opened by jgsimard 7
  • PermutohedralLattice.apply

    PermutohedralLattice.apply

    hi, here is the second question.

    i have an rgb image with height and width h and w. my initial thought is to compute the affinity matrix W then, perform other necessary matrix product. but, i think the whole point to pl is to effciently compoute the prodcut W * IMAGE_OR_SOFTMAXSCORES.

    W is of size (h * w, h * w).

    W is the matrix defined here, and eq.4 in here, and as you defined it in your code

    https://github.com/SamuelJoutard/Permutohedral_attention_module/blob/c86b8108fbfcf73ce300197e57cccbdfa25386ff/CRF/crf.py#L92

    of crf in here in eq.3 first term, and in eq6 in here.

    often, we need to compute St * W * (1-S) where S is the softmax scores, * is matrix product, t is matrix transpose.

    my question now is how to use your PermutohedralLattice.apply to compute either:

    1. St * W
    2. W * (1-S) or simply W * Z.

    i need to be able to perform both, in particular the second operation. Z has the same shape as S. it is useful to in order to compare to another c++ implementation that evaluates first W * Z where Z is simple S for a technical reason; then perform the product between St an the result of W * Z.

    thanks

    from your code in here,

    https://github.com/SamuelJoutard/Permutohedral_attention_module/blob/c86b8108fbfcf73ce300197e57cccbdfa25386ff/CRF/crf.py#L102

    here what i did, and then i got this error reported in the other post. it is your code, but used over dummy inputs and considering rgb image and the softmax scores has only one plan filled with 1. i expect the output i.e. norm_1 to have the same number of elements as ones:

        np.random.seed(0)
        n, c, h, w = 32, 3, 224, 225
    
        img = np.random.rand(n, c, h, w) * 255
        img = torch.cuda.FloatTensor(img)
        img = torch.clip(img, 0, 255)
    
        npx = h * w
        spatial_x, spatial_y= torch.meshgrid(
            torch.arange(h).cuda(),
            torch.arange(w).cuda()
    
        )
        spatial = torch.stack([spatial_x, spatial_y], dim=0)  # 4d tensor
        # Duplicate the coordinates along the batch dimension
        spatial = spatial.unsqueeze(0).repeat(n, 1, 1, 1)  # 5d tensor
        spatial = spatial.type(torch.cuda.FloatTensor).detach()
        spatial = torch.reshape(spatial, (n, spatial.size(1), -1))
        # Create the bilateral kernel features
        # Features for the first term of eq (3) in [1]
        img_fea = torch.reshape(img, (n, img.size(1), -1))
        _alpha = 1
        _beta = 1
        features_1 = torch.cat([spatial / _alpha, img_fea / _beta], dim=1)
        ones = torch.ones((n, 1, npx)).cuda()
        pl = PermutohedralLattice.apply
        norm_1 = pl(features_1, ones)
    

    thanks for your help

    opened by sbelharbi 4
  • RuntimeError: CUDA error: invalid device function

    RuntimeError: CUDA error: invalid device function

    hi, thanks for this code.

    i have 2 related questions.

    q1. when running this example on my machine, i got this error:

    Traceback (most recent call last):
      File "x/permut.py", line 29, in <module>
        torch.from_numpy(rgb / 0.125).cuda().float())
      File "x/PAM_cuda/pl.py", line 20, in forward
        rank, barycentric, blur_neighbours1, blur_neighbours2, indices = PermutohedralLattice.prepare(feat)
      File "x/PAM_cuda/pl.py", line 116, in prepare
        _ = HT_opp.insert(table, n_entries, loc[scit].type(torch.cuda.IntTensor), loc_hash[scit].type(torch.cuda.IntTensor))
    RuntimeError: CUDA error: invalid device function
    CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
    For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
    
    Process finished with exit code 139 (interrupted by signal 11: SIGSEGV)
    

    it is a segfault error. the installation is done using

    python setup.y build
    python setup.y install
    

    when running with $ CUDA_LAUNCH_BLOCKING=1 python permut.py, i got this:

    Traceback (most recent call last):
      File "permut.py", line 29, in <module>
        torch.from_numpy(rgb / 0.125).cuda().float())
      File "x/PAM_cuda/pl.py", line 20, in forward
        rank, barycentric, blur_neighbours1, blur_neighbours2, indices = PermutohedralLattice.prepare(feat)
      File "x/PAM_cuda/pl.py", line 116, in prepare
        _ = HT_opp.insert(table, n_entries, loc[scit].type(torch.cuda.IntTensor), loc_hash[scit].type(torch.cuda.IntTensor))
    RuntimeError: CUDA error: invalid device function
    Segmentation fault (core dumped)
    

    the used code is:

    
    
    import sys
    from os.path import dirname, abspath
    
    import re
    import torch.nn as nn
    import torch
    import torch.nn.functional as F
    
    # path stuff
    # path stuff
    
    from PAM_cuda.pl import PermutohedralLattice
    
    if __name__ == '__main__':
        import numpy as np
        import cv2
        import torch
        import matplotlib.pyplot as plt
    
        im = cv2.imread("dog.png")
        indices = np.reshape(np.indices(im.shape[:2]), (2, -1))[None, :]
        im = np.transpose(im, (2, 0, 1))
        rgb = np.reshape(im, (3, -1))[None, :]
    
        pl = PermutohedralLattice.apply
    
        out = pl(torch.from_numpy(indices / 5.0).cuda().float(),
                 torch.from_numpy(rgb / 0.125).cuda().float())
    
        output = out.squeeze().cpu().numpy()
        output = np.transpose(output, (1, 0))
        output = np.reshape(output, (im.shape[1], im.shape[2], 3))
    
        plt.imshow(output / output.max())
        plt.imshow(np.transpose(im, (1, 2, 0)))
    

    any idea how to fix this?

    i will post the other question in a separate issue. thanks for your help

    info: conda virtual env: conda create -n env_test python=3.7 python 3.7.9 pytorch 1.9.0 installed with conda install pytorch==1.9.0 torchvision==0.10.0 cudatoolkit=11.1 -c pytorch -c nvidia cv2 4.1.2

    $ nvcc --version nvcc: NVIDIA (R) Cuda compiler driver Copyright (c) 2005-2018 NVIDIA Corporation Built on Sat_Aug_25_21:08:01_CDT_2018 Cuda compilation tools, release 10.0, V10.0.130

    CUDA Version with nvcc-smi: 11.1 gpu: p100

    nvisia-smi: NVIDIA-SMI 455.32.00
    Driver Version: 455.32.00

    so far , i tested only on one server, where i expected the example to work. let me know if you need more info. the virtual env is within conda.

    thanks

    opened by sbelharbi 3
  • Fully-connected 3D CRF

    Fully-connected 3D CRF

    I have added an implementation of a fully-connected 3D CRF for segmentation in PyTorch /CRF/crf.py using the efficient permutohedral lattice implementation of this repo.

    opened by LucasFidon 0
  • Which file should I use?

    Which file should I use?

    Hello! I attempt to use PAM as a plug and play module for my task. Should I use PAM or PAM_cuda? Could you tell my about the detailed method of it? I'm a green hand of CUDA.

    opened by Stephen0808 0
  • Need some unit test

    Need some unit test

    Unit test for the diferent components would be required on the longer run. Examples of such tests for the permutohedral lattice and the CRF can be found here: https://github.com/NifTK/NiftyNet/blob/dev/tests/crf_test.py As part of the c++ code in http://graphics.stanford.edu/projects/drf/ https://github.com/lucasb-eyer/pydensecrf/tree/master/tests https://github.com/MiguelMonteiro/permutohedral_lattice/tree/master/Tests https://github.com/MiguelMonteiro/CRFasRNNLayer/tree/master/Tests

    opened by tvercaut 0
Owner
Samuel JOUTARD
Samuel JOUTARD
Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image classification, in Pytorch

Transformer in Transformer Implementation of Transformer in Transformer, pixel level attention paired with patch level attention for image c

Phil Wang 272 Dec 23, 2022
Official Pytorch Implementation of Relational Self-Attention: What's Missing in Attention for Video Understanding

Relational Self-Attention: What's Missing in Attention for Video Understanding This repository is the official implementation of "Relational Self-Atte

mandos 43 Dec 7, 2022
Implementation of Deformable Attention in Pytorch from the paper "Vision Transformer with Deformable Attention"

Deformable Attention Implementation of Deformable Attention from this paper in Pytorch, which appears to be an improvement to what was proposed in DET

Phil Wang 128 Dec 24, 2022
This repository contains a pytorch implementation of "StereoPIFu: Depth Aware Clothed Human Digitization via Stereo Vision".

StereoPIFu: Depth Aware Clothed Human Digitization via Stereo Vision | Project Page | Paper | This repository contains a pytorch implementation of "St

null 87 Dec 9, 2022
This repository contains a PyTorch implementation of "AD-NeRF: Audio Driven Neural Radiance Fields for Talking Head Synthesis".

AD-NeRF: Audio Driven Neural Radiance Fields for Talking Head Synthesis | Project Page | Paper | PyTorch implementation for the paper "AD-NeRF: Audio

null 551 Dec 29, 2022
This repository contains the PyTorch implementation of the paper STaCK: Sentence Ordering with Temporal Commonsense Knowledge appearing at EMNLP 2021.

STaCK: Sentence Ordering with Temporal Commonsense Knowledge This repository contains the pytorch implementation of the paper STaCK: Sentence Ordering

Deep Cognition and Language Research (DeCLaRe) Lab 23 Dec 16, 2022
RGBD-Net - This repository contains a pytorch lightning implementation for the 3DV 2021 RGBD-Net paper.

[3DV 2021] We propose a new cascaded architecture for novel view synthesis, called RGBD-Net, which consists of two core components: a hierarchical depth regression network and a depth-aware generator network.

Phong Nguyen Ha 4 May 26, 2022
This repository contains a pytorch implementation of "HeadNeRF: A Real-time NeRF-based Parametric Head Model (CVPR 2022)".

HeadNeRF: A Real-time NeRF-based Parametric Head Model This repository contains a pytorch implementation of "HeadNeRF: A Real-time NeRF-based Parametr

null 294 Jan 1, 2023
PyTorch code for our paper "Attention in Attention Network for Image Super-Resolution"

Under construction... Attention in Attention Network for Image Super-Resolution (A2N) This repository is an PyTorch implementation of the paper "Atten

Haoyu Chen 71 Dec 30, 2022
Implementation of the 😇 Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones

HaloNet - Pytorch Implementation of the Attention layer from the paper, Scaling Local Self-Attention For Parameter Efficient Visual Backbones. This re

Phil Wang 189 Nov 22, 2022
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification

STAM - Pytorch Implementation of STAM (Space Time Attention Model), yet another pure and simple SOTA attention model that bests all previous models in

Phil Wang 109 Dec 28, 2022
Implementation of a memory efficient multi-head attention as proposed in the paper, "Self-attention Does Not Need O(n²) Memory"

Memory Efficient Attention Pytorch Implementation of a memory efficient multi-head attention as proposed in the paper, Self-attention Does Not Need O(

Phil Wang 180 Jan 5, 2023
Official implementation of cosformer-attention in cosFormer: Rethinking Softmax in Attention

cosFormer Official implementation of cosformer-attention in cosFormer: Rethinking Softmax in Attention Update log 2022/2/28 Add core code License This

null 120 Dec 15, 2022
An image base contains 490 images for learning (400 cars and 90 boats), and another 21 images for testingAn image base contains 490 images for learning (400 cars and 90 boats), and another 21 images for testing

SVM Données Une base d’images contient 490 images pour l’apprentissage (400 voitures et 90 bateaux), et encore 21 images pour fait des tests. Prétrait

Achraf Rahouti 3 Nov 30, 2021
This repository contains PyTorch code for Robust Vision Transformers.

This repository contains PyTorch code for Robust Vision Transformers.

null 117 Dec 7, 2022
This repository contains PyTorch models for SpecTr (Spectral Transformer).

SpecTr: Spectral Transformer for Hyperspectral Pathology Image Segmentation This repository contains PyTorch models for SpecTr (Spectral Transformer).

Boxiang Yun 45 Dec 13, 2022
The AugNet Python module contains functions for the fast computation of image similarity.

AugNet AugNet: End-to-End Unsupervised Visual Representation Learning with Image Augmentation arxiv link In our work, we propose AugNet, a new deep le

Ming 74 Dec 28, 2022
An efficient and effective learning to rank algorithm by mining information across ranking candidates. This repository contains the tensorflow implementation of SERank model. The code is developed based on TF-Ranking.

SERank An efficient and effective learning to rank algorithm by mining information across ranking candidates. This repository contains the tensorflow

Zhihu 44 Oct 20, 2022
This repository contains the implementation of Deep Detail Enhancment for Any Garment proposed in Eurographics 2021

Deep-Detail-Enhancement-for-Any-Garment Introduction This repository contains the implementation of Deep Detail Enhancment for Any Garment proposed in

null 40 Dec 13, 2022