A mini lib that implements several useful functions binding to PyTorch in C++.

Overview

Torch-gather

A mini library that implements several useful functions binding to PyTorch in C++.

What does gather do? Why do we need it?

When dealing with sequences, a common way of processing the variable lengths is padding them to the max length, which leads to quite a lot redundancies and waste on computing and memory as sequences length varies. So gather just removes their paddings and makes computation without waste of computation resource.

Install

python setup.py install

Docs

Note that all the input tensors should be on cuda device.

  • gather.gathercat(x_padded:torch.FloatTensor, lx:torch.IntTensor)

    Return a concatence of given padded tensor x_padded according to its lengths lx.

    Input:

    x_padded (torch.float): padded tensor of size (N, L, V), where L=max(lx).

    lx (torch.int): lengths of size (N, ).

    Return:

    x_gather (torch.float): the gathered tensor without paddings of size (lx[0]+lx[1]+...+lx[N-1], V)

    Example:

    >>> import torch
    >>> from gather import gathercat
    >>> lx = torch.randint(3, 20, (5, ), dtype=torch.int32, device='cuda')
    >>> x_padded = torch.randn((5, lx.max(), 64), device='cuda')
    >>> x_padded.size(), lx.size()
    (torch.Size([5, 19, 64]), torch.Size([5]))
    >>> x_gather = gathercat(x_padded, lx)
    >>> x_gather.size()
    torch.Size([81, 64])
    # another example, with V=1
    >>> x_padded = torch.tensor([[1., 2., 3.],[1.,2.,0.]], device='cuda').unsqueeze(2)
    >>> lx = torch.tensor([3,2], dtype=torch.int32, device='cuda')
    >>> x_padded
    tensor([[[1.],
            [2.],
            [3.]],
    
            [[1.],
            [2.],
            [0.]]], device='cuda:0')
    >>> lx
    tensor([3, 2], device='cuda:0', dtype=torch.int32)
    >>> gathercat(x_padded, lx)
    tensor([[1.],
            [2.],
            [3.],
            [1.],
            [2.]], device='cuda:0')

    This function is easy to implement with torch python functions like torch.cat(), however, gathercat() is customized for specified tasks, and more efficient.

  • gather.gathersum(xs:torch.FloatTensor, ys:torch.FloatTensor, lx:torch.IntTensor, ly:torch.IntTensor)

    Return a sequence-matched broadcast sum of given paired gathered tensor xs and ys. For a pair of sequences in xs and ys, say xs_i and ys_i, gathersum() broadcast them so that they can be added up. The broadcast step can be understood as (xs_i.unsqueeze(1)+ys_i.unsqueeze(2)).reshape(-1, V) with python and torch.

    Input:

    xs (torch.float): gathered tensor of size (ST, V), where ST=sum(lx).

    ys (torch.float): gathered tensor of size (SU, V), where SU=sum(ly).

    lx (torch.int): lengths of size (N, ). lx[i] denotes length of the $i_{th}$ sequence in xs.

    ly (torch.int): lengths of size (N, ). ly[i] denotes length of the $i_{th}$ sequence in ys.

    Return:

    gathered_sum (torch.float): the gathered sequence-match sum of size (lx[0]ly[0]+lx[1]ly[1]+...+lx[N-1]ly[N-1], V)

    Example:

    >>> import torch
    >>> from gather import gathersum
    >>> N, T, U, V = 5, 4, 4, 3
    >>> lx = torch.randint(1, T, (N, ), dtype=torch.int32, device='cuda')
    >>> ly = torch.randint(1, U, (N, ), dtype=torch.int32, device='cuda')
    >>> xs = torch.randn((lx.sum(), V), device='cuda')
    >>> ys = torch.randn((ly.sum(), V), device='cuda')
    >>> xs.size(), ys.size(), lx.size(), ly.size()
    (torch.Size([11, 3]), torch.Size([10, 3]), torch.Size([5]), torch.Size([5]))
    >>> gathered_sum = gathersum(xs, ys, lx, ly)
    >>> gathered_sum.size()
    torch.Size([20, 3])
    # let's see how the size 20 comes out
    >>> lx.tolist(), ly.tolist()
    ([2, 2, 1, 3, 3], [3, 1, 3, 1, 2])
    # still unclear? Uh, how about this?
    >>> (lx * ly).sum().item()
    20

    This function seems doing something weird. Please refer to the discussion page for a specific usage example.

Reference

  • PyTorch binding refers to the 1ytic/warp-rnnt

  • For the specific usage of these functions, please refer to this discussion.

You might also like...
Python lib to talk to pylontech lithium batteries (US2000, US3000, ...) using RS485

python-pylontech Python lib to talk to pylontech lithium batteries (US2000, US3000, ...) using RS485 What is this lib ? This lib is meant to talk to P

A mini library for Policy Gradients with Parameter-based Exploration, with reference implementation of the ClipUp optimizer  from NNAISENSE.
A mini library for Policy Gradients with Parameter-based Exploration, with reference implementation of the ClipUp optimizer from NNAISENSE.

PGPElib A mini library for Policy Gradients with Parameter-based Exploration [1] and friends. This library serves as a clean re-implementation of the

This is the official implementation of TrivialAugment and a mini-library for the application of multiple image augmentation strategies including RandAugment and TrivialAugment.

Trivial Augment This is the official implementation of TrivialAugment (https://arxiv.org/abs/2103.10158), as was used for the paper. TrivialAugment is

Vanilla and Prototypical Networks with Random Weights for image classification on Omniglot and mini-ImageNet. Made with Python3.

vanilla-rw-protonets-project Vanilla Prototypical Networks and PNs with Random Weights for image classification on Omniglot and mini-ImageNet. Made wi

Robocop is your personal mini voice assistant  made using Python.
Robocop is your personal mini voice assistant made using Python.

Robocop-VoiceAssistant To use this project, you should have python installed in your system. If you don't have python installed, install it beforehand

Pca-on-genotypes - Mini bioinformatics project - PCA on genotypes

Mini bioinformatics project: PCA on genotypes This repo contains the code from t

Code of 3D Shape Variational Autoencoder Latent Disentanglement via Mini-Batch Feature Swapping for Bodies and Faces

3D Shape Variational Autoencoder Latent Disentanglement via Mini-Batch Feature Swapping for Bodies and Faces Installation After cloning the repo open

Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Mini Software that give reminder to drink water as per your weight.
Mini Software that give reminder to drink water as per your weight.

Water Notification Desktop Python The Mini Software built in Python (tkinter) that will remind you to drink water on specific time span based on your

Owner
maxwellzh
maxwellzh
A compendium of useful, interesting, inspirational usage of pandas functions, each example will be an ipynb file

Pandas_by_examples A compendium of useful/interesting/inspirational usage of pandas functions, each example will be an ipynb file What is this reposit

Guangyuan(Frank) Li 32 Nov 20, 2022
A repository with exploration into using transformers to predict DNA ↔ transcription factor binding

Transcription Factor binding predictions with Attention and Transformers A repository with exploration into using transformers to predict DNA ↔ transc

Phil Wang 62 Dec 20, 2022
Official implementation of "Generating 3D Molecules for Target Protein Binding"

Generating 3D Molecules for Target Protein Binding This is the official implementation of the GraphBP method proposed in the following paper. Meng Liu

DIVE Lab, Texas A&M University 74 Dec 7, 2022
Implements pytorch code for the Accelerated SGD algorithm.

AccSGD This is the code associated with Accelerated SGD algorithm used in the paper On the insufficiency of existing momentum schemes for Stochastic O

null 205 Jan 2, 2023
deep-table implements various state-of-the-art deep learning and self-supervised learning algorithms for tabular data using PyTorch.

deep-table implements various state-of-the-art deep learning and self-supervised learning algorithms for tabular data using PyTorch.

null 63 Oct 17, 2022
Reinforcement learning library(framework) designed for PyTorch, implements DQN, DDPG, A2C, PPO, SAC, MADDPG, A3C, APEX, IMPALA ...

Automatic, Readable, Reusable, Extendable Machin is a reinforcement library designed for pytorch. Build status Platform Status Linux Windows Supported

Iffi 348 Dec 24, 2022
mbrl-lib is a toolbox for facilitating development of Model-Based Reinforcement Learning algorithms.

mbrl-lib is a toolbox for facilitating development of Model-Based Reinforcement Learning algorithms. It provides easily interchangeable modeling and planning components, and a set of utility functions that allow writing model-based RL algorithms with only a few lines of code.

Facebook Research 724 Jan 4, 2023
OpenDILab RL Kubernetes Custom Resource and Operator Lib

DI Orchestrator DI Orchestrator is designed to manage DI (Decision Intelligence) jobs using Kubernetes Custom Resource and Operator. Prerequisites A w

OpenDILab 205 Dec 29, 2022
FluidNet re-written with ATen tensor lib

fluidnet_cxx: Accelerating Fluid Simulation with Convolutional Neural Networks. A PyTorch/ATen Implementation. This repository is based on the paper,

JoliBrain 50 Jun 7, 2022
Jittor Medical Segmentation Lib -- The assignment of Pattern Recognition course (2021 Spring) in Tsinghua University

THU模式识别2021春 -- Jittor 医学图像分割 模型列表 本仓库收录了课程作业中同学们采用jittor框架实现的如下模型: UNet SegNet DeepLab V2 DANet EANet HarDNet及其改动HarDNet_alter PSPNet OCNet OCRNet DL

null 48 Dec 26, 2022