Git Product home page Git Product logo

pytorch-batch-attention-seq2seq's Introduction

A fast, batched Bi-RNN(GRU) encoder & attention decoder implementation in PyTorch

Usage: Please refer to offical pytorch tutorial on attention-RNN machine translation, except that this implementation handles batched inputs, and that it implements a slightly different attention mechanism.
To find out the formula-level difference of implementation, illustrations below will help a lot.

PyTorch version mechanism illustration, see here:
http://pytorch.org/tutorials/_images/decoder-network.png
PyTorch offical Seq2seq machine translation tutorial:
http://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.html
Bahdanau attention illustration, see here:
http://images2015.cnblogs.com/blog/670089/201610/670089-20161012111504671-910168246.png

PyTorch version attention decoder fed "word_embedding" to compute attention weights, while in the origin paper it is supposed to be "encoder_outputs". In this repository, we implemented the origin attention decoder according to the paper

Update: dynamic encoder added and does not require inputs to be sorted by length in a batch.


Speed up with batched tensor manipulation

PyTorch supports element-wise fetching and assigning tensor values during procedure, but actually it is slow especially when running on GPU. In a tutorial(https://github.com/spro/practical-pytorch), attention values are assigned element-wise; it's absolutely correct(and intuitive from formulas in paper), but slow on our GPU. Thus, we re-implemented a real batched tensor manipulating version, and it achieves more than 10X speed improvement.

This code works well on personal projects.

pytorch-batch-attention-seq2seq's People

Contributors

aucson avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.