I wrote this examples with a data loader:
import os
import natsort
from PIL import Image
import torch
import torchvision.transforms as T
from res_mlp_pytorch.res_mlp_pytorch import ResMLP
class LPCustomDataSet(torch.utils.data.Dataset):
'''
Naive Torch Image Dataset Loader
with support for Image loading errors
and Image resizing
'''
def __init__(self, main_dir, transform):
self.main_dir = main_dir
self.transform = transform
all_imgs = os.listdir(main_dir)
self.total_imgs = natsort.natsorted(all_imgs)
def __len__(self):
return len(self.total_imgs)
def __getitem__(self, idx):
img_loc = os.path.join(self.main_dir, self.total_imgs[idx])
try:
image = Image.open(img_loc).convert("RGB")
tensor_image = self.transform(image)
return tensor_image
except:
pass
return None
@classmethod
def collate_fn(self, batch):
'''
Collate filtering not None images
'''
batch = list(filter(lambda x: x is not None, batch))
return torch.utils.data.dataloader.default_collate(batch)
@classmethod
def transform(self,img):
'''
Naive image resizer
'''
transform = T.Compose([
T.Resize(256),
T.CenterCrop(224),
T.ToTensor(),
T.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
return transform(img)
to feed ResMLP
:
model = ResMLP(
image_size = 256,
patch_size = 16,
dim = 512,
depth = 12,
num_classes = 1000
)
batch_size = 2
my_dataset = LPCustomDataSet(os.path.join(os.path.dirname(
os.path.abspath(__file__)), 'data'), transform=LPCustomDataSet.transform)
train_loader = torch.utils.data.DataLoader(my_dataset , batch_size=batch_size, shuffle=False,
num_workers=4, drop_last=True, collate_fn=LPCustomDataSet.collate_fn)
for idx, img in enumerate(train_loader):
pred = model(img) # (1, 1000)
print(idx, img.shape, pred.shape
But I get this error
RuntimeError: Given groups=1, weight of size [256, 256, 1], expected input[1, 196, 512] to have 256 channels, but got 196 channels instead
not sure if LPCustomDataSet.transform
has the correct for the input image