Hello,
Thank you very much for the paper and library.
I'm trying to reproduce the results using the suggested training script in the readme file.
I'm getting the following errors:
File "/home/ec2-user/GroupFormer/main.py", line 56, in <module> main() File "/home/ec2-user/GroupFormer/main.py", line 53, in main group_helper.train() File "/home/ec2-user/GroupFormer/group/group.py", line 239, in train self.train_epoch() File "/home/ec2-user/GroupFormer/group/group.py", line 259, in train_epoch actions_loss, activities_loss, aux_loss, loss = self.forward(batch) File "/home/ec2-user/GroupFormer/group/group.py", line 200, in forward actions, activities, aux_loss = self.model(batch[0], batch[1], batch[4], batch[5]) File "/home/ec2-user/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/ec2-user/GroupFormer/group/utils/distributed_utils.py", line 18, in forward return self.module(*inputs, **kwargs) File "/home/ec2-user/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/ec2-user/GroupFormer/group/models/__init__.py", line 149, in forward actions_scores1, activities_scores1, aux_loss1 = self.head(boxes_features, global_token) File "/home/ec2-user/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/ec2-user/GroupFormer/group/models/head/st_plus_tr_cross_cluster.py", line 208, in forward group = self.group_tr(group_query,x).reshape(1,B*T,-1) File "/home/ec2-user/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/ec2-user/GroupFormer/group/models/transformer.py", line 104, in forward output = layer(output, memory, tgt_mask=tgt_mask, File "/home/ec2-user/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/ec2-user/GroupFormer/group/models/transformer.py", line 243, in forward return self.forward_pre(tgt, memory, memory_mask, memory_key_padding_mask, pos) File "/home/ec2-user/GroupFormer/group/models/transformer.py", line 227, in forward_pre tgt2 = self.self_attn(q, k, value=src, attn_mask=src_mask, File "/home/ec2-user/.local/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl return forward_call(*input, **kwargs) File "/home/ec2-user/.local/lib/python3.9/site-packages/torch/nn/modules/activation.py", line 1153, in forward attn_output, attn_output_weights = F.multi_head_attention_forward( File "/home/ec2-user/.local/lib/python3.9/site-packages/torch/nn/functional.py", line 5030, in multi_head_attention_forward is_batched = _mha_shape_check(query, key, value, key_padding_mask, attn_mask, num_heads) File "/home/ec2-user/.local/lib/python3.9/site-packages/torch/nn/functional.py", line 4874, in _mha_shape_check assert key.dim() == 3 and value.dim() == 3, \ AssertionError: For batched (3-D) "query", expected "key" and "value" to be 3-D but found 4-D and 4-D tensors respectively
Any ideas why does it happen?
Thanks in advance!