Torch-gather
A mini library that implements several useful functions binding to PyTorch in C++.
What does gather do? Why do we need it?
When dealing with sequences, a common way of processing the variable lengths is padding them to the max length, which leads to quite a lot redundancies and waste on computing and memory as sequences length varies. So gather
just removes their paddings and makes computation without waste of computation resource.
Install
python setup.py install
Docs
Note that all the input tensors should be on cuda device.
-
gather.gathercat(x_padded:torch.FloatTensor, lx:torch.IntTensor)
Return a concatence of given padded tensor
x_padded
according to its lengthslx
.Input:
x_padded (
torch.float
): padded tensor of size(N, L, V)
, whereL=max(lx)
.lx (
torch.int
): lengths of size(N, )
.Return:
x_gather (
torch.float
): the gathered tensor without paddings of size(lx[0]+lx[1]+...+lx[N-1], V)
Example:
>>> import torch >>> from gather import gathercat >>> lx = torch.randint(3, 20, (5, ), dtype=torch.int32, device='cuda') >>> x_padded = torch.randn((5, lx.max(), 64), device='cuda') >>> x_padded.size(), lx.size() (torch.Size([5, 19, 64]), torch.Size([5])) >>> x_gather = gathercat(x_padded, lx) >>> x_gather.size() torch.Size([81, 64]) # another example, with V=1 >>> x_padded = torch.tensor([[1., 2., 3.],[1.,2.,0.]], device='cuda').unsqueeze(2) >>> lx = torch.tensor([3,2], dtype=torch.int32, device='cuda') >>> x_padded tensor([[[1.], [2.], [3.]], [[1.], [2.], [0.]]], device='cuda:0') >>> lx tensor([3, 2], device='cuda:0', dtype=torch.int32) >>> gathercat(x_padded, lx) tensor([[1.], [2.], [3.], [1.], [2.]], device='cuda:0')
This function is easy to implement with torch python functions like
torch.cat()
, however,gathercat()
is customized for specified tasks, and more efficient. -
gather.gathersum(xs:torch.FloatTensor, ys:torch.FloatTensor, lx:torch.IntTensor, ly:torch.IntTensor)
Return a sequence-matched broadcast sum of given paired gathered tensor
xs
andys
. For a pair of sequences inxs
andys
, sayxs_i
andys_i
,gathersum()
broadcast them so that they can be added up. The broadcast step can be understood as(xs_i.unsqueeze(1)+ys_i.unsqueeze(2)).reshape(-1, V)
with python and torch.Input:
xs (
torch.float
): gathered tensor of size(ST, V)
, whereST=sum(lx)
.ys (
torch.float
): gathered tensor of size(SU, V)
, whereSU=sum(ly)
.lx (
torch.int
): lengths of size(N, )
.lx[i]
denotes length of the $i_{th}$ sequence inxs
.ly (
torch.int
): lengths of size(N, )
.ly[i]
denotes length of the $i_{th}$ sequence inys
.Return:
gathered_sum (
torch.float
): the gathered sequence-match sum of size(lx[0]ly[0]+lx[1]ly[1]+...+lx[N-1]ly[N-1], V)
Example:
>>> import torch >>> from gather import gathersum >>> N, T, U, V = 5, 4, 4, 3 >>> lx = torch.randint(1, T, (N, ), dtype=torch.int32, device='cuda') >>> ly = torch.randint(1, U, (N, ), dtype=torch.int32, device='cuda') >>> xs = torch.randn((lx.sum(), V), device='cuda') >>> ys = torch.randn((ly.sum(), V), device='cuda') >>> xs.size(), ys.size(), lx.size(), ly.size() (torch.Size([11, 3]), torch.Size([10, 3]), torch.Size([5]), torch.Size([5])) >>> gathered_sum = gathersum(xs, ys, lx, ly) >>> gathered_sum.size() torch.Size([20, 3]) # let's see how the size 20 comes out >>> lx.tolist(), ly.tolist() ([2, 2, 1, 3, 3], [3, 1, 3, 1, 2]) # still unclear? Uh, how about this? >>> (lx * ly).sum().item() 20
This function seems doing something weird. Please refer to the discussion page for a specific usage example.
Reference
-
PyTorch binding refers to the 1ytic/warp-rnnt
-
For the specific usage of these functions, please refer to this discussion.