A port of muP to JAX/Haiku

Overview

MUP for Haiku

This is a (very preliminary) port of Yang and Hu et al.'s μP repo to Haiku and JAX. It's not feature complete, and I'm very open to suggestions on improving the usability.

Installation

pip install haiku-mup

Learning rate demo

These plots show the evolution of the optimal learning rate for a 3-hidden-layer MLP on MNIST, trained for 10 epochs (5 trials per lr/width combination).

With standard parameterization, the learning rate optimum (w.r.t. training loss) continues changing as the width increases, but μP keeps it approximately fixed:

Here's the same kind of plot for 3 layer transformers on the Penn Treebank, this time showing Validation loss instead of training loss, scaling both the number of heads and the embedding dimension simultaneously:

Note that the optima have the same value for n_embd=80. That's because the other hyperparameters were tuned using an SP model with that width, so this shouldn't be biased in favor of μP.

Usage

from functools import partial

import jax
import jax.numpy as jnp
import haiku as hk
from optax import adam, chain

from haiku_mup import apply_mup, Mup, Readout

class MyModel(hk.Module):
    def __init__(self, width, n_classes=10):
        super().__init__(name='model')
        self.width = width
        self.n_classes = n_classes

    def __call__(self, x):
        x = hk.Linear(self.width)(x)
        x = jax.nn.relu(x)
        return Readout(2)(x) # 1. Replace output layer with Readout layer

def fn(x, width=100):
    with apply_mup(): # 2. Modify parameter creation with apply_mup()
        return MyModel(width)(x)

mup = Mup()

init_input = jnp.zeros(123)
base_model = hk.transform(partial(fn, width=1))

with mup.init_base(): # 3. Use this context manager when initializing the base model
    hk.init(fn, jax.random.PRNGKey(0), init_input) 

model = hk.transform(fn)

with mup.init_target(): # 4. Use this context manager when initializng the target model
    params = model.init(jax.random.PRNGKey(0), init_input)

model = mup.wrap_model(model) # 5. Modify your model with Mup

optimizer = optax.adam(3e-4)
optimizer = mup.wrap_optimizer(optimizer, adam=True) # 6. Use wrap_optimizer to get layer specific learning rates

# Now the model can be trained as normal

Summary

  1. Replace output layers with Readout layers
  2. Modify parameter creation with the apply_mup() context manager
  3. Initialize a base model inside a Mup.init_base() context
  4. Initialize the target model inside a Mup.init_target() context
  5. Wrap the model with Mup.wrap_model
  6. Wrap optimizer with Mup.wrap_optimizer

Shared Input/Output embeddings

If you want to use the input embedding matrix as the output layer's weight matrix make the following two replacements:

# old: embedding_layer = hk.Embed(*args, **kwargs)
# new:
embedding_layer = haiku_mup.SharedEmbed(*args, **kwargs)
input_embeds = embedding_layer(x)

#old: output = hk.Linear(n_classes)(x)
# new:
output = haiku_mup.SharedReadout()(embedding_layer.get_weights(), x) 
You might also like...
A big endian Gentoo port developed on a Pine64.org RockPro64

Gentoo-aarch64_be A big endian Gentoo port developed on a Pine64.org RockPro64 The endian wars are over... little endian won. As a result, it is incre

FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX.

FedJAX: Federated learning with JAX What is FedJAX? FedJAX is a library for developing custom Federated Learning (FL) algorithms in JAX. FedJAX priori

Flax is a neural network ecosystem for JAX that is designed for flexibility.

Flax: A neural network library and ecosystem for JAX designed for flexibility Overview | Quick install | What does Flax look like? | Documentation See

Deep learning operations reinvented (for pytorch, tensorflow, jax and others)
Deep learning operations reinvented (for pytorch, tensorflow, jax and others)

This video in better quality. einops Flexible and powerful tensor operations for readable and reliable code. Supports numpy, pytorch, tensorflow, and

JAX-based neural network library

Haiku: Sonnet for JAX Overview | Why Haiku? | Quickstart | Installation | Examples | User manual | Documentation | Citing Haiku What is Haiku? Haiku i

Elegy is a framework-agnostic Trainer interface for the Jax ecosystem.

Elegy Elegy is a framework-agnostic Trainer interface for the Jax ecosystem. Main Features Easy-to-use: Elegy provides a Keras-like high-level API tha

Extending JAX with custom C++ and CUDA code

Extending JAX with custom C++ and CUDA code This repository is meant as a tutorial demonstrating the infrastructure required to provide custom ops in

Turning SymPy expressions into JAX functions

sympy2jax Turn SymPy expressions into parametrized, differentiable, vectorizable, JAX functions. All SymPy floats become trainable input parameters. S

Very deep VAEs in JAX/Flax
Very deep VAEs in JAX/Flax

Very Deep VAEs in JAX/Flax Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on I

Owner
null
Pretrained models for Jax/Haiku; MobileNet, ResNet, VGG, Xception.

Pre-trained image classification models for Jax/Haiku Jax/Haiku Applications are deep learning models that are made available alongside pre-trained we

Alper Baris CELIK 14 Dec 20, 2022
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
TensorFlow ROCm port

Documentation TensorFlow is an end-to-end open source platform for machine learning. It has a comprehensive, flexible ecosystem of tools, libraries, a

ROCm Software Platform 622 Jan 9, 2023
hipCaffe: the HIP port of Caffe

Caffe Caffe is a deep learning framework made with expression, speed, and modularity in mind. It is developed by the Berkeley Vision and Learning Cent

ROCm Software Platform 126 Dec 5, 2022
Pytorch port of Google Research's LEAF Audio paper

leaf-audio-pytorch Pytorch port of Google Research's LEAF Audio paper published at ICLR 2021. This port is not completely finished, but the Leaf() fro

Dennis Fedorishin 80 Oct 31, 2022
A data-driven maritime port simulator

PySeidon - A Data-Driven Maritime Port Simulator ?? Extendable and modular software for maritime port simulation. This software uses entity-component

null 6 Apr 10, 2022
A PyTorch port of the Neural 3D Mesh Renderer

Neural 3D Mesh Renderer (CVPR 2018) This repo contains a PyTorch implementation of the paper Neural 3D Mesh Renderer by Hiroharu Kato, Yoshitaka Ushik

Daniilidis Group University of Pennsylvania 1k Jan 9, 2023
Tensorflow port of a full NetVLAD network

netvlad_tf The main intention of this repo is deployment of a full NetVLAD network, which was originally implemented in Matlab, in Python. We provide

Robotics and Perception Group 225 Nov 8, 2022