Git Product home page Git Product logo

Comments (6)

whm1006 avatar whm1006 commented on June 18, 2024

另外,tf版本1.0.0,输入X的shape感觉好像不对:
根据tf.nn.dynamic_rnn(cell, inputs)的文档,inputs的shape应该是[batch_size, time_steps, ...],
但由generate函数得到的X的shape是[batch_size, 1, 10],这里time_step=10不是应该在第二个维度吗?
X的shape应该是[batch_size, 10, 1]才对吧?

from tensorflow-tutorial.

Hailprob avatar Hailprob commented on June 18, 2024

@wanghaoming1006 ,你好,我也有同样的问题,请问你后来修改通了吗

from tensorflow-tutorial.

whm1006 avatar whm1006 commented on June 18, 2024

@Hailprob 暂时还没有,之后如果理解了会分享

from tensorflow-tutorial.

forqzy avatar forqzy commented on June 18, 2024

@wanghaoming1006
@Hailprob
这个问题我也是想了很久才明白的,参照另一个,用keras写了一个,希望对大家有帮助
https://github.com/forqzy/keras-LSTM-predict-study

from tensorflow-tutorial.

whm1006 avatar whm1006 commented on June 18, 2024

@forqzy 感谢分享

from tensorflow-tutorial.

whm1006 avatar whm1006 commented on June 18, 2024

对于HIDDEN_SIZE的问题:换了台机器,改变HIDDEN_SIZE并不会报错,可能是设置问题。

对于1.0.0版本中X的shape问题:参考Mourad的原文,认为输入的shape应该是[batch_size, 10, 1]。我将源代码中的输入X改成[batch_size, 10, 1]后也有相同的预测效果:

def generate_data(seq):
    X = []
    y = []
    for i in range(len(seq) - TIMESTEPS - 1):
        X.append([[x] for x in seq[i: i + TIMESTEPS]]) # 此处将输入修改为[batch_size, 10, 1]
        y.append([seq[i + TIMESTEPS]])
    return np.array(X, dtype=np.float32), np.array(y, dtype=np.float32)

def lstm_model(X, y):
    lstm_cell = tf.contrib.rnn.BasicLSTMCell(HIDDEN_SIZE, state_is_tuple=True)
    cell = tf.contrib.rnn.MultiRNNCell([lstm_cell] * NUM_LAYERS)
    
    output, _ = tf.nn.dynamic_rnn(cell, X, dtype=tf.float32)
    output = tf.unstack(output, axis=1)[-1] # 此处修改为最后时刻的输出
    
    # 通过无激活函数的全联接层计算线性回归,并将数据压缩成一维数组的结构。
    predictions = tf.contrib.layers.fully_connected(output, 1, None)
    
    # 将predictions和labels调整统一的shape
    labels = tf.reshape(y, [-1])
    predictions=tf.reshape(predictions, [-1])
    
    loss = tf.losses.mean_squared_error(predictions, labels)
    
    train_op = tf.contrib.layers.optimize_loss(
        loss, tf.contrib.framework.get_global_step(),
        optimizer="Adagrad", learning_rate=0.1)

    return predictions, loss, train_op

如有理解错的地方,欢迎指正

from tensorflow-tutorial.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.