Git Product home page Git Product logo

Comments (12)

jingweiz avatar jingweiz commented on July 29, 2024

@AjayTalati
You said you also have an implementation so I guessed you also used the cumprod for computing the allocation weights, for me the backward simply fails with cumprod, do you maybe have an idea why?
Thanks a lot!

from pytorch-neucom.

jingweiz avatar jingweiz commented on July 29, 2024

Ah I just found that you have a custom cumprod function. But if I am reading it correctly, this is treating the input only as a tensor but not as a variable, then all the computation graph on the variable will be lost after this function right? Why this one does not need to be in the graph? Maybe there're something obvious that I'm missing?

from pytorch-neucom.

AjayTalati avatar AjayTalati commented on July 29, 2024

Hi @jingweiz

for me the backward simply fails with cumprod,

Can you post your error messages?

from pytorch-neucom.

ypxie avatar ypxie commented on July 29, 2024

If the input is tensor, then it will call the torch.cumprod. But if the input is variable, it will call the custom cumprod which is written using autograd torch function. So it should work just fine.

from pytorch-neucom.

jingweiz avatar jingweiz commented on July 29, 2024

@ypxie
Hey, thanks a lot for the reply! Can you explain further why it is a leaf node? When we input a sequence row by row, doesn't the current usage_t depend on the usage_{t-1}? I'm currently under the impression that all those intermediate calculations should be carried out in variables so as to keep the computation history for the complete sequence, then those variables are only reset and detached from the graph when there's a new sequence coming in. Or what is your understanding?

from pytorch-neucom.

jingweiz avatar jingweiz commented on July 29, 2024

@AjayTalati
Hey, it's not an error from this repo, but when I use the current cumprod from pytorch, and is due to the autograd has not been implemented for this op yet: https://discuss.pytorch.org/t/cumprod-exclusive-true-equivalences/2614/8 and https://github.com/pytorch/pytorch/pull/1439

from pytorch-neucom.

ypxie avatar ypxie commented on July 29, 2024

For example, index_mapper is a leaf node, so it doesn't matter that it is constructed online from a numpy array during the process.
For the gradient of cumprod, please take a look at updated response.

from pytorch-neucom.

jingweiz avatar jingweiz commented on July 29, 2024

@ypxie
I mean the allocation_weight depend on the usage, then the write_weight depend on the allocation weight, then the memory is updated with the write_weight, and the usage_{t} is also calculated out of usage_{t-1}. If the usage is somehow dropped from the graph, then all the history computations it carried will be lost, then the grad would not pass back to the older time steps; in this case the backward would still work, but it's just that the earlier part of the graph won't get their gradients cos the connection is lost from the detached usage. What do you think?

from pytorch-neucom.

jingweiz avatar jingweiz commented on July 29, 2024

So my solution is that: instead of just use the data from usage and drop the variable, I just implement a fake_cumprod using the existing torch ops, this way the backward would still work:

def fake_cumprod(vb):
    """
    args:
        vb:  [hei x wid] Variable
          -> NOTE: we are lazy here so now it only supports cumprod along wid
    """
    # real_cumprod = torch.cumprod(vb.data, 1)
    vb = vb.unsqueeze(0)
    mul_mask_vb = Variable(torch.zeros(vb.size(2), vb.size(1), vb.size(2))).type_as(vb)
    for i in range(vb.size(2)):
       mul_mask_vb[i, :, :i+1] = 1
    add_mask_vb = 1 - mul_mask_vb
    vb = vb.expand_as(mul_mask_vb) * mul_mask_vb + add_mask_vb
    vb = torch.prod(vb, 2).transpose(0, 2)
    # print(real_cumprod - vb.data) # NOTE: checked, ==0
    return vb

from pytorch-neucom.

ypxie avatar ypxie commented on July 29, 2024

the custom cumprod should work well with autograd, cause it is implemented based on the autograd torch function.

from pytorch-neucom.

jingweiz avatar jingweiz commented on July 29, 2024

@ypxie
Hey, thanks for the quick reply:)
What I'm confused is Line 98 in utils.py:

        output = Variable(inputs.data.new(*shape_).fill_(1.0), requires_grad = True)

Here the inputs is detached from the graph right? According to here: https://discuss.pytorch.org/t/help-clarifying-repackage-hidden-in-word-language-model/226/2?u=apaszke

from pytorch-neucom.

ypxie avatar ypxie commented on July 29, 2024

you can delete this line, it is not used in the following. I will remove it.

from pytorch-neucom.

Related Issues (7)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.