This is not basically an issue but I'm trying to understand the DPT architecture. I'm quite new with transformers but I have previous knowledge with CNN's.
Let's take an example where I would like to generate dpt large depth model with patch size of 16 and image size 384. I can see that pretrained weights for the original model variants are loaded from timm but in this example I would like to generate the model from the scratch.
I gathered all the essential functions for the below code snippet (hopefully). I don't totally understand that in where and how I should apply functions forward_flex() or _resize_pos_embed()? The second question is that if I would like to calculate multiscale-loss, for which layers it should be done? Thanks in advance!
import torch
import torch.nn as nn
import torch.nn.functional as F
class Transpose(nn.Module):
def __init__(self, dim0, dim1):
super(Transpose, self).__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x):
x = x.transpose(self.dim0, self.dim1)
return x
class Interpolate(nn.Module):
def __init__(self, scale_factor):
super(Interpolate, self).__init__()
self.scale_factor = scale_factor
def forward(self, x):
x = nn.functional.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=True)
return x
class ResidualConvUnit(nn.Module):
def __init__(self, features):
super().__init__()
self.conv1 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True)
self.conv2 = nn.Conv2d(features, features, kernel_size=3, stride=1, padding=1, bias=True)
self.relu = nn.ReLU(inplace=True)
def forward(self, x):
out = self.relu(x)
out = self.conv1(out)
out = self.relu(out)
out = self.conv2(out)
return out + x
class FeatureFusionBlock(nn.Module):
def __init__(self, features):
super(FeatureFusionBlock, self).__init__()
self.resConfUnit1 = ResidualConvUnit(features)
self.resConfUnit2 = ResidualConvUnit(features)
def forward(self, *xs):
output = xs[0]
if len(xs) == 2:
output += self.resConfUnit1(xs[1])
output = self.resConfUnit2(output)
output = nn.functional.interpolate(output, scale_factor=2, mode="bilinear", align_corners=True)
return output
class PatchEmbed(nn.Module):
def __init__(self, img_size=384, patch_size=16, in_chans=3, embed_dim=1024):
super().__init__()
self.img_size = img_size
self.patch_size = patch_size
self.grid_size = (img_size // patch_size, img_size // patch_size)
self.num_patches = self.grid_size[0] * self.grid_size[1]
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
def forward(self, x):
B, C, H, W = x.shape
x = self.proj(x)
x = x.flatten(2).transpose(1, 2)
return x
class Attention(nn.Module):
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):
super().__init__()
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
self.num_heads = num_heads
head_dim = dim // num_heads
self.scale = head_dim ** -0.5
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)
return x
class LayerScale(nn.Module):
def __init__(self, dim, init_values=1e-5, inplace=False):
super().__init__()
self.inplace = inplace
self.gamma = nn.Parameter(init_values * torch.ones(dim))
def forward(self, x):
return x.mul_(self.gamma) if self.inplace else x * self.gamma
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, bias=True, drop=0.):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
self.act = act_layer()
self.drop1 = nn.Dropout(drop)
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
self.drop2 = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
x = self.drop1(x)
x = self.fc2(x)
x = self.drop2(x)
return x
class Block(nn.Module):
def __init__(self, embed_dim=1024, num_heads=16, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
super().__init__()
self.norm1 = norm_layer(embed_dim)
self.attn = Attention(embed_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
self.norm2 = norm_layer(embed_dim)
self.mlp = Mlp(in_features=embed_dim, hidden_features=int(embed_dim * mlp_ratio), act_layer=act_layer, drop=drop)
def forward(self, x):
x = x + self.attn(self.norm1(x))
x = x + self.mlp(self.norm2(x))
return x
class ProjectReadout(nn.Module):
def __init__(self, in_features, start_index=1):
super(ProjectReadout, self).__init__()
self.start_index = start_index
self.project = nn.Sequential(nn.Linear(2 * in_features, in_features), nn.GELU())
def forward(self, x):
readout = x[:, 0].unsqueeze(1).expand_as(x[:, self.start_index :])
features = torch.cat((x[:, self.start_index :], readout), -1)
return self.project(features)
class Encoder(nn.Module):
def __init__(self, embed_dim=1024, num_heads=16, norm_layer=nn.LayerNorm):
super(Encoder, self).__init__()
self.patch_embed = PatchEmbed()
self.num_patches = self.patch_embed.num_patches
self.pos_embed = nn.Parameter(torch.randn(1, self.num_patches, embed_dim) * .02)
self.pos_drop = nn.Dropout(p=0.1)
self.encoder_block = Block(embed_dim=embed_dim, num_heads=num_heads)
self.norm = norm_layer(embed_dim)
def forward(self, x):
x = self.patch_embed(x)
x = self.pos_embed(x)
x = self.pos_drop(x)
for i in range(6):
x = self.encoder_block(x)
layer1 = x
for i in range(6):
x = self.encoder_block(x)
layer2 = x
for i in range(6):
x = self.encoder_block(x)
layer3 = x
for i in range(6):
x = self.encoder_block(x)
layer4 = self.norm(x)
return layer1, layer2, layer3, layer4
class DPT_VITL_16_384(nn.Module):
def __init__(self, img_size=384, features=256, embed_dim=1024, in_shape=[256, 512, 1024, 1024]):
super(DPT_VITL_16_384, self).__init__()
self.encoder = Encoder(embed_dim=embed_dim)
self.refinenet1 = FeatureFusionBlock(features)
self.refinenet2 = FeatureFusionBlock(features)
self.refinenet3 = FeatureFusionBlock(features)
self.refinenet4 = FeatureFusionBlock(features)
self.layer1_rn = nn.Conv2d(in_shape[0], features, kernel_size=3, stride=1, padding=1, bias=False)
self.layer2_rn = nn.Conv2d(in_shape[1], features, kernel_size=3, stride=1, padding=1, bias=False)
self.layer3_rn = nn.Conv2d(in_shape[2], features, kernel_size=3, stride=1, padding=1, bias=False)
self.layer4_rn = nn.Conv2d(in_shape[3], features, kernel_size=3, stride=1, padding=1, bias=False)
self.readout_oper0 = ProjectReadout(embed_dim, start_index=1)
self.readout_oper1 = ProjectReadout(embed_dim, start_index=1)
self.readout_oper2 = ProjectReadout(embed_dim, start_index=1)
self.readout_oper3 = ProjectReadout(embed_dim, start_index=1)
self.act_postprocess1 = nn.Sequential(self.readout_oper0, Transpose(1, 2), nn.Unflatten(2, torch.Size([img_size // 16, img_size // 16])),
nn.Conv2d(in_channels=embed_dim, out_channels=in_shape[0], kernel_size=1, stride=1, padding=0),
nn.ConvTranspose2d(in_channels=in_shape[0], out_channels=in_shape[0], kernel_size=4, stride=4, padding=0, bias=True, dilation=1, groups=1))
self.act_postprocess2 = nn.Sequential(self.readout_oper1, Transpose(1, 2), nn.Unflatten(2, torch.Size([img_size // 16, img_size // 16])),
nn.Conv2d(in_channels=embed_dim, out_channels=in_shape[1],kernel_size=1,stride=1,padding=0),
nn.ConvTranspose2d(in_channels=in_shape[1], out_channels=in_shape[1], kernel_size=2, stride=2, padding=0, bias=True, dilation=1, groups=1))
self.act_postprocess3 = nn.Sequential(self.readout_oper2, Transpose(1, 2), nn.Unflatten(2, torch.Size([img_size // 16, img_size // 16])),
nn.Conv2d(in_channels=embed_dim, out_channels=in_shape[2], kernel_size=1, stride=1, padding=0))
self.act_postprocess4 = nn.Sequential(self.readout_oper3, Transpose(1, 2), nn.Unflatten(2, torch.Size([img_size // 16, img_size // 16])),
nn.Conv2d(in_channels=embed_dim, out_channels=in_shape[3], kernel_size=1, stride=1, padding=0),
nn.Conv2d(in_channels=in_shape[3], out_channels=in_shape[3], kernel_size=3, stride=2, padding=1))
self.head = nn.Sequential(
nn.Conv2d(features, features // 2, kernel_size=3, stride=1, padding=1),
Interpolate(scale_factor=2),
nn.Conv2d(features // 2, 32, kernel_size=3, stride=1, padding=1),
nn.ReLU(True),
nn.Conv2d(32, 1, kernel_size=1, stride=1, padding=0),
nn.ReLU(True)
)
def forward(self, x):
layer_1, layer_2, layer_3, layer_4 = self.encoder(x)
layer_1 = self.act_postprocess1(layer_1)
layer_2 = self.act_postprocess1(layer_2)
layer_3 = self.act_postprocess1(layer_3)
layer_4 = self.act_postprocess1(layer_4)
layer_1_rn = self.layer1_rn(layer_1)
layer_2_rn = self.layer2_rn(layer_2)
layer_3_rn = self.layer3_rn(layer_3)
layer_4_rn = self.layer4_rn(layer_4)
path_4 = self.refinenet4(layer_4_rn)
path_3 = self.refinenet3(path_4, layer_3_rn)
path_2 = self.refinenet2(path_3, layer_2_rn)
path_1 = self.refinenet1(path_2, layer_1_rn)
out = self.head(path_1)
return out
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using {device} device")
model = DPT_VITL_16_384().to(device)
print(model)