Firstly, I would like to thank you for your implementation.
Below is my version. The main improvement is in using PyTorch's masked_fill_
. I guess this is the fastest method without creating a customized C++ function.
class PartialConv(nn.Module):
# reference:
# Image Inpainting for Irregular Holes Using Partial Convolutions
# http://masc.cs.gmu.edu/wiki/partialconv/show?time=2018-05-24+21%3A41%3A10
# https://github.com/naoto0804/pytorch-inpainting-with-partial-conv/blob/master/net.py
# https://github.com/SeitaroShinagawa/chainer-partial_convolution_image_inpainting/blob/master/common/net.py
# mask is binary, 0 is holes; 1 is not
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(PartialConv, self).__init__()
random.seed(0)
self.feature_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
nn.init.kaiming_normal_(self.feature_conv.weight)
self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias=False)
torch.nn.init.constant_(self.mask_conv.weight, 1.0)
for param in self.mask_conv.parameters():
param.requires_grad = False
def forward(self, args):
x, mask = args
output = self.feature_conv(x * mask)
if self.feature_conv.bias is not None:
output_bias = self.feature_conv.bias.view(1, -1, 1, 1).expand_as(output)
else:
output_bias = torch.zeros_like(output)
with torch.no_grad():
output_mask = self.mask_conv(mask) # mask sums
no_update_holes = output_mask == 0
# because those values won't be used , assign a easy value to compute
mask_sum = output_mask.masked_fill_(no_update_holes, 1.0)
output_pre = (output - output_bias) / mask_sum + output_bias
output = output_pre.masked_fill_(no_update_holes, 0.0)
new_mask = torch.ones_like(output)
new_mask = new_mask.masked_fill_(no_update_holes, 0.0)
return output, new_mask
Benchmark:
Your code:
Runtime: 1.7311532497406006
Memory increment on a forward pass: 125.9 MiB
My code:
Runtime: 0.3832552433013916
Memory increment on a forward pass: 57.1 MiB
Output feature difference: 0.0
Mask output difference: 0.0
Codes for the benchmark
import time
from memory_profiler import profile
import torch
from torch import nn
import random
from torch.nn import functional as F
def proftime(func):
def timed(*args, **kw):
ts = time.time()
result = func(*args, **kw)
te = time.time()
print(f"Runtime: {te-ts}")
return result
return timed
class PConv2d(nn.Module):
def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0):
super().__init__()
random.seed(0)
self.conv2d = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding)
nn.init.kaiming_normal_(self.conv2d.weight)
self.mask2d = nn.Conv2d(in_ch, out_ch, kernel_size, stride, padding)
self.mask2d.weight.data.fill_(1.0)
self.mask2d.bias.data.fill_(0.0)
# mask is not updated
for param in self.mask2d.parameters():
param.requires_grad = False
@profile
@proftime
def forward(self, input, input_mask):
# http://masc.cs.gmu.edu/wiki/partialconv
# C(X) = W^T * X + b, C(0) = b, D(M) = 1 * M + 0 = sum(M)
# W^T* (M .* X) / sum(M) + b = [C(M .* X) – C(0)] / D(M) + C(0)
input_0 = input.new_zeros(input.size())
output = F.conv2d(
input * input_mask, self.conv2d.weight, self.conv2d.bias,
self.conv2d.stride, self.conv2d.padding, self.conv2d.dilation,
self.conv2d.groups)
output_0 = F.conv2d(input_0, self.conv2d.weight, self.conv2d.bias,
self.conv2d.stride, self.conv2d.padding,
self.conv2d.dilation, self.conv2d.groups)
with torch.no_grad():
output_mask = F.conv2d(
input_mask, self.mask2d.weight, self.mask2d.bias,
self.mask2d.stride, self.mask2d.padding, self.mask2d.dilation,
self.mask2d.groups)
n_z_ind = (output_mask != 0.0)
z_ind = (output_mask == 0.0) # skip all the computation
output[n_z_ind] = \
(output[n_z_ind] - output_0[n_z_ind]) / output_mask[n_z_ind] + \
output_0[n_z_ind]
output[z_ind] = 0.0
output_mask[n_z_ind] = 1.0
output_mask[z_ind] = 0.0
return output, output_mask
class PartialConv(nn.Module):
# reference:
# Image Inpainting for Irregular Holes Using Partial Convolutions
# http://masc.cs.gmu.edu/wiki/partialconv/show?time=2018-05-24+21%3A41%3A10
# https://github.com/naoto0804/pytorch-inpainting-with-partial-conv/blob/master/net.py
# https://github.com/SeitaroShinagawa/chainer-partial_convolution_image_inpainting/blob/master/common/net.py
# mask is binary, 0 is holes; 1 is not
def __init__(self, in_channels, out_channels, kernel_size, stride=1,
padding=0, dilation=1, groups=1, bias=True):
super(PartialConv, self).__init__()
random.seed(0)
self.feature_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias)
nn.init.kaiming_normal_(self.feature_conv.weight)
self.mask_conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride,
padding, dilation, groups, bias=False)
torch.nn.init.constant_(self.mask_conv.weight, 1.0)
for param in self.mask_conv.parameters():
param.requires_grad = False
@profile
@proftime
def forward(self, args):
x, mask = args
output = self.feature_conv(x * mask)
if self.feature_conv.bias is not None:
output_bias = self.feature_conv.bias.view(1, -1, 1, 1).expand_as(output)
else:
output_bias = torch.zeros_like(output)
with torch.no_grad():
output_mask = self.mask_conv(mask) # mask sums
no_update_holes = output_mask == 0
# because those values won't be used , assign a easy value to compute
mask_sum = output_mask.masked_fill_(no_update_holes, 1.0)
output_pre = (output - output_bias) / mask_sum + output_bias
output = output_pre.masked_fill_(no_update_holes, 0.0)
new_mask = torch.ones_like(output)
new_mask = new_mask.masked_fill_(no_update_holes, 0.0)
return output, new_mask
# Your method
model1 = PConv2d(in_ch=256, out_ch=256, kernel_size=3, stride=1, padding=1)
# My method
model2 = PartialConv(in_channels=256, out_channels=256, kernel_size=3, stride=1,
padding=1, dilation=1, groups=1, bias=True)
# mask sure all learnable convolutions share the same weights
model2.feature_conv.weight.data.copy_(model1.conv2d.weight.data)
model2.feature_conv.bias.data.copy_(model1.conv2d.bias.data)
random.seed(0)
x1 = torch.randn(1, 256, 64, 64)
x2 = x1.clone()
mask1 = torch.ones_like(x1)
mask1[:, :, 25:50, 25:50] = 0
mask2 = mask1.clone()
y1 = model1.forward(x1, mask1)
y2 = model2.forward((x2, mask2))
print(f"Output feature output difference {torch.sum(y2[0] - y1[0])}")
print(f'Mask output difference {torch.sum(y2[1] - y1[1])}')
Some comments:
I find you are using batch norm after partial convolution. I would suggest disabling the bias term in the convolution right before batch norm which also include a bias term and offset the convolution's bias.
In addition, I prefer in place batch norm that is able to save around 20% - 40% memory usage while maintaining fast computation.
In your training script, the default learning rate is 2e-4. I highly recommend using cyclical learning rate and PyTorch's implementation. I am using 0.04-0.08 learning rate. If you are able to train a large batch size, the learning rate can be moving on [0.1, 1] or even larger, which is called super convergence.
Personal ad:
I am using partial convolution to create an manga inpainting tool: use image segmentation to figure out text locations, and then use inpainting to repair background & color. Suggestions and comments are very welcome.