Git Product home page Git Product logo

tensorflow-multi-dimensional-lstm's Introduction

Multi Dimensional Recurrent Networks

Tensorflow Implementation of the model described in Alex Graves' paper https://arxiv.org/pdf/0705.2011.pdf.


Example: 2D LSTM Architecture

What is MD LSTM?

Basically a LSTM that is multidirectional, for example, that can operate on a 2D grid. Here's a figure describing the way it works:


Example: 2D LSTM Architecture

How to get started?

git clone [email protected]:philipperemy/tensorflow-multi-dimensional-lstm.git
cd tensorflow-multi-dimensional-lstm

# create a new virtual python environment
virtualenv -p python3 venv
source venv/bin/activate
pip install -r requirements.txt

# usage: trainer.py [-h] --model_type {MD_LSTM,HORIZONTAL_SD_LSTM,SNAKE_SD_LSTM}
python trainer.py --model_type MD_LSTM
python trainer.py --model_type HORIZONTAL_SD_LSTM
python trainer.py --model_type SNAKE_SD_LSTM

Random diagonal Task

The random diagonal task consists in initializing a matrix with values very close to 0 except two which are set to 1. Those two values are on a straight line parallel to the diagonal of the matrix. The idea is to predict where those two values are. Here are some examples:

____________
|          |
|x         |
| x        |
|          |
|__________|


____________
|          |
|          |
|     x    |
|      x   |
|__________|

____________
|          |
| x        |
|  x       |
|          |
|__________|

A model performing on this task is considered as successful if it can correctly predict the second x (it's impossible to predict the first x).

  • A simple recurrent model going vertical or horizontal cannot predict any locations of x. This model is called HORIZONTAL_SD_LSTM. It should perform the worst.
  • If the matrix is flattened as one single vector, then the first location of x still cannot be predicted. However, a recurrent model should understand that the second x always comes after the first x (width+1 steps). (Model is SNAKE_SD_LSTM).
  • When predicting the second location of x, a MD recurrent model has a full view of the TOP LEFT corner. In that case, it should understand that when the first x is in the bottom right of its window, the second x will be next on the diagonal axis. Of course the first location x still cannot be predicted at all with this MD model.

After training on this task for 8x8 matrices, the losses look like this:

Overall loss of the random diagonal task (loss applied on all the elements of the inputs)

Overall loss of the random diagonal task (loss applied only on the location of the second x)

No surprise that MD LSTM performs the best here. It has direct connections between the grid cell that contains the first x and the second x (2 connections). The snake LSTM has width+1 = 9 steps between the two x. As expected, the vertical LSTM does not learn anything apart from outputting values very close to 0.


MD LSTM predictions (left) and ground truth (right) before training (predictions are all random).


MD LSTM predictions (left) and ground truth (right) after training. As expected, the MD LSTM can only predict the second x and not the first one. That means the task is correctly predicted.

Limitations

  • I could test it successfully with 32x32 matrices but the implementation is far from being well optimised.
  • This implementation can become numerically unstable quite easily.
  • I've noticed that inputs should be != 0. Otherwise some gradients are nan. So consider inputs += eps in case.
  • It's hard to use in Keras. This implementation is in pure tensorflow.
  • It runs on a GPU but the code is not optimized at all so I would say it's equally fast (CPU vs GPU).

Contributions

Welcome!

Special Thanks

  • A big thank you to Mosnoi Ion who provided the first skeleton of this MD LSTM.

tensorflow-multi-dimensional-lstm's People

Contributors

joankaradimov avatar philipperemy avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

tensorflow-multi-dimensional-lstm's Issues

Diagonal iteration instead of naive

Hello,

First of all thanks for your work, it is to my knowledge the first public attempt at defining a multidimensional while_loop.

If I understood your implementation, each cell of the grid is updated sequentially from left to right, then up to down. But as pointed by your illustration in the ReadMe, this process is parallelizable, as each cell on the diagonal orthogonal to the propagation can be updated independantly. Do you have any lead on performing this processing? I tried to work on this one but could not figure out a solution, Tensorflow really isn't easy to use in this case because the number of cell states to compute at a time is a function of h and w.

Thanks,
Quentin

Multi-dimensional vs Convolutional lstm

Hello, I have been researching networks like the one in this paper and I am confused about the difference between

  1. a multi dimensional lstm here or in https://arxiv.org/pdf/1611.05664.pdf where 2d lstms are part of a larger network of cnns

and

  1. convolutional 2d lstms as described https://arxiv.org/abs/1506.04214v1 ,the source of the Keras layer ConvLSTM2D.

What confuses me is that in the network described in the paper in 1) integrates 2d lstms into a cnn architecture, but it seems that this is different in some way from ConvLSTM2D. Could you explain how your project is different from 2), or how they are not?

Thank you

Multi Dimensional LSTM return size

In the comment section of "def multi_dimensional_rnn_while_loop(rnn_size, input_data, sh, dims=None, scope_n="layer1"):" funtion you have mensioned return tensor shape is [batch,h/sh[0],w/sh[1],channels x sh[0] x sh[1]]
but Actually when you are returning from above mensioned function your return tensor shape is
[batch,h/sh[0],w/sh[1],rnn_size]
Can you please explain why is it so?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.