Tensorflow 2.x based implementation of EDSR, WDSR and SRGAN for single image super-resolution

Overview

Travis CI

Single Image Super-Resolution with EDSR, WDSR and SRGAN

A Tensorflow 2.x based implementation of

This is a complete re-write of the old Keras/Tensorflow 1.x based implementation available here. Some parts are still work in progress but you can already train models as described in the papers via a high-level training API. Furthermore, you can also fine-tune EDSR and WDSR models in an SRGAN context. Training and usage examples are given in the notebooks

A DIV2K data provider automatically downloads DIV2K training and validation images of given scale (2, 3, 4 or 8) and downgrade operator ("bicubic", "unknown", "mild" or "difficult").

Important: if you want to evaluate the pre-trained models with a dataset other than DIV2K please read this comment (and replies) first.

Environment setup

Create a new conda environment with

conda env create -f environment.yml

and activate it with

conda activate sisr

Introduction

You can find an introduction to single-image super-resolution in this article. It also demonstrates how EDSR and WDSR models can be fine-tuned with SRGAN (see also this section).

Getting started

Examples in this section require following pre-trained weights for running (see also example notebooks):

Pre-trained weights

  • weights-edsr-16-x4.tar.gz
    • EDSR x4 baseline as described in the EDSR paper: 16 residual blocks, 64 filters, 1.52M parameters.
    • PSNR on DIV2K validation set = 28.89 dB (images 801 - 900, 6 + 4 pixel border included).
  • weights-wdsr-b-32-x4.tar.gz
    • WDSR B x4 custom model: 32 residual blocks, 32 filters, expansion factor 6, 0.62M parameters.
    • PSNR on DIV2K validation set = 28.91 dB (images 801 - 900, 6 + 4 pixel border included).
  • weights-srgan.tar.gz
    • SRGAN as described in the SRGAN paper: 1.55M parameters, trained with VGG54 content loss.

After download, extract them in the root folder of the project with

tar xvfz weights-<...>.tar.gz

EDSR

from model import resolve_single
from model.edsr import edsr

from utils import load_image, plot_sample

model = edsr(scale=4, num_res_blocks=16)
model.load_weights('weights/edsr-16-x4/weights.h5')

lr = load_image('demo/0851x4-crop.png')
sr = resolve_single(model, lr)

plot_sample(lr, sr)

result-edsr

WDSR

from model.wdsr import wdsr_b

model = wdsr_b(scale=4, num_res_blocks=32)
model.load_weights('weights/wdsr-b-32-x4/weights.h5')

lr = load_image('demo/0829x4-crop.png')
sr = resolve_single(model, lr)

plot_sample(lr, sr)

result-wdsr

Weight normalization in WDSR models is implemented with the new WeightNormalization layer wrapper of Tensorflow Addons. In its latest version, this wrapper seems to corrupt weights when running model.predict(...). A workaround is to set model.run_eagerly = True or compile the model with model.compile(loss='mae') in advance. This issue doesn't arise when calling the model directly with model(...) though. To be further investigated ...

SRGAN

from model.srgan import generator

model = generator()
model.load_weights('weights/srgan/gan_generator.h5')

lr = load_image('demo/0869x4-crop.png')
sr = resolve_single(model, lr)

plot_sample(lr, sr)

result-srgan

DIV2K dataset

For training and validation on DIV2K images, applications should use the provided DIV2K data loader. It automatically downloads DIV2K images to .div2k directory and converts them to a different format for faster loading.

Training dataset

from data import DIV2K

train_loader = DIV2K(scale=4,             # 2, 3, 4 or 8
                     downgrade='bicubic', # 'bicubic', 'unknown', 'mild' or 'difficult' 
                     subset='train')      # Training dataset are images 001 - 800
                     
# Create a tf.data.Dataset          
train_ds = train_loader.dataset(batch_size=16,         # batch size as described in the EDSR and WDSR papers
                                random_transform=True, # random crop, flip, rotate as described in the EDSR paper
                                repeat_count=None)     # repeat iterating over training images indefinitely

# Iterate over LR/HR image pairs                                
for lr, hr in train_ds:
    # .... 

Crop size in HR images is 96x96.

Validation dataset

from data import DIV2K

valid_loader = DIV2K(scale=4,             # 2, 3, 4 or 8
                     downgrade='bicubic', # 'bicubic', 'unknown', 'mild' or 'difficult' 
                     subset='valid')      # Validation dataset are images 801 - 900
                     
# Create a tf.data.Dataset          
valid_ds = valid_loader.dataset(batch_size=1,           # use batch size of 1 as DIV2K images have different size
                                random_transform=False, # use DIV2K images in original size 
                                repeat_count=1)         # 1 epoch
                                
# Iterate over LR/HR image pairs                                
for lr, hr in valid_ds:
    # ....                                 

Training

The following training examples use the training and validation datasets described earlier. The high-level training API is designed around steps (= minibatch updates) rather than epochs to better match the descriptions in the papers.

EDSR

from model.edsr import edsr
from train import EdsrTrainer

# Create a training context for an EDSR x4 model with 16 
# residual blocks.
trainer = EdsrTrainer(model=edsr(scale=4, num_res_blocks=16), 
                      checkpoint_dir=f'.ckpt/edsr-16-x4')
                      
# Train EDSR model for 300,000 steps and evaluate model
# every 1000 steps on the first 10 images of the DIV2K
# validation set. Save a checkpoint only if evaluation
# PSNR has improved.
trainer.train(train_ds,
              valid_ds.take(10),
              steps=300000, 
              evaluate_every=1000, 
              save_best_only=True)
              
# Restore from checkpoint with highest PSNR.
trainer.restore()

# Evaluate model on full validation set.
psnr = trainer.evaluate(valid_ds)
print(f'PSNR = {psnr.numpy():3f}')

# Save weights to separate location.
trainer.model.save_weights('weights/edsr-16-x4/weights.h5')                                    

Interrupting training and restarting it again resumes from the latest saved checkpoint. The trained Keras model can be accessed with trainer.model.

WDSR

from model.wdsr import wdsr_b
from train import WdsrTrainer

# Create a training context for a WDSR B x4 model with 32 
# residual blocks.
trainer = WdsrTrainer(model=wdsr_b(scale=4, num_res_blocks=32), 
                      checkpoint_dir=f'.ckpt/wdsr-b-8-x4')

# Train WDSR B model for 300,000 steps and evaluate model
# every 1000 steps on the first 10 images of the DIV2K
# validation set. Save a checkpoint only if evaluation
# PSNR has improved.
trainer.train(train_ds,
              valid_ds.take(10),
              steps=300000, 
              evaluate_every=1000, 
              save_best_only=True)

# Restore from checkpoint with highest PSNR.
trainer.restore()

# Evaluate model on full validation set.
psnr = trainer.evaluate(valid_ds)
print(f'PSNR = {psnr.numpy():3f}')

# Save weights to separate location.
trainer.model.save_weights('weights/wdsr-b-32-x4/weights.h5')

SRGAN

Generator pre-training

from model.srgan import generator
from train import SrganGeneratorTrainer

# Create a training context for the generator (SRResNet) alone.
pre_trainer = SrganGeneratorTrainer(model=generator(), checkpoint_dir=f'.ckpt/pre_generator')

# Pre-train the generator with 1,000,000 steps (100,000 works fine too). 
pre_trainer.train(train_ds, valid_ds.take(10), steps=1000000, evaluate_every=1000)

# Save weights of pre-trained generator (needed for fine-tuning with GAN).
pre_trainer.model.save_weights('weights/srgan/pre_generator.h5')

Generator fine-tuning (GAN)

from model.srgan import generator, discriminator
from train import SrganTrainer

# Create a new generator and init it with pre-trained weights.
gan_generator = generator()
gan_generator.load_weights('weights/srgan/pre_generator.h5')

# Create a training context for the GAN (generator + discriminator).
gan_trainer = SrganTrainer(generator=gan_generator, discriminator=discriminator())

# Train the GAN with 200,000 steps.
gan_trainer.train(train_ds, steps=200000)

# Save weights of generator and discriminator.
gan_trainer.generator.save_weights('weights/srgan/gan_generator.h5')
gan_trainer.discriminator.save_weights('weights/srgan/gan_discriminator.h5')

SRGAN for fine-tuning EDSR and WDSR models

It is also possible to fine-tune EDSR and WDSR x4 models with SRGAN. They can be used as drop-in replacement for the original SRGAN generator. More details in this article.

# Create EDSR generator and init with pre-trained weights
generator = edsr(scale=4, num_res_blocks=16)
generator.load_weights('weights/edsr-16-x4/weights.h5')

# Fine-tune EDSR model via SRGAN training.
gan_trainer = SrganTrainer(generator=generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=200000)
# Create WDSR B generator and init with pre-trained weights
generator = wdsr_b(scale=4, num_res_blocks=32)
generator.load_weights('weights/wdsr-b-16-32/weights.h5')

# Fine-tune WDSR B  model via SRGAN training.
gan_trainer = SrganTrainer(generator=generator, discriminator=discriminator())
gan_trainer.train(train_ds, steps=200000)
Issues
  • How to train the model for x8 scale?

    How to train the model for x8 scale?

    Thanks for creating this repo and very helpful blog post. I tried inference x2 and x4 of EDSR model and it is good. I'm changing source code to train for x8 model. In the upsampling method of model/edsr.py, I change it like this Screen Shot 2019-10-30 at 10 28 39 The training code of x8 run for 10000 step but the PSNR is just around 15. How should I update the EDSR training code for x8? Thanks in advance.

    opened by canhnht 13
  • Discriminator loss goes to zero

    Discriminator loss goes to zero

    Hi krasserm,

    I am training the argan model with slight change, instead of using gradient.tape() inside train loop, I am using tf.keras "train_on_batch" .

    I am also doing one sided label smoothing for discriminator training.

    During training , fake loss of discriminator goes close zero after few steps ,generator adversial loss also fluctuates around a small value(0.4~0.6, is that even possible given that discriminator is doing a very good job?) and only the vgg content loss decreases, which also stops going down after stabilizing and there is no further improvement in image after that.

    All the hyper parameters values are same. Can you think of any reason for this? Any suggestions?

    opened by Shubham3101 10
  • load re-trained model error

    load re-trained model error

    at the beginning, 'demo.py' ran successfully with pretrained model. then, after preparing and converting data, I retrained the model by

    python train.py --dataset ./DIV2K_BIN --outdir ./output --profile wdsr-a-8 --scale 2

    it ran normally for 300 epochs, and I got many new h5 files. when I loaded one of new h5 files, error occurred:

    TypeError: Expected float32, got {'type': 'ndarray', 'value': [114.44399999999999, 111.4605, 103.02000000000001]} of type 'dict' instead. could you offer me some suggestions?

    opened by leviome 10
  • ValueError: Cannot create group in read only mode.

    ValueError: Cannot create group in read only mode.

    when i try to load the trained model using demo.py, there is an error

    Traceback (most recent call last):
      File "F:/zhangqianqian/super-resolution-master/demo.py", line 78, in <module>
        main(args)
      File "F:/zhangqianqian/super-resolution-master/demo.py", line 48, in main
        model = load_model(args.model)
      File "F:\super-resolution-master\model\__init__.py", line 7, in load_model
        return train._load_model(path)
      File "F:\super-resolution-master\train.py", line 76, in _load_model
        return load_model(path, custom_objects={**_custom_objects, **_custom_objects_backwards_compat})
      File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\saving.py", line 419, in load_model
        model = _deserialize_model(f, custom_objects, compile)
      File "C:\ProgramData\Anaconda3\lib\site-packages\keras\engine\saving.py", line 221, in _deserialize_model
        model_config = f['model_config']
      File "C:\ProgramData\Anaconda3\lib\site-packages\keras\utils\io_utils.py", line 302, in __getitem__
        raise ValueError('Cannot create group in read only mode.')
    ValueError: Cannot create group in read only mode.
    
    opened by kikyo1314 8
  • NotImplementedError: Cannot convert a symbolic Tensor (Cast_80:0) to a numpy array.

    NotImplementedError: Cannot convert a symbolic Tensor (Cast_80:0) to a numpy array.

    When I am running the examples, attained super resolution image can not be show because it is a symbolic Tensor. The detail error is given as the following:

    NotImplementedError Traceback (most recent call last) in () 9 lr = load_image('superresolution/demo/0851x4-crop.png') 10 sr = resolve_single(model, lr) ---> 11 plot_sample(lr,sr) 12 13 type(sr)

    8 frames /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py in array(self) 734 def array(self): 735 raise NotImplementedError("Cannot convert a symbolic Tensor ({}) to a numpy" --> 736 " array.".format(self.name)) 737 738 def len(self):

    NotImplementedError: Cannot convert a symbolic Tensor (Cast_80:0) to a numpy array.

    opened by ShaniaShan 7
  • Color shift with srgan model

    Color shift with srgan model

    Hello

    First of all, thank you for your work on this. I've noticed a slight color shift on inference from srgan model. Tried tweaking and normalize methods to modify with div2k mean values and retrain the mode, but did not improve the results.

    Thank you

    opened by fjallraven 7
  • ResourceExhaustedError

    ResourceExhaustedError

    I try to scale one photo by " resolve_and_plot('demo/2.jpg')", and get the error like this: ResourceExhaustedError: OOM when allocating tensor with shape[1,666,1000,64] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:Conv2D]

    how to fix this? thanks lot!

    opened by ruanjiyang 6
  • Issues trying to reproduce results

    Issues trying to reproduce results

    Hello,

    I am trying to run the sran notebook only for evaluation and I get the error shown below. I think that the error is that images = [lr, pre_sr, gan_sr] computed from

    pre_sr = resolve_single(pre_generator, lr)
    gan_sr = resolve_single(gan_generator, lr)
    

    return Tensor datatype instead of numpy array and imshow function expect a numpy array. In that case, in the sran notebook I don't understand why the output from resulve_single function is not converted to numpy array and more important why it is working in the notebook shown in the github.


    TypeError Traceback (most recent call last) in ----> 1 resolve_and_plot('demo/0869x4-crop.png')

    in resolve_and_plot(lr_image_path) 16 for i, (img, title, pos) in enumerate(zip(images, titles, positions)): 17 plt.subplot(2, 2, pos) ---> 18 plt.imshow(img) 19 plt.title(title) 20 plt.xticks([])

    ~/anaconda2/envs/keras/lib/python3.6/site-packages/matplotlib/pyplot.py in imshow(X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, data, **kwargs) 2699 filternorm=filternorm, filterrad=filterrad, imlim=imlim, 2700 resample=resample, url=url, **({"data": data} if data is not -> 2701 None else {}), **kwargs) 2702 sci(__ret) 2703 return __ret

    ~/anaconda2/envs/keras/lib/python3.6/site-packages/matplotlib/init.py in inner(ax, data, *args, **kwargs) 1808 "the Matplotlib list!)" % (label_namer, func.name), 1809 RuntimeWarning, stacklevel=2) -> 1810 return func(ax, *args, **kwargs) 1811 1812 inner.doc = _add_data_doc(inner.doc,

    ~/anaconda2/envs/keras/lib/python3.6/site-packages/matplotlib/axes/_axes.py in imshow(self, X, cmap, norm, aspect, interpolation, alpha, vmin, vmax, origin, extent, shape, filternorm, filterrad, imlim, resample, url, **kwargs) 5492 resample=resample, **kwargs) 5493 -> 5494 im.set_data(X) 5495 im.set_alpha(alpha) 5496 if im.get_clip_path() is None:

    ~/anaconda2/envs/keras/lib/python3.6/site-packages/matplotlib/image.py in set_data(self, A) 640 if (self._A.dtype != np.uint8 and 641 not np.can_cast(self._A.dtype, float, "same_kind")): --> 642 raise TypeError("Image data cannot be converted to float") 643 644 if not (self._A.ndim == 2

    TypeError: Image data cannot be converted to float

    opened by RdlP 5
  • training my own data very slowly

    training my own data very slowly

    Hi, kraserm, thank you for your great work. But when I training my own data, It is very slow. I have 4303 train data, and I set iter_steps 4000 per epoch, the epoch is 300, the batch size is 32. After 20 hours, only training 2000 iterations in the one epoch. I used one GPU (GTX1080 11G). and the GPU-util is all zero percent. so I think maybe some bottleneck in the data load. or another something leads to so slowly. Could you help me to boost the training speed? thank you very much.

    opened by yanmenglu 5
  • error in training srgan

    error in training srgan

    While running this function in file -------'example-srgan.ipynb'----- gan_trainer.train(train_ds, steps=200000)

    AttributeError: 'Operation' object has no attribute '_graph'

    opened by sronast 4
  • load_model error in demo.py

    load_model error in demo.py

    I tried run demo.py with this model but this error occurred :

    D:\Artificial Intelligence\SuperResolution\Images\wdsr>python demo.py
    Using TensorFlow backend.
    WARNING: Logging before flag parsing goes to stderr.
    W0817 19:36:48.204280  8632 deprecation_wrapper.py:119] From D:\Artificial Intelligence\SuperResolution\Images\wdsr\util.py:30: The name tf.ConfigProto is deprecated. Please use tf.compat.v1.ConfigProto instead.
    
    W0817 19:36:48.204280  8632 deprecation_wrapper.py:119] From D:\Artificial Intelligence\SuperResolution\Images\wdsr\util.py:33: The name tf.Session is deprecated. Please use tf.compat.v1.Session instead.
    
    2019-08-17 19:36:48.228737: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
    2019-08-17 19:36:48.238670: I tensorflow/stream_executor/platform/default/dso_loader.cc:42] Successfully opened dynamic library nvcuda.dll
    2019-08-17 19:36:48.860316: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1640] Found device 0 with properties:
    name: GeForce GTX 960M major: 5 minor: 0 memoryClockRate(GHz): 1.176
    pciBusID: 0000:01:00.0
    2019-08-17 19:36:48.868030: I tensorflow/stream_executor/platform/default/dlopen_checker_stub.cc:25] GPU libraries are statically linked, skip dlopen check.
    2019-08-17 19:36:48.874822: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1763] Adding visible gpu devices: 0
    2019-08-17 19:36:50.439236: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1181] Device interconnect StreamExecutor with strength 1 edge matrix:
    2019-08-17 19:36:50.445774: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1187]      0
    2019-08-17 19:36:50.448722: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1200] 0:   N
    2019-08-17 19:36:50.453298: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1326] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 3276 MB memory) -> physical GPU (device: 0, name: GeForce GTX 960M, pci bus id: 0000:01:00.0, compute capability: 5.0)
    Began Load Image
    W0817 19:36:50.487827  8632 deprecation_wrapper.py:119] From C:\Users\127051\AppData\Local\Programs\Python\Python35\lib\site-packages\keras\backend\tensorflow_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.
    
    Traceback (most recent call last):
      File "demo.py", line 73, in <module>
        main(args)
      File "demo.py", line 41, in main
        model = load_model(args.model)
      File "D:\Artificial Intelligence\SuperResolution\Images\wdsr\model\__init__.py", line 7, in load_model
        return train._load_model(path)
      File "D:\Artificial Intelligence\SuperResolution\Images\wdsr\train.py", line 76, in _load_model
        return load_model(path, custom_objects={**_custom_objects, **_custom_objects_backwards_compat})
      File "C:\Users\127051\AppData\Local\Programs\Python\Python35\lib\site-packages\keras\engine\saving.py", line 419, in load_model
        model = _deserialize_model(f, custom_objects, compile)
      File "C:\Users\127051\AppData\Local\Programs\Python\Python35\lib\site-packages\keras\engine\saving.py", line 225, in _deserialize_model
        model = model_from_config(model_config, custom_objects=custom_objects)
      File "C:\Users\127051\AppData\Local\Programs\Python\Python35\lib\site-packages\keras\engine\saving.py", line 458, in model_from_config
        return deserialize(config, custom_objects=custom_objects)
      File "C:\Users\127051\AppData\Local\Programs\Python\Python35\lib\site-packages\keras\layers\__init__.py", line 55, in deserialize
        printable_module_name='layer')
      File "C:\Users\127051\AppData\Local\Programs\Python\Python35\lib\site-packages\keras\utils\generic_utils.py", line 145, in deserialize_keras_object
        list(custom_objects.items())))
      File "C:\Users\127051\AppData\Local\Programs\Python\Python35\lib\site-packages\keras\engine\network.py", line 1032, in from_config
        process_node(layer, node_data)
      File "C:\Users\127051\AppData\Local\Programs\Python\Python35\lib\site-packages\keras\engine\network.py", line 991, in process_node
        layer(unpack_singleton(input_tensors), **kwargs)
      File "C:\Users\127051\AppData\Local\Programs\Python\Python35\lib\site-packages\keras\engine\base_layer.py", line 457, in __call__
        output = self.call(inputs, **kwargs)
      File "C:\Users\127051\AppData\Local\Programs\Python\Python35\lib\site-packages\keras\layers\core.py", line 687, in call
        return self.function(inputs, **arguments)
      File "/home/martin/Development/extern/wdsr/model/common.py", line 14, in <lambda>
    IndexError: tuple index out of range
    
    opened by aligoglos 4
  • weird ouput on pre-trained weights

    weird ouput on pre-trained weights

    I tried to run demo for SRGAN in google colab. Since tensorflow version and keras are quite old as it seems, there is some version compatability issue in colab. So, I switched to tensorflow 1.13.1 version.

    I found that in older version, to plot the image tensorflow session should be explicitly coded: https://stackoverflow.com/questions/55315275/why-am-i-getting-typeerror-image-data-cannot-be-converted-to-float

    So, I did it as follows: `import tensorflow from model import resolve_single from utils import load_image,plot_sample

    pre_generator = generator() gan_generator = generator()

    pre_generator.load_weights('/content/weights/srgan/pre_generator.h5') gan_generator.load_weights('/content/weights/srgan/gan_generator.h5')

    with tensorflow.Session() as sess: sess.run(tensorflow.global_variables_initializer())

    lr = load_image('/content/super-resolution/demo/0851x4-crop.png')
    
    pre_sr = resolve_single(pre_generator, lr)
    gan_sr = resolve_single(gan_generator, lr)
    
    pre_sr = sess.run(tensorflow.convert_to_tensor(pre_sr))
    gan_sr = sess.run(tensorflow.convert_to_tensor(gan_sr))
    
    plt.figure(figsize=(20, 20))
    
    images = [lr, pre_sr, gan_sr]
    
    titles = ['LR', 'SR (PRE)', 'SR (GAN)']
    positions = [1, 3, 4]
    
    for i, (img, title, pos) in enumerate(zip(images, titles, positions)):
        plt.subplot(2, 2, pos)
        plt.imshow(img)
        plt.title(title)
        plt.xticks([])
        plt.yticks([])
    
    sess.close()
    

    Now, my output is like this::

    Capture index

    Can anybody help me understand why is this happening? Thanks.

    opened by HardeyPandya 0
  • How should I create caches file?

    How should I create caches file?

    I am using custom dataset for Super Resoution. I downscaled the images using bicubic method via MATLAB. However, like in the div2k dataset, should I create caches folder too? How should I generate caches file? Please help.

    opened by PIjarihd 1
  • How Should I customize data for Custom Dataset Training?

    How Should I customize data for Custom Dataset Training?

    I am trying to train the custom dataset. However, I am quite confused with the process. What exact changes should I made? Can you please guide me step wise? I studied #62. #8 and #5, however I still could not figure out how should I prepare the data? Do I need to make the folders exactly like this? KakaoTalk_20211229_194244828 Which code should I run to downscale the image? How should I prepare caches? Please guide me or provide me some tutorial link. Thanks in advance.

    opened by PIjarihd 0
  • ValueError: axes don't match array when loading WDSR pre-trained weights

    ValueError: axes don't match array when loading WDSR pre-trained weights

    Hello everyone.

    I managed to test EDSR and SRGAN on a custom dataset with the pre-trained weights with no problem.

    However, I when I load the WDSR weights, I get the error: ValueError: axes don't match array

    I don't know if it is related, but, I was getting the following error, when calling the wdsr_b model: `~/.pyenv/versions/3.6.8/envs/srgan/lib/python3.6/site-packages/keras/layers/wrappers.py in init(self, layer, **kwargs) 44 45 def init(self, layer, **kwargs): ---> 46 assert isinstance(layer, Layer) 47 self.layer = layer 48 super(Wrapper, self).init(**kwargs)

    AssertionError:`

    Thus, I changed conv2d_weightnorm function (as mentioned in https://www.reddit.com/r/tensorflow/comments/dn0hjv/applying_weight_normalization_layer_in_tf_2/ in the following way:

    def conv2d_weightnorm(filters, kernel_size, padding='same', activation=None, **kwargs): return Conv2D(filters, kernel_size, padding=padding, activation=activation, **kwargs)

    Does anyone have an idea what I am doing wrong?

    Thank you in advance!

    opened by adavradou 0
  • AssertionError

    AssertionError

    AssertionError Traceback (most recent call last) in ----> 1 model = wdsr_b(scale=scale, num_res_blocks=depth) 2 model.load_weights(weights_file)

    ~/Desktop/GAN-research/super-resolution/model/wdsr.py in wdsr_b(scale, num_filters, num_res_blocks, res_block_expansion, res_block_scaling) 12 13 def wdsr_b(scale, num_filters=32, num_res_blocks=8, res_block_expansion=6, res_block_scaling=None): ---> 14 return wdsr(scale, num_filters, num_res_blocks, res_block_expansion, res_block_scaling, res_block_b) 15 16

    ~/Desktop/GAN-research/super-resolution/model/wdsr.py in wdsr(scale, num_filters, num_res_blocks, res_block_expansion, res_block_scaling, res_block) 20 21 # main branch ---> 22 m = conv2d_weightnorm(num_filters, 3, padding='same')(x) 23 for i in range(num_res_blocks): 24 m = res_block(m, num_filters, res_block_expansion, kernel_size=3, scaling=res_block_scaling)

    ~/Desktop/GAN-research/super-resolution/model/wdsr.py in conv2d_weightnorm(filters, kernel_size, padding, activation, **kwargs) 57 58 def conv2d_weightnorm(filters, kernel_size, padding='same', activation=None, **kwargs): ---> 59 return tfa.layers.WeightNormalization(Conv2D(filters, kernel_size, padding=padding, activation=activation, **kwargs), data_init=False)

    ~/.pyenv/versions/3.6.8/envs/srgan/lib/python3.6/site-packages/typeguard/init.py in wrapper(*args, **kwargs) 1030 memo = _CallMemo(python_func, _localns, args=args, kwargs=kwargs) 1031 check_argument_types(memo) -> 1032 retval = func(*args, **kwargs) 1033 try: 1034 check_return_type(retval, memo)

    ~/.pyenv/versions/3.6.8/envs/srgan/lib/python3.6/site-packages/tensorflow_addons/layers/wrappers.py in init(self, layer, data_init, **kwargs) 58 @typechecked 59 def init(self, layer: tf.keras.layers, data_init: bool = True, **kwargs): ---> 60 super().init(layer, **kwargs) 61 self.data_init = data_init 62 self._track_trackable(layer, name="layer")

    ~/.pyenv/versions/3.6.8/envs/srgan/lib/python3.6/site-packages/keras/layers/wrappers.py in init(self, layer, **kwargs) 44 45 def init(self, layer, **kwargs): ---> 46 assert isinstance(layer, Layer) 47 self.layer = layer 48 super(Wrapper, self).init(**kwargs)

    AssertionError:

    I wasn't able to download weights for WDSR, any ideas what's the problem?

    opened by otsebriy 3
Owner
Martin Krasser
Freelance machine learning engineer, software developer and consultant. Mountainbike freerider, bass guitar player.
Martin Krasser
Torch implementation of "Enhanced Deep Residual Networks for Single Image Super-Resolution"

NTIRE2017 Super-resolution Challenge: SNU_CVLab Introduction This is our project repository for CVPR 2017 Workshop (2nd NTIRE). We, Team SNU_CVLab, (B

Bee Lim 619 Aug 5, 2022
PyTorch code for our ECCV 2020 paper "Single Image Super-Resolution via a Holistic Attention Network"

HAN PyTorch code for our ECCV 2020 paper "Single Image Super-Resolution via a Holistic Attention Network" This repository is for HAN introduced in the

五维空间 128 Aug 8, 2022
Practical Single-Image Super-Resolution Using Look-Up Table

Practical Single-Image Super-Resolution Using Look-Up Table [Paper] Dependency Python 3.6 PyTorch glob numpy pillow tqdm tensorboardx 1. Training deep

Younghyun Jo 91 Aug 10, 2022
PyTorch version of the paper 'Enhanced Deep Residual Networks for Single Image Super-Resolution' (CVPRW 2017)

About PyTorch 1.2.0 Now the master branch supports PyTorch 1.2.0 by default. Due to the serious version problem (especially torch.utils.data.dataloade

Sanghyun Son 2k Aug 5, 2022
PyTorch code for 'Efficient Single Image Super-Resolution Using Dual Path Connections with Multiple Scale Learning'

Efficient Single Image Super-Resolution Using Dual Path Connections with Multiple Scale Learning This repository is for EMSRDPN introduced in the foll

null 7 Feb 10, 2022
Augmentation for Single-Image-Super-Resolution

SRAugmentation Augmentation for Single-Image-Super-Resolution Implimentation CutBlur Cutout CutMix Cutup CutMixup Blend RGBPermutation Identity OneOf

Yubo 6 Jun 27, 2022
Official PyTorch code for Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021)

Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021) This repository is the official P

Jingyun Liang 147 Jul 7, 2022
Official PyTorch code for Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021)

Hierarchical Conditional Flow: A Unified Framework for Image Super-Resolution and Image Rescaling (HCFlow, ICCV2021) This repository is the official P

Jingyun Liang 147 Jul 7, 2022
MASA-SR: Matching Acceleration and Spatial Adaptation for Reference-Based Image Super-Resolution (CVPR2021)

MASA-SR Official PyTorch implementation of our CVPR2021 paper MASA-SR: Matching Acceleration and Spatial Adaptation for Reference-Based Image Super-Re

DV Lab 109 Aug 9, 2022
PyTorch implementation of Graph Convolutional Networks in Feature Space for Image Deblurring and Super-resolution, IJCNN 2021.

GCResNet PyTorch implementation of Graph Convolutional Networks in Feature Space for Image Deblurring and Super-resolution, IJCNN 2021. The code will

null 11 May 19, 2022
Unofficial pytorch implementation of the paper "Dynamic High-Pass Filtering and Multi-Spectral Attention for Image Super-Resolution"

DFSA Unofficial pytorch implementation of the ICCV 2021 paper "Dynamic High-Pass Filtering and Multi-Spectral Attention for Image Super-Resolution" (p

null 2 Nov 15, 2021
Official implementation of the paper 'Efficient and Degradation-Adaptive Network for Real-World Image Super-Resolution'

DASR Paper Efficient and Degradation-Adaptive Network for Real-World Image Super-Resolution Jie Liang, Hui Zeng, and Lei Zhang. In arxiv preprint. Abs

null 61 Aug 1, 2022
Implementation of paper: "Image Super-Resolution Using Dense Skip Connections" in PyTorch

SRDenseNet-pytorch Implementation of paper: "Image Super-Resolution Using Dense Skip Connections" in PyTorch (http://openaccess.thecvf.com/content_ICC

wxy 112 Aug 3, 2022
Unoffical implementation about Image Super-Resolution via Iterative Refinement by Pytorch

Image Super-Resolution via Iterative Refinement Paper | Project Brief This is a unoffical implementation about Image Super-Resolution via Iterative Re

LiangWei Jiang 2k Aug 5, 2022
PyTorch Implementation of "Light Field Image Super-Resolution with Transformers"

LFT PyTorch implementation of "Light Field Image Super-Resolution with Transformers", arXiv 2021. [pdf]. Contributions: We make the first attempt to a

Squidward 60 Jul 26, 2022
An unofficial implementation of "Unpaired Image Super-Resolution using Pseudo-Supervision." CVPR2020

UnpairedSR An unofficial implementation of "Unpaired Image Super-Resolution using Pseudo-Supervision." CVPR2020 turn RCAN(modified) --> xmodel(xilinx

JiaKui Hu 7 Mar 12, 2022
Official implementation of Unfolded Deep Kernel Estimation for Blind Image Super-resolution.

Unfolded Deep Kernel Estimation for Blind Image Super-resolution Hongyi Zheng, Hongwei Yong, Lei Zhang, "Unfolded Deep Kernel Estimation for Blind Ima

Z80 12 Jul 19, 2022
Official PyTorch implementation of the paper "Deep Constrained Least Squares for Blind Image Super-Resolution", CVPR 2022.

Deep Constrained Least Squares for Blind Image Super-Resolution [Paper] This is the official implementation of 'Deep Constrained Least Squares for Bli

MEGVII Research 95 Aug 7, 2022
Official implementation of the paper 'Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution' in CVPR 2022

LDL Paper | Supplementary Material Details or Artifacts: A Locally Discriminative Learning Approach to Realistic Image Super-Resolution Jie Liang*, Hu

null 120 Aug 10, 2022