Pretrained models for Jax/Flax: StyleGAN2, GPT2, VGG, ResNet.

Overview
flax

Flax Models

A collection of pretrained models in Flax.

About

The goal of this project is to make current deep learning models more easily available for the awesome Jax/Flax ecosystem.

Models

Example Notebooks to play with on Colab

Installation

You will need Python 3.7 or later.

  1. For GPU usage, follow the Jax installation with CUDA.
  2. Then install:
    > pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git

For CPU-only you can skip step 1.

Documentation

The documentation for the models is on the individual model pages.

Testing

To run the tests, pytest needs to be installed.

> git clone https://github.com/matthias-wright/flaxmodels.git
> cd flaxmodels
> python -m pytest tests/

Acknowledgments

Thank you to the developers of Jax and Flax. The title image is a photograph of a flax flower, kindly made available by Marta Matyszczyk.

License

Each model has an individual license.

Comments
  • StyleGAN 2 doesn't work with Colab TPU despite successful TPU connection/init?

    StyleGAN 2 doesn't work with Colab TPU despite successful TPU connection/init?

    When trying to run the StyleGAN 2 training code on Google Colab, I'm getting:

    WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
    

    But that's after confirming that the TPU is setup correctly:

    image

    Here's a minimal example: https://colab.research.google.com/gist/josephrocca/5e64c9906db96f27b583f0a577ef9b4a/debugging-matthias-wright-s-stylegan2-jax-tpu-not-detected.ipynb

    If I set TF_CPP_MIN_LOG_LEVEL=0, I get:

    2021-10-08 16:05:10.421297: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcudart.so.11.0
    2021-10-08 16:05:12.352286: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:171] XLA service 0x55a8d14dddc0 initialized for platform Interpreter (this does not guarantee that XLA will be used). Devices:
    2021-10-08 16:05:12.352348: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:179]   StreamExecutor device (0): Interpreter, <undefined>
    2021-10-08 16:05:12.358082: I external/org_tensorflow/tensorflow/compiler/xla/pjrt/tfrt_cpu_pjrt_client.cc:163] TfrtCpuClient created.
    2021-10-08 16:05:12.371498: E external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
    2021-10-08 16:05:12.371542: I external/org_tensorflow/tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (4b26ff987b6b): /proc/driver/nvidia/version does not exist
    2021-10-08 16:05:12.371984: I external/org_tensorflow/tensorflow/stream_executor/tpu/tpu_platform_interface.cc:74] No TPU platform found.
    WARNING:absl:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
    2021-10-08 16:05:12.508418: I tensorflow/core/platform/cloud/google_auth_provider.cc:180] Attempting an empty bearer token since no token was retrieved from files, and GCE metadata check was skipped.
    2021-10-08 16:05:12.545218: I tensorflow/core/platform/cloud/google_auth_provider.cc:180] Attempting an empty bearer token since no token was retrieved from files, and GCE metadata check was skipped.
    2021-10-08 16:05:12.586517: I tensorflow/compiler/jit/xla_cpu_device.cc:41] Not creating XLA devices, tf_xla_enable_xla_devices not set
    2021-10-08 16:05:12.586652: I tensorflow/stream_executor/platform/default/dso_loader.cc:49] Successfully opened dynamic library libcuda.so.1
    2021-10-08 16:05:12.595212: E tensorflow/stream_executor/cuda/cuda_driver.cc:328] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
    2021-10-08 16:05:12.595243: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:156] kernel driver does not appear to be running on this host (4b26ff987b6b): /proc/driver/nvidia/version does not exist
    2021-10-08 16:05:12.595544: I tensorflow/core/platform/cpu_feature_guard.cc:142] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
    To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
    2021-10-08 16:05:12.597344: I tensorflow/compiler/jit/xla_gpu_device.cc:99] Not creating XLA devices, tf_xla_enable_xla_devices not set
    

    Not sure if this problem is specific to the StyleGAN 2 training code, since I haven't tried any of the other models. I'm going to continue trying to debug this tommorrow - will update this post if I find out what's going on here.

    opened by josephrocca 9
  • `data_pipeline.py` needs more changes than suggested in README to support `ImageFolder` datasets

    `data_pipeline.py` needs more changes than suggested in README to support `ImageFolder` datasets

    Some problems I ran into:

    • I wasn't able to get tfds.ImageFolder working with a "flat" folder of images. I had to nest a dummy label folder inside a dummy split folder. I followed the instructions here: https://www.tensorflow.org/datasets/api_docs/python/tfds/folder_dataset/ImageFolder
    • There doesn't seem to be a num_examples property in tfds.core.DatasetInfo, so I had to use builder.info.splits['fake_split'].num_examples where fake_split is the name of my dummy split folder. It does look like there's a total_num_examples property, but I'm not sure how to access it - maybe it's a private field (though I'm not sure if those are possible in Python)?
    • I had to edit pre_process because it was expecting protobufs instead of {image, label} objects.

    Note that the reason I am using the ImageFolder approach is because the tfrecords approach blew my 3GB dataset up to 200GB, since I think it's storing the raw tensor data? I'm new to this, but it seems like it'd make more sense to just store the data in jpg format since jpg decoding is so fast? That said, even if the tfrecords approach used a reasonable amount of space, I'd probably still prefer to store the ImageFolder approach since it just seems nicer and more portable. Even better, from my (newbie) perspective, would be the ability to load a tar of images with any internal directory structure.

    Below is my new data_pipeline.py so far. It seems to work okay now, but I haven't got training to work yet as I'm still debugging some stuff. Will update this post if I run into any more problems with data_pipeline.py.

    import tensorflow as tf
    import tensorflow_datasets as tfds
    import jax
    import flax
    import numpy as np
    from PIL import Image
    import os
    from typing import Sequence
    from tqdm import tqdm
    import json
    from tqdm import tqdm
    
    
    def prefetch(dataset, n_prefetch):
        # Taken from: https://github.com/google-research/vision_transformer/blob/master/vit_jax/input_pipeline.py
        ds_iter = iter(dataset)
        ds_iter = map(lambda x: jax.tree_map(lambda t: np.asarray(memoryview(t)), x),
                      ds_iter)
        if n_prefetch:
            ds_iter = flax.jax_utils.prefetch_to_device(ds_iter, n_prefetch)
        return ds_iter
    
    
    def get_data(data_dir, img_size, img_channels, num_classes, num_devices, batch_size, shuffle_buffer=1000):
        """
    
        Args:
            data_dir (str): Root directory of the dataset.
            img_size (int): Image size for training.
            img_channels (int): Number of image channels.
            num_classes (int): Number of classes, 0 for no classes.
            num_devices (int): Number of devices.
            batch_size (int): Batch size (per device).
            shuffle_buffer (int): Buffer used for shuffling the dataset.
    
        Returns:
            (tf.data.Dataset): Dataset.
        """
    
        def pre_process(example):
            # feature = {'height': tf.io.FixedLenFeature([], tf.int64),
            #            'width': tf.io.FixedLenFeature([], tf.int64),
            #            'channels': tf.io.FixedLenFeature([], tf.int64),
            #            'image': tf.io.FixedLenFeature([], tf.string),
            #            'label': tf.io.FixedLenFeature([], tf.int64)}
            # example = tf.io.parse_single_example(serialized_example, feature)
    
            # height = tf.cast(example['height'], dtype=tf.int64)
            # width = tf.cast(example['width'], dtype=tf.int64)
            # channels = tf.cast(example['channels'], dtype=tf.int64)
    
            # image = tf.io.decode_raw(example['image'], out_type=tf.uint8)
            # image = tf.reshape(image, shape=[height, width, channels])
    
            image = example['image']
    
            image = tf.cast(image, dtype='float32')
            image = tf.image.resize(image, size=[img_size, img_size], method='bicubic', antialias=True)
            image = tf.image.random_flip_left_right(image)
            
            image = (image - 127.5) / 127.5
            
            label = tf.one_hot(example['label'], num_classes)
            return {'image': image, 'label': label}
    
        def shard(data):
            # Reshape images from [num_devices * batch_size, H, W, C] to [num_devices, batch_size, H, W, C]
            # because the first dimension will be mapped across devices using jax.pmap
            data['image'] = tf.reshape(data['image'], [num_devices, -1, img_size, img_size, img_channels])
            data['label'] = tf.reshape(data['label'], [num_devices, -1, num_classes])
            return data
    
        # print('Loading TFRecord...')
        # with open(os.path.join(data_dir, 'dataset_info.json'), 'r') as fin:
        #    dataset_info = json.load(fin)
    
        # ds = tf.data.TFRecordDataset(filenames=os.path.join(data_dir, 'dataset.tfrecords'))
        # ds = ds.shuffle(min(dataset_info['num_examples'], shuffle_buffer))
    
        builder = tfds.ImageFolder(data_dir)
        print(builder.info)
        ds = builder.as_dataset(split='fake_split', shuffle_files=True)
        num_examples = builder.info.splits['fake_split'].num_examples
        dataset_info = {'num_examples': num_examples, 'num_classes': 1}
        
        ds = ds.shuffle(min(num_examples, shuffle_buffer))
        ds = ds.map(pre_process, tf.data.AUTOTUNE)
        ds = ds.batch(batch_size * num_devices, drop_remainder=True)
        ds = ds.map(shard, tf.data.AUTOTUNE)
        ds = ds.prefetch(1)
        return ds, dataset_info
    
    opened by josephrocca 5
  • Hosting models on a more reliable service

    Hosting models on a more reliable service

    Hey,

    Super cool project! I discovered it as I plan to try to port lpips to JAX (VGG16 and inference only, no training) and I see that the VGG16 part is already done so only the lpips module needs to be ported.

    I noticed that the models were hosted on dropbox. May I suggest to host them on huggingface model hub for more reliability and control (versions, etc)? Also storage is free there so it's probably more interesting!

    opened by borisdayma 5
  • Input preprocessing for VGG

    Input preprocessing for VGG

    Hi,

    In the README, it is mentioned that input should be between 0 and 1.

    In the training code, they seem to be between -1 and 1.

    In the torchvision doc, they seem to be loaded between 0 and 1 and then normalized with

    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    

    Should they be preprocessed as per the torchvision docs?

    opened by borisdayma 4
  • PyPI distribution

    PyPI distribution

    Thank you for this code! We are using it to implement a simple version of lpips loss here. However, I didn't find a way to install our package from PyPI as I don't know how to include a dependency to your repository distribution, and installing flaxmodels from PyPI doesn't work in our case. See this issue for a brief discussion.

    Do you think it would be possible to update your PyPI distribution so we can in turn build a package that uses it as a dependency? Being new to PyPI myself, I'd also be interested to learn about any drawbacks about doing it this way.

    Thanks again!

    opened by pcuenca 2
  • Installation - Numpy version

    Installation - Numpy version

    Hi @matthias-wright, thanks a lot for releasing this nice, complete package for pretrained models! I have recently used your package in a tutorial to extract features from a pre-trained ResNet34, and noticed that the pip installation of the package requires an old numpy version (v1.19.5, setup.py, line 15). However, the current tensorflow package requires a newer numpy (>v1.21) and installing flaxmodels can break an existing tensorflow installation. In this case, Flax throws an error during import regarding checkpoints from tensorflow because tensorflow has been compiled with a different numpy version than flaxmodels overwrote. A re-installation of the newest numpy version fixes the issue. Is it possible to change the requirement of the numpy package to >= instead of ==, similar to what is currently used for JAX and Flax?

    opened by phlippe 2
  • How did you convert pytorch model weight to Flax?

    How did you convert pytorch model weight to Flax?

    From the doc string of ResNet18, I saw you have the following comments:

    The pretrained parameters are taken from:
        https://github.com/pytorch/vision/blob/master/torchvision/models/resnet.py
    

    May I know how did you convert a PyTorch model weight into Flax? I ddin't find any reference that touch upon this area. Many thanks!

    opened by riven314 2
  • Resnet18 demo gives `OSError: Unable to open file (file signature not found)`

    Resnet18 demo gives `OSError: Unable to open file (file signature not found)`

    Demo notebook: flaxmodels/resnet/resnet_demo.ipynb

    Notebook executed on Google Colab (GPU runtime)

    Descriptive error message:

    UnfilteredStackTrace: OSError: Unable to open file (file signature not found)
    
    The stack trace below excludes JAX-internal frames.
    The preceding is the original exception that occurred, unmodified.
    
    --------------------
    
    The above exception was the direct cause of the following exception:
    
    OSError                                   Traceback (most recent call last)
    /usr/local/lib/python3.7/dist-packages/h5py/_hl/files.py in make_fid(name, mode, userblock_size, fapl, fcpl, swmr)
        171         if swmr and swmr_support:
        172             flags |= h5f.ACC_SWMR_READ
    --> 173         fid = h5f.open(name, flags, fapl=fapl)
        174     elif mode == 'r+':
        175         fid = h5f.open(name, h5f.ACC_RDWR, fapl=fapl)
    
    h5py/_objects.pyx in h5py._objects.with_phil.wrapper()
    
    h5py/_objects.pyx in h5py._objects.with_phil.wrapper()
    
    h5py/h5f.pyx in h5py.h5f.open()
    
    OSError: Unable to open file (file signature not found)
    
    opened by DevPranjal 2
  • Question about batch norm

    Question about batch norm

    Thank you so much for creating this. For resnet, I am wondering how the implementation of BatchNorm differs from the Flax implementation? Basically, I'm wondering if I can replace ops.BatchNorm with flax.linen.BatchNorm to reduce dependencies? Thanks!

    opened by crawles 2
  • Weight Initialization PRNGKey

    Weight Initialization PRNGKey

    Hi,

    First of all, you did a really good job on converting the stylegan2 implementation to jax. I am not sure if this is an issue or it was originally intended to design the module this way, but one thing I noticed is that you pass the PRNG for the weight initialization as a parameter in the flax module.

    This means that initializing the module using flax.Module.init with different PRNG will result in the same weight initialization for both the generator and the discriminator. The only way to produce different initialization would be to pass different PRNGKeys at the creation of the StyleGAN module.

    Here's a minimal code that I made the demonstrates this:

    # Create a StyleGAN2 Generator Flax.Module
    G_model = stylegan2.generator.Generator(pretrained=None)
    
    # Invoking this method will initialize the module based on the PRNGKey passed (i.e., g_rng)
    def init_g(g_rng):
        z_shape = (4, 512)
    
        @jax.jit
        def _init(*args):
          return G_model.init(*args, train=True)
    
        variables = _init({'params': g_rng}, jnp.ones(z_shape, G_model.dtype))
        return variables['params'], variables['moving_stats'], variables['noise_consts']
    

    Now initializing the module using the flax.Module.init will give the same params:

    params_1, _, _ = init_g(jax.random.PRNGKey(10))
    params_2, _, _ = init_g(jax.random.PRNGKey(58))
    ## The following returns true
    jnp.alltrue( params_1['mapping_network']['LinearLayer_0']['weight'] == params_2['mapping_network']['LinearLayer_0']['weight'])
    

    In order to initialize the module with different seeds, then you need to pass the PRNGKey explicitly at the creation of the Module.

    If this was intended, then I think you also need to split the RNGs whenever the key is passed to submodules. This will produce different random numbers for different weights. Otherwise, creating the same layer twice (for example two ops.LinearLayer with the same hyperparameters, including the PRNGKey) will create the exact parameters for these layers.

    Thanks again for the hard work on making the module Jax accessible 👍

    opened by moabarar 2
  • StyleGAN2 `--fmap_base` only affects generator feature maps

    StyleGAN2 `--fmap_base` only affects generator feature maps

    Hello Matthias,

    We have noticed that adjusting the value of --fmap_base only affects the generator:

    https://github.com/matthias-wright/flaxmodels/blob/0ec7f22bda80c3e3c475e976af92b838cbbc22d4/training/stylegan2/training.py#L78

    but not the discriminator:

    https://github.com/matthias-wright/flaxmodels/blob/0ec7f22bda80c3e3c475e976af92b838cbbc22d4/training/stylegan2/training.py#L101

    Is this intentional?

    In the StyleGAN 2 paper, both G and D receive increased capacity (bottom of page 7):

    This leads us to hypothesize that there is a capacity problem in our networks, which we test by doubling the number of feature maps in the highest-resolution layers of both networks.

    We double the number of feature maps in resolutions 64^2–1024^2 while keeping other parts of the networks unchanged. This increases the total number of trainable parameters in the generator by 22% (25M → 30M) and in the discriminator by 21% (24M → 29M).

    opened by MasterScrat 1
  • Use `safetensors` to store tensors instead of `pickle`

    Use `safetensors` to store tensors instead of `pickle`

    Hi @matthias-wright, I've been playing around for a couple days with your project and it's so cool, thanks for building some pure flax models here 👍🏻

    Don't know if you're aware, but @huggingface developed a new format for storing tensors named safetensors as most of the serialized models from PyTorch use pickle to store the tensors, which seems to be not super efficient plus it has some known security issues. So I want to know whether you're considering to port the current tensors to use safetensors instead.

    I've recently built safejax so as to easily do that, which means that the storage is optimal and more safe! If this is something you could consider to improve flaxmodels please let me know and I can try to help if applicable!

    P.S. Did you consider publishing the Python package to PyPI tracking it through GitHub Release so that it attracts more users due to the ease of installation through pip from PyPI instead of from source as in the README.md?

    opened by alvarobartt 3
  • Possible bug in StyleGAN training code

    Possible bug in StyleGAN training code

    Hi,

    Nice repo! Is this line a bug? Since I think batch['images'] is N x B x H x W x C, so the indices should be shift up by 1.

    https://github.com/matthias-wright/flaxmodels/blob/edc6a8571a6d7202bd9f3bc9241221405c083fd4/training/stylegan2/training.py#L239

    opened by wilson1yan 1
Owner
Matthias Wright
PhD Student in Computer Vision @ Heidelberg University
Matthias Wright
Cartoon-StyleGan2 🙃 : Fine-tuning StyleGAN2 for Cartoon Face Generation

Fine-tuning StyleGAN2 for Cartoon Face Generation

Jihye Back 520 Jan 4, 2023
Base pretrained models and datasets in pytorch (MNIST, SVHN, CIFAR10, CIFAR100, STL10, AlexNet, VGG16, VGG19, ResNet, Inception, SqueezeNet)

This is a playground for pytorch beginners, which contains predefined models on popular dataset. Currently we support mnist, svhn cifar10, cifar100 st

Aaron Chen 2.4k Dec 28, 2022
Flax is a neural network ecosystem for JAX that is designed for flexibility.

Flax: A neural network library and ecosystem for JAX designed for flexibility Overview | Quick install | What does Flax look like? | Documentation See

Google 3.9k Jan 2, 2023
Very deep VAEs in JAX/Flax

Very Deep VAEs in JAX/Flax Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on I

Jamie Townsend 42 Dec 12, 2022
Standalone pre-training recipe with JAX+Flax

Sabertooth Sabertooth is standalone pre-training recipe based on JAX+Flax, with data pipelines implemented in Rust. It runs on CPU, GPU, and/or TPU, b

Nikita Kitaev 26 Nov 28, 2022
Local Attention - Flax module for Jax

Local Attention - Flax Autoregressive Local Attention - Flax module for Jax Install $ pip install local-attention-flax Usage from jax import random fr

Phil Wang 16 Jun 16, 2022
Implementation of FitVid video prediction model in JAX/Flax.

FitVid Video Prediction Model Implementation of FitVid video prediction model in JAX/Flax. If you find this code useful, please cite it in your paper:

Google Research 62 Nov 25, 2022
Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax

Clockwork VAEs in JAX/Flax Implementation of experiments in the paper Clockwork Variational Autoencoders (project website) using JAX and Flax, ported

Julius Kunze 26 Oct 5, 2022
Reimplementation of the paper "Attention, Learn to Solve Routing Problems!" in jax/flax.

JAX + Attention Learn To Solve Routing Problems Reinplementation of the paper Attention, Learn to Solve Routing Problems! using Jax and Flax. Fully su

Gabriela Surita 7 Dec 1, 2022
JAXDL: JAX (Flax) Deep Learning Library

JAXDL: JAX (Flax) Deep Learning Library Simple and clean JAX/Flax deep learning algorithm implementations: Soft-Actor-Critic (arXiv:1812.05905) Transf

Patrick Hart 4 Nov 27, 2022
Advantage Actor Critic (A2C): jax + flax implementation

Advantage Actor Critic (A2C): jax + flax implementation Current version supports only environments with continious action spaces and was tested on muj

Andrey 3 Jan 23, 2022
RepVGG: Making VGG-style ConvNets Great Again

RepVGG: Making VGG-style ConvNets Great Again (PyTorch) This is a super simple ConvNet architecture that achieves over 80% top-1 accuracy on ImageNet

null 2.8k Jan 4, 2023
RepVGG: Making VGG-style ConvNets Great Again

This repository is the code that needs to be submitted for OpenMMLab Algorithm Ecological Challenge,the paper is RepVGG: Making VGG-style ConvNets Great Again

Ty Feng 62 May 21, 2022
Quickly comparing your image classification models with the state-of-the-art models (such as DenseNet, ResNet, ...)

Image Classification Project Killer in PyTorch This repo is designed for those who want to start their experiments two days before the deadline and ki

null 349 Dec 8, 2022
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
A collection of pre-trained StyleGAN2 models trained on different datasets at different resolution.

Awesome Pretrained StyleGAN2 A collection of pre-trained StyleGAN2 models trained on different datasets at different resolution. Note the readme is a

Justin 1.1k Dec 24, 2022
Facebook Research 605 Jan 2, 2023