Is there an existing issue for this?
- [X] I have searched the existing issues
Bug summary
tio.inference.GridSampler
and GridAggregator
do not allow the model output to be smaller than the input.
I was going to submit this as a feature request before making a PR, however, I realised that tio actually supports this depending on patch_overlap
and overlap_mode
so I believe this should be a bug.
Code for reproduction
# This is not a MWE but a test named `test_inference_smaller.py`
from torch.utils.data import DataLoader
from torchio import DATA
from torchio import LOCATION
from torchio.data.inference import GridAggregator
from torchio.data.inference import GridSampler
from ...utils import TorchioTestCase
class TestInference(TorchioTestCase):
"""Tests for `inference` module."""
def test_inference_no_padding(self):
self.try_inference(None)
def test_inference_padding(self):
self.try_inference(3)
def try_inference(self, padding_mode):
for mode in ["crop", "average", "hann"]:
for n in 17, 27:
patch_size = 10, 15, n
patch_overlap = 0, 0, 0 # <------------- this is important and different from the usual test
batch_size = 6
grid_sampler = GridSampler(
self.sample_subject,
patch_size,
patch_overlap,
padding_mode=padding_mode,
)
aggregator = GridAggregator(grid_sampler, overlap_mode=mode)
patch_loader = DataLoader(grid_sampler, batch_size=batch_size)
for patches_batch in patch_loader:
input_tensor = patches_batch['t1'][DATA]
locations = patches_batch[LOCATION]
logits = model(input_tensor) # some model
outputs = logits
#
i_ini, j_ini, k_ini = 1, 1, 1
i_fin, j_fin, k_fin = patch_size[0]-1, patch_size[1]-1, patch_size[2]-1
outputs = outputs[
:,
:,
i_ini:i_fin,
j_ini:j_fin,
k_ini:k_fin,
]
aggregator.add_batch(outputs, locations)
output = aggregator.get_output_tensor()
assert (output == -5).all()
assert output.shape == self.sample_subject.t1.shape
def model(tensor):
tensor[:] = -5
return tensor
Actual outcome
This raises a RuntimeError if patch_overlap is smaller than the difference between input and output, and the overlap mode is anything but crop
Below is the output of running pytest tests/data/inference/test_inference_smaller.py
Error messages
==================================================================================================== FAILURES =====================================================================================================
_____________________________________________________________________________________ TestInference.test_inference_no_padding _____________________________________________________________________________________
self = <tests.data.inference.test_inference_smaller.TestInference testMethod=test_inference_no_padding>
def test_inference_no_padding(self):
> self.try_inference(None)
test_inference_smaller.py:13:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
test_inference_smaller.py:47: in try_inference
aggregator.add_batch(outputs, locations)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <torchio.data.inference.aggregator.GridAggregator object at 0x7f8353643bb0>
batch_tensor = tensor([[[[[-5, -5, -5, ..., -5, -5, -5],
[-5, -5, -5, ..., -5, -5, -5],
[-5, -5, -5, ..., -5..., -5, -5],
[-5, -5, -5, ..., -5, -5, -5],
[-5, -5, -5, ..., -5, -5, -5]]]]], dtype=torch.int32)
locations = array([[ 0, 0, 0, 10, 15, 17],
[ 0, 0, 13, 10, 15, 30],
[ 0, 5, 0, 10, 20, 17],
[ 0, 5, 13, 10, 20, 30]])
def add_batch(
self,
batch_tensor: torch.Tensor,
locations: torch.Tensor,
) -> None:
"""Add batch processed by a CNN to the output prediction volume.
Args:
batch_tensor: 5D tensor, typically the output of a convolutional
neural network, e.g. ``batch['image'][torchio.DATA]``.
locations: 2D tensor with shape :math:`(B, 6)` representing the
patch indices in the original image. They are typically
extracted using ``batch[torchio.LOCATION]``.
"""
batch = batch_tensor.cpu()
locations = locations.cpu().numpy()
patch_sizes = locations[:, 3:] - locations[:, :3]
# There should be only one patch size
assert len(np.unique(patch_sizes, axis=0)) == 1
input_spatial_shape = tuple(batch.shape[-3:])
target_spatial_shape = tuple(patch_sizes[0])
if input_spatial_shape != target_spatial_shape:
message = (
f'The shape of the input batch, {input_spatial_shape},'
' does not match the shape of the target location,'
f' which is {target_spatial_shape}'
)
> raise RuntimeError(message)
E RuntimeError: The shape of the input batch, (8, 13, 15), does not match the shape of the target location, which is (10, 15, 17)
../../../src/torchio/data/inference/aggregator.py:153: RuntimeError
______________________________________________________________________________________ TestInference.test_inference_padding _______________________________________________________________________________________
self = <tests.data.inference.test_inference_smaller.TestInference testMethod=test_inference_padding>
def test_inference_padding(self):
> self.try_inference(3)
test_inference_smaller.py:16:
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
test_inference_smaller.py:47: in try_inference
aggregator.add_batch(outputs, locations)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
self = <torchio.data.inference.aggregator.GridAggregator object at 0x7f835149ca90>
batch_tensor = tensor([[[[[-5, -5, -5, ..., -5, -5, -5],
[-5, -5, -5, ..., -5, -5, -5],
[-5, -5, -5, ..., -5..., -5, -5],
[-5, -5, -5, ..., -5, -5, -5],
[-5, -5, -5, ..., -5, -5, -5]]]]], dtype=torch.int32)
locations = array([[ 0, 0, 0, 10, 15, 17],
[ 0, 0, 13, 10, 15, 30],
[ 0, 5, 0, 10, 20, 17],
[ 0, 5, 13, 10, 20, 30]])
def add_batch(
self,
batch_tensor: torch.Tensor,
locations: torch.Tensor,
) -> None:
"""Add batch processed by a CNN to the output prediction volume.
Args:
batch_tensor: 5D tensor, typically the output of a convolutional
neural network, e.g. ``batch['image'][torchio.DATA]``.
locations: 2D tensor with shape :math:`(B, 6)` representing the
patch indices in the original image. They are typically
extracted using ``batch[torchio.LOCATION]``.
"""
batch = batch_tensor.cpu()
locations = locations.cpu().numpy()
patch_sizes = locations[:, 3:] - locations[:, :3]
# There should be only one patch size
assert len(np.unique(patch_sizes, axis=0)) == 1
input_spatial_shape = tuple(batch.shape[-3:])
target_spatial_shape = tuple(patch_sizes[0])
if input_spatial_shape != target_spatial_shape:
message = (
f'The shape of the input batch, {input_spatial_shape},'
' does not match the shape of the target location,'
f' which is {target_spatial_shape}'
)
> raise RuntimeError(message)
E RuntimeError: The shape of the input batch, (8, 13, 15), does not match the shape of the target location, which is (10, 15, 17)
../../../src/torchio/data/inference/aggregator.py:153: RuntimeError
================================================================================================ warnings summary =================================================================================================
test_inference_smaller.py: 16 warnings
/home/wahab/miniconda3/envs/torchioenv/lib/python3.10/site-packages/SimpleITK/extra.py:183: DeprecationWarning: Converting `np.character` to a dtype is deprecated. The current result is `np.dtype(np.str_)` which is not strictly correct. Note that `np.character` is generally deprecated and 'S1' should be used.
_np_sitk = {np.dtype(np.character): sitkUInt8,
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
============================================================================================= short test summary info =============================================================================================
FAILED test_inference_smaller.py::TestInference::test_inference_no_padding - RuntimeError: The shape of the input batch, (8, 13, 15), does not match the shape of the target location, which is (10, 15, 17)
FAILED test_inference_smaller.py::TestInference::test_inference_padding - RuntimeError: The shape of the input batch, (8, 13, 15), does not match the shape of the target location, which is (10, 15, 17)
========================================================================================= 2 failed, 16 warnings in 0.97s =========================================================================================
Expected outcome
I believe tio should be able to handle smaller outputs. My model predictions are terrible even with averaging or hann windowing. Unfortunately most popular model libraries (such as the great monai
) only provide models with the same output size and input. But it is crucial in my application to let the model see a bigger input ROI than semantic label outputs - by padding convolutions, as this gives context for the prediction. The original unet paper uses padded convolutions for smaller outputs than inputs.
I am going to make a PR tomorrow to add a fix for this, my planned changes are to only change the aggregator. This can be fixed with only changes to GridAggregator
and the sampler can be left the same :
- [x] Check if the aggregator input is smaller than the sampler output in `GridAggregator.add_batch()' before comparing it to the location patch size
- [x] Create a variable in aggregator called
patch_diffs
which is the difference between input_spatial_shape
and target_spatial_shape
- [x] Change each dimension of
self.patch_overlap
to patch_diffs
if it is smaller
- [ ] ~Edit each location before cropping by adding half the diffs from
i_ini
etc and removing half the diffs from i_fin
~
- [x] Write a new unit test (Let me know if this can be improved)
If you see an issue with this happening behind the scenes, should model_output_size
be added as an argument to GridAggreator
or GridSampler
? Or should Aggregator raise a warning if it detects it behind the scenes?
This is a bit confusing even in the code as the models output is the aggregators input, I've tried to be clear here, let me know if I havent.
System info
Platform: Linux-5.4.0-131-generic-x86_64-with-glibc2.27
TorchIO: 0.18.86
PyTorch: 1.13.0+cu117
SimpleITK: 2.2.0 (ITK 5.3)
NumPy: 1.23.4
Python: 3.10.8 (main, Nov 4 2022, 13:48:29) [GCC 11.2.0]
bug