137 Repositories
Python jax-ppo Libraries
Python code for "Machine learning: a probabilistic perspective" (2nd edition)
Python code for "Machine learning: a probabilistic perspective" (2nd edition)
ML Optimizers from scratch using JAX
Toy implementations of some popular ML optimizers using Python/JAX
Aggragrating Nested Transformer Official Jax Implementation
NesT is a simple method, which aggragrates nested local transformers on image blocks. The idea makes vision transformers attain better accuracy, data efficiency, and convergence on the ImageNet benchmark. NesT can be scaled to small datasets to match convnet accuracy.
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"
Optimal Model Design for Reinforcement Learning This repository contains JAX code for the paper Control-Oriented Model-Based Reinforcement Learning wi
JAX + dataclasses
jax_dataclasses jax_dataclasses provides a wrapper around dataclasses.dataclass for use in JAX, which enables automatic support for: Pytree registrati
Symbolic Parallel Adaptive Importance Sampling for Probabilistic Program Analysis in JAX
SYMPAIS: Symbolic Parallel Adaptive Importance Sampling for Probabilistic Program Analysis Overview | Installation | Documentation | Examples | Notebo
jaxfg - Factor graph-based nonlinear optimization library for JAX.
Factor graphs + nonlinear optimization in JAX
Local Attention - Flax module for Jax
Local Attention - Flax Autoregressive Local Attention - Flax module for Jax Install $ pip install local-attention-flax Usage from jax import random fr
A GPT, made only of MLPs, in Jax
MLP GPT - Jax (wip) A GPT, made only of MLPs, in Jax. The specific MLP to be used are gMLPs with the Spatial Gating Units. Working Pytorch implementat
JMP is a Mixed Precision library for JAX.
Mixed precision training [0] is a technique that mixes the use of full and half precision floating point numbers during training to reduce the memory bandwidth requirements and improve the computational efficiency of a given model.
Contains code for the paper "Vision Transformers are Robust Learners".
Vision Transformers are Robust Learners This repository contains the code for the paper Vision Transformers are Robust Learners by Sayak Paul* and Pin
Standalone pre-training recipe with JAX+Flax
Sabertooth Sabertooth is standalone pre-training recipe based on JAX+Flax, with data pipelines implemented in Rust. It runs on CPU, GPU, and/or TPU, b
Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.
Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.
Bayesian optimization in JAX
Bayesian optimization in JAX
A collection of various RL algorithms like policy gradients, DQN and PPO. The goal of this repo will be to make it a go-to resource for learning about RL. How to visualize, debug and solve RL problems. I've additionally included playground.py for learning more about OpenAI gym, etc.
Reinforcement Learning (PyTorch) ๐ค + ๐ฐ = โค๏ธ This repo will contain PyTorch implementation of various fundamental RL algorithms. It's aimed at making
Plug-n-Play Reinforcement Learning in Python with OpenAI Gym and JAX
coax is built on top of JAX, but it doesn't have an explicit dependence on the jax python package. The reason is that your version of jaxlib will depend on your CUDA version.
Newt - a Gaussian process library in JAX.
Newt __ \/_ (' \`\ _\, \ \\/ /`\/\ \\ \ \\
Functional tensors for probabilistic programming
Funsor Funsor is a tensor-like library for functions and distributions. See Functional tensors for probabilistic programming for a system description.
Objax Apache-2Objax (๐ฅ19 ยท โญ 580) - Objax is a machine learning framework that provides an Object.. Apache-2 jax
Objax Tutorials | Install | Documentation | Philosophy This is not an officially supported Google product. Objax is an open source machine learning fr
Model parallel transformers in Jax and Haiku
Mesh Transformer Jax A haiku library using the new(ly documented) xmap operator in Jax for model parallelism of transformers. See enwik8_example.py fo
Code for "Infinitely Deep Bayesian Neural Networks with Stochastic Differential Equations"
Infinitely Deep Bayesian Neural Networks with SDEs This library contains JAX and Pytorch implementations of neural ODEs and Bayesian layers for stocha
Very deep VAEs in JAX/Flax
Very Deep VAEs in JAX/Flax Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on I
This is the official implementation of Multi-Agent PPO.
MAPPO Chao Yu*, Akash Velu*, Eugene Vinitsky, Yu Wang, Alexandre Bayen, and Yi Wu. Website: https://sites.google.com/view/mappo This repository implem
Turning SymPy expressions into JAX functions
sympy2jax Turn SymPy expressions into parametrized, differentiable, vectorizable, JAX functions. All SymPy floats become trainable input parameters. S
3D Vision functions with end-to-end support for deep learning developers, written in Ivy.
Ivy vision focuses predominantly on 3D vision, with functions for camera geometry, image projections, co-ordinate frame transformations, forward warping, inverse warping, optical flow, depth triangulation, voxel grids, point clouds, signed distance functions, and others. Check out the docs for more info!
Ivy is a templated deep learning framework which maximizes the portability of deep learning codebases.
Ivy is a templated deep learning framework which maximizes the portability of deep learning codebases. Ivy wraps the functional APIs of existing frameworks. Framework-agnostic functions, libraries and layers can then be written using Ivy, with simultaneous support for all frameworks. Ivy currently supports Jax, TensorFlow, PyTorch, MXNet and Numpy. Check out the docs for more info!
Extending JAX with custom C++ and CUDA code
Extending JAX with custom C++ and CUDA code This repository is meant as a tutorial demonstrating the infrastructure required to provide custom ops in
Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.
Elegy Elegy is a framework-agnostic Trainer interface for the Jax ecosystem. Main Features Easy-to-use: Elegy provides a Keras-like high-level API tha
JAX-based neural network library
Haiku: Sonnet for JAX Overview | Why Haiku? | Quickstart | Installation | Examples | User manual | Documentation | Citing Haiku What is Haiku? Haiku i
Fast and Easy Infinite Neural Networks in Python
Neural Tangents ICLR 2020 Video | Paper | Quickstart | Install guide | Reference docs | Release notes Overview Neural Tangents is a high-level neural
Deep learning operations reinvented (for pytorch, tensorflow, jax and others)
This video in better quality. einops Flexible and powerful tensor operations for readable and reliable code. Supports numpy, pytorch, tensorflow, and
Flax is a neural network ecosystem for JAX that is designed for flexibility.
Flax: A neural network library and ecosystem for JAX designed for flexibility Overview | Quick install | What does Flax look like? | Documentation See
๐ฎ A refreshing functional take on deep learning, compatible with your favorite libraries
Thinc: A refreshing functional take on deep learning, compatible with your favorite libraries From the makers of spaCy, Prodigy and FastAPI Thinc is a
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
JAX: Autograd and XLA Quickstart | Transformations | Install guide | Neural net libraries | Change logs | Reference docs | Code search News: JAX tops
FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.
FedJAX: Federated learning with JAX What is FedJAX? FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX priori
Modular Deep Reinforcement Learning framework in PyTorch. Companion library of the book "Foundations of Deep Reinforcement Learning".
SLM Lab Modular Deep Reinforcement Learning framework in PyTorch. Documentation: https://slm-lab.gitbook.io/slm-lab/ BeamRider Breakout KungFuMaster M
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
JAX: Autograd and XLA Quickstart | Transformations | Install guide | Neural net libraries | Change logs | Reference docs | Code search News: JAX tops