Hey,
I found a difference between the output of the Instance loss implemented vs NT Xent loss taken from SIMCLR(https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/nt_xent.py)
Although the functions loss very similar, the outputs seems to be different. Could you please look into it and share your insights?
import torch
import torch.nn as nn
import math
class InstanceLoss(nn.Module):
def __init__(self, batch_size, temperature, device):
super(InstanceLoss, self).__init__()
self.batch_size = batch_size
self.temperature = temperature
self.device = device
self.mask = self.mask_correlated_samples(batch_size)
self.criterion = nn.CrossEntropyLoss(reduction="sum")
def mask_correlated_samples(self, batch_size):
N = 2 * batch_size
mask = torch.ones((N, N))
mask = mask.fill_diagonal_(0)
for i in range(batch_size):
mask[i, batch_size + i] = 0
mask[batch_size + i, i] = 0
mask = mask.bool()
return mask
def forward(self, z_i, z_j):
N = 2 * self.batch_size
z = torch.cat((z_i, z_j), dim=0)
sim = torch.matmul(z, z.T) / self.temperature
sim_i_j = torch.diag(sim, self.batch_size)
sim_j_i = torch.diag(sim, -self.batch_size)
positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(N, 1)
negative_samples = sim[self.mask].reshape(N, -1)
labels = torch.zeros(N).to(positive_samples.device).long()
logits = torch.cat((positive_samples, negative_samples), dim=1)
loss = self.criterion(logits, labels)
loss /= N
return loss
class NT_Xent(nn.Module):
"""
More than inspired from https://github.com/Spijkervet/SimCLR/blob/master/modules/nt_xent.py
Notes
=====
Using this pytorch implementation, you don't actually need to l2-norm the inputs, the results will be
identical, as shown if you run this file.
"""
def __init__(self, batch_size, temperature, device):
super(NT_Xent, self).__init__()
self.batch_size = batch_size
self.temperature = temperature
self.mask = self.get_correlated_samples_mask()
self.device = device
self.criterion = nn.CrossEntropyLoss(reduction="sum")
self.similarity_f = nn.CosineSimilarity(dim=2)
def forward(self, z_i, z_j):
"""
We do not sample negative examples explicitly.
Instead, given a positive pair, similar to (Chen et al., 2017), we treat the other 2(N − 1) augmented examples within a minibatch as negative examples.
"""
p1 = torch.cat((z_i, z_j), dim=0)
sim = self.similarity_f(p1.unsqueeze(1), p1.unsqueeze(0)) / self.temperature
sim_i_j = torch.diag(sim, self.batch_size)
sim_j_i = torch.diag(sim, -self.batch_size)
positive_samples = torch.cat((sim_i_j, sim_j_i), dim=0).reshape(self.batch_size * 2, 1)
negative_samples = sim[self.mask].reshape(self.batch_size * 2, -1)
labels = torch.zeros(self.batch_size * 2).to(self.device).long()
logits = torch.cat((positive_samples, negative_samples), dim=1)
loss = self.criterion(logits, labels)
loss /= 2 * self.batch_size
return loss
def get_correlated_samples_mask(self):
mask = torch.ones((self.batch_size * 2, self.batch_size * 2), dtype=bool)
mask = mask.fill_diagonal_(0)
for i in range(self.batch_size):
mask[i, self.batch_size + i] = 0
mask[self.batch_size + i, i] = 0
return mask
a, b = torch.rand(8, 12), torch.rand(8, 12)
a_norm, b_norm = torch.nn.functional.normalize(a), torch.nn.functional.normalize(b)
cosine_sim = torch.nn.CosineSimilarity()
instance_loss = InstanceLoss(8, 0.5, "cpu")
ntxent_loss = NT_Xent(8, 0.5, "cpu")
print('Cosine')
print(cosine_sim(a, b))
print(cosine_sim(a_norm, b_norm))
print('NT Xent')
print(ntxent_loss(a, b))
print(ntxent_loss(a_norm, b_norm))
print('Instance')
print(instance_loss(a, b))
print(instance_loss(a_norm, b_norm))
Output:
Cosine
tensor([0.6606, 0.7330, 0.7845, 0.8602, 0.6992, 0.8224, 0.7167, 0.7500])
tensor([0.6606, 0.7330, 0.7845, 0.8602, 0.6992, 0.8224, 0.7167, 0.7500])
NT Xent
tensor(2.7081)
tensor(2.7081)
Instance
tensor(3.1286)
tensor(2.7081)
As you can see, Instance loss gives different results where as the others don't when fed a_norm and b_norm.
Colab notebook:
https://github.com/Spijkervet/SimCLR/blob/master/simclr/modules/nt_xent.py