Includes PyTorch -> Keras model porting code for ConvNeXt family of models with fine-tuning and inference notebooks.

Overview

ConvNeXt-TF

This repository provides TensorFlow / Keras implementations of different ConvNeXt [1] variants. It also provides the TensorFlow / Keras models that have been populated with the original ConvNeXt pre-trained weights available from [2]. These models are not blackbox SavedModels i.e., they can be fully expanded into tf.keras.Model objects and one can call all the utility functions on them (example: .summary()).

As of today, all the TensorFlow / Keras variants of the models listed here are available in this repository except for the isotropic ones. This list includes the ImageNet-1k as well as ImageNet-21k models.

Refer to the "Using the models" section to get started. Additionally, here's a related blog post that jots down my experience.

Conversion

TensorFlow / Keras implementations are available in models/convnext_tf.py. Conversion utilities are in convert.py.

Models

The converted models are available on TF-Hub.

There should be a total of 15 different models each having two variants: classifier and feature extractor. You can load any model and get started like so:

import tensorflow as tf

model_gcs_path = "gs://tfhub-modules/sayakpaul/convnext_tiny_1k_224/1/uncompressed"
model = tf.keras.models.load_model(model_gcs_path)
print(model.summary(expand_nested=True))

The model names are interpreted as follows:

  • convnext_large_21k_1k_384: This means that the model was first pre-trained on the ImageNet-21k dataset and was then fine-tuned on the ImageNet-1k dataset. Resolution used during pre-training and fine-tuning: 384x384. large denotes the topology of the underlying model.
  • convnext_large_1k_224: Means that the model was pre-trained on the ImageNet-1k dataset with a resolution of 224x224.

Results

Results are on ImageNet-1k validation set (top-1 accuracy).

name original acc@1 keras acc@1
convnext_tiny_1k_224 82.1 81.312
convnext_small_1k_224 83.1 82.392
convnext_base_1k_224 83.8 83.28
convnext_base_1k_384 85.1 84.876
convnext_large_1k_224 84.3 83.844
convnext_large_1k_384 85.5 85.376
convnext_base_21k_1k_224 85.8 85.364
convnext_base_21k_1k_384 86.8 86.79
convnext_large_21k_1k_224 86.6 86.36
convnext_large_21k_1k_384 87.5 87.504
convnext_xlarge_21k_1k_224 87.0 86.732
convnext_xlarge_21k_1k_384 87.8 87.68

Differences in the results are primarily because of the differences in the library implementations especially how image resizing is implemented in PyTorch and TensorFlow. Results can be verified with the code in i1k_eval. Logs are available at this URL.

Using the models

Pre-trained models:

Randomly initialized models:

from models.convnext_tf import get_convnext_model

convnext_tiny = get_convnext_model()
print(convnext_tiny.summary(expand_nested=True))

To view different model configurations, refer here.

Upcoming (contributions welcome)

  • Align layer initializers (useful if someone wanted to train the models from scratch)
  • Allow the models to accept arbitrary shapes (useful for downstream tasks)
  • Convert the isotropic models as well
  • Fine-tuning notebook (thanks to awsaf49)
  • Off-the-shelf-classification notebook
  • Publish models on TF-Hub

References

[1] ConvNeXt paper: https://arxiv.org/abs/2201.03545

[2] Official ConvNeXt code: https://github.com/facebookresearch/ConvNeXt

Acknowledgements

Comments
  • finetune notebook added

    finetune notebook added

    Hi, I have created a notebook to showcase fine-tuning using ConvNext-TF on Flower Classification Dataset. It supports both TPU and GPU. Let me know if it checks the requirements.

    opened by awsaf49 9
  • RuntimeError: Unable to create link (name already exists)

    RuntimeError: Unable to create link (name already exists)

    @sayakpaul I've tried to run conv-next from tf-hub but faced this error. Could you please take a look into this? The error can be reproduced with this script; with following addition:

    mdckpt = tf.keras.callbacks.ModelCheckpoint(
        "model.h5", 
        monitor='val_accuracy', 
        verbose=1, 
        save_best_only=True,
        save_weights_only=True, 
        mode='max', 
        save_freq='epoch'
    )
    

    Also note that, the error can be addrssed by renaming the layer parameter perhaps? For examle

    import uuid
    
    # ref web
    def handle_name_exist_issue(model):
        def unique_name():
            return uuid.uuid4().hex.upper()[0:10]
    
        def postprocess_weight_name(name):
            if len(name.split('/')) == 1:
                return f'{unique_name()}/{name}'
            elif len(name.split('/')) == 2:
                group, name = name.split('/')
                return f'{group}{unique_name()}/{name}'
            elif len(name.split('/')) == 3:
                group, name_1, name_2 = name.split('/')
                return f'{group}{unique_name()}/{name_1}/{name_2}'
       
        model._name = model._name + unique_name()
        for layer in model.layers:
            layer._name = layer._name + unique_name()
        for i in range(len(model.weights)):
            model.weights[i]._handle_name = postprocess_weight_name(model.weights[i].name)
        return model
    
    with strategy.scope(): 
      model = get_model(MODEL_PATH)
      model = handle_name_exist_issue(model)
      model.compile(loss=loss, optimizer=optimizer, metrics=["accuracy"])
    
    history = model.fit(train_dataset, validation_data=val_dataset, 
                       epochs=EPOCHS, callbacks=[mdckpt])
    

    By using handle_name_exist_issue, it solves on Colab (GPU: TF 2.8). But not on Kaggle (TPU TF 2.4.1). I didn't test on Colab TPU and not other TF versions.

    Have you faced such issues with these models? I also tried other TF-Hub model, they work as expected. However, any suggestioin for general solutions to work with conv-next-hub models on TPU?

    opened by innat 6
  • JSONDecodeError When Loading Model on Kaggle TPU TF 2.4.1

    JSONDecodeError When Loading Model on Kaggle TPU TF 2.4.1

    When loading the model in a Kaggle TPU notebook with Tensorflow version 2.4.1 a JSONDecodeError is thrown.

    What would be the cause of this error and how could one load a ConvNext model in Tensorflow 2.4.1?

    print(f'tensorflow version: {tf.__version__}')
    print(f'tensorflow keras version: {tf.keras.__version__}')
    print(f'python version: P{sys.version}')
    ​
    model_gcs_path = "gs://tfhub-modules/sayakpaul/convnext_tiny_1k_224/1/uncompressed"
    model = tf.keras.models.load_model(model_gcs_path)
    print(model.summary(expand_nested=True))
    

    Gives the following error trace:

    tensorflow version: 2.4.1
    tensorflow keras version: 2.4.0
    python version: P3.7.10 | packaged by conda-forge | (default, Feb 19 2021, 16:07:37) 
    [GCC 9.3.0]
    
    JSONDecodeError                           Traceback (most recent call last)
    /tmp/ipykernel_48/3485362230.py in <module>
          4 
          5 model_gcs_path = "gs://tfhub-modules/sayakpaul/convnext_tiny_1k_224/1/uncompressed"
    ----> 6 model = tf.keras.models.load_model(model_gcs_path)
          7 print(model.summary(expand_nested=True))
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/save.py in load_model(filepath, custom_objects, compile, options)
        210       if isinstance(filepath, six.string_types):
        211         loader_impl.parse_saved_model(filepath)
    --> 212         return saved_model_load.load(filepath, compile, options)
        213 
        214   raise IOError(
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py in load(path, compile, options)
        136   # Recreate layers and metrics using the info stored in the metadata.
        137   keras_loader = KerasObjectLoader(metadata, object_graph_def)
    --> 138   keras_loader.load_layers(compile=compile)
        139 
        140   # Generate a dictionary of all loaded nodes.
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py in load_layers(self, compile)
        374       self.loaded_nodes[node_metadata.node_id] = self._load_layer(
        375           node_metadata.node_id, node_metadata.identifier,
    --> 376           node_metadata.metadata)
        377 
        378     for node_metadata in metric_list:
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/load.py in _load_layer(self, node_id, identifier, metadata)
        395   def _load_layer(self, node_id, identifier, metadata):
        396     """Load a single layer from a SavedUserObject proto."""
    --> 397     metadata = json_utils.decode(metadata)
        398 
        399     # If node was already created
    
    /opt/conda/lib/python3.7/site-packages/tensorflow/python/keras/saving/saved_model/json_utils.py in decode(json_string)
         67 
         68 def decode(json_string):
    ---> 69   return json.loads(json_string, object_hook=_decode_helper)
         70 
         71 
    
    /opt/conda/lib/python3.7/json/__init__.py in loads(s, encoding, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw)
        359     if parse_constant is not None:
        360         kw['parse_constant'] = parse_constant
    --> 361     return cls(**kw).decode(s)
    
    /opt/conda/lib/python3.7/json/decoder.py in decode(self, s, _w)
        335 
        336         """
    --> 337         obj, end = self.raw_decode(s, idx=_w(s, 0).end())
        338         end = _w(s, end).end()
        339         if end != len(s):
    
    /opt/conda/lib/python3.7/json/decoder.py in raw_decode(self, s, idx)
        353             obj, end = self.scan_once(s, idx)
        354         except StopIteration as err:
    --> 355             raise JSONDecodeError("Expecting value", s, err.value) from None
        356         return obj, end
    
    JSONDecodeError: Expecting value: line 1 column 1 (char 0)
    
    opened by MarkWijkhuizen 6
  • Model arguments for get_convnext_model.

    Model arguments for get_convnext_model.

    Hi there.

    May you provide the arguments for "get_convnext_model" function to produce similar models to the lists here: https://github.com/facebookresearch/ConvNeXt#results-and-pre-trained-models

    In particular, I need them to generate "convnext_base_1k_224" and "convnext_xlarge_21k_1k_224".

    opened by vafaei-ar 1
  • Huggingface Spaces

    Huggingface Spaces

    Hi, would you be interested in sharing a web demo on Huggingface Spaces for ConvNeXt-TF?

    It would make this model more accessible as it would allow people to try out the model directly from the browser. Some other recent machine learning model repos have set up Spaces for easy access including Convnext:

    github: https://github.com/facebookresearch/ConvNeXt Spaces: https://huggingface.co/spaces/akhaliq/convnext

    github: https://github.com/salesforce/BLIP Spaces: https://huggingface.co/spaces/akhaliq/BLIP

    github: https://github.com/facebookresearch/omnivore Spaces: https://huggingface.co/spaces/akhaliq/omnivore

    Spaces is completely free, and I can help setup a Gradio Space. Here are some getting started instructions if you'd prefer to do it yourself: https://huggingface.co/blog/gradio-spaces

    opened by AK391 1
  • Proposals

    Proposals

    • Added Conv2D initializations (https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py#L103)
    • Added Dense initializations (https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py#L103)
    • Attempted a fix for Padding on the DWConv in the ConvNext block (https://github.com/facebookresearch/ConvNeXt/blob/main/models/convnext.py#L28)
    opened by soumik12345 1
Owner
Sayak Paul
ML Engineer at @carted | One PR at a time
Sayak Paul
A web porting for NVlabs' StyleGAN2, to facilitate exploring all kinds characteristic of StyleGAN networks

This project is a web porting for NVlabs' StyleGAN2, to facilitate exploring all kinds characteristic of StyleGAN networks. Thanks for NVlabs' excelle

K.L. 150 Dec 15, 2022
PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices.

PyTorch-LIT PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices. With

Amin Rezaei 157 Dec 11, 2022
Black-Box-Tuning - Black-Box Tuning for Language-Model-as-a-Service

Black-Box-Tuning Source code for paper "Black-Box Tuning for Language-Model-as-a

Tianxiang Sun 149 Jan 4, 2023
This repository is the official implementation of Unleashing the Power of Contrastive Self-Supervised Visual Models via Contrast-Regularized Fine-Tuning (NeurIPS21).

Core-tuning This repository is the official implementation of ``Unleashing the Power of Contrastive Self-Supervised Visual Models via Contrast-Regular

vanint 18 Dec 17, 2022
Code for T-Few from "Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learning"

T-Few This repository contains the official code for the paper: "Few-Shot Parameter-Efficient Fine-Tuning is Better and Cheaper than In-Context Learni

null 220 Dec 31, 2022
Code for ACL2021 paper Consistency Regularization for Cross-Lingual Fine-Tuning.

xTune Code for ACL2021 paper Consistency Regularization for Cross-Lingual Fine-Tuning. Environment DockerFile: dancingsoul/pytorch:xTune Install the f

Bo Zheng 42 Dec 9, 2022
Example Of Fine-Tuning BERT For Named-Entity Recognition Task And Preparing For Cloud Deployment Using Flask, React, And Docker

Example Of Fine-Tuning BERT For Named-Entity Recognition Task And Preparing For Cloud Deployment Using Flask, React, And Docker This repository contai

Nikita 12 Dec 14, 2022
Ensemble Knowledge Guided Sub-network Search and Fine-tuning for Filter Pruning

Ensemble Knowledge Guided Sub-network Search and Fine-tuning for Filter Pruning This repository is official Tensorflow implementation of paper: Ensemb

Seunghyun Lee 12 Oct 18, 2022
Classification models 1D Zoo - Keras and TF.Keras

Classification models 1D Zoo - Keras and TF.Keras This repository contains 1D variants of popular CNN models for classification like ResNets, DenseNet

Roman Solovyev 12 Jan 6, 2023
Torchserve server using a YoloV5 model running on docker with GPU and static batch inference to perform production ready inference.

Yolov5 running on TorchServe (GPU compatible) ! This is a dockerfile to run TorchServe for Yolo v5 object detection model. (TorchServe (PyTorch librar

null 82 Nov 29, 2022
Cartoon-StyleGan2 🙃 : Fine-tuning StyleGAN2 for Cartoon Face Generation

Fine-tuning StyleGAN2 for Cartoon Face Generation

Jihye Back 520 Jan 4, 2023
Official codebase for Legged Robots that Keep on Learning: Fine-Tuning Locomotion Policies in the Real World

Legged Robots that Keep on Learning Official codebase for Legged Robots that Keep on Learning: Fine-Tuning Locomotion Policies in the Real World, whic

Laura Smith 70 Dec 7, 2022
Fine-tuning StyleGAN2 for Cartoon Face Generation

Cartoon-StyleGAN ?? : Fine-tuning StyleGAN2 for Cartoon Face Generation Abstract Recent studies have shown remarkable success in the unsupervised imag

Jihye Back 520 Jan 4, 2023
Implementation of the paper "Fine-Tuning Transformers: Vocabulary Transfer"

Transformer-vocabulary-transfer Implementation of the paper "Fine-Tuning Transfo

LEYA 13 Nov 30, 2022
YOLOv5 🚀 is a family of object detection architectures and models pretrained on the COCO dataset

YOLOv5 ?? is a family of object detection architectures and models pretrained on the COCO dataset, and represents Ultralytics open-source research int

阿才 73 Dec 16, 2022
RITA is a family of autoregressive protein models, developed by LightOn in collaboration with the OATML group at Oxford and the Debora Marks Lab at Harvard.

RITA: a Study on Scaling Up Generative Protein Sequence Models RITA is a family of autoregressive protein models, developed by a collaboration of Ligh

LightOn 69 Dec 22, 2022
The Hailo Model Zoo includes pre-trained models and a full building and evaluation environment

Hailo Model Zoo The Hailo Model Zoo provides pre-trained models for high-performance deep learning applications. Using the Hailo Model Zoo you can mea

Hailo 50 Dec 7, 2022