Equivariant CNNs for the sphere and SO(3) implemented in PyTorch

Related tags

Deep Learning s2cnn
Overview

⚠️ ⚠️ This code is old and does not support the last versions of pytorch! Especially since the change in the fft interface. ⚠️ ⚠️

Spherical CNNs

Equivariant CNNs for the sphere and SO(3) implemented in PyTorch

Equivariance

Overview

This library contains a PyTorch implementation of the rotation equivariant CNNs for spherical signals (e.g. omnidirectional images, signals on the globe) as presented in [1]. Equivariant networks for the plane are available here.

Dependencies

(commands to install all the dependencies on a new conda environment)

conda create --name cuda9 python=3.6 
conda activate cuda9

# s2cnn deps
#conda install pytorch torchvision cuda90 -c pytorch # get correct command line at http://pytorch.org/
conda install -c anaconda cupy  
pip install pynvrtc joblib

# lie_learn deps
conda install -c anaconda cython  
conda install -c anaconda requests  

# shrec17 example dep
conda install -c anaconda scipy  
conda install -c conda-forge rtree shapely  
conda install -c conda-forge pyembree  
pip install "trimesh[easy]"  

Installation

To install, run

$ python setup.py install

Usage

Please have a look at the examples.

Please cite [1] in your work when using this library in your experiments.

Design choices for Spherical CNN Architectures

Spherical CNNs come with different choices of grids and grid hyperparameters which are on the first look not obviously related to those of conventional CNNs. The s2_near_identity_grid and so3_near_identity_grid are the preferred choices since they correspond to spatially localized kernels, defined at the north pole and rotated over the sphere via the action of SO(3). In contrast, s2_equatorial_grid and so3_equatorial_grid define line-like (or ring-like) kernels around the equator.

To clarify the possible parameter choices for s2_near_identity_grid:

max_beta:

Adapts the size of the kernel as angle measured from the north pole. Conventional CNNs on flat space usually use a fixed kernel size but pool the signal spatially. This spatial pooling gives the kernels in later layers an effectively increased field of view. One can emulate a pooling by a factor of 2 in spherical CNNs by decreasing the signal bandwidth by 2 and increasing max_beta by 2.

n_beta:

Number of rings of the kernel around the equator, equally spaced in [β=0, β=max_beta]. The choice n_beta=1 corresponds to a small 3x3 kernel in conv2d since in both cases the resulting kernel consists of one central pixel and one ring around the center.

n_alpha:

Gives the number of learned parameters of the rings around the pole. These values are per default equally spaced on the azimuth. A sensible number of values depends on the bandwidth and max_beta since a higher resolution or spatial extent allow to sample more fine kernels without producing aliased results. In practice this value is typically set to a constant, low value like 6 or 8. A reduced bandwidth of the signal is thereby counteracted by an increased max_beta to emulate spatial pooling.

The so3_near_identity_grid has two additional parameters max_gamma and n_gamma. SO(3) can be seen as a (principal) fiber bundle SO(3)→S² with the sphere S² as base space and fiber SO(2) attached to each point. The additional parameters control the grid on the fiber in the following way:

max_gamma:

The kernel spans over the fiber SO(2) between γ∈[0, max_gamma]. The fiber SO(2) encodes the kernel responses for every sampled orientation at a given position on the sphere. Setting max_gamma≨2π results in the kernel not seeing the responses of all kernel orientations simultaneously and is in general unfavored. Steerable CNNs [3] usually always use max_gamma=2π.

n_gamma:

Number of learned parameters on the fiber. Typically set equal to n_alpha, i.e. to a low value like 6 or 8.

See the deep model of the MNIST example for an example of how to adapt these parameters over layers.

Feedback

For questions and comments, feel free to contact us: geiger.mario (gmail), taco.cohen (gmail), jonas (argmin.xyz).

License

MIT

References

[1] Taco S. Cohen, Mario Geiger, Jonas Köhler, Max Welling, Spherical CNNs. International Conference on Learning Representations (ICLR), 2018.

[2] Taco S. Cohen, Mario Geiger, Jonas Köhler, Max Welling, Convolutional Networks for Spherical Signals. ICML Workshop on Principled Approaches to Deep Learning, 2017.

[3] Taco S. Cohen, Mario Geiger, Maurice Weiler, Intertwiners between Induced Representations (with applications to the theory of equivariant neural networks), ArXiv preprint 1803.10743, 2018.

Comments
  • Usage documentation?

    Usage documentation?

    Sorry to spam you all with multiple issues, but is there any usage documentation associated with your s2cnn package? In particular, I'm curious when to use the different grid types and convolution types. Things like so3_equatorial_grid() vs. so3_soft_grid(), etc.

    I also notice you explicitly call so3_integrate() in the MNIST example. I am wondering why there is a need for explicit integration, and what the operation is doing (I couldn't find that in the papers).

    opened by meder411 14
  • module 's2cnn.ops.gpu.lib_cufft' has no attribute 'destroy'

    module 's2cnn.ops.gpu.lib_cufft' has no attribute 'destroy'

    This problem has been asked in issue 3, but I can't fix it. After adding , encoding='utf-8', now the open() function in setup.py is like this: long_description=open(os.path.join(os.path.dirname(__file__), "README.md"), encoding='utf-8').read(),

    Then after runing python setup.py install, the encoding issue : UnicodeDecodeError: 'ascii' codec can't decode byte 0xc3 in position 1264: ordinal not in range(128) is gone. But when I want to run the example of shrec17, the following problem still remains: module 's2cnn.ops.gpu.lib_cufft' has no attribute 'destroy'

    I think may be the real problem occurs in the version of CUDA or pytorch. Here's my system environment:

    • Python 3.6
    • CUDA 9.1.85
    • pytorch 0.4.0

    Can you help me? Thanks a lot!

    opened by zhixuanli 9
  • module 's2cnn.ops.gpu.lib_cufft' has no attribute 'destroy'

    module 's2cnn.ops.gpu.lib_cufft' has no attribute 'destroy'

    Trying to run shrec17 example in a conda3 docker image with Python 3.6 [cuda-9.0 NVIDIA drivers 384.111] I faced the following error:

    {'num_workers': 1, 'batch_size': 32, 'dataset': 'train', 'augmentation': 4, 'model_path': 'model.py', 'log_dir': 'my_run', 'learning_rate': 0.5}
    Downloading http://3dvision.princeton.edu/ms/shrec17-data/train_perturbed.zip
    Unzip data/train_perturbed.zip
    Fix obj files
    Downloading http://3dvision.princeton.edu/ms/shrec17-data/train.csv
    Done!
    402955 paramerters in total
    5555 paramerters in the last layer
    learning rate = 1 and batch size = 32
    transform data/train_perturbed/046114.obj...
    /root/miniconda3/lib/python3.6/site-packages/trimesh/triangles.py:188: RuntimeWarning: divide by zero encountered in true_divide
      center_mass = integrated[1:4] / volume
    /root/miniconda3/lib/python3.6/site-packages/trimesh/triangles.py:188: RuntimeWarning: invalid value encountered in true_divide
      center_mass = integrated[1:4] / volume
    transform data/train_perturbed/005351.obj...
    transform data/train_perturbed/019736.obj...
    transform data/train_perturbed/018758.obj...
    transform data/train_perturbed/029336.obj...
    transform data/train_perturbed/012867.obj...
    transform data/train_perturbed/045223.obj...
    transform data/train_perturbed/025009.obj...
    transform data/train_perturbed/048329.obj...
    transform data/train_perturbed/038370.obj...
    transform data/train_perturbed/037326.obj...
    transform data/train_perturbed/025172.obj...
    transform data/train_perturbed/015628.obj...
    transform data/train_perturbed/038990.obj...
    transform data/train_perturbed/040417.obj...
    transform data/train_perturbed/044571.obj...
    transform data/train_perturbed/038458.obj...
    transform data/train_perturbed/048180.obj...
    transform data/train_perturbed/033437.obj...
    transform data/train_perturbed/030847.obj...
    transform data/train_perturbed/050627.obj...
    transform data/train_perturbed/005628.obj...
    transform data/train_perturbed/045656.obj...
    transform data/train_perturbed/008172.obj...
    transform data/train_perturbed/010100.obj...
    transform data/train_perturbed/024292.obj...
    transform data/train_perturbed/038671.obj...
    transform data/train_perturbed/025215.obj...
    transform data/train_perturbed/032604.obj...
    transform data/train_perturbed/048823.obj...
    transform data/train_perturbed/018781.obj...
    transform data/train_perturbed/040830.obj...
    transform data/train_perturbed/007740.obj...
    Traceback (most recent call last):
      File "train.py", line 135, in <module>
        main(**args.__dict__)
      File "train.py", line 105, in main
        loss, correct = train_step(data, target)
      File "train.py", line 74, in train_step
        prediction = model(data)
      File "/root/miniconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 357, in __call__
        result = self.forward(*input, **kwargs)
      File "my_run/model.py", line 47, in forward
        x = self.sequential(x)  # [batch, feature, beta, alpha, gamma]
      File "/root/miniconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 357, in __call__
        result = self.forward(*input, **kwargs)
      File "/root/miniconda3/lib/python3.6/site-packages/torch/nn/modules/container.py", line 67, in forward
        input = module(input)
      File "/root/miniconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 357, in __call__
        result = self.forward(*input, **kwargs)
      File "/root/miniconda3/lib/python3.6/site-packages/s2cnn-1.0.0-py3.6-linux-x86_64.egg/s2cnn/nn/soft/s2_conv.py", line 40, in forward
        x = S2_fft_real(b_out=self.b_out)(x) # [l * m, batch, feature_in, complex]
      File "/root/miniconda3/lib/python3.6/site-packages/s2cnn-1.0.0-py3.6-linux-x86_64.egg/s2cnn/nn/soft/gpu/s2_fft.py", line 231, in forward
        return s2_fft(as_complex(x), b_out=self.b_out)
      File "/root/miniconda3/lib/python3.6/site-packages/s2cnn-1.0.0-py3.6-linux-x86_64.egg/s2cnn/nn/soft/gpu/s2_fft.py", line 27, in s2_fft
        output = _s2_fft(x, for_grad=for_grad, b_in=b_in, b_out=b_out) # [l * m, batch, complex]
      File "/root/miniconda3/lib/python3.6/site-packages/s2cnn-1.0.0-py3.6-linux-x86_64.egg/s2cnn/nn/soft/gpu/s2_fft.py", line 43, in _s2_fft
        plan = _setup_fft_plan(b_in, nbatch)
      File "/root/miniconda3/lib/python3.6/site-packages/s2cnn-1.0.0-py3.6-linux-x86_64.egg/s2cnn/nn/soft/gpu/s2_fft.py", line 146, in _setup_fft_plan
        plan = Plan1d_c2c(N=2 * b, batch=nbatch * 2 * b)
      File "/root/miniconda3/lib/python3.6/site-packages/s2cnn-1.0.0-py3.6-linux-x86_64.egg/s2cnn/ops/gpu/torchcufft.py", line 12, in __init__
        self.handler = cufft.plan1d_c2c(N, istride, idist, ostride, odist, batch)
    AttributeError: module 's2cnn.ops.gpu.lib_cufft' has no attribute 'plan1d_c2c'
    Exception ignored in: <bound method Plan1d_c2c.__del__ of <s2cnn.ops.gpu.torchcufft.Plan1d_c2c object at 0x7efc0f667438>>
    Traceback (most recent call last):
      File "/root/miniconda3/lib/python3.6/site-packages/s2cnn-1.0.0-py3.6-linux-x86_64.egg/s2cnn/ops/gpu/torchcufft.py", line 18, in __del__
        cufft.destroy(self.handler)
    AttributeError: module 's2cnn.ops.gpu.lib_cufft' has no attribute 'destroy'
    
    opened by blancaag 7
  • How to decide the bandwidth?

    How to decide the bandwidth?

    Firstly, thank you for the great work!

    As the paper mentioned The maximum frequency b is known as the bandwidth, and is related to the resolution of the spatial grid (Kostelec and Rockmore, 2007).

    I notice that the bandwidth is set to 30 when you generate the new MNIST dataset. For each S2Convolution and SO3Convolution, the bandwidth is different.

    I wonder how to decide the parameter bandwidth when we use S2Convolution or SO3Convolution? Is it also a super parameter (empirical value) or need to be calculated meticulously?

    Why you set bandwidth = 30 when you generate the new MNIST dataset?

    Thank you!

    opened by Jiankai-Sun 5
  • RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

    RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation

    Hi, I tried to use your s2conv/so3conv in multi model like following. (Model includes your s2conv/so3conv)

    def train(epoch):
        model.train()
        for batch_idx, (image,target) in enumerate(train_loader):
            image = image.to(device)
            optimizer.zero_grad()
           
            # multi model
            re_image1 = model(image)
            re_image2 = model(image)
            loss = re_image1.abs().mean() + re_image2.abs().mean()
    
            loss.backward()
            optimizer.step()
    

    Then I got following error.

      File "main.py", line 66, in <module>
        main()
      File "main.py", line 62, in main
        train(epoch)
      File "main.py", line 53, in train
        loss.backward()
      File "/home/hayashi/.python-venv/lib/python3.5/site-packages/torch/tensor.py", line 93, in backward
        torch.autograd.backward(self, gradient, retain_graph, create_graph)
      File "/home/hayashi/.python-venv/lib/python3.5/site-packages/torch/autograd/__init__.py", line 89, in backward
        allow_unreachable=True)  # allow_unreachable flag
    RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation
    

    There are no error when I use mono-model like following

    def train(epoch):
        model.train()
        for batch_idx, (image,target) in enumerate(train_loader):
            image = image.to(device)
            optimizer.zero_grad()
           
            # mono model
            image1 = model(image)
            loss = image1.abs().mean() 
    
            loss.backward()
            optimizer.step()
    

    So I think this error is not caused from inplace operation. Do you know this error's detail?

    P.S. I found this error doesn't occur when I use past version of your s2conv/so3conv. (maybe this is for Pytorch v0.3.1) If you can, please republish past version of s2cnn (for Pytorch v0.3.1).

    opened by udonuser 5
  • Error running example

    Error running example

    Hello,

    I am trying to run mnist example, but i am getting many errors running script /examples/mnist/run.py. Seems like i have wrong version of some packages but cant figure out which versions are needed. Fixing one error leads to another error..

    Some errors: x = torch.fft(torch.stack((x, torch.zeros_like(x)), dim=-1), 2) TypeError: 'module' object is not callable

    RuntimeError: view_as_complex is only supported for half, float and double tensors, but got a tensor of scalar type: ComplexFloat

    I have installed all the dependencies:

    conda create --name cuda9 python=3.6 
    conda activate cuda9
    
    # s2cnn deps
    #conda install pytorch torchvision cuda90 -c pytorch # get correct command line at http://pytorch.org/
    conda install -c anaconda cupy  
    pip install pynvrtc joblib
    
    # lie_learn deps
    conda install -c anaconda cython  
    conda install -c anaconda requests  
    
    # shrec17 example dep
    conda install -c anaconda scipy  
    conda install -c conda-forge rtree shapely  
    conda install -c conda-forge pyembree  
    pip install "trimesh[easy]"  
    

    python setup.py install

    Any help? Thanks.

    opened by DavidHribek 4
  •  query about feature maps

    query about feature maps

    Hi all, thanks for your great effort in this research! I have a question on how to get the feature map of an equirectangular image input.

    For example, I used the model S2ConvNet_deep in the MNIST example code to get feature map of an equirectangular image by cutting out the linear block at the end. I resized input panoramic image to 100x100 and the initial bandwith is 50. But the model returns an output in shape (1, 64). How can I get larger feature map like a VGG feature extractor?

    Can you please share your precious comments?

    opened by ustundag 4
  • some question when I run gendata.py in /examples/mnist folder

    some question when I run gendata.py in /examples/mnist folder

    When I run gendata.py with no rotation in training data and test data with the command :

            python3 gendata.py --no_rotate_train --no_rotate_test                
    

    And I visualize some results, but some images generated seem to has nothing to do with the original images. Some results are like these: Figure_1

    Figure_2

    Figure_3

    I think the spherical images(right) are strange, there has nothing to do with the original images(left). I also test the algorithm on some other image: Figure_4

    Obviously, there is something wrong with gendata.py, but I can not figure out what is wrong, so can you explain it?

    opened by townblack 4
  • cupy.cuda.driver.CUDADriverError: CUDA_ERROR_INVALID_HANDLE: invalid resource handle

    cupy.cuda.driver.CUDADriverError: CUDA_ERROR_INVALID_HANDLE: invalid resource handle

    Thank you for publishing such a great code ! I have a question.

    When I use this Spherical convolution in our network, I tried to train our model using on multi GPU like torch.nn.DataParallel(model).cuda() However I got following error message.

      File "/home/users/.python_venv_3/lib/python3.5/site-packages/torch/nn/modules/module.py", line 491, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/users/.python_venv_3/lib/python3.5/site-packages/s2cnn-1.0.0-py3.5.egg/s2cnn/soft/s2_conv.py", line 40, in forward
      File "/home/users/.python_venv_3/lib/python3.5/site-packages/s2cnn-1.0.0-py3.5.egg/s2cnn/soft/gpu/s2_fft.py", line 225, in forward
      File "/home/users/.python_venv_3/lib/python3.5/site-packages/s2cnn-1.0.0-py3.5.egg/s2cnn/soft/gpu/s2_fft.py", line 27, in s2_fft
      File "/home/users/.python_venv_3/lib/python3.5/site-packages/s2cnn-1.0.0-py3.5.egg/s2cnn/soft/gpu/s2_fft.py", line 51, in _s2_fft
      File "cupy/cuda/function.pyx", line 147, in cupy.cuda.function.Function.__call__
      File "cupy/cuda/function.pyx", line 129, in cupy.cuda.function._launch
      File "cupy/cuda/driver.pyx", line 195, in cupy.cuda.driver.launchKernel
      File "cupy/cuda/driver.pyx", line 75, in cupy.cuda.driver.check_status
    cupy.cuda.driver.CUDADriverError: CUDA_ERROR_INVALID_HANDLE: invalid resource handle
    

    Can we use DataParallel for this spherical convolution? Or is there a future plan to implement for this?

    opened by syinari0123 4
  • Error occurred in the process of installation

    Error occurred in the process of installation

    When I run the command python setup.py install, I got output as follows

    running install running bdist_egg running egg_info writing s2cnn.egg-info/PKG-INFO writing dependency_links to s2cnn.egg-info/dependency_links.txt writing requirements to s2cnn.egg-info/requires.txt writing top-level names to s2cnn.egg-info/top_level.txt reading manifest file 's2cnn.egg-info/SOURCES.txt' writing manifest file 's2cnn.egg-info/SOURCES.txt' installing library code to build/bdist.linux-x86_64/egg running install_lib running build_py copying s2cnn/ops/gpu/lib_cufft/init.py -> build/lib.linux-x86_64-3.6/s2cnn/ops/gpu/lib_cufft running build_ext generating cffi module 'build/temp.linux-x86_64-3.6/s2cnn.ops.gpu.lib_cufft._lib_cufft.c' already up-to-date building 's2cnn.ops.gpu.lib_cufft._lib_cufft' extension gcc -pthread -B /home/whu/apps/anaconda3/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -DWITH_CUDA -I/home/whu/apps/anaconda3/lib/python3.6/site-packages/torch/utils/ffi/../../lib/include -I/home/whu/apps/anaconda3/lib/python3.6/site-packages/torch/utils/ffi/../../lib/include/TH -I/home/whu/apps/anaconda3/lib/python3.6/site-packages/torch/utils/ffi/../../lib/include/THC -I/usr/local/cuda/include -I/home/whu/apps/anaconda3/include/python3.6m -c build/temp.linux-x86_64-3.6/s2cnn.ops.gpu.lib_cufft._lib_cufft.c -o build/temp.linux-x86_64-3.6/build/temp.linux-x86_64-3.6/s2cnn.ops.gpu.lib_cufft._lib_cufft.o gcc -pthread -B /home/whu/apps/anaconda3/compiler_compat -Wl,--sysroot=/ -Wsign-compare -DNDEBUG -g -fwrapv -O3 -Wall -Wstrict-prototypes -fPIC -DWITH_CUDA -I/home/whu/apps/anaconda3/lib/python3.6/site-packages/torch/utils/ffi/../../lib/include -I/home/whu/apps/anaconda3/lib/python3.6/site-packages/torch/utils/ffi/../../lib/include/TH -I/home/whu/apps/anaconda3/lib/python3.6/site-packages/torch/utils/ffi/../../lib/include/THC -I/usr/local/cuda/include -I/home/whu/apps/anaconda3/include/python3.6m -c /home/whu/Documents/s2cnn/s2cnn/ops/gpu/plan_cufft.c -o build/temp.linux-x86_64-3.6/home/whu/Documents/s2cnn/s2cnn/ops/gpu/plan_cufft.o gcc -pthread -shared -B /home/whu/apps/anaconda3/compiler_compat -L/home/whu/apps/anaconda3/lib -Wl,-rpath=/home/whu/apps/anaconda3/lib -Wl,--no-as-needed -Wl,--sysroot=/ build/temp.linux-x86_64-3.6/build/temp.linux-x86_64-3.6/s2cnn.ops.gpu.lib_cufft._lib_cufft.o build/temp.linux-x86_64-3.6/home/whu/Documents/s2cnn/s2cnn/ops/gpu/plan_cufft.o -lcufft -o build/lib.linux-x86_64-3.6/s2cnn/ops/gpu/lib_cufft/_lib_cufft.abi3.so /home/whu/apps/anaconda3/compiler_compat/ld: cannot find -lcufft collect2: error: ld returned 1 exit status error: command 'gcc' failed with exit status 1

    opened by sherkwast 4
  • Error with einsum in Equivariance plot

    Error with einsum in Equivariance plot

    Hello,

    I'm trying to run the equivariance_plot but come up with this error, I am using CUDA 11 with pytorch and cupy for cuda 11:

     File "/home/owen/anaconda3/envs/s2cnn/lib/python3.8/site-packages/torch/functional.py", line 344, in einsum
        return _VF.einsum(equation, operands)  # type: ignore
    RuntimeError: expected scalar type Float but found ComplexFloat
    

    Full error log:

    runfile('/home/owen/PycharmProjects/s2cnn/examples/equivariance_plot/main.py', wdir='/home/owen/PycharmProjects/s2cnn/examples/equivariance_plot')
    /home/owen/anaconda3/envs/s2cnn/lib/python3.8/site-packages/torch/cuda/__init__.py:52: UserWarning: CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at  /opt/conda/conda-bld/pytorch_1607370172916/work/c10/cuda/CUDAFunctions.cpp:100.)
      return torch._C._cuda_getDeviceCount() > 0
    compute 0.pkl.gz... save 0.pkl.gz... done
    compute 0.pkl.gz... save 0.pkl.gz... done
    Traceback (most recent call last):
      File "/home/owen/anaconda3/envs/s2cnn/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3437, in run_code
        exec(code_obj, self.user_global_ns, self.user_ns)
      File "<ipython-input-2-34cedf137222>", line 1, in <module>
        runfile('/home/owen/PycharmProjects/s2cnn/examples/equivariance_plot/main.py', wdir='/home/owen/PycharmProjects/s2cnn/examples/equivariance_plot')
      File "/home/owen/.local/share/JetBrains/Toolbox/apps/PyCharm-P/ch-0/211.7142.13/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
        pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
      File "/home/owen/.local/share/JetBrains/Toolbox/apps/PyCharm-P/ch-0/211.7142.13/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
        exec(compile(contents+"\n", file, 'exec'), glob, loc)
      File "/home/owen/PycharmProjects/s2cnn/examples/equivariance_plot/main.py", line 102, in <module>
        main()
      File "/home/owen/PycharmProjects/s2cnn/examples/equivariance_plot/main.py", line 76, in main
        y1 = phi(s2_rotation(x, *abc))
      File "/home/owen/PycharmProjects/s2cnn/examples/equivariance_plot/main.py", line 14, in s2_rotation
        x = so3_rotation(x.view(*x.size(), 1).expand(*x.size(), x.size(-1)), a, b, c)
      File "/home/owen/PycharmProjects/s2cnn/s2cnn/soft/so3_rotation.py", line 21, in so3_rotation
        x = SO3_fft_real.apply(x)  # [l * m * n, ..., complex]
      File "/home/owen/PycharmProjects/s2cnn/s2cnn/soft/so3_fft.py", line 453, in forward
        return so3_rfft(x, b_out=ctx.b_out)
      File "/home/owen/PycharmProjects/s2cnn/s2cnn/soft/so3_fft.py", line 110, in so3_rfft
        out = torch.einsum("bmn,zbmnc->mnzc", (wigner[:, s].view(-1, 2 * l + 1, 2 * l + 1), xx))
      File "/home/owen/anaconda3/envs/s2cnn/lib/python3.8/site-packages/torch/functional.py", line 342, in einsum
        return einsum(equation, *_operands)
      File "/home/owen/anaconda3/envs/s2cnn/lib/python3.8/site-packages/torch/functional.py", line 344, in einsum
        return _VF.einsum(equation, operands)  # type: ignore
    RuntimeError: expected scalar type Float but found ComplexFloat
    
    opened by SkirOwen 3
  • No module named 'lie_learn.representations.SO3.irrep_bases'

    No module named 'lie_learn.representations.SO3.irrep_bases'

    @mariogeiger In your provided MNIST example, running S2CNN encounters the following problem about lie_learn module.

    Traceback (most recent call last):
      File "<frozen importlib._bootstrap>", line 983, in _find_and_load
      File "<frozen importlib._bootstrap>", line 967, in _find_and_load_unlocked
      File "<frozen importlib._bootstrap>", line 677, in _load_unlocked
      File "<frozen importlib._bootstrap_external>", line 728, in exec_module
      File "<frozen importlib._bootstrap>", line 219, in _call_with_frames_removed
      File "C:\Users\peng\anaconda3\envs\pytorch_1.8\lib\site-packages\lie_learn\representations\SO3\wigner_d.py", line 5, in <module>
        from lie_learn.representations.SO3.irrep_bases import change_of_basis_matrix
    ModuleNotFoundError: No module named 'lie_learn.representations.SO3.irrep_bases'
    

    Originally posted by @EricPengShuai in https://github.com/jonkhler/s2cnn/issues/52#issuecomment-893144464

    opened by EricPengShuai 4
  • Error in so3_rotation (Jd matrix size) with custom data

    Error in so3_rotation (Jd matrix size) with custom data

    I'm trying to make the equivariance_plot script work on my own weather data of size (721, 1440).
    When it runs the so3_rotation function I have the following error:

    Traceback (most recent call last):
      File "/home/owen/anaconda3/envs/cupy-test/lib/python3.8/site-packages/IPython/core/interactiveshell.py", line 3441, in run_code
        exec(code_obj, self.user_global_ns, self.user_ns)
      File "<ipython-input-2-678a6fff5083>", line 1, in <module>
        runfile('/home/owen/PycharmProjects/cranfield/irp/ai.py', wdir='/home/owen/PycharmProjects/cranfield/irp')
      File "/home/owen/.local/share/JetBrains/Toolbox/apps/PyCharm-P/ch-0/211.7628.24/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
        pydev_imports.execfile(filename, global_vars, local_vars)  # execute the script
      File "/home/owen/.local/share/JetBrains/Toolbox/apps/PyCharm-P/ch-0/211.7628.24/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
        exec(compile(contents+"\n", file, 'exec'), glob, loc)
      File "/home/owen/PycharmProjects/cranfield/irp/ai.py", line 109, in <module>
        main(date, cycle)
      File "/home/owen/PycharmProjects/cranfield/irp/ai.py", line 80, in main
        y1 = phi(s2_rotation(x, *abc))
      File "/home/owen/PycharmProjects/cranfield/irp/ai.py", line 15, in s2_rotation
        x = so3_rotation(x.view(*x.size(), 1).expand(*x.size(), x.size(-1)), a, b, c)
      File "/home/owen/PycharmProjects/cranfield/irp/s2cnn/soft/so3_rotation.py", line 18, in so3_rotation
        Us = _setup_so3_rotation(b, alpha, beta, gamma, device_type=x.device.type, device_index=x.device.index)
      File "/home/owen/PycharmProjects/cranfield/irp/s2cnn/soft/so3_rotation.py", line 67, in _setup_so3_rotation
        Us = __setup_so3_rotation(b, alpha, beta, gamma)
      File "/home/owen/PycharmProjects/cranfield/irp/s2cnn/utils/decorator.py", line 97, in wrapper
        result = func(*args)
      File "/home/owen/PycharmProjects/cranfield/irp/s2cnn/soft/so3_rotation.py", line 55, in __setup_so3_rotation
        Us = [wigner_D_matrix(l, alpha, beta, gamma,
      File "/home/owen/PycharmProjects/cranfield/irp/s2cnn/soft/so3_rotation.py", line 55, in <listcomp>
        Us = [wigner_D_matrix(l, alpha, beta, gamma,
      File "/home/owen/anaconda3/envs/cupy-test/lib/python3.8/site-packages/lie_learn/representations/SO3/wigner_d.py", line 63, in wigner_D_matrix
        D = rot_mat(alpha=alpha, beta=beta, gamma=gamma, l=l, J=Jd[l])
    IndexError: index 151 is out of bounds for axis 0 with size 151
    

    After trying to debug it, I have found that the issue is with the Jd matrix of size 151, defined in lie_learn when calling the line 55 in s2cnn/soft/so3_rotation.py

    Us = [wigner_D_matrix(l, alpha, beta, gamma,
                              field='complex', normalization='quantum', order='centered', condon_shortley='cs')
              for l in range(b)]
    

    b is set as b = x.size()[-1] // 2 so 720 in my case

    the winger_D_matrix calls the rot_mat function which uses Jd[l], but since Jd is size 151, it throws the error.

    Is there a way to redefine the Jd matrix or do I need to use a different function?

    opened by SkirOwen 0
  • Running MNIST Example Problems

    Running MNIST Example Problems

    Hello,

    Thanks for the great work! I have some issues with getting this code to run, starting with the example given in the repository. Basically, when I try to do python run.py in the mnist folder, I get a bunch of different errors that I feel like I shouldn't have to fix in order for the code to work since this is the basic example. The first error is this:

    Traceback (most recent call last): File "run.py", line 257, in main(args.network) File "run.py", line 221, in main outputs = classifier(images) File "/home/kiran/anaconda3/envs/OCT_latent_space/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "run.py", line 89, in forward x = self.conv1(x) File "/home/kiran/anaconda3/envs/OCT_latent_space/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl return forward_call(*input, **kwargs) File "/home/kiran/Desktop/Dev/s2cnn/mnist/s2cnn/soft/s2_conv.py", line 40, in forward x = S2_fft_real.apply(x, self.b_out) # [l * m, batch, feature_in, complex] File "/home/kiran/Desktop/Dev/s2cnn/mnist/s2cnn/soft/s2_fft.py", line 233, in forward return s2_fft(as_complex(x), b_out=ctx.b_out) File "/home/kiran/Desktop/Dev/s2cnn/mnist/s2cnn/soft/s2_fft.py", line 56, in s2_fft output[s] = torch.einsum("bm,zbmc->mzc", (wigner[:, s], xx)) File "/home/kiran/anaconda3/envs/OCT_latent_space/lib/python3.8/site-packages/torch/functional.py", line 297, in einsum return einsum(equation, *_operands) File "/home/kiran/anaconda3/envs/OCT_latent_space/lib/python3.8/site-packages/torch/functional.py", line 299, in einsum return _VF.einsum(equation, operands) # type: ignore[attr-defined] RuntimeError: expected scalar type Float but found ComplexFloat

    If I try to fix this, by setting xx to a real tensor then I get a bunch of errors down the line. I was wondering if anyone had any advice with this.

    opened by kkokilep 3
  • Correlation Between Spheres

    Correlation Between Spheres

    Hi, Thank you so much for your wonderful work, really appreciate it! They are really easy to use. However, I did encounter some problems when trying to calculate the rotation R in SO(3) between two rotated spheres.

    Basically, I followed the s2cnn/s2cnn/soft/s2_conv.py, and changed the torch kernel y with another sphere. Now that Sphere1 and Sphere2 are passed into S2_fft_real.apply() with the results of Sphere1_FFT and Sphere2_FFT. The correlation result is then calculated by s2_mm(Sphere1, Sphere2) with slight modification in channels and shape. Then the correlation is passed to the SO3_ifft_real.apply() with the result argmax of the ZYZ angles.

    I was wondering if this is the correct way of using the code to calculate rotations between two rotated spheres because by now the result seems incorrect.

    Thanks in advance!!!

    opened by jessychen1016 4
  • Questions about the computations

    Questions about the computations

    Hi,

    Thank you for the amazing work! I was wondering, just to be sure: you use spherical harmonics to handle some of your computations, right? I'm searching for a way to compute the spherical spectrum with torch bindings, and I thought that your library may already do the stuff.

    opened by daidedou 2
Owner
Jonas Köhler
PhD student @noegroup - Research Scientist Intern @deepmind
Jonas Köhler
MatryODShka: Real-time 6DoF Video View Synthesis using Multi-Sphere Images

Main repo for ECCV 2020 paper MatryODShka: Real-time 6DoF Video View Synthesis using Multi-Sphere Images. visual.cs.brown.edu/matryodshka

Brown University Visual Computing Group 75 Dec 13, 2022
Pytorch Implementations of large number classical backbone CNNs, data enhancement, torch loss, attention, visualization and some common algorithms.

Torch-template-for-deep-learning Pytorch implementations of some **classical backbone CNNs, data enhancement, torch loss, attention, visualization and

Li Shengyan 270 Dec 31, 2022
CNNs for Sentence Classification in PyTorch

Introduction This is the implementation of Kim's Convolutional Neural Networks for Sentence Classification paper in PyTorch. Kim's implementation of t

Shawn Ng 956 Dec 19, 2022
📦 PyTorch based visualization package for generating layer-wise explanations for CNNs.

Explainable CNNs ?? Flexible visualization package for generating layer-wise explanations for CNNs. It is a common notion that a Deep Learning model i

Ashutosh Hathidara 183 Dec 15, 2022
Implementation of Lie Transformer, Equivariant Self-Attention, in Pytorch

Lie Transformer - Pytorch (wip) Implementation of Lie Transformer, Equivariant Self-Attention, in Pytorch. Only the SE3 version will be present in thi

Phil Wang 78 Oct 26, 2022
Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch.

SE3 Transformer - Pytorch Implementation of SE3-Transformers for Equivariant Self-Attention, in Pytorch. May be needed for replicating Alphafold2 resu

Phil Wang 207 Dec 23, 2022
EGNN - Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch

EGNN - Pytorch Implementation of E(n)-Equivariant Graph Neural Networks, in Pytorch. May be eventually used for Alphafold2 replication. This

Phil Wang 259 Jan 4, 2023
Study of human inductive biases in CNNs and Transformers.

Are Convolutional Neural Networks or Transformers more like human vision? This repository contains the code and fine-tuned models of popular Convoluti

Shikhar Tuli 39 Dec 8, 2022
A light weight data augmentation tool for training CNNs and Viola Jones detectors

hey-daug A light weight data augmentation tool for training CNNs and Viola Jones detectors (Haar Cascades). This tool inflates your data by up to six

Jaiyam Sharma 2 Nov 23, 2019
Spherical CNNs

Spherical CNNs Equivariant CNNs for the sphere and SO(3) implemented in PyTorch Overview This library contains a PyTorch implementation of the rotatio

Jonas Köhler 893 Dec 28, 2022
Training RNNs as Fast as CNNs

News SRU++, a new SRU variant, is released. [tech report] [blog] The experimental code and SRU++ implementation are available on the dev branch which

ASAPP Research 2.1k Jan 1, 2023
GAN-generated image detection based on CNNs

GAN-image-detection This repository contains a GAN-generated image detector developed to distinguish real images from synthetic ones. The detector is

Image and Sound Processing Lab 17 Dec 15, 2022
VOneNet: CNNs with a Primary Visual Cortex Front-End

VOneNet: CNNs with a Primary Visual Cortex Front-End A family of biologically-inspired Convolutional Neural Networks (CNNs). VOneNets have the followi

The DiCarlo Lab at MIT 99 Dec 22, 2022
It's a implement of this paper:Relation extraction via Multi-Level attention CNNs

Relation Classification via Multi-Level Attention CNNs It's a implement of this paper:Relation Classification via Multi-Level Attention CNNs. Training

Aybss 2 Nov 4, 2022
This repository contains the source code of our work on designing efficient CNNs for computer vision

Efficient networks for Computer Vision This repo contains source code of our work on designing efficient networks for different computer vision tasks:

Sachin Mehta 386 Nov 26, 2022
This repository provides the official implementation of 'Learning to ignore: rethinking attention in CNNs' accepted in BMVC 2021.

inverse_attention This repository provides the official implementation of 'Learning to ignore: rethinking attention in CNNs' accepted in BMVC 2021. Le

Firas Laakom 5 Jul 8, 2022
[CVPRW 2022] Attentions Help CNNs See Better: Attention-based Hybrid Image Quality Assessment Network

Attention Helps CNN See Better: Hybrid Image Quality Assessment Network [CVPRW 2022] Code for Hybrid Image Quality Assessment Network [paper] [code] T

IIGROUP 49 Dec 11, 2022
Authors implementation of LieTransformer: Equivariant Self-Attention for Lie Groups

LieTransformer This repository contains the implementation of the LieTransformer used for experiments in the paper LieTransformer: Equivariant self-at

null 35 Oct 18, 2022
Implementation of E(n)-Transformer, which extends the ideas of Welling's E(n)-Equivariant Graph Neural Network to attention

E(n)-Equivariant Transformer (wip) Implementation of E(n)-Equivariant Transformer, which extends the ideas from Welling's E(n)-Equivariant G

Phil Wang 132 Jan 2, 2023