kuc2477 / pytorch-ewc Goto Github PK
View Code? Open in Web Editor NEWUnofficial PyTorch implementation of DeepMind's PNAS 2017 paper "Overcoming Catastrophic Forgetting"
License: MIT License
Unofficial PyTorch implementation of DeepMind's PNAS 2017 paper "Overcoming Catastrophic Forgetting"
License: MIT License
Great implementations, many thanks.
I'm new to the continuous learning and a little bit confused about its problem setting. I see that you generated different permuted MNIST data for every task (or session), but it looks labels of all 8 tasks are shared (0~10). My understanding is that continuous learning is about learning new task without forgetting old tasks. Should we give new labels (e.g., 11) as a new task? I appreciate your help.
For example, check Figure 2(e) of this paper: https://arxiv.org/pdf/1606.09282.pdf
I am trying to run EWC on my dataset with resnet50 model. While updating the fisher matrix using your function, My code says Cuda out of memory due to "log_liklihoods.append(output[:, target])" in the code. I read this "https://stackoverflow.com/questions/59805901/unable-to-allocate-gpu-memory-when-there-is-enough-of-cached-memory" and figured out the problem using 'detach()'. After doing detach etc, I get an error: RuntimeError: One of the differentiated Tensors appears to not have been used in the graph. Set allow_unused=True if this is the desired behavior.
To further solve this, I set "allow_unused=True" in autograd. As a result, all my gradients go to 0. Why is this happening?
I found that performing forward-backward in a loop is much faster than using autograd.grad with retain_graph=True. Your current code is:
loglikelihood_grads = zip(*[autograd.grad(
l, self.parameters(),
retain_graph=(i < len(loglikelihoods))
) for i, l in enumerate(loglikelihoods, 1)])
https://github.com/kuc2477/pytorch-ewc/blob/master/model.py#L75
after I change it to:
for batch, label in zip(buffer_data, buffer_label):
self.zero_grad()
loglikelihood = F.cross_entropy(self(batch), label)
loglikelihood.backward()
for n, p in self.named_parameters():
n = n.replace('.', '__')
grads[n] = grads.get(n, 0) + p.grad ** 2
it runs much faster. Can you please investigate this?
Thank you.
Thanh Tung
As the code, we can get negative log-likelihood by F.log_softmax. How can I get the log-likelihood of regression problem??
The computation for fisher matrix should be the average of the gradient^2. But in your implementation. you did the average of the loss, then only use that averaged loss to calculate the gradient and estimate the fisher matrix.
Is this a bug or the math is equivalent.
check the tensorflow version: https://github.com/ariseff/overcoming-catastrophic
Code in line 80 of model.py :
fisher_diagonals = [(g ** 2).mean() for g in loglikelihood_grads]
should be:
fisher_diagonals = [(g ** 2).mean(dim=0) for g in loglikelihood_grads]
When we train the current task, will we use the data of the previous task? ewc need task A data to compute fisher info, when we train task B how can we use ewc to constraint model's parameters change?
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.