Git Product home page Git Product logo

pytorch-ewc's People

Contributors

kuc2477 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-ewc's Issues

labels are shared for every task

Great implementations, many thanks.

I'm new to the continuous learning and a little bit confused about its problem setting. I see that you generated different permuted MNIST data for every task (or session), but it looks labels of all 8 tasks are shared (0~10). My understanding is that continuous learning is about learning new task without forgetting old tasks. Should we give new labels (e.g., 11) as a new task? I appreciate your help.

For example, check Figure 2(e) of this paper: https://arxiv.org/pdf/1606.09282.pdf

Fisher update

I am trying to run EWC on my dataset with resnet50 model. While updating the fisher matrix using your function, My code says Cuda out of memory due to "log_liklihoods.append(output[:, target])" in the code. I read this "https://stackoverflow.com/questions/59805901/unable-to-allocate-gpu-memory-when-there-is-enough-of-cached-memory" and figured out the problem using 'detach()'. After doing detach etc, I get an error: RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
To further solve this, I set "allow_unused=True" in autograd. As a result, all my gradients go to 0. Why is this happening?

Performance improvement

I found that performing forward-backward in a loop is much faster than using autograd.grad with retain_graph=True. Your current code is:

        loglikelihood_grads = zip(*[autograd.grad(
            l, self.parameters(),
            retain_graph=(i < len(loglikelihoods))
        ) for i, l in enumerate(loglikelihoods, 1)])

https://github.com/kuc2477/pytorch-ewc/blob/master/model.py#L75

after I change it to:

        for batch, label in zip(buffer_data, buffer_label):
            self.zero_grad()
            loglikelihood = F.cross_entropy(self(batch), label)
            loglikelihood.backward()
            for n, p in self.named_parameters():
                n = n.replace('.', '__')
                grads[n] = grads.get(n, 0) + p.grad ** 2

it runs much faster. Can you please investigate this?

Thank you.
Thanh Tung

The size of fisher matrix is wrong

Code in line 80 of model.py :
fisher_diagonals = [(g ** 2).mean() for g in loglikelihood_grads]
should be:
fisher_diagonals = [(g ** 2).mean(dim=0) for g in loglikelihood_grads]

may I ask some questions about trainning?

When we train the current task, will we use the data of the previous task? ewc need task A data to compute fisher info, when we train task B how can we use ewc to constraint model's parameters change?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.