Git Product home page Git Product logo

pysdtw's Introduction

pysdtw - Downloads

Torch implementation of the Soft-DTW algorithm with support for both cpu and CUDA hardware.

This repository started as a fork from this project, but now exists as a stand-alone to include several improvements:

  • availability on pypi
  • code organisation as a package
  • improved API with type declaration
  • support for time series of arbitrary lengths on CUDA
  • support for packed sequences
  • fixes for Sakoe-Ichiba bands

Installation

This package is available on pypi and depends on pytorch and numba.

Install with:

pip install pysdtw

Usage

Below is a small snippet showcasing the computation of the DTW between two batched tensors which also yields the gradient of the DTW with regards to one of the inputs:

import torch
import pysdtw

device=torch.device('cuda')

# the input data includes a batch dimension
X = torch.rand((10, 5, 7), device=device, requires_grad=True)
Y = torch.rand((10, 9, 7), device=device)

# optionally choose a pairwise distance function
fun = pysdtw.distance.pairwise_l2_squared

# create the SoftDTW distance function
sdtw = pysdtw.SoftDTW(gamma=1.0, dist_func=fun, use_cuda=True)

# soft-DTW discrepancy, approaches DTW as gamma -> 0
res = sdtw(X, Y)

# define a loss, which gradient can be backpropagated
loss = res.sum()
loss.backward()

# X.grad now contains the gradient with respect to the loss

You can also have a look at the code in the tests directory. Different test suites ensure that pysdtw behaves similarly to pytorch-softdtw-cuda by Maghoumi and soft-dtw by Blondel. The tests also include some comparative performance measurements. Run the tests with python -m unittests from the root.

Acknowledgements

Supported by the ELEMENT project (ANR-18-CE33-0002) and the ARCOL project (ANR-19-CE33-0001) from the French National Research Agency

pysdtw's People

Contributors

toinsson avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

Forkers

shaun95

pysdtw's Issues

Gradients are gone when moving the code to CUDA

Hi,

First of all, thank you so much for this amazing implementation!
I am trying to use your code (the example code), but I am getting an error when I moved everything to CUDA.

import pysdtw

device = torch.device("cuda")

# the input data includes a batch dimension
X = torch.rand((10, 5, 7), requires_grad=True).to(device)
Y = torch.rand((10, 9, 7)).to(device)

# optionally choose a pairwise distance function
fun = pysdtw.distance.pairwise_l2_squared

# create the SoftDTW distance function
sdtw = pysdtw.SoftDTW(gamma=1.0, dist_func=fun, use_cuda=True)

# soft-DTW discrepancy, approaches DTW as gamma -> 0
res = sdtw(X, Y)

# define a loss, which gradient can be backpropagated
loss = res.sum()
loss.backward()

# X.grad now contains the gradient with respect to the loss

If I print X.grad, the result is empty and I get the following warning message:

/usr/local/lib/python3.7/dist-packages/torch/_tensor.py:1083: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the .grad field to be populated for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations. (Triggered internally at  aten/src/ATen/core/TensorBody.h:477.)
  return self._grad

I'm running the code using Google Colab. Any idea why this is happening? Again thank you so much!

Error when using SoftDTW

Hi,

Thank you for releasing the code. I'm having trouble using it though. When using the simple l2_distance I get the following error
image
And when changing the dist_func for l2_exact I get another error :
image

Any idea why ?

Pytorch 1.8.2+cu111
numba 0.56.4

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.