thanks for your contribution!
Here, for some reason, i need to realize the "involution2D,3D" by myself, and I take this project for validation.
However, my results can not be the same as yours. In the begining, i think it may be my fault, but after check i am not sure!!!
So could you help me?
Here is my question:
1、I think the “Tensor.unfold()" use in "involution.py" are not right........( may be ).
Here is the code ( with problems):
‘’‘
input_unfolded = self.pad(input_initial)
.unfold(dimension=2, size=self.kernel_size[0], step=self.stride[0])
.unfold(dimension=3, size=self.kernel_size[1], step=self.stride[1])
.unfold(dimension=4, size=self.kernel_size[2], step=self.stride[2])
input_unfolded = input_unfolded.reshape(batch_size, self.groups, self.out_channels // self.groups,
self.kernel_size[0] * self.kernel_size[1] * self.kernel_size[2], -1)
input_unfolded = input_unfolded.reshape(tuple(input_unfolded.shape[:-1])
+ (out_depth, out_height, out_width))
’‘’
In officials, they use "nn.Unfold()" and this is right.
the Tensor.unfold() returns ”B,C,H,W,K,K“, and the "nn.Unfold()" returns "B,CxKxK,HxW".
So I think the " permute" needed be used if use ”Tensor.unfold()“.
And I give an example for comparsion:
################The Code:##############
def nnUnfold_Tensorunfold():
input = torch.ones((1, 1, 5, 5))
# ----------------nnUnfold----------------- #
Unfold1 = nn.Unfold(3, 1, (3 - 1) // 2, 1)
input_unfolded = Unfold1(input) #====>B,CxKxK,HxW
input_unfolded = input_unfolded.contiguous().view(1,9,5,5)
print("Official: nn.Unfold():",input_unfolded)
# ---------------Tensorunfold--------------- #
pad = nn.ConstantPad2d(padding=(1, 1,1, 1), value=0.)
input = pad(input)
input_unfolded = input
input_unfolded = input_unfolded.unfold(dimension=2, size=3, step=1)
input_unfolded = input_unfolded.unfold(dimension=3, size=3, step=1) #===>B,C,H,W,K,K
before = input_unfolded.contiguous().view(1,9,5,5)
print("Wrong: Tensor.unfold():",before)
after = input_unfolded.permute(0,1,4,5,2,3).contiguous().view(1,9,5,5) #====> permute should be used
print("Right: after permute:",after)
# --------------------------------- #
if name == 'main':
nnUnfold_Tensorunfold()
################The Results:##############
Official: nn.Unfold(): tensor([[[[0., 0., 0., 0., 0.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.]],
[[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]],
[[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.]],
[[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.]],
[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]],
[[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.]],
[[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 0., 0., 0., 0.]],
[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[0., 0., 0., 0., 0.]],
[[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[0., 0., 0., 0., 0.]]]])
Wrong: Tensor.unfold(): tensor([[[[0., 0., 0., 0., 1.],
[1., 0., 1., 1., 0.],
[0., 0., 1., 1., 1.],
[1., 1., 1., 0., 0.],
[0., 1., 1., 1., 1.]],
[[1., 1., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[1., 0., 0., 0., 1.],
[1., 0., 1., 1., 0.],
[0., 1., 1., 0., 1.]],
[[1., 0., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]],
[[1., 1., 1., 1., 1.],
[1., 1., 1., 0., 1.],
[1., 0., 1., 1., 0.],
[0., 1., 1., 0., 1.],
[1., 0., 1., 1., 1.]],
[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]],
[[1., 1., 1., 0., 1.],
[1., 0., 1., 1., 0.],
[0., 1., 1., 0., 1.],
[1., 0., 1., 1., 1.],
[1., 1., 1., 1., 1.]],
[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 0., 1.]],
[[1., 0., 1., 1., 0.],
[0., 1., 1., 0., 1.],
[1., 0., 0., 0., 1.],
[1., 1., 1., 1., 1.],
[0., 0., 0., 1., 1.]],
[[1., 1., 1., 1., 0.],
[0., 0., 1., 1., 1.],
[1., 1., 1., 0., 0.],
[0., 1., 1., 0., 1.],
[1., 0., 0., 0., 0.]]]])
Right: after permute: tensor([[[[0., 0., 0., 0., 0.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.]],
[[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]],
[[0., 0., 0., 0., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.]],
[[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.]],
[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.]],
[[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.]],
[[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 1., 1., 1., 1.],
[0., 0., 0., 0., 0.]],
[[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[1., 1., 1., 1., 1.],
[0., 0., 0., 0., 0.]],
[[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[1., 1., 1., 1., 0.],
[0., 0., 0., 0., 0.]]]])
########################################
Maybe i am wrong..... could you help me?