Cockpit crashes with non descriptive error.
Description
I am trying to run cockpit on a simple network trained with MSE loss (although implemented through custom modules, not nn.Sequential). As far as I can see, no unsupported operations are involved (Linear, Sequential, Tanh, Identity).
Cockpit crashes on computing the BatchGrad extension. Using backpack to compute the batch gradient works without crashes.
Steps to Reproduce
- Setup
self.loss_fn = lambda pred, target: ((target - pred) ** 2).mean(dim=-1)
self._cockpit = Cockpit(model.parameters(),
quantities=configuration("economy"))
model = MyCustomModel()
model = extend(model)
optimizer = torch.optim.Adam(params=model.parameters(), lr=1e-3)
- Training Step
pred = model(inp)
loss = self.loss_fn(pred, target)
mean_loss = loss.mean()
optimizer.zero_grad()
info = {
"batch_size": len(loss),
"individual_losses": loss,
"loss": mean_loss,
"optimizer": optimizer,
}
with self._cockpit(step, info=info, debug=True):
create_graph = self._cockpit.create_graph(step)
mean_loss.backward(create_graph=create_graph) # CRASH HERE
optimizer.step()
Source or Possible Fix
Stacktrace
[DEBUG, step 0]
↪Quantities : [<cockpit.quantities.alpha.Alpha object at 0x7fa751af5f10>, <cockpit.quantities.distance.Distance object at 0x7fa74e7c0cd0>, <cockpit.quantities.grad_hist.GradHist1d object at 0x7fa74e7c0d10>, <cockpit.quantities.grad_norm.GradNorm object at 0x7fa74e7c0d90>, <cockpit.quantities.inner_test.InnerTest object at 0x7fa74e7c0dd0>, <cockpit.quantities.loss.Loss object at 0x7fa74e7cd050>, <cockpit.quantities.norm_test.NormTest object at 0x7fa74e7cd090>, <cockpit.quantities.ortho_test.OrthoTest object at 0x7fa74e7cd0d0>, <cockpit.quantities.update_size.UpdateSize object at 0x7fa74e7cd110>]
↪Extensions : [<backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fa74e7cd750>]
↪Hooks : <cockpit.quantities.utils_transforms.BatchGradTransformsHook object at 0x7fa74e7931d0>
↪Create graph: False
↪Save memory : True
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fa74e7cd750> on Linear(in_features=128, out_features=1, bias=True)
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fa74e7cd750> on LinearHead(
(l): Linear(in_features=128, out_features=1, bias=True)
)
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fa74e7cd750> on Identity()
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fa74e7cd750> on Sequential(
(0): Linear(in_features=1, out_features=128, bias=True)
(1): Tanh()
(2): Linear(in_features=128, out_features=128, bias=True)
(3): Tanh()
(4): LinearHead(
(l): Linear(in_features=128, out_features=1, bias=True)
)
(5): Identity()
)
[DEBUG] Running extension <backpack.extensions.firstorder.batch_grad.BatchGrad object at 0x7fa74e7cd750> on MLP(
(layers): Sequential(
(0): Linear(in_features=1, out_features=128, bias=True)
(1): Tanh()
(2): Linear(in_features=128, out_features=128, bias=True)
(3): Tanh()
(4): LinearHead(
(l): Linear(in_features=128, out_features=1, bias=True)
)
(5): Identity()
)
)
Traceback (most recent call last):
File "./envs/pytorch/lib/python3.7/site-packages/backpack/__init__.py", line 169, in run_extension_hook
CTX.get_extension_hook()(module)
File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/hooks/base.py", line 53, in __call__
self.run_hook(param, module)
File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/hooks/base.py", line 80, in run_hook
value = self.module_hook(param, module)
File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/hooks/base.py", line 139, in module_hook
return self.param_hook(param)
File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/utils_transforms.py", line 78, in param_hook
param.grad_batch._param_weakref = weakref.ref(param)
AttributeError: 'Parameter' object has no attribute 'grad_batch'
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "./code/trainers/base_trainer.py", line 112, in backward
mean_loss.backward(create_graph=create_graph)
File "./envs/pytorch/lib/python3.7/site-packages/torch/tensor.py", line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "./envs/pytorch/lib/python3.7/site-packages/torch/autograd/__init__.py", line 132, in backward
allow_unreachable=True) # allow_unreachable flag
File "./envs/pytorch/lib/python3.7/site-packages/backpack/__init__.py", line 151, in hook_run_extensions
run_extension_hook(module)
File "./envs/pytorch/lib/python3.7/site-packages/backpack/__init__.py", line 172, in run_extension_hook
raise RuntimeError(f"Post extensions hook failed: {message}")
RuntimeError: Post extensions hook failed: AttributeError("'Parameter' object has no attribute 'grad_batch'")
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "code/train.py", line 215, in <module>
metrics = main(opts)
File "code/train.py", line 185, in main
training_info = trainer.train(model, dataset, eval_data, logger)
File "./code/trainers/mse_trainer.py", line 128, in train
self.backward(global_step, loss, mean_loss, optimizer)
File "./code/trainers/base_trainer.py", line 112, in backward
mean_loss.backward(create_graph=create_graph)
File "./envs/pytorch/lib/python3.7/site-packages/cockpit/context.py", line 137, in __exit__
self.cp.track(self.global_step, protected_savefields=self.protected_savefields)
File "./envs/pytorch/lib/python3.7/site-packages/cockpit/cockpit.py", line 178, in track
q.track(global_step, self.params, batch_loss)
File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/quantity.py", line 87, in track
iteration, result = self.compute(global_step, params, batch_loss)
File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/quantity.py", line 516, in compute
save_result = self._compute(global_step, params, batch_loss)
File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/quantity.py", line 538, in _compute
self._compute_start(global_step, params, batch_loss)
File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/alpha.py", line 280, in _compute_start
self._save_1st_order_info(global_step, params, batch_loss, point, until)
File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/alpha.py", line 326, in _save_1st_order_info
grad_dict = {id(p): p.grad.data.clone().detach() for p in params}
File "./envs/pytorch/lib/python3.7/site-packages/cockpit/quantities/alpha.py", line 326, in <dictcomp>
grad_dict = {id(p): p.grad.data.clone().detach() for p in params}
AttributeError: 'NoneType' object has no attribute 'data'
🐛 Type: Bug 👷 Status: In Progress