Hi,
Thanks for the great work! I am trying to reproduce some results and have a question regarding batch implementation of IRM loss. In Section 3.2 and Appendix D, you suggest to use following to do batch implementation:
def compute_penalty(losses, dummy_w):
g1 = grad(losses[0::2].mean(), dummy_w, create_graph=True)[0]
g2 = grad(losses[1::2].mean(), dummy_w, create_graph=True)[0]
return (g1 * g2).sum()
I am wondering whether we can do following:
def compute_penalty(losses, dummy_w):
g = grad(losses.mean(), dummy_w, create_graph=True)[0]
return (g ** 2).sum()
You mentioned that the former one is "unbiased estimate of the squared gradient norm", but I am not sure why it is the case. If you can provide some explanation, that would be great.
Thank you!