Git Product home page Git Product logo

Comments (4)

aymericdamien avatar aymericdamien commented on April 28, 2024

Yes, BasicLSTMCell requires to provide parameters together, for computation efficiency. During calculation, they will be splitted.

# Parameters of gates are concatenated into one multiply for efficiency.
      c, h = tf.split(1, 2, state)

Note: This is only specific to LSTM cell, for SimpleRNN or GRU cell, you doesn't need to provide 2xnum_units, just num_units

For more info, you can have a look at BasicLSTMCell function defines in rnn_cell.py:

class BasicLSTMCell(RNNCell):
  """Basic LSTM recurrent network cell.

  The implementation is based on: http://arxiv.org/pdf/1409.2329v5.pdf.

  It does not allow cell clipping, a projection layer, and does not
  use peep-hole connections: it is the basic baseline.

  Biases of the forget gate are initialized by default to 1 in order to reduce
  the scale of forgetting in the beginning of the training.
  """

  def __init__(self, num_units, forget_bias=1.0):
    self._num_units = num_units
    self._forget_bias = forget_bias

  @property
  def input_size(self):
    return self._num_units

  @property
  def output_size(self):
    return self._num_units

  @property
  def state_size(self):
    return 2 * self._num_units

  def __call__(self, inputs, state, scope=None):
    """Long short-term memory cell (LSTM)."""
    with tf.variable_scope(scope or type(self).__name__):  # "BasicLSTMCell"
      # Parameters of gates are concatenated into one multiply for efficiency.
      c, h = tf.split(1, 2, state)
      concat = linear.linear([inputs, h], 4 * self._num_units, True)

      # i = input_gate, j = new_input, f = forget_gate, o = output_gate
      i, j, f, o = tf.split(1, 4, concat)

      new_c = c * tf.sigmoid(f + self._forget_bias) + tf.sigmoid(i) * tf.tanh(j)
      new_h = tf.tanh(new_c) * tf.sigmoid(o)

    return new_h, tf.concat(1, [new_c, new_h])

from tensorflow-examples.

yanghoonkim avatar yanghoonkim commented on April 28, 2024

I have one more question that:
why you do some pre-processing to input data batch?

# input shape: (batch_size, n_steps, n_input)
    _X = tf.transpose(_X, [1, 0, 2])  # permute n_steps and batch_size
    # Reshape to prepare input to hidden activation
    _X = tf.reshape(_X, [-1, n_input]) # (n_steps*batch_size, n_input)
    # Linear activation
    _X = tf.matmul(_X, _weights['hidden']) + _biases['hidden']

    # Define a lstm cell with tensorflow
    lstm_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
    # Split data because rnn cell needs a list of inputs for the RNN inner loop
    _X = tf.split(0, n_steps, _X) # n_steps * (batch_size, n_hidden)

from the example provided by google (see ptb_word_lm.py in tensorflow/models/rnn/ptb/)

self._input_data = tf.placeholder(tf.int32, [batch_size, num_steps])

and then

with tf.device("/cpu:0"):
      embedding = tf.get_variable("embedding", [vocab_size, size])
      inputs = tf.nn.embedding_lookup(embedding, self._input_data)

they did nothing but embedding

can you tell me the difference?

Your acknowledgement is appreciated.

from tensorflow-examples.

aymericdamien avatar aymericdamien commented on April 28, 2024

In this case, we are classifying MNIST dataset using a RNN, by considering every image pixel row as a sequence. So for every image, we will have 28 sequences with 28 steps.

Here we do not need embedding, because we can already "compare" pixels together (every sequence element represents a pixel gray-scale intensity). But in tensorflow, they need embedding, because their sequence is a list of integers (word index in a dictionary), and you can't really "compare" them directly (they are just IDs). In practice, embedding can greatly improves performances in NLP.

The word IDs will be embedded into a dense representation (see the Vector Representations Tutorial) before feeding to the LSTM. This allows the model to efficiently represent the knowledge about particular words. It is also easy to write:

Extracted from https://www.tensorflow.org/versions/master/tutorials/recurrent/index.html#lstm

from tensorflow-examples.

yanghoonkim avatar yanghoonkim commented on April 28, 2024

Thank you for your prompt reply.

from tensorflow-examples.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.