Geometric Algebra package for JAX

Overview

JAXGA - JAX Geometric Algebra

Build status PyPI

GitHub | Docs

JAXGA is a Geometric Algebra package on top of JAX. It can handle high dimensional algebras by storing only the non-zero basis blade coefficients. It makes use of JAX's just-in-time (JIT) compilation by first precomputing blade indices and signs and then JITting the function doing the actual calculations.

Installation

Install using pip: pip install jaxga

Requirements:

Usage

Unlike most other Geometric Algebra packages, it is not necessary to pre-specify an algebra. JAXGA can either be used with the MultiVector class or by using lower-level functions which is useful for example when using JAX's jit or automatic differentaition.

The MultiVector class provides operator overloading and is constructed with an array of values and their corresponding basis blades. The basis blades are encoded as tuples, for example the multivector 2 e_1 + 4 e_23 would have the values [2, 4] and the basis blade tuple ((1,), (2, 3)).

MultiVector example

import jax.numpy as jnp
from jaxga.mv import MultiVector

a = MultiVector(
    values=2 * jnp.ones([1], dtype=jnp.float32),
    indices=((1,),)
)
# Alternative: 2 * MultiVector.e(1)

b = MultiVector(
    values=4 * jnp.ones([2], dtype=jnp.float32),
    indices=((2, 3),)
)
# Alternative: 4 * MultiVector.e(2, 3)

c = a * b
print(c)

Output: Multivector(8.0 e_{1, 2, 3})

The lower-level functions also deal with values and blades. Functions are provided that take the blades and return a function that does the actual calculation. The returned function is JITted and can also be automatically differentiated with JAX. Furthermore, some operations like the geometric product take a signature function that takes a basis vector index and returns their square.

Lower-level function example

import jax.numpy as jnp
from jaxga.signatures import positive_signature
from jaxga.ops.multiply import get_mv_multiply

a_values = 2 * jnp.ones([1], dtype=jnp.float32)
a_indices = ((1,),)

b_values = 4 * jnp.ones([1], dtype=jnp.float32)
b_indices = ((2, 3),)

mv_multiply, c_indices = get_mv_multiply(a_indices, b_indices, positive_signature)
c_values = mv_multiply(a_values, b_values)
print("C indices:", c_indices, "C values:", c_values)

Output: C indices: ((1, 2, 3),) C values: [8.]

Some notes

  • Both the MultiVector and lower-level function approaches support batches: the axes after the first one (which indexes the basis blades) are treated as batch indices.
  • The MultiVector class can also take a signature in its constructor (default is square to 1 for all basis vectors). Doing operations with MultiVectors with different signatures is undefined.
  • The jaxga.signatures submodule contains a few predefined signature functions.
  • get_mv_multiply and similar functions cache their result by their inputs.
  • The flaxmodules submodule provides flax (a popular neural network library for jax) modules with Geometric Algebra operations.
  • Because we don't deal with a specific algebra, the dual needs an input that specifies the dimensionality of the space in which we want to find the dual element.

Benchmarks

N-d vector * N-d vector, batch size 100, N=range(1, 10), CPU

JaxGA stores only the non-zero basis blade coefficients. TFGA and Clifford on the other hand store all GA elements as full multivectors including all zeros. As a result, JaxGA does better than these for high dimensional algebras.

Below is a benchmark of the geometric product of two vectors with increasing dimensionality from 1 to 9. 100 vectors are multiplied at a time.

JAXGA (CPU) tfga (CPU) clifford
benchmark-results benchmark-results benchmark-results

N-d vector * N-d vector, batch size 100, N=range(1, 50, 5), CPU

Below is a benchmark for higher dimensions that TFGA and Clifford could not handle. Note that the X axis isn't sorted naturally.

benchmark-results

You might also like...
Geometric Vector Perceptron --- a rotation-equivariant GNN for learning from biomolecular structure

Geometric Vector Perceptron Code to accompany Learning from Protein Structure with Geometric Vector Perceptrons by B Jing, S Eismann, P Suriana, RJL T

Certifiable Outlier-Robust Geometric Perception

Certifiable Outlier-Robust Geometric Perception About This repository holds the implementation for certifiably solving outlier-robust geometric percep

Pytorch Geometric Tutorials

Pytorch Geometric Tutorials

A geometric deep learning pipeline for predicting protein interface contacts.
A geometric deep learning pipeline for predicting protein interface contacts.

A geometric deep learning pipeline for predicting protein interface contacts.

3DMV jointly combines RGB color and geometric information to perform 3D semantic segmentation of RGB-D scans.
3DMV jointly combines RGB color and geometric information to perform 3D semantic segmentation of RGB-D scans.

3DMV 3DMV jointly combines RGB color and geometric information to perform 3D semantic segmentation of RGB-D scans. This work is based on our ECCV'18 p

Code for "SRHEN: Stepwise-Refining Homography Estimation Network via Parsing Geometric Correspondences in Deep Latent Space"

SRHEN This is a better and simpler implementation for "SRHEN: Stepwise-Refining Homography Estimation Network via Parsing Geometric Correspondences in

Geometric Vector Perceptrons --- a rotation-equivariant GNN for learning from biomolecular structure
Geometric Vector Perceptrons --- a rotation-equivariant GNN for learning from biomolecular structure

Geometric Vector Perceptron Implementation of equivariant GVP-GNNs as described in Learning from Protein Structure with Geometric Vector Perceptrons b

Computational modelling of ray propagation through optical elements using the principles of geometric optics (Ray Tracer)
Computational modelling of ray propagation through optical elements using the principles of geometric optics (Ray Tracer)

Computational modelling of ray propagation through optical elements using the principles of geometric optics (Ray Tracer) Introduction By applying the

GeoTransformer - Geometric Transformer for Fast and Robust Point Cloud Registration
GeoTransformer - Geometric Transformer for Fast and Robust Point Cloud Registration

Geometric Transformer for Fast and Robust Point Cloud Registration PyTorch imple

Comments
  • Memory layout question

    Memory layout question

    Thanks for making this library; I wanted to do something similar a few months ago but other things got into the way; awesome to see that the JAX ecosystem is maturing this fast!

    One design question I had was wether to go with a struct-of-arrays or array-of-structs memory layout. Unless ive misread, your design is the latter; that is, if I vmap over a multivector, the components of the multivector will form the last axis; which in JAX is the contiguous stride==1 axis. If youd think about how this would get vectorized on a GPU, that wouldnt be ideal; if every thread in a block gets to work on a single element of the vmapped axis, which is the most straightforward parrelelization, now the threads in this warp are not performing contiguous memory accesses. Hence the structs-of-arrays are generally preferred on the GPU. Also if you dig into the deepmind/alphafold repo, you will see that they also use a struct-of-array layout for their vector types and the like.

    Now this is all terrible premature optimization as far as the actual goals im trying to achieve; but I guess im trying to form a bit of a deeper understanding of JAX and TPUs on a low level. So with that in mind; was this a deliberate choice, or something that you have given any thought to?

    opened by EelcoHoogendoorn 18
  • use segment_sum for mv_multiply

    use segment_sum for mv_multiply

    • Previously was using a loop and add at index which gets unrolled, now using segment_sum to sum same output indices
    • 10x faster JIT on CPU, 6x faster JIT on GPU
    • 100x slower runtime on CPU, 5x faster runtime on GPU

    Should maybe add a flag for whether to use this one, very useful for large algebras where JIT takes very long because of analyzing the unrolled loop. Maybe make it the default on GPU too.

    CPU results show segment_sum runtime very dependent on batch size

    a_val, a_ind = jnp.array(jnp.ones([5, 10]), dtype=jnp.float32), tuple((i,) for i in range(5))
    b_val, b_ind = jnp.array(jnp.ones([5, 10]), dtype=jnp.float32), tuple((i,i+1) for i in range(5))
    
    new:
    Wall time: 94 ms
    10.8 µs ± 301 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    
    old:
    Wall time: 1.04 s
    10.9 µs ± 152 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    
    ---
    a_val, a_ind = jnp.array(jnp.ones([5, 100]), dtype=jnp.float32), tuple((i,) for i in range(5))
    b_val, b_ind = jnp.array(jnp.ones([5, 100]), dtype=jnp.float32), tuple((i,i+1) for i in range(5))
    
    new:
    Wall time: 227 ms
    48.6 µs ± 640 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    
    old:
    Wall time: 676 ms
    11.6 µs ± 464 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    ---
    a_val, a_ind = jnp.array(jnp.ones([10, 100]), dtype=jnp.float32), tuple((i,) for i in range(5))
    b_val, b_ind = jnp.array(jnp.ones([10, 100]), dtype=jnp.float32), tuple((i,i+1) for i in range(5))
    
    new:
    Wall time: 261 ms
    49.1 µs ± 1.49 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
    
    old:
    Wall time: 687 ms
    11.6 µs ± 196 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    ---
    a_val, a_ind = jnp.array(jnp.ones([10, 1000]), dtype=jnp.float32), tuple((i,) for i in range(5))
    b_val, b_ind = jnp.array(jnp.ones([10, 1000]), dtype=jnp.float32), tuple((i,i+1) for i in range(5))
    
    new:
    Wall time: 256 ms
    1.19 ms ± 69.3 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
    
    old:
    Wall time: 558 ms
    16.9 µs ± 234 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)
    
    opened by RobinKa 0
Owner
Robin Kahlow
Software / Machine Learning Engineer
Robin Kahlow
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
PyArmadillo: an alternative approach to linear algebra in Python

PyArmadillo is a linear algebra library for the Python language, with an emphasis on ease of use.

Terry Zhuo 58 Oct 11, 2022
A Temporal Extension Library for PyTorch Geometric

Documentation | External Resources | Datasets PyTorch Geometric Temporal is a temporal (dynamic) extension library for PyTorch Geometric. The library

Benedek Rozemberczki 1.9k Jan 7, 2023
Implementation of Geometric Vector Perceptron, a simple circuit for 3d rotation equivariance for learning over large biomolecules, in Pytorch. Idea proposed and accepted at ICLR 2021

Geometric Vector Perceptron Implementation of Geometric Vector Perceptron, a simple circuit with 3d rotation equivariance for learning over large biom

Phil Wang 59 Nov 24, 2022
Geometric Deep Learning Extension Library for PyTorch

Documentation | Paper | Colab Notebooks | External Resources | OGB Examples PyTorch Geometric (PyG) is a geometric deep learning extension library for

Matthias Fey 16.5k Jan 8, 2023
Multi-Scale Geometric Consistency Guided Multi-View Stereo

ACMM [News] The code for ACMH is released!!! [News] The code for ACMP is released!!! About ACMM is a multi-scale geometric consistency guided multi-vi

Qingshan Xu 118 Jan 4, 2023
Code for "FGR: Frustum-Aware Geometric Reasoning for Weakly Supervised 3D Vehicle Detection", ICRA 2021

FGR This repository contains the python implementation for paper "FGR: Frustum-Aware Geometric Reasoning for Weakly Supervised 3D Vehicle Detection"(I

Yi Wei 31 Dec 8, 2022
A PyTorch implementation of "DGC-Net: Dense Geometric Correspondence Network"

DGC-Net: Dense Geometric Correspondence Network This is a PyTorch implementation of our work "DGC-Net: Dense Geometric Correspondence Network" TL;DR A

null 191 Dec 16, 2022