TRIQ implementation

Overview

TRIQ Implementation

TF-Keras implementation of TRIQ as described in Transformer for Image Quality Assessment.

Installation

  1. Clone this repository.
  2. Install required Python packages. The code is developed by PyCharm in Python 3.7. The requirements.txt document is generated by PyCharm, and the code should also be run in latest versions of the packages.

Training a model

An example of training TRIQ can be seen in train/train_triq.py. Argparser should be used, but the authors prefer to use dictionary with parameters being defined. It is easy to convert to take arguments. In principle, the following parameters can be defined:

args = {}
args['multi_gpu'] = 0 # gpu setting, set to 1 for using multiple GPUs
args['gpu'] = 0  # If having multiple GPUs, specify which GPU to use

args['result_folder'] = r'..\databases\experiments' # Define result path
args['n_quality_levels'] = 5  # Choose between 1 (MOS prediction) and 5 (distribution prediction)

args['transformer_params'] = [2, 32, 8, 64]

args['train_folders'] =  # Define folders containing training images
    [
    r'..\databases\train\koniq_normal',
    r'..\databases\train\koniq_small',
    r'..\databases\train\live'
    ]
args['val_folders'] =  # Define folders containing testing images
    [
    r'..\databases\val\koniq_normal',
    r'..\databases\val\koniq_small',
    r'..\databases\val\live'
    ]
args['koniq_mos_file'] = r'..\databases\koniq10k_images_scores.csv'  # MOS (distribution of scores) file for KonIQ database
args['live_mos_file'] = r'..\databases\live_mos.csv'   # MOS (standard distribution of scores) file for LIVE-wild database

args['backbone'] = 'resnet50' # Choose from ['resnet50', 'vgg16']
args['weights'] = r'...\pretrained_weights\resnet50_weights_tf_dim_ordering_tf_kernels_notop.h5'  # Define the path of ImageNet pretrained weights
args['initial_epoch'] = 0  # Define initial epoch for use in fine-tune

args['lr_base'] = 1e-4 / 2  # Define the back learning rate in warmup and rate decay approach
args['lr_schedule'] = True  # Choose between True and False, indicating if learning rate schedule should be used or not
args['batch_size'] = 32  # Batch size, should choose to fit in the GPU memory
args['epochs'] = 120  # Maximal epoch number, can set early stop in the callback or not

args['image_aug'] = True # Choose between True and False, indicating if image augmentation should be used or not

Predict image quality using the trained model

After TRIQ has been trained, and the weights have been stored in h5 file, it can be used to predict image quality with arbitrary sizes,

    args = {}
    args['n_quality_levels'] = 5
    args['backbone'] = 'resnet50'
    args['weights'] = r'..\\TRIQ.h5'
    model = create_triq_model(n_quality_levels=args['n_quality_levels'],
                              backbone=args['backbone'],])
    model.load_weights(args['weights'])

And then use ModelEvaluation to predict quality of image set.

In the "examples" folder, an example script examples\image_quality_prediction.py is provided to use the trained weights to predict quality of example images. In the "train" folder, an example script train\validation.py is provided to use the trained weights to predict quality of images in folders.

A potential issue is image shape mismatch. For example, if an image is too large, then line 146 in transformer_iqa.py should be changed to increase the pooling size. For example, it can be changed to self.pooling_small = MaxPool2D(pool_size=(4, 4)) or even larger.

Prepare datasets for model training

This work uses two publicly available databases: KonIQ-10k KonIQ-10k: An ecologically valid database for deep learning of blind image quality assessment by V. Hosu, H. Lin, T. Sziranyi, and D. Saupe; and LIVE-wild Massive online crowdsourced study of subjective and objective picture quality by D. Ghadiyaram, and A.C. Bovik

  1. The two databases were merged, and then split to training and testing sets. Please see README in databases for details.

  2. Make MOS files (note: do NOT include head line):

    For database with score distribution available, the MOS file is like this (koniq format):

        image path, voter number of quality scale 1, voter number of quality scale 2, voter number of quality scale 3, voter number of quality scale 4, voter number of quality scale 5, MOS or Z-score
        10004473376.jpg,0,0,25,73,7,3.828571429
        10007357496.jpg,0,3,45,47,1,3.479166667
        10007903636.jpg,1,0,20,73,2,3.78125
        10009096245.jpg,0,0,21,75,13,3.926605505
    

    For database with standard deviation available, the MOS file is like this (live format):

        image path, standard deviation, MOS or Z-score
        t1.bmp,18.3762,63.9634
        t2.bmp,13.6514,25.3353
        t3.bmp,18.9246,48.9366
        t4.bmp,18.2414,35.8863
    

    The format of MOS file ('koniq' or 'live') and the format of MOS or Z-score ('mos' or 'z_score') should also be specified in misc/imageset_handler/get_image_scores.

  3. In the train script in train/train_triq.py the folders containing training and testing images are provided.

  4. Pretrained ImageNet weights can be downloaded (see README in.\pretrained_weights) and pointed to in the train script.

Trained TRIQ weights

TRIQ has been trained on KonIQ-10k and LIVE-wild databases, and the weights file can be downloaded here.

State-of-the-art models

Other three models are also included in the work. The original implementations of metrics are employed, and they can be found below.

Koncept512 KonIQ-10k: An ecologically valid database for deep learning of blind image quality assessment

SGDNet SGDNet: An end-to-end saliency-guided deep neural network for no-reference image quality assessment

CaHDC End-to-end blind image quality prediction with cascaded deep neural network

Comparison results

We have conducted several experiments to evaluate the performance of TRIQ, please see results.pdf for detailed results.

Error report

In case errors/exceptions are encountered, please first check all the paths. After fixing the path isse, please report any errors in Issues.

FAQ

  • To be added

ViT (Vision Transformer) for IQA

This work is heavily inspired by ViT An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale. The module vit_iqa contains implementation of ViT for IQA, and mainly followed the implementation of ViT-PyTorch. Pretrained ViT weights can be downloaded here.

Comments
  • Training

    Training

    hello,I want to rapeat your work and rewrite it by pytorch. can you tell me more about the detail about training,"A base learning rate 5e-5 was used for pretraining"you mean pretrain in the same dataset(Koniq-10k and livec)?

    opened by zmm96 10
  • Same output for every input image

    Same output for every input image

    def create_triq_model(n_quality_levels,
                          input_shape=(None, None, 3),
                          backbone='resnet50',
                          transformer_params=(2, 32, 8, 64),
                          maximum_position_encoding=193,
                          vis=False):
        chanDim = -1
        # define the model input
        inputs = Input(shape=input_shape)
        filters = (32, 64, 128)
        # loop over the number of filters
        for (i, f) in enumerate(filters):
            # if this is the first CONV layer then set the input
            # appropriately
            if i == 0:
                x = Rescaling(1./255)(inputs)
    
            # CONV => RELU => BN => POOL
            x = Conv2D(f, (3, 3), padding="same")(x)
            x = Activation("relu")(x)
            x = BatchNormalization(axis=chanDim)(x)
            x = MaxPooling2D(pool_size=(2, 2))(x)
        
        x = Conv2D(256, (3, 3), padding="same")(x)
        x = Activation("relu")(x)
        x = BatchNormalization(axis=chanDim)(x)
        x = MaxPooling2D(pool_size=(2, 2))(x)
        
        x = ZeroPadding2D(padding=(1, 1))(x)
        x = Conv2D(2048, (3, 3), padding="same")(x)
        x = Activation("relu")(x)
        x = BatchNormalization(axis=chanDim)(x)
        x = MaxPooling2D(pool_size=(2, 2))(x)
        dropout_rate = 0.1
        
        transformer = TriQImageQualityTransformer(
            num_layers=transformer_params[0],
            d_model=transformer_params[1],
            num_heads=transformer_params[2],
            mlp_dim=transformer_params[3],
            dropout=dropout_rate,
            n_quality_levels=n_quality_levels,
            maximum_position_encoding=maximum_position_encoding,
            vis=vis
        )
        outputs = transformer(x)
      
        model = Model(inputs=inputs, outputs=outputs)
        model.summary()
        return model
    
    gpus = tf.config.experimental.list_physical_devices('GPU')
    tf.config.experimental.set_visible_devices(gpus[0], 'GPU')
    input_shape = (564, 504, 3)
    #model = create_triq_model(n_quality_levels=5, input_shape=input_shape, backbone='vgg16')
    model = create_triq_model(n_quality_levels=1, input_shape=input_shape, backbone='resnet50')
    
    from tensorflow.keras.optimizers import Adam
    opt = Adam(learning_rate=0.001, decay=1e-3 / 200)
    model.compile(loss="mean_squared_error", optimizer=opt)
    model.fit(trainImagesX, trainY, validation_data=(valImagesX, valY),
              epochs=108, batch_size=16)
    
    

    In the above code, I have modified the create_triq_model function in such a way that it uses a custom CNN model instead of the RSNET or VGGNet. The custom CNN model is such that its output shape is (18, 16, 2048). This output is fed to TriqImageQualityTransformer.

    The issue is that after training the model predicts the same value for every input. I have experimented with various hyperparameters. It might output different values for different hyperparameter settings but for a particular setting, for every image as input, it outputs the same output. One more thing to note is that if I do not use a transformer but instead use an Artificial Neural Network, then the network trains well.

    Ca you please suggest what am I doing wrong here?

    opened by sulakshgupta988 6
  • run image_quality_prediction.py shape erro

    run image_quality_prediction.py shape erro

    thanks for your work. it is very cool. I test jpg image with size 1919 × 1440. it will show me that:

    tensorflow.python.framework.errors_impl.InvalidArgumentError:  Incompatible shapes: [1,661,32] vs. [1,193,32]
    	 [[node model/tri_q_image_quality_transformer/add_1 (defined at /data2/zhx3/triq/src/models/transformer_iqa.py:198) ]] [Op:__inference_predict_function_9663]
    
    Errors may have originated from an input operation.
    Input Source operations connected to node model/tri_q_image_quality_transformer/add_1:
     model/tri_q_image_quality_transformer/concat (defined at /data2/zhx3/triq/src/models/transformer_iqa.py:194)
    
    Function call stack:
    predict_function
    
    opened by Usernamezhx 6
  • Input

    Input

    Hello, I have been reproducing this project recently, and it feels great. Now I encounter a problem. I want to input two pictures at a time (or input one, and then enter the model and then segment it), I have read it for a long time, but I didn't find where to modify it. For the input-shape of None type, I can do nothing. Looking forward to your comments and guidance, thank you very much. - a beginner

    opened by Alen334 5
  • Issue with Training - Generator error

    Issue with Training - Generator error

    Hello! I followed all the instructions for training and prepared the data & labels accordingly. When I ran the training script it runs for a few steps say 170/2135 and then it stops throwing exception errors. image image image

    I then changed return np.array(images_aug), np.array(y_scores) to return np.array(images_aug, dtype='object'), np.array(y_scores, dtype='object'), but now script is just stuck and doesn't consume much GPU memory after a while(700MB/16GB). I even tried training from scratch(not loaded ImageNet pretrained weights) but still no luck.

    My conda env details: tensorflow-gpu==2.1.0 tensorflow_addons==0.8.3 h5py==2.10.0

    opened by nikhilgunti 5
  • About dataset

    About dataset

    Hello, I have a question, the data shape of koniq-10k dataset is not consistent. Some is (224,224), otherwise some is(224,224,3)。but I do not find the process about the difference. Can you tell me more about the detail? thanks a lot.

    opened by zmm96 5
  • plcc

    plcc

    Hello, I would like to ask what is the value of PLCC of the training set you get, when the epoch of training is 120? I think the result I get is a bit wrong.

    opened by Alen334 4
  • OOM

    OOM

    hello, i have a question. I want to predict all the pictures of koniq using the trained model. So, I used a loop to process all the pictures in the folder, but there are some problem like this, can you help me?

    tensorflow.python.framework.errors_impl.ResourceExhaustedError: OOM when allocating tensor with shape[64,386,514] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [[node model_10/pool1_pad/Pad (defined at E:/Graudate/Code/triq-master-play/src/examples/image_quality_prediction.py:21) ]] Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info. [Op:__inference_predict_function_106843]

    opened by cs19469 4
  • About the tensorflow version

    About the tensorflow version

    Hi,

    The requirement.txt file said that the TensorFlow version used in this project is 2.2.0. However, when I tried to run the train_triq.py file, the error happened, which said that "Tensorflow Addons supports using Python ops for all Tensorflow versions above or equal to 2.3.0 and strictly below 2.5.0 (nightly versions are not supported)". It seems like the function tensorflow_addons.activations.gelu does not support the TensorFlow 2.2.0.

    I'm not familiar with TensorFlow. Therefore, I want to check the TensorFlow version and discuss why this error happened.

    opened by sunguwei 4
  • About koniq-10k dataset

    About koniq-10k dataset

    opened by CharlesWu123 3
  • Handling different size inputs during training

    Handling different size inputs during training

    Hi,

    Could you please tell how you handled different image sizes as input during the training phase? Lets say we have three images of size (1080x1080), (1608x1608) and (2000x2000). If we give these images as an input to the network during training, how was this taken care of? Were the images padded with zeros to the image resolution of maximum size? Thanks.

    opened by arp95 3
Owner
Junyong You
Junyong You
ALBERT-pytorch-implementation - ALBERT pytorch implementation

ALBERT-pytorch-implementation developing... 모델의 개념이해를 돕기 위한 구현물로 현재 변수명을 상세히 적었고

BG Kim 3 Oct 6, 2022
Numenta Platform for Intelligent Computing is an implementation of Hierarchical Temporal Memory (HTM), a theory of intelligence based strictly on the neuroscience of the neocortex.

NuPIC Numenta Platform for Intelligent Computing The Numenta Platform for Intelligent Computing (NuPIC) is a machine intelligence platform that implem

Numenta 6.3k Dec 30, 2022
PyTorch implementation of neural style transfer algorithm

neural-style-pt This is a PyTorch implementation of the paper A Neural Algorithm of Artistic Style by Leon A. Gatys, Alexander S. Ecker, and Matthias

null 770 Jan 2, 2023
PyTorch implementation of DeepDream algorithm

neural-dream This is a PyTorch implementation of DeepDream. The code is based on neural-style-pt. Here we DeepDream a photograph of the Golden Gate Br

null 121 Nov 5, 2022
The project is an official implementation of our CVPR2019 paper "Deep High-Resolution Representation Learning for Human Pose Estimation"

Deep High-Resolution Representation Learning for Human Pose Estimation (CVPR 2019) News [2020/07/05] A very nice blog from Towards Data Science introd

Leo Xiao 3.9k Jan 5, 2023
Image-to-Image Translation with Conditional Adversarial Networks (Pix2pix) implementation in keras

pix2pix-keras Pix2pix implementation in keras. Original paper: Image-to-Image Translation with Conditional Adversarial Networks (pix2pix) Paper Author

William Falcon 141 Dec 30, 2022
Python implementation of cover trees, near-drop-in replacement for scipy.spatial.kdtree

This is a Python implementation of cover trees, a data structure for finding nearest neighbors in a general metric space (e.g., a 3D box with periodic

Patrick Varilly 28 Nov 25, 2022
Home repository for the Regularized Greedy Forest (RGF) library. It includes original implementation from the paper and multithreaded one written in C++, along with various language-specific wrappers.

Regularized Greedy Forest Regularized Greedy Forest (RGF) is a tree ensemble machine learning method described in this paper. RGF can deliver better r

RGF-team 364 Dec 28, 2022
Implementation of Restricted Boltzmann Machine (RBM) and its variants in Tensorflow

xRBM Library Implementation of Restricted Boltzmann Machine (RBM) and its variants in Tensorflow Installation Using pip: pip install xrbm Examples Tut

Omid Alemi 55 Dec 29, 2022
A fast Evolution Strategy implementation in Python

Evostra: Evolution Strategy for Python Evolution Strategy (ES) is an optimization technique based on ideas of adaptation and evolution. You can learn

Mika 251 Dec 8, 2022
🌳 A Python-inspired implementation of the Optimum-Path Forest classifier.

OPFython: A Python-Inspired Optimum-Path Forest Classifier Welcome to OPFython. Note that this implementation relies purely on the standard LibOPF. Th

Gustavo Rosa 30 Jan 4, 2023
Implementation of Geometric Vector Perceptron, a simple circuit for 3d rotation equivariance for learning over large biomolecules, in Pytorch. Idea proposed and accepted at ICLR 2021

Geometric Vector Perceptron Implementation of Geometric Vector Perceptron, a simple circuit with 3d rotation equivariance for learning over large biom

Phil Wang 59 Nov 24, 2022
Official implementation of AAAI-21 paper "Label Confusion Learning to Enhance Text Classification Models"

Description: This is the official implementation of our AAAI-21 accepted paper Label Confusion Learning to Enhance Text Classification Models. The str

null 101 Nov 25, 2022
Official PyTorch implementation for paper Context Matters: Graph-based Self-supervised Representation Learning for Medical Images

Context Matters: Graph-based Self-supervised Representation Learning for Medical Images Official PyTorch implementation for paper Context Matters: Gra

null 49 Nov 23, 2022
PyTorch implementation of "Conformer: Convolution-augmented Transformer for Speech Recognition" (INTERSPEECH 2020)

PyTorch implementation of Conformer: Convolution-augmented Transformer for Speech Recognition. Transformer models are good at capturing content-based

Soohwan Kim 565 Jan 4, 2023
An essential implementation of BYOL in PyTorch + PyTorch Lightning

Essential BYOL A simple and complete implementation of Bootstrap your own latent: A new approach to self-supervised Learning in PyTorch + PyTorch Ligh

Enrico Fini 48 Sep 27, 2022
The official implementation of NeMo: Neural Mesh Models of Contrastive Features for Robust 3D Pose Estimation [ICLR-2021]. https://arxiv.org/pdf/2101.12378.pdf

NeMo: Neural Mesh Models of Contrastive Features for Robust 3D Pose Estimation [ICLR-2021] Release Notes The offical PyTorch implementation of NeMo, p

Angtian Wang 76 Nov 23, 2022
A PyTorch re-implementation of the paper 'Exploring Simple Siamese Representation Learning'. Reproduced the 67.8% Top1 Acc on ImageNet.

Exploring simple siamese representation learning This is a PyTorch re-implementation of the SimSiam paper on ImageNet dataset. The results match that

Taojiannan Yang 72 Nov 9, 2022
PyTorch implementation of "A Full-Band and Sub-Band Fusion Model for Real-Time Single-Channel Speech Enhancement."

FullSubNet This Git repository for the official PyTorch implementation of "A Full-Band and Sub-Band Fusion Model for Real-Time Single-Channel Speech E

郝翔 357 Jan 4, 2023