Git Product home page Git Product logo

torch-imle's Introduction

torch-imle

Concise and self-contained PyTorch library implementing the I-MLE gradient estimator proposed in our NeurIPS 2021 paper Implicit MLE: Backpropagating Through Discrete Exponential Family Distributions.

This repository contains a library for transforming any combinatorial black-box solver in a differentiable layer. All code for reproducing the experiments in the NeurIPS paper is available in the official NEC Laboratories Europe repository.

Overview

Implicit MLE (I-MLE) makes it possible to include discrete combinatorial optimization algorithms, such as Dijkstra's algorithm or integer linear program (ILP) solvers, in standard deep learning architectures. The core idea of I-MLE is that it defines an implicit maximum likelihood objective whose gradients are used to update upstream parameters of the model. Every instance of I-MLE requires two ingredients:

  1. A method to approximately sample from a complex and intractable distribution induced by the combinatorial solver over the space of solutions, where optimal solutions have the highest probability mass. For this, we use Perturb-and-MAP (aka the Gumbel-max trick) and propose a novel family of noise perturbations tailored to the problem at hand.
  2. A method to compute a surrogate empirical distribution: Vanilla MLE reduces the KL divergence between the current distribution and the empirical distribution. Since in our setting, we do not have access to an empirical distribution, we have to design surrogate empirical distributions. Here we propose two families of surrogate distributions which are widely applicable and work well in practice.

Example

For example, let's consider a map from a simple game where the task is to find the shortest path from the top-left to the bottom-right corner. Darker areas have a higher cost, and brighter areas have a lower cost. In the middle, you can see what happens when we use the proposed sum-of-gamma noise distribution to sample paths. On the right, you can see the resulting marginal probabilities for every tile (the probability of each tile being part of a sampled path).

Gradients and Learning

Let us assume that the optimal shortest path is the one of the left. Starting from random weights, the model can learn to produce the weights that will result in the optimal shortest path via Gradient Descent, by minimising the Hamming loss between the produced path and the gold path. Here we show the paths being produced during training (middle), and the corresponding map weights (right).

Input noise temperature set to 0.0, and target noise temperature set to 0.0:

Input noise temperature set to 1.0, and target noise temperature set to 1.0:

Input noise temperature set to 2.0, and target noise temperature set to 2.0:

Input noise temperature set to 5.0, and target noise temperature set to 5.0:

Input noise temperature set to 5.0, and target noise temperature set to 0.0:

All animations were generated by this script.

Code

Using this library is extremely easy -- see this example as a reference. Assuming we have a method that implements a black-box combinatorial solver such as Dijkstra's algorithm:

import numpy as np

import torch
from torch import Tensor

def torch_solver(weights_batch: Tensor) -> Tensor:
    weights_batch = weights_batch.detach().cpu().numpy()
    y_batch = np.asarray([solver(w) for w in list(weights_batch)])
    return torch.tensor(y_batch, requires_grad=False)

We can obtain the corresponding distribution and gradients in this way:

from imle.wrapper import imle
from imle.target import TargetDistribution
from imle.noise import SumOfGammaNoiseDistribution

target_distribution = TargetDistribution(alpha=0.0, beta=10.0)
noise_distribution = SumOfGammaNoiseDistribution(k=k, nb_iterations=100)

def torch_solver(weights_batch: Tensor) -> Tensor:
    weights_batch = weights_batch.detach().cpu().numpy()
    y_batch = np.asarray([solver(w) for w in list(weights_batch)])
    return torch.tensor(y_batch, requires_grad=False)

imle_solver = imle(torch_solver,
                   target_distribution=target_distribution,
                    noise_distribution=noise_distribution,
                    nb_samples=10,
                    input_noise_temperature=input_noise_temperature,
                    target_noise_temperature=target_noise_temperature)

Or, alternatively, using a simple function annotation:

@imle(target_distribution=target_distribution,
      noise_distribution=noise_distribution,
      nb_samples=10,
      input_noise_temperature=input_noise_temperature,
      target_noise_temperature=target_noise_temperature)
def imle_solver(weights_batch: Tensor) -> Tensor:
    return torch_solver(weights_batch)

Papers using I-MLE

Reference

@inproceedings{niepert21imle,
  author    = {Mathias Niepert and
               Pasquale Minervini and
               Luca Franceschi},
  title     = {Implicit {MLE:} Backpropagating Through Discrete Exponential Family
               Distributions},
  booktitle = {NeurIPS},
  series    = {Proceedings of Machine Learning Research},
  publisher = {{PMLR}},
  year      = {2021}
}

torch-imle's People

Contributors

pminervini avatar mniepert avatar armanmielke avatar pwhiddy avatar

Stargazers

Andrei Manolache avatar YuZhao avatar Roberto Alejandro Calzadilla avatar Leopold Kuttner avatar Pandey, Rohit avatar Bo Tang avatar Junhao Liu avatar Takumi Otagaki avatar Xiaoqi Wang avatar Ignacio Carlucho avatar jary pomponi avatar Paul avatar Marc Hadfield avatar  avatar Mikle Mazurov avatar Damiano Azzolini avatar Kaiwen Dong avatar Lucas Saldyt avatar Kehua Chen avatar Jibril Muhammad Adam avatar Andreas Grivas avatar Dominik Schmidt avatar FangyuanLuo avatar Miguel Palencia-Olivar, Ph.D. avatar Stefan Abi-Karam avatar Avi Halevy avatar Chenhui Zhang avatar Florin Gogianu avatar Haoyu Geng avatar Liran Ringel avatar mikigom avatar Zijing Ou avatar Marco Valentino avatar Sayantan Das avatar QiQi avatar Michael Poli avatar Harry Julian avatar Bernhard J. Conzelmann avatar Shyam Sudhakaran avatar Jeroen Van Goey avatar Adrien B avatar Rabqubit avatar Guangyuan Zhao avatar Andrew J. Wren avatar Linwei Sang avatar  avatar haoyu avatar gennaro gala | gg avatar Zhengyang Liang avatar Yash Kumar Atri avatar 爱可可-爱生活 avatar Chitreddy_Sairam avatar Bingchen Zhao avatar João Felipe Santos avatar Xitong Gao avatar  avatar  avatar  avatar Sam Considine avatar Carlos Perales avatar Shining avatar Ivan Kakhaev avatar  avatar Joao Ponte avatar Ryuichiro Hataya avatar Lum avatar Samuele Buosi avatar Filip Karlo Došilović avatar Valentin Aliferov avatar Miras Amir avatar Marijan Smetko avatar Daniel Gafni avatar Polina Roshchina avatar Serg Miller avatar Pavlo Radiuk avatar Roman Tezikov avatar David avatar Calvin-Khang Ta avatar Giuseppe Serra avatar Huy Manh avatar Muhammad Khalifa avatar pure glay avatar Farhad Dalirani avatar RF Liang avatar Farley Lai avatar Stardust Song avatar Adrian Johnston avatar Jon Chun avatar Christopher Parsonson avatar Vincent Ho avatar Jack Langerman avatar Zhihong Shao avatar 曾令燊 avatar Youngjun Park avatar Ruixiang Zhang avatar  avatar Sebastian Björkqvist avatar Xingyu Li avatar  avatar Sagnik Ray Choudhury avatar

Watchers

Sebastian Riedel avatar  avatar James Cloos avatar Jongwook Choi avatar  avatar  avatar Luca Franceschi avatar  avatar

torch-imle's Issues

Differentiation doesn't behave correctly

I'm working on some discrete combinatorial optimization problems, in which IMLE seems to be highly suitable. However, in some instances, it does not behave as anticipated.
Here's a simplified example for better understanding:

Definition of Solver

@imle(target_distribution=TargetDistribution(alpha=1.0, beta=1.0),
      noise_distribution=SumOfGammaNoiseDistribution(k=10, nb_iterations=100),
      input_noise_temperature=10.0,
      target_noise_temperature=10.0)
def solver(weights_batch: Tensor) -> Tensor:
    """
    weights_batch: is a tensor of float values range between -1 and 1
    """
    return (weights_batch > 0).float()

Definition of Parameters, Optimizer, and Labels

param = nn.Parameter(torch.zeros(8, requires_grad=True))
optimizer = torch.optim.AdamW([param], lr=0.1)
y1 = torch.tensor([0., 0., 0., 0., 1., 1., 1., 1.])
y2 = torch.tensor([1., 1., 1., 1., 0., 0., 0., 0.])

Train with y1 as labels

for i in range(100):
    w = torch.tanh(param)
    pred = solver(w)
    loss = (pred - y1).square().sum()
    loss.backward()
    optimizer.step()
param

Expected Result:

  • Values in param[:4] should all turn negative after training.
  • Values in param[4:] should all turn positive after training.

Actual Result: as expected

Parameter containing:
tensor([ -9.9164, -10.1681,  -9.8578,  -9.9074,  10.4283,   9.9732,  10.2848,
         10.2281], requires_grad=True)

Train with y2 as labels

Continue training with negated labels without resetting the parameters.

for i in range(100):
    w = torch.tanh(param)
    pred = solver(w)
    loss = (pred - y2).square().sum()
    loss.backward()
    optimizer.step()
param

Expected Result:

  • The values in param[:4] should start to rise and eventually turn positive.
  • The values in param[4:] should start to fall and eventually turn negative.

Actual Result:

  • param[:4] continues to fall, while param[4:] continues to rise.
Parameter containing:
tensor([-18.7772, -19.0104, -18.5812, -18.6452,  19.3325,  18.8648,  19.1460,
         19.0642], requires_grad=True)

I've attempted this with various hyperparameter configurations, but the outcome remains the same. I'm uncertain as to where the potential problems might lie. Could you provide some insights?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.