112 Repositories
Python jax Libraries
Official Pytorch and JAX implementation of "Efficient-VDVAE: Less is more"
The Official Pytorch and JAX implementation of "Efficient-VDVAE: Less is more" Arxiv preprint Louay Hazami · Rayhane Mama · Ragavan Thurairatn
CLOOB training (JAX) and inference (JAX and PyTorch)
cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train
Second Order Optimization and Curvature Estimation with K-FAC in JAX.
KFAC-JAX - Second Order Optimization with Approximate Curvature in JAX Installation | Quickstart | Documentation | Examples | Citing KFAC-JAX KFAC-JAX
A Python library that enables ML teams to share, load, and transform data in a collaborative, flexible, and efficient way :chestnut:
Squirrel Core Share, load, and transform data in a collaborative, flexible, and efficient way What is Squirrel? Squirrel is a Python library that enab
Scalable Optical Flow-based Image Montaging and Alignment
SOFIMA SOFIMA (Scalable Optical Flow-based Image Montaging and Alignment) is a tool for stitching, aligning and warping large 2d, 3d and 4d microscopy
KoCLIP: Korean port of OpenAI CLIP, in Flax
KoCLIP This repository contains code for KoCLIP, a Korean port of OpenAI's CLIP. This project was conducted as part of Hugging Face's Flax/JAX communi
A port of muP to JAX/Haiku
MUP for Haiku This is a (very preliminary) port of Yang and Hu et al.'s μP repo to Haiku and JAX. It's not feature complete, and I'm very open to sugg
Pretrained models for Jax/Haiku; MobileNet, ResNet, VGG, Xception.
Pre-trained image classification models for Jax/Haiku Jax/Haiku Applications are deep learning models that are made available alongside pre-trained we
Repository for fine-tuning Transformers 🤗 based seq2seq speech models in JAX/Flax.
Seq2Seq Speech in JAX A JAX/Flax repository for combining a pre-trained speech encoder model (e.g. Wav2Vec2, HuBERT, WavLM) with a pre-trained text de
Code for "Continuous-Time Meta-Learning with Forward Mode Differentiation" (ICLR 2022)
Continuous-Time Meta-Learning with Forward Mode Differentiation ICLR 2022 (Spotlight) - Installation - Example - Citation This repository contains the
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX
mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t
A lossless neural compression framework built on top of JAX.
Kompressor Branch CI Coverage main (active) main development A neural compression framework built on top of JAX. Install setup.py assumes a compatible
Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable.
Diffrax Numerical differential equation solvers in JAX. Autodifferentiable and GPU-capable. Diffrax is a JAX-based library providing numerical differe
Loopy belief propagation for factor graphs on discrete variables, in JAX!
PGMax implements general factor graphs for discrete probabilistic graphical models (PGMs), and hardware-accelerated differentiable loopy belief propagation (LBP) in JAX.
JAXMAPP: JAX-based Library for Multi-Agent Path Planning in Continuous Spaces
JAXMAPP: JAX-based Library for Multi-Agent Path Planning in Continuous Spaces JAXMAPP is a JAX-based library for multi-agent path planning (MAPP) in c
pybaum provides tools to work with pytrees which is a concept burrowed from JAX.
pybaum provides tools to work with pytrees which is a concept burrowed from JAX.
The unified machine learning framework, enabling framework-agnostic functions, layers and libraries.
The unified machine learning framework, enabling framework-agnostic functions, layers and libraries. Contents Overview In a Nutshell Where Next? Overv
Rax is a Learning-to-Rank library written in JAX
🦖 Rax: Composable Learning to Rank using JAX Rax is a Learning-to-Rank library written in JAX. Rax provides off-the-shelf implementations of ranking
To provide 100 JAX exercises over different sections structured as a course or tutorials to teach and learn for beginners, intermediates as well as experts
JaxTon 💯 JAX exercises Mission 🚀 To provide 100 JAX exercises over different sections structured as a course or tutorials to teach and learn for beg
Jaxtorch (a jax nn library)
Jaxtorch (a jax nn library) This is my jax based nn library. I created this because I was annoyed by the complexity and 'magic'-ness of the popular ja
Original Implementation of Prompt Tuning from Lester, et al, 2021
Prompt Tuning This is the code to reproduce the experiments from the EMNLP 2021 paper "The Power of Scale for Parameter-Efficient Prompt Tuning" (Lest
Evolving neural network parameters in JAX.
Evolving Neural Networks in JAX This repository holds code displaying techniques for applying evolutionary network training strategies in JAX. Each sc
Unofficial JAX implementations of Deep Learning models
JAX Models Table of Contents About The Project Getting Started Prerequisites Installation Usage Contributing License Contact About The Project The JAX
Trax — Deep Learning with Clear Code and Speed
Trax — Deep Learning with Clear Code and Speed Trax is an end-to-end library for deep learning that focuses on clear code and speed. It is actively us
Relaxed-machines - explorations in neuro-symbolic differentiable interpreters
Relaxed Machines Explorations in neuro-symbolic differentiable interpreters. Baby steps: inc_stop Libraries JAX Haiku Optax Resources Chapter 3 (∂4: A
Advantage Actor Critic (A2C): jax + flax implementation
Advantage Actor Critic (A2C): jax + flax implementation Current version supports only environments with continious action spaces and was tested on muj
tree-math: mathematical operations for JAX pytrees
tree-math: mathematical operations for JAX pytrees tree-math makes it easy to implement numerical algorithms that work on JAX pytrees, such as iterati
A minimal TPU compatible Jax implementation of NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis
NeRF Minimal Jax implementation of NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis. Result of Tiny-NeRF RGB Depth
PPO Lagrangian in JAX
PPO Lagrangian in JAX This repository implements PPO in JAX. Implementation is tested on the safety-gym benchmark. Usage Install dependencies using th
learned_optimization: Training and evaluating learned optimizers in JAX
learned_optimization: Training and evaluating learned optimizers in JAX learned_optimization is a research codebase for training learned optimizers. I
Memory Efficient Attention (O(sqrt(n)) for Jax and PyTorch
Memory Efficient Attention This is unofficial implementation of Self-attention Does Not Need O(n^2) Memory for Jax and PyTorch. Implementation is almo
Model parallel transformers in JAX and Haiku
Table of contents Mesh Transformer JAX Updates Pretrained Models GPT-J-6B Links Acknowledgments License Model Details Zero-Shot Evaluations Architectu
🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
English | 简体中文 | 繁體中文 | 한국어 State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow 🤗 Transformers provides thousands of pretrained models
Awesome Treasure of Transformers Models Collection
💁 Awesome Treasure of Transformers Models for Natural Language processing contains papers, videos, blogs, official repo along with colab Notebooks. 🛫☑️
The versatile ocean simulator, in pure Python, powered by JAX.
Veros is the versatile ocean simulator -- it aims to be a powerful tool that makes high-performance ocean modeling approachable and fun. Because Veros
Official code for "Maximum Likelihood Training of Score-Based Diffusion Models", NeurIPS 2021 (spotlight)
Maximum Likelihood Training of Score-Based Diffusion Models This repo contains the official implementation for the paper Maximum Likelihood Training o
A small library for creating and manipulating custom JAX Pytree classes
Treeo A small library for creating and manipulating custom JAX Pytree classes Light-weight: has no dependencies other than jax. Compatible: Treeo Tree
Jax/Flax implementation of Variational-DiffWave.
jax-variational-diffwave Jax/Flax implementation of Variational-DiffWave. (Zhifeng Kong et al., 2020, Diederik P. Kingma et al., 2021.) DiffWave with
A demo of how to use JAX to create a simple gravity simulation
JAX Gravity This repo contains a demo of how to use JAX to create a simple gravity simulation. It uses JAX's experimental ode package to solve the dif
This is a JAX implementation of Neural Radiance Fields for learning purposes.
learn-nerf This is a JAX implementation of Neural Radiance Fields for learning purposes. I've been curious about NeRF and its follow-up work for a whi
JAX bindings to the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) library
JAX bindings to FINUFFT This package provides a JAX interface to (a subset of) the Flatiron Institute Non-uniform Fast Fourier Transform (FINUFFT) lib
Reinforcement learning library in JAX.
Reinforcement learning library in JAX.
A Python toolbox to create adversarial examples that fool neural networks in PyTorch, TensorFlow, and JAX
Foolbox Native: Fast adversarial attacks to benchmark the robustness of machine learning models in PyTorch, TensorFlow, and JAX Foolbox is a Python li
GAN JAX - A toy project to generate images from GANs with JAX
GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and
Flaxformer: transformer architectures in JAX/Flax
Flaxformer is a transformer library for primarily NLP and multimodal research at Google.
Einshape: DSL-based reshaping library for JAX and other frameworks.
Einshape: DSL-based reshaping library for JAX and other frameworks. The jnp.einsum op provides a DSL-based unified interface to matmul and tensordot o
Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX
CQL-JAX This repository implements Conservative Q Learning for Offline Reinforcement Reinforcement Learning in JAX (FLAX). Implementation is built on
A JAX-based research framework for writing differentiable numerical simulators with arbitrary discretizations
jaxdf - JAX-based Discretization Framework Overview | Example | Installation | Documentation ⚠️ This library is still in development. Breaking changes
v objective diffusion inference code for JAX.
v-diffusion-jax v objective diffusion inference code for JAX, by Katherine Crowson (@RiversHaveWings) and Chainbreakers AI (@jd_pressman). The models
Flaxformer: transformer architectures in JAX/Flax
Flaxformer: transformer architectures in JAX/Flax Flaxformer is a transformer library for primarily NLP and multimodal research at Google. It is used
Geometric Algebra package for JAX
JAXGA - JAX Geometric Algebra GitHub | Docs JAXGA is a Geometric Algebra package on top of JAX. It can handle high dimensional algebras by storing onl
Map single-cell transcriptomes to copy number evolutionary trees.
Map single-cell transcriptomes to copy number evolutionary trees. Check out the tutorial for more information. Installation $ pip install scatrex SCA
Bayes-Newton—A Gaussian process library in JAX, with a unifying view of approximate Bayesian inference as variants of Newton's algorithm.
Bayes-Newton Bayes-Newton is a library for approximate inference in Gaussian processes (GPs) in JAX (with objax), built and actively maintained by Wil
Machine Learning with JAX Tutorials
The purpose of this repo is to make it easy to get started with JAX. It contains my "Machine Learning with JAX" series of tutorials (YouTube videos and Jupyter Notebooks) as well as the content I found useful while learning JAX.
A suite of benchmarks for CPU and GPU performance of the most popular high-performance libraries for Python :rocket:
A suite of benchmarks for CPU and GPU performance of the most popular high-performance libraries for Python :rocket:
JAXDL: JAX (Flax) Deep Learning Library
JAXDL: JAX (Flax) Deep Learning Library Simple and clean JAX/Flax deep learning algorithm implementations: Soft-Actor-Critic (arXiv:1812.05905) Transf
🤗 Transformers: State-of-the-art Natural Language Processing for Pytorch, TensorFlow, and JAX.
English | 简体中文 | 繁體中文 | 한국어 State-of-the-art Natural Language Processing for Jax, PyTorch and TensorFlow 🤗 Transformers provides thousands of pretrai
An example showing how to use jax to train resnet50 on multi-node multi-GPU
jax-multi-gpu-resnet50-example This repo shows how to use jax for multi-node multi-GPU training. The example is adapted from the resnet50 example in d
Use Jax functions in Pytorch with DLPack
Use Jax functions in Pytorch with DLPack
A JAX implementation of Broaden Your Views for Self-Supervised Video Learning, or BraVe for short.
BraVe This is a JAX implementation of Broaden Your Views for Self-Supervised Video Learning, or BraVe for short. The model provided in this package wa
Reimplementation of the paper "Attention, Learn to Solve Routing Problems!" in jax/flax.
JAX + Attention Learn To Solve Routing Problems Reinplementation of the paper Attention, Learn to Solve Routing Problems! using Jax and Flax. Fully su
functorch is a prototype of JAX-like composable function transforms for PyTorch.
functorch is a prototype of JAX-like composable function transforms for PyTorch.
🤗 Transformers: State-of-the-art Natural Language Processing for Pytorch, TensorFlow, and JAX.
English | 简体中文 | 繁體中文 State-of-the-art Natural Language Processing for Jax, PyTorch and TensorFlow 🤗 Transformers provides thousands of pretrained mo
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.
Rockpool Rockpool is a Python package for developing signal processing applications with spiking neural networks. Rockpool allows you to build network
A Pytree Module system for Deep Learning in JAX
Treex A Pytree-based Module system for Deep Learning in JAX Intuitive: Modules are simple Python objects that respect Object-Oriented semantics and sh
Scenic: A Jax Library for Computer Vision and Beyond
Scenic Scenic is a codebase with a focus on research around attention-based models for computer vision. Scenic has been successfully used to develop c
Callable PyTrees and filtered JIT/grad transformations = neural networks in JAX.
Equinox Callable PyTrees and filtered JIT/grad transformations = neural networks in JAX Equinox brings more power to your model building in JAX. Repr
NeuralCompression is a Python repository dedicated to research of neural networks that compress data
NeuralCompression is a Python repository dedicated to research of neural networks that compress data. The repository includes tools such as JAX-based entropy coders, image compression models, video compression models, and metrics for image and video evaluation.
🤗 Transformers: State-of-the-art Natural Language Processing for Pytorch, TensorFlow, and JAX.
English | 简体中文 | 繁體中文 State-of-the-art Natural Language Processing for Jax, PyTorch and TensorFlow 🤗 Transformers provides thousands of pretrained mo
Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax
Clockwork VAEs in JAX/Flax Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax, ported
RoBERTa Marathi Language model trained from scratch during huggingface 🤗 x flax community week
RoBERTa base model for Marathi Language (मराठी भाषा) Pretrained model on Marathi language using a masked language modeling (MLM) objective. RoBERTa wa
PIX is an image processing library in JAX, for JAX.
PIX PIX is an image processing library in JAX, for JAX. Overview JAX is a library resulting from the union of Autograd and XLA for high-performance ma
Hardware accelerated, batchable and differentiable optimizers in JAX.
JAXopt Installation | Examples | References Hardware accelerated (GPU/TPU), batchable and differentiable optimizers in JAX. Installation JAXopt can be
Brax is a differentiable physics engine that simulates environments made up of rigid bodies, joints, and actuators
Brax is a differentiable physics engine that simulates environments made up of rigid bodies, joints, and actuators. It's also a suite of learning algorithms to train agents to operate in these environments (PPO, SAC, evolutionary strategy, and direct trajectory optimization are implemented).
Shared code for training sentence embeddings with Flax / JAX
flax-sentence-embeddings This repository will be used to share code for the Flax / JAX community event to train sentence embeddings on 1B+ training pa
Implementation and replication of ProGen, Language Modeling for Protein Generation, in Jax
ProGen - (wip) Implementation and replication of ProGen, Language Modeling for Protein Generation, in Pytorch and Jax (the weights will be made easily
Implementation of FitVid video prediction model in JAX/Flax.
FitVid Video Prediction Model Implementation of FitVid video prediction model in JAX/Flax. If you find this code useful, please cite it in your paper:
Brax is a differentiable physics engine that simulates environments made up of rigid bodies, joints, and actuators
Brax is a differentiable physics engine that simulates environments made up of rigid bodies, joints, and actuators. It's also a suite of learning algorithms to train agents to operate in these environments (PPO, SAC, evolutionary strategy, and direct trajectory optimization are implemented).
Python code for "Machine learning: a probabilistic perspective" (2nd edition)
Python code for "Machine learning: a probabilistic perspective" (2nd edition)
ML Optimizers from scratch using JAX
Toy implementations of some popular ML optimizers using Python/JAX
Aggragrating Nested Transformer Official Jax Implementation
NesT is a simple method, which aggragrates nested local transformers on image blocks. The idea makes vision transformers attain better accuracy, data efficiency, and convergence on the ImageNet benchmark. NesT can be scaled to small datasets to match convnet accuracy.
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"
Optimal Model Design for Reinforcement Learning This repository contains JAX code for the paper Control-Oriented Model-Based Reinforcement Learning wi
JAX + dataclasses
jax_dataclasses jax_dataclasses provides a wrapper around dataclasses.dataclass for use in JAX, which enables automatic support for: Pytree registrati
Symbolic Parallel Adaptive Importance Sampling for Probabilistic Program Analysis in JAX
SYMPAIS: Symbolic Parallel Adaptive Importance Sampling for Probabilistic Program Analysis Overview | Installation | Documentation | Examples | Notebo
jaxfg - Factor graph-based nonlinear optimization library for JAX.
Factor graphs + nonlinear optimization in JAX
Local Attention - Flax module for Jax
Local Attention - Flax Autoregressive Local Attention - Flax module for Jax Install $ pip install local-attention-flax Usage from jax import random fr
A GPT, made only of MLPs, in Jax
MLP GPT - Jax (wip) A GPT, made only of MLPs, in Jax. The specific MLP to be used are gMLPs with the Spatial Gating Units. Working Pytorch implementat
JMP is a Mixed Precision library for JAX.
Mixed precision training [0] is a technique that mixes the use of full and half precision floating point numbers during training to reduce the memory bandwidth requirements and improve the computational efficiency of a given model.
Contains code for the paper "Vision Transformers are Robust Learners".
Vision Transformers are Robust Learners This repository contains the code for the paper Vision Transformers are Robust Learners by Sayak Paul* and Pin
Standalone pre-training recipe with JAX+Flax
Sabertooth Sabertooth is standalone pre-training recipe based on JAX+Flax, with data pipelines implemented in Rust. It runs on CPU, GPU, and/or TPU, b
Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.
Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.
Bayesian optimization in JAX
Bayesian optimization in JAX
Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX
coax is built on top of JAX, but it doesn't have an explicit dependence on the jax python package. The reason is that your version of jaxlib will depend on your CUDA version.
Newt - a Gaussian process library in JAX.
Newt __ \/_ (' \`\ _\, \ \\/ /`\/\ \\ \ \\
Functional tensors for probabilistic programming
Funsor Funsor is a tensor-like library for functions and distributions. See Functional tensors for probabilistic programming for a system description.
Objax Apache-2Objax (🥉19 · ⭐ 580) - Objax is a machine learning framework that provides an Object.. Apache-2 jax
Objax Tutorials | Install | Documentation | Philosophy This is not an officially supported Google product. Objax is an open source machine learning fr
Model parallel transformers in Jax and Haiku
Mesh Transformer Jax A haiku library using the new(ly documented) xmap operator in Jax for model parallelism of transformers. See enwik8_example.py fo
Code for "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations"
Infinitely Deep Bayesian Neural Networks with SDEs This library contains JAX and Pytorch implementations of neural ODEs and Bayesian layers for stocha
Very deep VAEs in JAX/Flax
Very Deep VAEs in JAX/Flax Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on I
Turning SymPy expressions into JAX functions
sympy2jax Turn SymPy expressions into parametrized, differentiable, vectorizable, JAX functions. All SymPy floats become trainable input parameters. S