Official PyTorch implementation of "The Center of Attention: Center-Keypoint Grouping via Attention for Multi-Person Pose Estimation" (ICCV 21).

Overview

CenterGroup

This the official implementation of our ICCV 2021 paper

The Center of Attention: Center-Keypoint Grouping via Attention for Multi-Person Pose Estimation,
Method Visualization Guillem Brasó, Nikita Kister, Laura Leal-Taixé
We introduce CenterGroup, an attention-based framework to estimate human poses from a set of identity-agnostic keypoints and person center predictions in an image. Our approach uses a transformer to obtain context-aware embeddings for all detected keypoints and centers and then applies multi-head attention to directly group joints into their corresponding person centers. While most bottom-up methods rely on non-learnable clustering at inference, CenterGroup uses a fully differentiable attention mechanism that we train end-to-end together with our keypoint detector. As a result, our method obtains state-of-the-art performance with up to 2.5x faster inference time than competing bottom-up methods.

@article{Braso_2021_ICCV,
    author    = {Bras\'o, Guillem and Kister, Nikita and Leal-Taix\'e, Laura},
    title     = {The Center of Attention: Center-Keypoint Grouping via Attention for Multi-Person Pose Estimation},
    journal = {ICCV},
    year      = {2021}
}

Main Results

With the code contained in this repo, you should be able to reproduce the following results.

Results on COCO val2017

Method Detector Multi-Scale Test Input size AP AP.5 AP .75 AP (M) AP (L)
CenterGroup HigherHRNet-w32 512 69.0 87.7 74.4 59.9 75.3
CenterGroup HigherHRNet-w48 640 71.0 88.7 76.5 63.1 75.2
CenterGroup HigherHRNet-w32 512 71.9 89.0 78.0 63.7 77.4
CenterGroup HigherHRNet-w48 640 73.3 89.7 79.2 66.4 76.7

Results on COCO test2017

Method Detector Multi-Scale Test Input size AP AP .5 AP .75 AP (M) AP (L)
CenterGroup HigherHRNet-w32 512 67.6 88.6 73.6 62.0 75.6
CenterGroup HigherHRNet-w48 640 69.5 89.7 76.0 65.0 76.2
CenterGroup HigherHRNet-w32 512 70.3 90.0 76.9 65.4 77.5
CenterGroup HigherHRNet-w48 640 71.4 90.5 78.1 67.2 77.5

Results on CrowdPose test

Method Detector Multi-Scale Test Input size AP AP .5 AP .75 AP (E) AP (M) AP (H)
CenterGroup HigherHRNet-w48 640 67.6 87.6 72.7 74.2 68.1 61.1
CenterGroup HigherHRNet-w48 640 70.3 89.1 75.7 77.3 70.8 63.2

Installation

Please see docs/INSTALL.md

Model Zoo

Please see docs/MODEL_ZOO.md

Evaluation

To evaluate a model you have to specify its configuration file, its checkpoint, and the number of GPUs you want to use. All of our configurations and checkpoints are available here) For example, to run CenterGroup with a HigherHRNet32 detector and a single GPU you can run the following:

NUM_GPUS=1
./tools/dist_test.sh configs/centergroup2/coco/higherhrnet_w32_coco_512x512 models/centergroup/centergroup_higherhrnet_w32_coco_512x512.pth $NUM_GPUS 1234

If you want to use multi-scale testing, please add the --multi-scale flag, e.g.:

./tools/dist_test.sh configs/centergroup2/coco/higherhrnet_w32_coco_512x512 models/centergroup/centergroup_higherhrnet_w32_coco_512x512.pth $NUM_GPUS 1234 --multi-scale

You can also modify any other config entry with the --cfg-options entry. For example, to disable flip-testing, which is used by default, you can run:

./tools/dist_test.sh configs/centergroup2/coco/higherhrnet_w32_coco_512x512 models/centergroup/centergroup_higherhrnet_w32_coco_512x512.pth $NUM_GPUS 1234 --cfg-options model.test_cfg.flip_test=False

You may need to modify the checkpoint's path, depending on where you downloaded it, and the entry data_root in the config file, depending on where you stored your data.

Training HigherHRNet with Centers

TODO

Training CenterGroup

TODO

Demo

TODO

Acknowledgements

Our code is based on mmpose, which reimplemented HigherHRNet's work. We thank the authors of these codebases for their great work!

Comments
  • import centergroup error

    import centergroup error

    Hello, for Training the centregroup with this command: python tools/train.py --cfg configs/coco/centergroup/coco/higherhrnet_w32_coco_512x512.py --num_gpus 2 --out output --run_str my_training --cfg-options data.samples_per_gpu=1 I got the below error: File "tools/train.py", line 21, in import centergroup ModuleNotFoundError: No module named 'centergroup'

    Could you please help me? Thank you

    opened by FereshtehFeiz 3
  • How to set the input size?

    How to set the input size?

    Hello, I didn't find the param 'input size' in higherhrnet_w32_coco_512x512.py and higherhrnet_w48_crowdpose_640x640.py. If I want to train centergroup from the scratch, how could I set the hyper parammeters such as 'input size', 'epoch', 'optimizer', 'lr schedule'?

    opened by calmisential 3
  • Train code

    Train code

    Hello, sorry to bother you. I'm so interested in your good work, and I want to learn more details. Could u please release the training code at your convinience? Best regards!

    opened by SSSKYue 2
  • The model and loaded state dict do not match exactly

    The model and loaded state dict do not match exactly

    I get errors when I use the weights downloaded from the links provided, e.g., from: https://vision.in.tum.de/webshare/u/brasoand/centergroup/models/centergroup/centergroup_higherhrnet_w32_coco_512x512.pth:

    When I run python tools/test.py configs/centergroup/coco/higherhrnet_w32_coco_512x512.py models/centergroup_higherhrnet_w32_coco_512x512.pth

    I get the following error: OSError: CenterGroup: models/higherhrnet_w_root_w32_coco_512x512.pth is not a checkpoint file

    When I change the config model path to point to the downloaded weights at "models/centergroup_higherhrnet_w32_coco_512x512.pth", I get the following error: RuntimeError: CenterGroup: The model and loaded state dict do not match exactly

    Please let me know how to fix this. Thank you.

    opened by wmcnally 1
  • Fix validation logging when using multiple GPUs

    Fix validation logging when using multiple GPUs

    Since we only log when the self.world_rank == 0, we have to set the rank_zero_only=True param in self.log(). Otherwise, we are getting a deadlock as pytorch lightning is waiting for the other processes to log as well.

    opened by jonathangpk 0
  • How to download final models and pretrained HigherHRNet?

    How to download final models and pretrained HigherHRNet?

    Thanks for your excellent study. The provided ckpt files are text after opening hyperlinks. Could you give me some suggestions to download pretrain models?

    opened by YHDang 0
  • Runtime error while start to train CenterGroup

    Runtime error while start to train CenterGroup

    TRAINING FROM SCRATCH
    Use load_from_local loader
    NHEADS 4
    NO WEIGHT SHARING!!!
    /anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp.py:20: LightningDeprecationWarning: The `pl.plugins.training_type.ddp.DDPPlugin` is deprecated in v1.6 and will be removed in v1.8. Use `pl.strategies.ddp.DDPStrategy` instead.
      rank_zero_deprecation(
    /anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:287: LightningDeprecationWarning: Passing `Trainer(accelerator='ddp')` has been deprecated in v1.5 and will be removed in v1.7. Use `Trainer(strategy='ddp')` instead.
      rank_zero_deprecation(
    /anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:307: LightningDeprecationWarning: Passing <pytorch_lightning.plugins.training_type.ddp.DDPPlugin object at 0x7fcac5539c40> `strategy` to the `plugins` flag in Trainer has been deprecated in v1.5 and will be removed in v1.7. Use `Trainer(strategy=<pytorch_lightning.plugins.training_type.ddp.DDPPlugin object at 0x7fcac5539c40>)` instead.
      rank_zero_deprecation(
    Using 16bit native Automatic Mixed Precision (AMP)
    /anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/utilities.py:91: PossibleUserWarning: `max_epochs` was not set. Setting it to 1000 epochs. To train without an epoch limit, set `max_epochs=-1`.
      rank_zero_warn(
    /anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:151: LightningDeprecationWarning: Setting `Trainer(checkpoint_callback=<pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint object at 0x7fcac54c0c10>)` is deprecated in v1.5 and will be removed in v1.7. Please consider using `Trainer(enable_checkpointing=<pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint object at 0x7fcac54c0c10>)`.
      rank_zero_deprecation(
    GPU available: True, used: True
    TPU available: False, using: 0 TPU cores
    IPU available: False, using: 0 IPUs
    HPU available: False, using: 0 HPUs
    Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/2
    TRAINING FROM SCRATCH
    Use load_from_local loader
    NHEADS 4
    NO WEIGHT SHARING!!!
    Initializing distributed: GLOBAL_RANK: 1, MEMBER: 2/2
    ----------------------------------------------------------------------------------------------------
    distributed_backend=nccl
    All distributed processes registered. Starting with 2 processes
    ----------------------------------------------------------------------------------------------------
    
    LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [4,5]
    LOCAL_RANK: 1 - CUDA_VISIBLE_DEVICES: [4,5]
    
      | Name             | Type                  | Params
    -----------------------------------------------------------
    0 | bu_model         | AssociativeEmbedding_ | 28.6 M
    1 | kp_embed_cnn     | Sequential            | 595 K 
    2 | person_embed_cnn | Sequential            | 595 K 
    3 | pos_embed        | PositionEmbeddingSine | 0     
    4 | group_model      | GroupingModel         | 2.6 M 
    -----------------------------------------------------------
    32.4 M    Trainable params
    54.9 K    Non-trainable params
    32.5 M    Total params
    64.932    Total estimated model params size (MB)
    Sanity Checking: 0it [00:00, ?it/s]loading annotations into memory...
    loading annotations into memory...
    Done (t=0.42s)
    creating index...
    Done (t=0.43s)
    creating index...
    index created!
    index created!
    => num_images: 5000
    => num_images: 5000
    Sanity Checking DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.34s/it]/center-group/mmpose/mmpose/datasets/datasets/bottom_up/bottom_up_coco.py:264: RuntimeWarning: overflow encountered in half_scalars
      area = (np.max(kpt[:, 0]) - np.min(kpt[:, 0])) * (
    /anaconda3/envs/centergroup/lib/python3.8/site-packages/json_tricks/encoders.py:367: UserWarning: json-tricks: numpy scalar serialization is experimental and may work differently in future versions
      warnings.warn('json-tricks: numpy scalar serialization is experimental and may work differently in future versions')
    Loading and preparing results...
    DONE (t=0.00s)
    creating index...
    index created!
    Running per image evaluation...
    Evaluate annotation type *keypoints*
    loading annotations into memory...
    DONE (t=0.66s).
    Accumulating evaluation results...
    DONE (t=0.04s).
     Average Precision  (AP) @[ IoU=0.50:0.95 | area=   all | maxDets= 20 ] =  0.000
     Average Precision  (AP) @[ IoU=0.50      | area=   all | maxDets= 20 ] =  0.000
     Average Precision  (AP) @[ IoU=0.75      | area=   all | maxDets= 20 ] =  0.000
     Average Precision  (AP) @[ IoU=0.50:0.95 | area=medium | maxDets= 20 ] =  0.000
     Average Precision  (AP) @[ IoU=0.50:0.95 | area= large | maxDets= 20 ] =  0.000
     Average Recall     (AR) @[ IoU=0.50:0.95 | area=   all | maxDets= 20 ] =  0.000
     Average Recall     (AR) @[ IoU=0.50      | area=   all | maxDets= 20 ] =  0.000
     Average Recall     (AR) @[ IoU=0.75      | area=   all | maxDets= 20 ] =  0.000
     Average Recall     (AR) @[ IoU=0.50:0.95 | area=medium | maxDets= 20 ] =  0.000
     Average Recall     (AR) @[ IoU=0.50:0.95 | area= large | maxDets= 20 ] =  0.000
    loading annotations into memory...                                                                                                                                
    Done (t=9.15s)
    creating index...
    index created!
    Done (t=9.16s)
    creating index...
    => num_images: 64115
    index created!
    => num_images: 64115
    Epoch 0:   0%|                                                                                                                          | 0/34558 [00:00<?, ?it/s]/anaconda3/envs/centergroup/lib/python3.8/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.
      warnings.warn(warning.format(ret))
    /anaconda3/envs/centergroup/lib/python3.8/site-packages/torch/nn/_reduction.py:42: UserWarning: size_average and reduce args will be deprecated, please use reduction='none' instead.
      warnings.warn(warning.format(ret))
    Epoch 0:   0%|▎                                                                                       | 126/34558 [00:45<3:29:19,  2.74it/s, loss=nan, v_num=0-34]Traceback (most recent call last):
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 722, in _call_and_handle_interrupt
        return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 93, in launch
        return function(*args, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 812, in _fit_impl
        results = self._run(model, ckpt_path=self.ckpt_path)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1237, in _run
        results = self._run_stage()
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1324, in _run_stage
        return self._run_train()
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1354, in _run_train
        self.fit_loop.run()
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run
        self.advance(*args, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 269, in advance
        self._outputs = self.epoch_loop.run(self._data_fetcher)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run
        self.advance(*args, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 208, in advance
        batch_output = self.batch_loop.run(batch, batch_idx)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run
        self.advance(*args, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
        outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run
        self.advance(*args, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 203, in advance
        result = self._run_optimization(
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 256, in _run_optimization
        self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 369, in _optimizer_step
        self.trainer._call_lightning_module_hook(
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1596, in _call_lightning_module_hook
        output = fn(*args, **kwargs)
      File "/center-group/centergroup/models/centergroup.py", line 385, in optimizer_step
        super().optimizer_step(
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1625, in optimizer_step
        optimizer.step(closure=optimizer_closure)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 168, in step
        step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 278, in optimizer_step
        optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 193, in optimizer_step
        return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/native_amp.py", line 85, in optimizer_step
        closure_result = closure()
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 148, in __call__
        self._result = self.closure(*args, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 134, in closure
        step_output = self._step_fn()
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 427, in _training_step
        training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1766, in _call_strategy_hook
        output = fn(*args, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 344, in training_step
        return self.model(*args, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
        result = self.forward(*input, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 705, in forward
        output = self.module(*inputs[0], **kwargs[0])
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
        result = self.forward(*input, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 82, in forward
        output = self.module.training_step(*inputs, **kwargs)
      File "/center-group/centergroup/models/centergroup.py", line 327, in training_step
        losses = compute_group_loss(preds, person_batch['vis_target'], person_batch['loc_target'],
      File "/center-group/centergroup/core/train_utils.py", line 178, in compute_group_loss
        assert not torch.isnan(person_tp_loss).any() and not torch.isnan(losses['person_loss'])
    AssertionError
    
    During handling of the above exception, another exception occurred:
    
    Traceback (most recent call last):
      File "tools/train.py", line 126, in <module>
        main()
      File "tools/train.py", line 123, in main
        trainer.fit(model, datamodule, ckpt_path=ckpt_path)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 771, in fit
        self._call_and_handle_interrupt(
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 737, in _call_and_handle_interrupt
        self.strategy.reconciliate_processes(traceback.format_exc())
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 446, in reconciliate_processes
        raise DeadlockDetectedException(f"DeadLock detected from rank: {self.global_rank} \n {trace}")
    pytorch_lightning.utilities.exceptions.DeadlockDetectedException: DeadLock detected from rank: 0 
     Traceback (most recent call last):
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 722, in _call_and_handle_interrupt
        return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 93, in launch
        return function(*args, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 812, in _fit_impl
        results = self._run(model, ckpt_path=self.ckpt_path)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1237, in _run
        results = self._run_stage()
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1324, in _run_stage
        return self._run_train()
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1354, in _run_train
        self.fit_loop.run()
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run
        self.advance(*args, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/fit_loop.py", line 269, in advance
        self._outputs = self.epoch_loop.run(self._data_fetcher)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run
        self.advance(*args, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 208, in advance
        batch_output = self.batch_loop.run(batch, batch_idx)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run
        self.advance(*args, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/batch/training_batch_loop.py", line 88, in advance
        outputs = self.optimizer_loop.run(split_batch, optimizers, batch_idx)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/base.py", line 204, in run
        self.advance(*args, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 203, in advance
        result = self._run_optimization(
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 256, in _run_optimization
        self._optimizer_step(optimizer, opt_idx, batch_idx, closure)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 369, in _optimizer_step
        self.trainer._call_lightning_module_hook(
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1596, in _call_lightning_module_hook
        output = fn(*args, **kwargs)
      File "/center-group/centergroup/models/centergroup.py", line 385, in optimizer_step
        super().optimizer_step(
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/core/lightning.py", line 1625, in optimizer_step
        optimizer.step(closure=optimizer_closure)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/core/optimizer.py", line 168, in step
        step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 278, in optimizer_step
        optimizer_output = super().optimizer_step(optimizer, opt_idx, closure, model, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 193, in optimizer_step
        return self.precision_plugin.optimizer_step(model, optimizer, opt_idx, closure, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/plugins/precision/native_amp.py", line 85, in optimizer_step
        closure_result = closure()
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 148, in __call__
        self._result = self.closure(*args, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 134, in closure
        step_output = self._step_fn()
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/loops/optimization/optimizer_loop.py", line 427, in _training_step
        training_step_output = self.trainer._call_strategy_hook("training_step", *step_kwargs.values())
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1766, in _call_strategy_hook
        output = fn(*args, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 344, in training_step
        return self.model(*args, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
        result = self.forward(*input, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 705, in forward
        output = self.module(*inputs[0], **kwargs[0])
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/torch/nn/modules/module.py", line 889, in _call_impl
        result = self.forward(*input, **kwargs)
      File "/anaconda3/envs/centergroup/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 82, in forward
        output = self.module.training_step(*inputs, **kwargs)
      File "/center-group/centergroup/models/centergroup.py", line 327, in training_step
        losses = compute_group_loss(preds, person_batch['vis_target'], person_batch['loc_target'],
      File "/center-group/centergroup/core/train_utils.py", line 178, in compute_group_loss
        assert not torch.isnan(person_tp_loss).any() and not torch.isnan(losses['person_loss'])
    AssertionError
    
    opened by FereshtehFeiz 3
  • RuntimeError: The size of tensor a (0) must match the size of tensor b (2) at non-singleton dimension 0

    RuntimeError: The size of tensor a (0) must match the size of tensor b (2) at non-singleton dimension 0

    TRAINING FROM SCRATCH
    NHEADS 4
    NO WEIGHT SHARING!!!
    /environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/plugins/training_type/ddp.py:21: LightningDeprecationWarning: The `pl.plugins.training_type.ddp.DDPPlugin` is deprecated in v1.6 and will be removed in v1.8. Use `pl.strategies.ddp.DDPStrategy` instead.
      "The `pl.plugins.training_type.ddp.DDPPlugin` is deprecated in v1.6 and will be removed in v1.8."
    /environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:293: LightningDeprecationWarning: Passing `Trainer(accelerator='ddp')` has been deprecated in v1.5 and will be removed in v1.7. Use `Trainer(strategy='ddp')` instead.
      f"Passing `Trainer(accelerator={accelerator!r})` has been deprecated"
    /environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/accelerator_connector.py:313: LightningDeprecationWarning: Passing <pytorch_lightning.plugins.training_type.ddp.DDPPlugin object at 0x7f91ff3999d0> `strategy` to the `plugins` flag in Trainer has been deprecated in v1.5 and will be removed in v1.7. Use `Trainer(strategy=<pytorch_lightning.plugins.training_type.ddp.DDPPlugin object at 0x7f91ff3999d0>)` instead.
      f"Passing {plugin} `strategy` to the `plugins` flag in Trainer has been deprecated"
    Using 16bit native Automatic Mixed Precision (AMP)
    /environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/callback_connector.py:152: LightningDeprecationWarning: Setting `Trainer(checkpoint_callback=<pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint object at 0x7f91ff3958d0>)` is deprecated in v1.5 and will be removed in v1.7. Please consider using `Trainer(enable_checkpointing=<pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint object at 0x7f91ff3958d0>)`.
      f"Setting `Trainer(checkpoint_callback={checkpoint_callback})` is deprecated in v1.5 and will "
    GPU available: True, used: True
    TPU available: False, using: 0 TPU cores
    IPU available: False, using: 0 IPUs
    HPU available: False, using: 0 HPUs
    Initializing distributed: GLOBAL_RANK: 0, MEMBER: 1/1
    ----------------------------------------------------------------------------------------------------
    distributed_backend=nccl
    All distributed processes registered. Starting with 1 processes
    ----------------------------------------------------------------------------------------------------
    
    LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
    
      | Name             | Type                  | Params
    -----------------------------------------------------------
    0 | bu_model         | AssociativeEmbedding_ | 28.6 M
    1 | kp_embed_cnn     | Sequential            | 595 K 
    2 | person_embed_cnn | Sequential            | 595 K 
    3 | pos_embed        | PositionEmbeddingSine | 0     
    4 | group_model      | GroupingModel         | 2.3 M 
    -----------------------------------------------------------
    32.2 M    Trainable params
    0         Non-trainable params
    32.2 M    Total params
    64.330    Total estimated model params size (MB)
    Sanity Checking: 0it [00:00, ?it/s]loading annotations into memory...
    Done (t=0.47s)
    creating index...
    index created!
    => num_images: 8000
    /environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:245: PossibleUserWarning: The dataloader, val_dataloader 0, does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` (try 6 which is the number of cpus on this machine) in the `DataLoader` init to improve performance.
      category=PossibleUserWarning,
    Sanity Checking DataLoader 0:   0%|                                                                                                            | 0/2 [00:00<?, ?it/s]Traceback (most recent call last):
      File "tools/train.py", line 129, in <module>
        main()
      File "tools/train.py", line 126, in main
        trainer.fit(model, datamodule, ckpt_path=ckpt_path)
      File "/environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 769, in fit
        self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
      File "/environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 719, in _call_and_handle_interrupt
        return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
      File "/environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 93, in launch
        return function(*args, **kwargs)
      File "/environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 809, in _fit_impl
        results = self._run(model, ckpt_path=self.ckpt_path)
      File "/environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1234, in _run
        results = self._run_stage()
      File "/environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1321, in _run_stage
        return self._run_train()
      File "/environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1343, in _run_train
        self._run_sanity_check()
      File "/environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1411, in _run_sanity_check
        val_loop.run()
      File "/environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 204, in run
        self.advance(*args, **kwargs)
      File "/environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 153, in advance
        dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
      File "/environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/loops/base.py", line 204, in run
        self.advance(*args, **kwargs)
      File "/environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 127, in advance
        output = self._evaluation_step(**kwargs)
      File "/environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 222, in _evaluation_step
        output = self.trainer._call_strategy_hook("validation_step", *kwargs.values())
      File "/environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 1763, in _call_strategy_hook
        output = fn(*args, **kwargs)
      File "/environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/strategies/ddp.py", line 347, in validation_step
        return self.model(*args, **kwargs)
      File "/environment/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
        return forward_call(*input, **kwargs)
      File "/environment/miniconda3/lib/python3.7/site-packages/torch/nn/parallel/distributed.py", line 886, in forward
        output = self.module(*inputs[0], **kwargs[0])
      File "/environment/miniconda3/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
        return forward_call(*input, **kwargs)
      File "/environment/miniconda3/lib/python3.7/site-packages/pytorch_lightning/overrides/base.py", line 93, in forward
        return self.module.validation_step(*inputs, **kwargs)
      File "/home/featurize/center/centergroup/models/centergroup.py", line 151, in validation_step
        return self.forward(**batch, return_loss=False)
      File "/home/featurize/center/centergroup/models/centergroup.py", line 140, in forward
        return self.forward_test(batch, multiscale=multiscale)
      File "/home/featurize/center/centergroup/models/centergroup.py", line 91, in forward_test
        preds, preds_flip = _split_flip_group_preds(preds, batch_idx = person_batch['batch'])        
      File "/home/featurize/center/centergroup/core/flip_utils.py", line 49, in _split_flip_group_preds
        assert (torch.unique(batch_idx) == torch.as_tensor([0, 1], device=batch_idx.device)).all(), "Predictions contain more than 1 image!"
    RuntimeError: The size of tensor a (0) must match the size of tensor b (2) at non-singleton dimension 0
    
    opened by calmisential 4
  • What is pad_mask?

    What is pad_mask?

    Sorry for bothering. I can't understand where is the batch['pad_mask'] from? Is its definition from results['pad_mask']?

    or, does it have something to do with the 'mask' in the following image? image

    opened by ian860505 5
Owner
Dynamic Vision and Learning Group
Dynamic Vision and Learning Group
[ICCV 2021] Official Pytorch implementation for Discriminative Region-based Multi-Label Zero-Shot Learning SOTA results on NUS-WIDE and OpenImages

Discriminative Region-based Multi-Label Zero-Shot Learning (ICCV 2021) [arXiv][Project page >> coming soon] Sanath Narayan*, Akshita Gupta*, Salman Kh

Akshita Gupta 54 Nov 21, 2022
[ICCV 2021] Official Pytorch implementation for Discriminative Region-based Multi-Label Zero-Shot Learning SOTA results on NUS-WIDE and OpenImages

Discriminative Region-based Multi-Label Zero-Shot Learning (ICCV 2021) [arXiv][Project page >> coming soon] Sanath Narayan*, Akshita Gupta*, Salman Kh

Akshita Gupta 54 Nov 21, 2022
(ICCV'21) Official PyTorch implementation of Relational Embedding for Few-Shot Classification

Relational Embedding for Few-Shot Classification (ICCV 2021) Dahyun Kang, Heeseung Kwon, Juhong Min, Minsu Cho [paper], [project hompage] We propose t

Dahyun Kang 82 Dec 24, 2022
[ICCV 2021] Official PyTorch implementation for Deep Relational Metric Learning.

Deep Relational Metric Learning This repository is the official PyTorch implementation of Deep Relational Metric Learning. Framework Datasets CUB-200-

Borui Zhang 39 Dec 10, 2022
official Pytorch implementation of ICCV 2021 paper FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting.

FuseFormer: Fusing Fine-Grained Information in Transformers for Video Inpainting By Rui Liu, Hanming Deng, Yangyi Huang, Xiaoyu Shi, Lewei Lu, Wenxiu

null 77 Dec 27, 2022
Official pytorch implementation of "Scaling-up Disentanglement for Image Translation", ICCV 2021.

Official pytorch implementation of "Scaling-up Disentanglement for Image Translation", ICCV 2021.

Aviv Gabbay 41 Nov 29, 2022
Official PyTorch implementation of "Camera Distance-aware Top-down Approach for 3D Multi-person Pose Estimation from a Single RGB Image", ICCV 2019

PoseNet of "Camera Distance-aware Top-down Approach for 3D Multi-person Pose Estimation from a Single RGB Image" Introduction This repo is official Py

Gyeongsik Moon 677 Dec 25, 2022
Official PyTorch Implementation of paper "Deep 3D Mask Volume for View Synthesis of Dynamic Scenes", ICCV 2021.

Deep 3D Mask Volume for View Synthesis of Dynamic Scenes Official PyTorch Implementation of paper "Deep 3D Mask Volume for View Synthesis of Dynamic S

Ken Lin 17 Oct 12, 2022
This is the official pytorch implementation for our ICCV 2021 paper "TRAR: Routing the Attention Spans in Transformers for Visual Question Answering" on VQA Task

?? ERASOR (RA-L'21 with ICRA Option) Official page of "ERASOR: Egocentric Ratio of Pseudo Occupancy-based Dynamic Object Removal for Static 3D Point C

Hyungtae Lim 225 Dec 29, 2022
Official PyTorch implementation of N-ImageNet: Towards Robust, Fine-Grained Object Recognition with Event Cameras (ICCV 2021)

N-ImageNet: Towards Robust, Fine-Grained Object Recognition with Event Cameras Official PyTorch implementation of N-ImageNet: Towards Robust, Fine-Gra

null 32 Dec 26, 2022
Official Pytorch implementation of the paper "Action-Conditioned 3D Human Motion Synthesis with Transformer VAE", ICCV 2021

ACTOR Official Pytorch implementation of the paper "Action-Conditioned 3D Human Motion Synthesis with Transformer VAE", ICCV 2021. Please visit our we

Mathis Petrovich 248 Dec 23, 2022
Official PyTorch code of DeepPanoContext: Panoramic 3D Scene Understanding with Holistic Scene Context Graph and Relation-based Optimization (ICCV 2021 Oral).

DeepPanoContext (DPC) [Project Page (with interactive results)][Paper] DeepPanoContext: Panoramic 3D Scene Understanding with Holistic Scene Context G

Cheng Zhang 66 Nov 16, 2022
Official pytorch code for SSC-GAN: Semi-Supervised Single-Stage Controllable GANs for Conditional Fine-Grained Image Generation(ICCV 2021)

SSC-GAN_repo Pytorch implementation for 'Semi-Supervised Single-Stage Controllable GANs for Conditional Fine-Grained Image Generation'.PDF SSC-GAN:Sem

tyty 4 Aug 28, 2022
This Repo is the official CUDA implementation of ICCV 2019 Oral paper for CARAFE: Content-Aware ReAssembly of FEatures

Introduction This Repo is the official CUDA implementation of ICCV 2019 Oral paper for CARAFE: Content-Aware ReAssembly of FEatures. @inproceedings{Wa

Jiaqi Wang 42 Jan 7, 2023
Official implementation of NPMs: Neural Parametric Models for 3D Deformable Shapes - ICCV 2021

NPMs: Neural Parametric Models Project Page | Paper | ArXiv | Video NPMs: Neural Parametric Models for 3D Deformable Shapes Pablo Palafox, Aljaz Bozic

PabloPalafox 109 Nov 22, 2022
Official implementation of the paper Vision Transformer with Progressive Sampling, ICCV 2021.

Vision Transformer with Progressive Sampling This is the official implementation of the paper Vision Transformer with Progressive Sampling, ICCV 2021.

yuexy 123 Jan 1, 2023
Official implementation of the paper ``Unifying Nonlocal Blocks for Neural Networks'' (ICCV'21)

Spectral Nonlocal Block Overview Official implementation of the paper: Unifying Nonlocal Blocks for Neural Networks (ICCV'21) Spectral View of Nonloca

null 91 Dec 14, 2022
Official implementation of the ICCV 2021 paper "Conditional DETR for Fast Training Convergence".

The DETR approach applies the transformer encoder and decoder architecture to object detection and achieves promising performance. In this paper, we handle the critical issue, slow training convergence, and present a conditional cross-attention mechanism for fast DETR training. Our approach is motivated by that the cross-attention in DETR relies highly on the content embeddings and that the spatial embeddings make minor contributions, increasing the need for high-quality content embeddings and thus increasing the training difficulty.

null 281 Dec 30, 2022
The official implementation of ICCV paper "Box-Aware Feature Enhancement for Single Object Tracking on Point Clouds".

Box-Aware Tracker (BAT) Pytorch-Lightning implementation of the Box-Aware Tracker. Box-Aware Feature Enhancement for Single Object Tracking on Point C

Kangel Zenn 5 Mar 26, 2022