Git Product home page Git Product logo

soft-dtw-jax's Introduction

sdtw-jax

Soft-DTW loss (with warp penalty) in JAX.

Usage:

x = jax.random.normal(jax.random.PRNGKey(42), (4, 800, 80))
y = jnp.roll(x, 4, axis=1)
batched_sdtw(x, y, warp_penalty=1.0, temperature=0.01)
# DeviceArray([16.947954, 16.809141, 16.411541, 17.066374], dtype=float32)

Source:

# Reference: https://arxiv.org/abs/2006.03575

import jax
import jax.numpy as jnp
import numpy as np
from functools import partial


def soft_minimum(values, temperature):
    values = jnp.stack(values, axis=-1)
    return -temperature * jax.nn.logsumexp(-values / temperature, axis=-1)


def skew_matrix(x):
    """Skew a matrix so that the diagonals become the rows."""
    clip = lambda x, a, b: min(max(x, a), b)
    height, width = x.shape
    ids = np.empty((height + width - 1, width), dtype=np.int32)
    for i in range(height + width - 1):
        for j in range(width):  # Shift each column j down by j steps.
            ids[i, j] = clip(i - j, 0, height - 1)
    x = jnp.take_along_axis(x, ids, axis=0)
    return x


def kernel_dist(kernel, xs, ys):
    """
    Returns:
    A 2d array `a` such that `a[i, j] = kernel(xs[i], ys[j])`.
    """
    return jax.vmap(lambda x: jax.vmap(lambda y: kernel(x, y))(ys))(xs)


@partial(jax.jit, static_argnums=(2, 3, 4))
def sdtw(a, b, warp_penalty=1.0, temperature=0.01, INFINITY=1e8):
    N, D1 = a.shape
    M, D2 = b.shape
    assert D1 == D2
    dist_fn = lambda x, y: jnp.mean(jnp.abs(x - y), axis=-1)
    cost = kernel_dist(dist_fn, a, b)
    size = cost.shape[-1]
    path_cost = INFINITY * jnp.ones((size + 1,))
    path_cost_prev = INFINITY * jnp.ones((size,))
    path_cost_prev = jnp.pad(path_cost_prev, (1, 0), constant_values=0.0)
    cost = skew_matrix(cost)

    def scan_fn(prev, inputs):
        path_cost_prev, path_cost = prev
        cost_i = inputs
        penalty_cost = path_cost + warp_penalty
        directions = [path_cost_prev[:-1], penalty_cost[1:], penalty_cost[:-1]]
        path_cost_next = cost_i + soft_minimum(directions, temperature)
        path_cost_next = jnp.pad(path_cost_next, (1, 0), constant_values=INFINITY)
        path_cost, path_cost_prev = path_cost_next, path_cost
        return (path_cost_prev, path_cost), None

    (path_cost_prev, path_cost), _ = jax.lax.scan(
        scan_fn, (path_cost_prev, path_cost), cost
    )
    return path_cost[-1]


@partial(jax.jit, static_argnums=(2, 3, 4))
def batched_sdtw(a, b, warp_penalty=1.0, temperature=0.01, INFINITY=1e8):
    dtw = partial(
        sdtw,
        warp_penalty=warp_penalty,
        temperature=temperature,
        INFINITY=INFINITY,
    )
    return jax.vmap(dtw)(a, b)

soft-dtw-jax's People

Contributors

ntt123 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.