I am trying to use your code on SVHN dataset, but I got the following error. Can you please help to sort out the problem? Thank you in advanced.
####Data Loader###
train_dataset = datasets.SVHN('/media/user/DATA/New_CODE/pytorch-capsule/SVHN', download=True, transform=transform, split='train')
test_dataset = datasets.SVHN('/media/user/DATA/New_CODE/pytorch-capsule/SVHN', download=True, transform=transform, split='test')
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=test_batch_size, shuffle=True)
CapsuleNetwork (
(conv1): CapsuleConvLayer (
(conv0): Conv2d(1, 256, kernel_size=(9, 9), stride=(1, 1))
(relu): ReLU (inplace)
)
(primary): CapsuleLayer (
(unit_0): ConvUnit (
(conv0): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
)
(unit_1): ConvUnit (
(conv0): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
)
(unit_2): ConvUnit (
(conv0): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
)
(unit_3): ConvUnit (
(conv0): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
)
(unit_4): ConvUnit (
(conv0): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
)
(unit_5): ConvUnit (
(conv0): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
)
(unit_6): ConvUnit (
(conv0): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
)
(unit_7): ConvUnit (
(conv0): Conv2d(256, 32, kernel_size=(9, 9), stride=(2, 2))
)
)
(digits): CapsuleLayer (
)
(reconstruct0): Linear (160 -> 522)
(reconstruct1): Linear (522 -> 1176)
(reconstruct2): Linear (1176 -> 784)
(relu): ReLU (inplace)
(sigmoid): Sigmoid ()
)
Traceback (most recent call last):
File "main.py", line 195, in
last_loss = train(epoch)
File "main.py", line 171, in train
output = network(data)
File "/home/mahfuj/pytorch_python3/lib/python3.5/site-packages/torch/nn/modules/module.py", line 224, in call
result = self.forward(*input, **kwargs)
File "/media/mahfuj/DATA/New_CODE/pytorch-capsule/capsule_network.py", line 61, in forward
return self.digits(self.primary(self.conv1(x)))
File "/home/mahfuj/pytorch_python3/lib/python3.5/site-packages/torch/nn/modules/module.py", line 224, in call
result = self.forward(*input, **kwargs)
File "/media/mahfuj/DATA/New_CODE/pytorch-capsule/capsule_conv_layer.py", line 27, in forward
return self.relu(self.conv0(x))
File "/home/mahfuj/pytorch_python3/lib/python3.5/site-packages/torch/nn/modules/module.py", line 224, in call
result = self.forward(*input, **kwargs)
File "/home/mahfuj/pytorch_python3/lib/python3.5/site-packages/torch/nn/modules/conv.py", line 254, in forward
self.padding, self.dilation, self.groups)
File "/home/mahfuj/pytorch_python3/lib/python3.5/site-packages/torch/nn/functional.py", line 52, in conv2d
return f(input, weight, bias)
RuntimeError: Need input.size[1] == 1 but got 3 instead.