Comments (3)
Yes, you can change batch size at any time, while training or testing. For more simplicity, you can use master version of tensorflow, where you don't need to provides seq_length. So you can directly use:
def BiRNN(_X, _istate_fw, _istate_bw, _weights, _biases):
# input shape: (batch_size, n_steps, n_input)
_X = tf.transpose(_X, [1, 0, 2]) # permute n_steps and batch_size
# Reshape to prepare input to hidden activation
_X = tf.reshape(_X, [-1, n_input]) # (n_steps*batch_size, n_input)
# Linear activation
_X = tf.matmul(_X, _weights['hidden']) + _biases['hidden']
# Define lstm cells with tensorflow
# Forward direction cell
lstm_fw_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
# Backward direction cell
lstm_bw_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
# Split data because rnn cell needs a list of inputs for the RNN inner loop
_X = tf.split(0, n_steps, _X) # n_steps * (batch_size, n_hidden)
# Get lstm cell output
outputs = rnn.bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, _X,
initial_state_fw=_istate_fw,
initial_state_bw=_istate_bw)
# Linear activation
# Get inner loop last output
return tf.matmul(outputs[-1], _weights['out']) + _biases['out']
pred = BiRNN(x, istate_fw, istate_bw, weights, biases)
This function is then independent from batch_size.
If you are using the 0.6.0 version, and want to change batch size, you need to change a little the function and provides batch_size as a placeholder:
# Import MINST data
import input_data
mnist = input_data.read_data_sets("/mnist/", one_hot=True)
import tensorflow as tf
from tensorflow.python.ops.constant_op import constant
from tensorflow.models.rnn import rnn, rnn_cell
import numpy as np
'''
To classify images using a bidirectional reccurent neural network, we consider every image row as a sequence of pixels.
Because MNIST image shape is 28*28px, we will then handle 28 sequences of 28 steps for every sample.
'''
# Parameters
learning_rate = 0.001
training_iters = 100000
batch_size = tf.placeholder(dtype=tf.int32)
display_step = 10
# Network Parameters
n_input = 28 # MNIST data input (img shape: 28*28)
n_steps = 28 # timesteps
n_hidden = 128 # hidden layer num of features
n_classes = 10 # MNIST total classes (0-9 digits)
# tf Graph input
x = tf.placeholder("float", [None, n_steps, n_input])
# Tensorflow LSTM cell requires 2x n_hidden length (state & cell)
istate_fw = tf.placeholder("float", [None, 2*n_hidden])
istate_bw = tf.placeholder("float", [None, 2*n_hidden])
y = tf.placeholder("float", [None, n_classes])
# Define weights
weights = {
# Hidden layer weights => 2*n_hidden because of foward + backward cells
'hidden': tf.Variable(tf.random_normal([n_input, 2*n_hidden])),
'out': tf.Variable(tf.random_normal([2*n_hidden, n_classes]))
}
biases = {
'hidden': tf.Variable(tf.random_normal([2*n_hidden])),
'out': tf.Variable(tf.random_normal([n_classes]))
}
def BiRNN(_X, _istate_fw, _istate_bw, _weights, _biases, _batch_size, _seq_len):
# BiRNN requires to supply sequence_length as [batch_size, int64]
# Note: Tensorflow 0.6.0 requires BiRNN sequence_length parameter to be set
# For a better implementation with latest version of tensorflow, check below
_seq_len = tf.fill(_batch_size, constant(_seq_len, dtype=tf.int32))
_seq_len = tf.cast(_seq_len, dtype=tf.int64) #seq_len needs type int64..
# input shape: (batch_size, n_steps, n_input)
_X = tf.transpose(_X, [1, 0, 2]) # permute n_steps and batch_size
# Reshape to prepare input to hidden activation
_X = tf.reshape(_X, [-1, n_input]) # (n_steps*batch_size, n_input)
# Linear activation
_X = tf.matmul(_X, _weights['hidden']) + _biases['hidden']
# Define lstm cells with tensorflow
# Forward direction cell
lstm_fw_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
# Backward direction cell
lstm_bw_cell = rnn_cell.BasicLSTMCell(n_hidden, forget_bias=1.0)
# Split data because rnn cell needs a list of inputs for the RNN inner loop
_X = tf.split(0, n_steps, _X) # n_steps * (batch_size, n_hidden)
# Get lstm cell output
outputs = rnn.bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, _X,
initial_state_fw=_istate_fw,
initial_state_bw=_istate_bw,
sequence_length=_seq_len)
# Linear activation
# Get inner loop last output
return tf.matmul(outputs[-1], _weights['out']) + _biases['out']
pred = BiRNN(x, istate_fw, istate_bw, weights, biases, batch_size, n_steps)
# Define loss and optimizer
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(pred, y)) # Softmax loss
optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) # Adam Optimizer
# Evaluate model
correct_pred = tf.equal(tf.argmax(pred,1), tf.argmax(y,1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
# Initializing the variables
init = tf.initialize_all_variables()
# Launch the graph
with tf.Session() as sess:
tr_batch_size = 128
sess.run(init)
step = 1
# Keep training until reach max iterations
while step * tr_batch_size < training_iters:
batch_xs, batch_ys = mnist.train.next_batch(tr_batch_size)
# Reshape data to get 28 seq of 28 elements
batch_xs = batch_xs.reshape((tr_batch_size, n_steps, n_input))
# Fit training using batch data
sess.run(optimizer, feed_dict={x: batch_xs, y: batch_ys,
istate_fw: np.zeros((tr_batch_size, 2*n_hidden)),
istate_bw: np.zeros((tr_batch_size, 2*n_hidden)),
batch_size: [tr_batch_size]})
if step % display_step == 0:
# Calculate batch accuracy
acc = sess.run(accuracy, feed_dict={x: batch_xs, y: batch_ys,
istate_fw: np.zeros((tr_batch_size, 2*n_hidden)),
istate_bw: np.zeros((tr_batch_size, 2*n_hidden)),
batch_size: [tr_batch_size]})
# Calculate batch loss
loss = sess.run(cost, feed_dict={x: batch_xs, y: batch_ys,
istate_fw: np.zeros((tr_batch_size, 2*n_hidden)),
istate_bw: np.zeros((tr_batch_size, 2*n_hidden)),
batch_size: [tr_batch_size]})
print "Iter " + str(step*tr_batch_size) + ", Minibatch Loss= " + "{:.6f}".format(loss) + \
", Training Accuracy= " + "{:.5f}".format(acc)
step += 1
print "Optimization Finished!"
# Calculate accuracy for 256 mnist test images
te_batch_size = 256
test_data = mnist.test.images[:te_batch_size].reshape((-1, n_steps, n_input))
test_label = mnist.test.labels[:te_batch_size]
print "Testing Accuracy:", sess.run(accuracy, feed_dict={x: test_data, y: test_label,
istate_fw: np.zeros((te_batch_size, 2*n_hidden)),
istate_bw: np.zeros((te_batch_size, 2*n_hidden)),
batch_size: [te_batch_size]})
from tensorflow-examples.
Thanks Aymeric. Your suggestions really helped me. I was trying to make the batch_size in a placeholder. I did not figure out how it works. Combined with the early stop tech working with variant length of input, I think now the bi-directional RNN module is completed. Thanks for the efforts.
from tensorflow-examples.
Glad to hear :) I close the issue then.
from tensorflow-examples.
Related Issues (20)
- What should TFlearn users do?
- 404 not
- 404 not found
- Please provide a example for stacked bidirectional LSTM for Tensorflow 2.x
- [Potential NAN bug] Loss may become NAN during training HOT 1
- Tensor
- How to get prediction code ?
- fixes for Word2Vec for Python 3
- ml_introduction.ipynb Links
- InternalError: Dst tensor is not initialized. [[{{node IteratorGetNext/_2}}]] [Op:__inference_distributed_function_24557]
- In the tf1 example: I replace the weigtht and bias with tf.layers.dense, I found the accuracy decrease. why??? HOT 1
- The CNN example diagram shows 3 conv & pooling layers but the implementation only has 2
- AttributeError on placeholder HOT 1
- fig
- In K-Means Example, when i am running "from tensorflow.contrib.factorization import KMeans" line, i am getting an error "ModuleNotFoundError: No module named 'tensorflow.contrib'" HOT 2
- possible issue at: tensorflow_v2/notebooks/3_NeuralNetworks/autoencoder.ipynb HOT 6
- Add a development container HOT 3
- Ikvvh
- TPU Usage
- Activities HOT 3
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from tensorflow-examples.