Hi Jeff,
I still have trouble training with RLE in my project.
the loss can decrease correctly at beginning, but after some iters, it increases immediately and finally becomes to nan.
I'm using Adam and a cosine scheduler with warm-up strategy.

I implement the regression module as follow:
(this is a handpose project, I add two fc heads to predict hand validness and handtype(left/right hand))

```
class RegressFlow3D(nn.Module):
def __init__(self, cfg, in_dim):
super(RegressFlow3D, self).__init__()
self.num_joints = cfg.joint_num
self.root_idx = cfg.wrist_joint_idx
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.hand_type_fc = make_linear_layers([in_dim, 128, 1], relu_final=False)
self.hand_valid_fc = make_linear_layers([in_dim, 64, 1], relu_final=False)
self.fc_coord = Linear(in_dim, self.num_joints * 3)
self.fc_sigma = Linear(in_dim, self.num_joints * 3)
# self.fc_layers = [self.fc_coord, self.fc_sigma]
prior = distributions.MultivariateNormal(torch.zeros(2), torch.eye(2), validate_args=False)
masks = torch.from_numpy(np.array([[0, 1], [1, 0]] * 3).astype(np.float32))
prior3d = distributions.MultivariateNormal(torch.zeros(3), torch.eye(3), validate_args=False)
masks3d = torch.from_numpy(np.array([[0, 0, 1], [1, 1, 0]] * 3).astype(np.float32))
self.flow2d = RealNVP(nets, nett, masks, prior)
self.flow3d = RealNVP(nets3d, nett3d, masks3d, prior3d)
# def _initialize(self):
# for m in self.fc_layers:
# if isinstance(m, nn.Linear):
# nn.init.xavier_uniform_(m.weight, gain=0.01)
def forward(self, feat, labels=None):
BATCH_SIZE = feat.shape[0]
feat = self.avg_pool(feat).reshape(BATCH_SIZE, -1)
hand_type = self.hand_type_fc(feat)
hand_valid = self.hand_valid_fc(feat)
out_coord = self.fc_coord(feat).reshape(BATCH_SIZE, self.num_joints, 3)
# (B, N, 3)
pred_jts = out_coord.reshape(BATCH_SIZE, self.num_joints, 3)
pred_jts[:, :, 2] = pred_jts[:, :, 2] - pred_jts[:, self.root_idx:self.root_idx + 1, 2]
if labels is not None:
gt_uvd = labels['target_coord'].reshape(pred_jts.shape)
gt_3d_mask = labels['mask3d']
out_sigma = self.fc_sigma(feat).reshape(BATCH_SIZE, self.num_joints, -1)
sigma = out_sigma.reshape(BATCH_SIZE, self.num_joints, -1).sigmoid() + 1e-9
scores = 1 - sigma
scores = torch.mean(scores, dim=2, keepdim=True)
bar_mu = (pred_jts - gt_uvd) / sigma
bar_mu = bar_mu.reshape(-1, 3)
bar_mu_3d = bar_mu[gt_3d_mask > 0]
bar_mu_2d = bar_mu[gt_3d_mask < 1][:, :2]
log_phi = torch.zeros_like(bar_mu[:, 0])
# (B, K, 3)
num_3d = bar_mu_3d.shape[0]
num_2d = bar_mu_2d.shape[0]
if num_3d:
log_phi_3d = self.flow3d.log_prob(bar_mu_3d)
log_phi[gt_3d_mask > 0] = log_phi_3d
if num_2d:
log_phi_2d = self.flow2d.log_prob(bar_mu_2d)
log_phi[gt_3d_mask < 1] = log_phi_2d
log_phi = log_phi.reshape(BATCH_SIZE, self.num_joints, 1)
nf_loss = torch.log(sigma) - log_phi
return pred_jts, scores, nf_loss, sigma, hand_type, hand_valid
else:
return pred_jts, hand_type, hand_valid
```

and loss as follow:

```
class RLELoss3D(nn.Module):
''' RLE Regression Loss 3D
'''
def __init__(self, OUTPUT_3D=False, size_average=True):
super(RLELoss3D, self).__init__()
self.size_average = size_average
self.amp = 1 / math.sqrt(2 * math.pi)
def logQ(self, gt_uv, pred_jts, sigma):
return torch.log(sigma / self.amp) + torch.abs(gt_uv - pred_jts) / (math.sqrt(2) * sigma + 1e-9)
def forward(self, pred_jts, nf_loss, sigma, target):
gt_uv = target.reshape(pred_jts.shape)
Q_logprob = self.logQ(gt_uv, pred_jts, sigma)
loss = nf_loss + Q_logprob
if self.size_average:
return loss.sum() / len(loss)
else:
return loss.sum()
```

Could you provide any suggestions about debugging?