PyTorch Scatter
This package consists of a small extension library of highly optimized sparse update (scatter and segment) operations for the use in PyTorch, which are missing in the main package. Scatter and segment operations can be roughly described as reduce operations based on a given "group-index" tensor. Segment operations require the "group-index" tensor to be sorted, whereas scatter operations are not subject to these requirements.
The package consists of the following operations with reduction types "sum"|"mean"|"min"|"max"
:
- scatter based on arbitrary indices
- segment_coo based on sorted indices
- segment_csr based on compressed indices via pointers
In addition, we provide the following composite functions which make use of scatter_*
operations under the hood: scatter_std
, scatter_logsumexp
, scatter_softmax
and scatter_log_softmax
.
All included operations are broadcastable, work on varying data types, are implemented both for CPU and GPU with corresponding backward implementations, and are fully traceable.
Installation
Binaries
We provide pip wheels for all major OS/PyTorch/CUDA combinations, see here.
PyTorch 1.8.0
To install the binaries for PyTorch 1.8.0, simply run
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.8.0+${CUDA}.html
where ${CUDA}
should be replaced by either cpu
, cu101
, cu102
, or cu111
depending on your PyTorch installation.
cpu |
cu101 |
cu102 |
cu111 |
|
---|---|---|---|---|
Linux |
|
|
|
|
Windows |
|
|
|
|
macOS |
|
PyTorch 1.7.0/1.7.1
To install the binaries for PyTorch 1.7.0 and 1.7.1, simply run
pip install torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+${CUDA}.html
where ${CUDA}
should be replaced by either cpu
, cu92
, cu101
, cu102
, or cu110
depending on your PyTorch installation.
cpu |
cu92 |
cu101 |
cu102 |
cu110 |
|
---|---|---|---|---|---|
Linux |
|
|
|
|
|
Windows |
|
|
|
|
|
macOS |
|
Note: Binaries of older versions are also provided for PyTorch 1.4.0, PyTorch 1.5.0 and PyTorch 1.6.0 (following the same procedure).
From source
Ensure that at least PyTorch 1.5.0 is installed and verify that cuda/bin
and cuda/include
are in your $PATH
and $CPATH
respectively, e.g.:
$ python -c "import torch; print(torch.__version__)"
>>> 1.5.0
$ echo $PATH
>>> /usr/local/cuda/bin:...
$ echo $CPATH
>>> /usr/local/cuda/include:...
Then run:
pip install torch-scatter
When running in a docker container without NVIDIA driver, PyTorch needs to evaluate the compute capabilities and may fail. In this case, ensure that the compute capabilities are set via TORCH_CUDA_ARCH_LIST
, e.g.:
export TORCH_CUDA_ARCH_LIST = "6.0 6.1 7.2+PTX 7.5+PTX"
Example
import torch
from torch_scatter import scatter_max
src = torch.tensor([[2, 0, 1, 4, 3], [0, 2, 1, 3, 4]])
index = torch.tensor([[4, 5, 4, 2, 3], [0, 0, 2, 2, 1]])
out, argmax = scatter_max(src, index, dim=-1)
print(out)
tensor([[0, 0, 4, 3, 2, 0],
[2, 4, 3, 0, 0, 0]])
print(argmax)
tensor([[5, 5, 3, 4, 0, 1]
[1, 4, 3, 5, 5, 5]])
Running tests
python setup.py test
C++ API
torch-scatter
also offers a C++ API that contains C++ equivalent of python models.
mkdir build
cd build
# Add -DWITH_CUDA=on support for the CUDA if needed
cmake ..
make
make install