learned_optimization: Training and evaluating learned optimizers in JAX

Overview

learned_optimization: Training and evaluating learned optimizers in JAX

learned_optimization is a research codebase for training learned optimizers. It implements hand designed and learned optimizers, tasks to meta-train and meta-test them on, and outer-training algorithms such as ES and PES.

Quick Start Colab Notebooks

  • Introduction notebook: Open In Colab
  • Creating custom tasks: Open In Colab

The fastest way to get started is to copy the Introduction notebook, and experiment using a free accelerator in colab (go to Runtime -> Change runtime type in colab to select a TPU or GPU backend).

Local Installation

We strongly recommend using virtualenv to work with this package.

pip3 install virtualenv
git clone [email protected]:google/learned_optimizers.git
cd learned_optimizers
python3 -m venv env
source env/bin/activate
pip install -e .

Then run the tests to make sure everything is functioning properly.

python3 -m nose

If something is broken please file an issue and we will take a look!

Disclaimer

learned_optimization is not an official Google product.

You might also like...
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.

Rockpool Rockpool is a Python package for developing signal processing applications with spiking neural networks. Rockpool allows you to build network

Standalone pre-training recipe with JAX+Flax

Sabertooth Sabertooth is standalone pre-training recipe based on JAX+Flax, with data pipelines implemented in Rust. It runs on CPU, GPU, and/or TPU, b

A PyTorch implementation of
A PyTorch implementation of "TokenLearner: What Can 8 Learned Tokens Do for Images and Videos?"

TokenLearner: What Can 8 Learned Tokens Do for Images and Videos? Source: Improving Vision Transformer Efficiency and Accuracy by Learning to Tokenize

Object detection on multiple datasets with an automatically learned unified label space.
Object detection on multiple datasets with an automatically learned unified label space.

Simple multi-dataset detection An object detector trained on multiple large-scale datasets with a unified label space; Winning solution of E

Code for Mesh Convolution Using a Learned Kernel Basis

Mesh Convolution This repository contains the implementation (in PyTorch) of the paper FULLY CONVOLUTIONAL MESH AUTOENCODER USING EFFICIENT SPATIALLY

Official implementation of
Official implementation of "Accelerating Reinforcement Learning with Learned Skill Priors", Pertsch et al., CoRL 2020

Accelerating Reinforcement Learning with Learned Skill Priors [Project Website] [Paper] Karl Pertsch1, Youngwoon Lee1, Joseph Lim1 1CLVR Lab, Universi

 Self-Learned Video Rain Streak Removal: When Cyclic Consistency Meets Temporal Correspondence
Self-Learned Video Rain Streak Removal: When Cyclic Consistency Meets Temporal Correspondence

In this paper, we address the problem of rain streaks removal in video by developing a self-learned rain streak removal method, which does not require any clean groundtruth images in the training process.

Learned image compression

Overview Pytorch code of our recent work A Unified End-to-End Framework for Efficient Deep Image Compression. We first release the code for Variationa

Learned Token Pruning for Transformers
Learned Token Pruning for Transformers

LTP: Learned Token Pruning for Transformers Check our paper for more details. Installation We follow the same installation procedure as the original H

Comments
  • Very Large Memory Consumption for Even A Small Dataset

    Very Large Memory Consumption for Even A Small Dataset

    Dataset: fashion_mnist Dataset Size: 36.42MB (https://www.tensorflow.org/datasets/catalog/fashion_mnist)

    Reproduce the Issue:

    from learned_optimization.tasks import fixed_mlp
    task = fixed_mlp.FashionMnistRelu32_8()
    

    or

    from learned_optimization.tasks.datasets import base
    
    batch_size=128
    image_size=(8, 8)
    splits = ("train[0:80%]", "train[80%:90%]", "train[90%:]", "test")
    stack_channels = 1
    
    dataset = preload_tfds_image_classification_datasets(
          "fashion_mnist",
          splits,
          batch_size=batch_size,
          image_size=image_size,
          stack_channels=stack_channels)
    

    Issue Description: As you can see, the original FashionMnist dataset is very small. However, when I run the above code, the memory usage became crazy high, such as 10G+.

    In my case, the issues occurs when the program reaches this line which in the function preload_tfds_image_classification_datasets:

      return Datasets(
          *[make_python_iter(split) for split in splits],
          extra_info={"num_classes": num_classes})
    

    Here is the code of make_python_iter:

      def make_python_iter(split: str) -> Iterator[Batch]:
        # load the entire dataset into memory
        dataset = tfds.load(datasetname, split=split, batch_size=-1)
        data = tfds.as_numpy(_image_map_fn(cfg, dataset))
    
        use a python iterator as this is faster than TFDS.
        def generator_fn():
    
          def iter_fn():
            batches = data["image"].shape[0] // batch_size
            idx = onp.arange(data["image"].shape[0])
            while True:
              # every epoch shuffle indicies
              onp.random.shuffle(idx)
              for bi in range(0, batches):
                idxs = idx[bi * batch_size:(bi + 1) * batch_size]
    
                def index_into(idxs, x):
                  return x[idxs]
    
                yield jax.tree_map(functools.partial(index_into, idxs), data)
    
          return prefetch_iterator.PrefetchIterator(iter_fn(), prefetch_batches)
    
        return ThreadSafeIterator(LazyIterator(generator_fn))
    

    Could you please suggest a way to reduce the huge memory usage, do you have any idea why it requires so high memory, and do you (or anybody) also have this issue?

    Thank you very much and looking forward to your comments.

    opened by createmomo 5
  • Monkey-patch `os.path` to support `gs://` URIs

    Monkey-patch `os.path` to support `gs://` URIs

    This monkey-patches the private os.path._get_sep function used by os.path.join to handle paths that start with "gs://". If a path starts with "gs://", we ignore the OS separator and return "/" as the separator to use.

    Background info

    For example in the pretrained_optimizers.py file, os.path.join is used to concat strings for the gs:// URIs (one example). This will cause issues on Windows machines because there, URIs are concatenated with "".

    Since I understand monkey-patching is always quite fickle, I can also replace all os.path.join usages with a new method that uses '/'.join(paths) for "gs://" paths instead.

    opened by janEbert 0
  • License of checkpoints

    License of checkpoints

    What is the license of the checkpoints listed in https://github.com/google/learned_optimization/blob/main/learned_optimization/research/general_lopt/pretrained_optimizers.py ? Is it an Apache 2 license, too?

    opened by maciejjaskowski 0
  • Colab link not working

    Colab link not working

    Link in the following section doesn't work:

    Build a learned optimizer from scratch

    Simple, self-contained, learned optimizer example that does not depend on the learned_optimization library: Open In Colab

    The link goes like this [https://colab.research.google.com/github/google/learned_optimization/blob/main/docs/notebooks/no_dependency_learned_optimizer.ipynb.ipynb]. There's a trailing .ipynb that should not be there.

    opened by gahaalt 0
Owner
Google
Google ❤️ Open Source
Google
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022
Differentiable Optimizers with Perturbations in Pytorch

Differentiable Optimizers with Perturbations in PyTorch This contains a PyTorch implementation of Differentiable Optimizers with Perturbations in Tens

Jake Tuero 54 Jun 22, 2022
TLDR; Train custom adaptive filter optimizers without hand tuning or extra labels.

AutoDSP TLDR; Train custom adaptive filter optimizers without hand tuning or extra labels. About Adaptive filtering algorithms are commonplace in sign

Jonah Casebeer 48 Sep 19, 2022
Repository for open research on optimizers.

Open Optimizers Repository for open research on optimizers. This is a test in sharing research/exploration as it happens. If you use anything from thi

Ariel Ekgren 6 Jun 24, 2022
High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.

TL;DR Ignite is a high-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently. Click on the image to

null 4.2k Jan 1, 2023
A library for preparing, training, and evaluating scalable deep learning hybrid recommender systems using PyTorch.

collie_recs Collie is a library for preparing, training, and evaluating implicit deep learning hybrid recommender systems, named after the Border Coll

ShopRunner 97 Jan 3, 2023
A library for preparing, training, and evaluating scalable deep learning hybrid recommender systems using PyTorch.

collie Collie is a library for preparing, training, and evaluating implicit deep learning hybrid recommender systems, named after the Border Collie do

ShopRunner 96 Dec 29, 2022
torchlm is aims to build a high level pipeline for face landmarks detection, it supports training, evaluating, exporting, inference(Python/C++) and 100+ data augmentations

??A high level pipeline for face landmarks detection, supports training, evaluating, exporting, inference and 100+ data augmentations, compatible with torchvision and albumentations, can easily install with pip.

DefTruth 142 Dec 25, 2022