Git Product home page Git Product logo

pytorch-metric-learning's Introduction

PyTorch Metric Learning

PyPi version PyPi stats

Anaconda version Anaconda downloads

Commit activity License

Documentation

View the documentation here

Benefits of this library

  1. Ease of use
    • Add metric learning to your application with just 2 lines of code in your training loop.
    • Mine pairs and triplets with a single function call.
  2. Flexibility
    • Mix and match losses, miners, and trainers in ways that other libraries don't allow.

Installation:

Conda:

conda install pytorch-metric-learning -c metric-learning

Pip:

pip install pytorch-metric-learning

Benchmark results

See powerful-benchmarker to view benchmark results and to use the benchmarking tool.

Library contents

Base Classes, Mixins, and Wrappers:

Overview

Let’s try the vanilla triplet margin loss. In all examples, embeddings is assumed to be of size (N, embedding_size), and labels is of size (N).

from pytorch_metric_learning import losses
loss_func = losses.TripletMarginLoss(margin=0.1)
loss = loss_func(embeddings, labels)

Loss functions typically come with a variety of parameters. For example, with the TripletMarginLoss, you can control how many triplets per sample to use in each batch. You can also use all possible triplets within each batch:

loss_func = losses.TripletMarginLoss(triplets_per_anchor="all")

Sometimes it can help to add a mining function:

from pytorch_metric_learning import miners, losses
miner = miners.MultiSimilarityMiner(epsilon=0.1)
loss_func = losses.TripletMarginLoss(margin=0.1)
hard_pairs = miner(embeddings, labels)
loss = loss_func(embeddings, labels, hard_pairs)

In the above code, the miner finds positive and negative pairs that it thinks are particularly difficult. Note that even though the TripletMarginLoss operates on triplets, it’s still possible to pass in pairs. This is because the library automatically converts pairs to triplets and triplets to pairs, when necessary.

In general, all loss functions take in embeddings and labels, with an optional indices_tuple argument (i.e. the output of a miner):

# From BaseMetricLossFunction
def forward(self, embeddings, labels, indices_tuple=None)

And (almost) all mining functions take in embeddings and labels:

# From BaseMiner
def forward(self, embeddings, labels)

For more complex approaches, like deep adversarial metric learning, use one of the trainers.

To check the accuracy of your model, use one of the testers. Which tester should you use? Almost definitely GlobalEmbeddingSpaceTester, because it does what most metric-learning papers do.

Also check out the example scripts. Each one shows how to set up models, optimizers, losses etc for a particular trainer.

To learn more about all of the above, see the documentation.

Acknowledgements

Facebook AI

Thank you to Ser-Nam Lim at Facebook AI, and my research advisor, Professor Serge Belongie. This project began during my internship at Facebook AI where I received valuable feedback from Ser-Nam, and his team of computer vision and machine learning engineers and research scientists. In particular, thanks to Ashish Shah and Austin Reiter for reviewing my code during its early stages of development.

Open-source repos

This library contains code that has been adapted and modified from the following great open-source repos:

Citing this library

If you'd like to cite pytorch-metric-learning in your paper, you can use this bibtex:

@misc{Musgrave2019,
  author = {Musgrave, Kevin and Lim, Ser-Nam and Belongie, Serge},
  title = {PyTorch Metric Learning},
  year = {2019},
  publisher = {GitHub},
  journal = {GitHub repository},
  howpublished = {\url{https://github.com/KevinMusgrave/pytorch-metric-learning}},
}

pytorch-metric-learning's People

Contributors

kevinmusgrave avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.