Custom TensorFlow2 implementations of forward and backward computation of soft-DTW algorithm in batch mode.

Overview

Batch Soft-DTW(Dynamic Time Warping) in TensorFlow2 including forward and backward computation

Custom TensorFlow2 implementations of forward and backward computation of soft-DTW(Dynamic Time Warping) algorithm in batch mode, which is proposed in paper 《Soft-DTW: a Differentiable Loss Function for Time-Series》.

I have implemented two versions of soft-DTW, one is the original paper, the other is Parallel Tacotron2's paper(with warp penalty). For latter version, I solved the equations of backward computation myself.

If you have questions or improvements about the code, welcome to submit issues ASAP!

You might also like...
Tf alloc - Simplication of GPU allocation for Tensorflow2

tf_alloc Simpliying GPU allocation for Tensorflow Developer: korkite (Junseo Ko)

Tensorflow2 Keras-based Semantic Segmentation Models Implementation

Tensorflow2 Keras-based Semantic Segmentation Models Implementation

TorchMetrics is a collection of 25+ PyTorch metrics implementations and an easy-to-use API to create custom metrics.
Official implementation of "DSP: Dual Soft-Paste for Unsupervised Domain Adaptive Semantic Segmentation"

DSP Official implementation of "DSP: Dual Soft-Paste for Unsupervised Domain Adaptive Semantic Segmentation". Accepted by ACM Multimedia 2021. Authors

 Fast Soft Color Segmentation
Fast Soft Color Segmentation

Fast Soft Color Segmentation

Permute Me Softly: Learning Soft Permutations for Graph Representations

Permute Me Softly: Learning Soft Permutations for Graph Representations

Implementation of parameterized soft-exponential activation function.
Implementation of parameterized soft-exponential activation function.

Soft-Exponential-Activation-Function: Implementation of parameterized soft-exponential activation function. In this implementation, the parameters are

Multi-task Multi-agent Soft Actor Critic for SMAC

Multi-task Multi-agent Soft Actor Critic for SMAC Overview The CARE formulti-task: Multi-Task Reinforcement Learning with Context-based Representation

[ICLR 2022] Contact Points Discovery for Soft-Body Manipulations with Differentiable Physics
[ICLR 2022] Contact Points Discovery for Soft-Body Manipulations with Differentiable Physics

CPDeform Code and data for paper Contact Points Discovery for Soft-Body Manipulations with Differentiable Physics at ICLR 2022 (Spotlight). @InProceed

Comments
  • How to use it as a custom loss function in tf?

    How to use it as a custom loss function in tf?

    Thank you very much for your share! Could you help me figure out how to implement the soft dtw as a custom loss in the deep learning training framework (in tf)? What I hope is to replace the RMSE/MAE.. metrics by the dtw. Thank you in advance.

    opened by wangshuo1994 0
  • Questions about your code

    Questions about your code

    Running your code causes two problems.

    In the line, soft_dtw_distance = batch_soft_dtw(a, b, gamma=0.01, metric="L1"), the code is not executed because no warp is given.

    I put a warp value and ran the code, but I had the following error message. '('custom_gradient function expected to return', 2, 'gradients but returned', 1, 'instead.')'

    I tested the code based tf 2.3 and tf_nightly 2.5, and got the same error.

    opened by JunetaeKim 0
  • Thank you :D

    Thank you :D

    @zzw922cn Hi, I'm a creator of TensorFlowTTS (https://github.com/TensorSpeech/TensorFlowTTS). I have a plan to reproduce Parallel Tacotron2 and will use this implementation :D. Thank you so much, it's not EZ to implement soft_dtw by TensorFlow since tf.Tensor still an immutable object and we need to use tf.TensorArray as an alternative solution.

    opened by dathudeptrai 1
Owner
Life is short, do things you love. 人生苦短,做自己热爱的事情。
null
TensorFlow2 Classification Model Zoo playing with TensorFlow2 on the CIFAR-10 dataset.

Training CIFAR-10 with TensorFlow2(TF2) TensorFlow2 Classification Model Zoo. I'm playing with TensorFlow2 on the CIFAR-10 dataset. Architectures LeNe

Chia-Hung Yuan 16 Sep 27, 2022
Implements Stacked-RNN in numpy and torch with manual forward and backward functions

Recurrent Neural Networks Implements simple recurrent network and a stacked recurrent network in numpy and torch respectively. Both flavours implement

Vishal R 1 Nov 16, 2021
Collection of TensorFlow2 implementations of Generative Adversarial Network varieties presented in research papers.

TensorFlow2-GAN Collection of tf2.0 implementations of Generative Adversarial Network varieties presented in research papers. Model architectures will

null 41 Apr 28, 2022
Transport Mode detection - can detect the mode of transport with the help of features such as acceeration,jerk etc

title emoji colorFrom colorTo sdk app_file pinned Transport_Mode_Detector ?? purple yellow gradio app.py false Configuration title: string Display tit

Nishant Rajadhyaksha 3 Jan 16, 2022
Genetic Algorithm, Particle Swarm Optimization, Simulated Annealing, Ant Colony Optimization Algorithm,Immune Algorithm, Artificial Fish Swarm Algorithm, Differential Evolution and TSP(Traveling salesman)

scikit-opt Swarm Intelligence in Python (Genetic Algorithm, Particle Swarm Optimization, Simulated Annealing, Ant Colony Algorithm, Immune Algorithm,A

郭飞 3.7k Jan 3, 2023
STBP is a way to train SNN with datasets by Backward propagation.

Spiking neural network (SNN), compared with depth neural network (DNN), has faster processing speed, lower energy consumption and more biological interpretability, which is expected to approach Strong AI.

Ling Zhang 18 Dec 9, 2022
Softlearning is a reinforcement learning framework for training maximum entropy policies in continuous domains. Includes the official implementation of the Soft Actor-Critic algorithm.

Softlearning Softlearning is a deep reinforcement learning toolbox for training maximum entropy policies in continuous domains. The implementation is

Robotic AI & Learning Lab Berkeley 997 Dec 30, 2022
null 202 Jan 6, 2023
Regression Metrics Calculation Made easy for tensorflow2 and scikit-learn

Regression Metrics Installation To install the package from the PyPi repository you can execute the following command: pip install regressionmetrics I

Ashish Patel 11 Dec 16, 2022
Pointer networks Tensorflow2

Pointer networks Tensorflow2 原文:https://arxiv.org/abs/1506.03134 仅供参考与学习,内含代码备注 环境 tensorflow==2.6.0 tqdm matplotlib numpy 《pointer networks》阅读笔记 应用场景

HUANG HAO 7 Oct 27, 2022