Git Product home page Git Product logo

lstm-rnn's Introduction

Recurrent Neural Network with LSTM Cells, in pure Python

A vanilla implementation of a Recurrent Neural Network (RNN) with Long-Short-Term-Memory cells, without using any ML libraries.

Background

These networks are particularly good for learning long-term dependencies within data, and can be applied to a variety of problems including language modelling, translation and speech recognition.

An LSTM cell has 4 gates, based on the following formulas:

Each gate has it's own set of paramaters to learn, which makes training vanilla implementations (such as this one) expensive.

These are collected into a single cell state value:

This is then given to a hidden state, as a normal RNN cell would: LSTM cells can effectively be treated no differently to any other cell within the network.

Training and Initialisation

To initialise the network, create an instance of the class by calling the constructor with the arguments:

rnn = new LSTM_RNN(lr, in_dim, h_dim, out_dim)

Where lr is the learning rate; in_dim is the dimension of the input layer; h_dim is the dimension of the hidden layer and out_dim is the dimension of the output layer. These should correspond to your training data.

The training data should be encoded as integers, and given as two lists: a list of inputs and a corresponding one of targets. The RNN can then be trained by calling the function:

rnn.train(iterations, inputs, targets, seq_len)

Where iterations is the number of iterations to run, inputs and targets are the training data, and seq_len is the length of each batch of data.

Planned Features and Improvements

  • A sampling method to view the output of the network as it is training, using a forward pass.
  • Refactor the code to use a graph of computation model.
  • Use a linear sigmoid function to improve the speed.

lstm-rnn's People

Contributors

tompntn avatar

Stargazers

 avatar  avatar

Watchers

 avatar

lstm-rnn's Issues

Wrong variable ?

Hello, I am going though your implementation and I was wondering why are you getting the same variable as on previous line.

prev_c = self.state['h'][t - 1] if t > 0 else first_prev_c

Shouldn't there be

prev_c = self.state['c'][t - 1] if t > 0 else first_prev_c

Thanks for your response.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.