katerakelly / pytorch-maml Goto Github PK
View Code? Open in Web Editor NEWPyTorch implementation of MAML: https://arxiv.org/abs/1703.03400
License: MIT License
PyTorch implementation of MAML: https://arxiv.org/abs/1703.03400
License: MIT License
Hi Kate,
Under split='train' in inner loop, the data loader always output the same data no matter under what batch_cutoff. I know that this is the setting in the MAML paper, but I am wondering is there a way to make batch_cutoff work? Many thanks!
Hi kelly, I read the code and find that the get_data_loader function get the same batch samples every iteration, is that you want? (sorry, I'm not familiar with the maml algorithm.)
below is the samples got from the get_data_loader funtion.
('test train loader: ', ['../data/omniglot/images_evaluation/Kannada/character01/1205_20.png',
'../data/omniglot/images_evaluation/Manipuri/character05/1323_07.png',
'../data/omniglot/images_evaluation/Angelic/character06/0970_19.png',
'../data/omniglot/images_evaluation/ULOG/character14/1611_15.png',
'../data/omniglot/images_evaluation/Keble/character02/1247_11.png',
'../data/omniglot/images_evaluation/Kannada/character05/1209_20.png',
'../data/omniglot/images_evaluation/Sylheti/character05/1507_12.png',
'../data/omniglot/images_evaluation/Manipuri/character15/1333_01.png',
'../data/omniglot/images_evaluation/Kannada/character27/1231_08.png',
'../data/omniglot/images_evaluation/Mongolian/character30/1388_20.png',
'../data/omniglot/images_evaluation/Keble/character07/1252_12.png',
'../data/omniglot/images_evaluation/Tibetan/character05/1560_19.png',
'../data/omniglot/images_evaluation/Glagolitic/character14/1128_04.png',
'../data/omniglot/images_evaluation/Ge_ez/character06/1094_07.png',
'../data/omniglot/images_evaluation/Avesta/character05/1067_17.png',
'../data/omniglot/images_evaluation/Malayalam/character05/1276_16.png',
'../data/omniglot/images_evaluation/Aurek-Besh/character03/1039_18.png',
'../data/omniglot/images_evaluation/Syriac_(Serto)/character14/1493_07.png',
'../data/omniglot/images_evaluation/Keble/character16/1261_09.png',
'../data/omniglot/images_evaluation/Aurek-Besh/character13/1049_15.png'])
('test train loader: ', ['../data/omniglot/images_evaluation/Syriac_(Serto)/character14/1493_07.png',
'../data/omniglot/images_evaluation/Sylheti/character05/1507_12.png',
'../data/omniglot/images_evaluation/Keble/character16/1261_09.png',
'../data/omniglot/images_evaluation/Ge_ez/character06/1094_07.png',
'../data/omniglot/images_evaluation/Tibetan/character05/1560_19.png',
'../data/omniglot/images_evaluation/Manipuri/character05/1323_07.png',
'../data/omniglot/images_evaluation/Kannada/character05/1209_20.png',
'../data/omniglot/images_evaluation/Manipuri/character15/1333_01.png',
'../data/omniglot/images_evaluation/Avesta/character05/1067_17.png',
'../data/omniglot/images_evaluation/ULOG/character14/1611_15.png',
'../data/omniglot/images_evaluation/Glagolitic/character14/1128_04.png',
'../data/omniglot/images_evaluation/Malayalam/character05/1276_16.png',
'../data/omniglot/images_evaluation/Aurek-Besh/character13/1049_15.png',
'../data/omniglot/images_evaluation/Kannada/character01/1205_20.png',
'../data/omniglot/images_evaluation/Keble/character02/1247_11.png',
'../data/omniglot/images_evaluation/Angelic/character06/0970_19.png',
'../data/omniglot/images_evaluation/Aurek-Besh/character03/1039_18.png',
'../data/omniglot/images_evaluation/Mongolian/character30/1388_20.png',
'../data/omniglot/images_evaluation/Kannada/character27/1231_08.png',
'../data/omniglot/images_evaluation/Keble/character07/1252_12.png'])
('test train loader: ', ['../data/omniglot/images_evaluation/Kannada/character05/1209_20.png',
'../data/omniglot/images_evaluation/Malayalam/character05/1276_16.png',
'../data/omniglot/images_evaluation/Manipuri/character05/1323_07.png',
'../data/omniglot/images_evaluation/Tibetan/character05/1560_19.png',
'../data/omniglot/images_evaluation/Kannada/character27/1231_08.png',
'../data/omniglot/images_evaluation/Ge_ez/character06/1094_07.png',
'../data/omniglot/images_evaluation/Keble/character16/1261_09.png',
'../data/omniglot/images_evaluation/Kannada/character01/1205_20.png',
'../data/omniglot/images_evaluation/Glagolitic/character14/1128_04.png',
'../data/omniglot/images_evaluation/Angelic/character06/0970_19.png',
'../data/omniglot/images_evaluation/Aurek-Besh/character03/1039_18.png',
'../data/omniglot/images_evaluation/ULOG/character14/1611_15.png',
'../data/omniglot/images_evaluation/Aurek-Besh/character13/1049_15.png',
'../data/omniglot/images_evaluation/Keble/character02/1247_11.png',
'../data/omniglot/images_evaluation/Sylheti/character05/1507_12.png',
'../data/omniglot/images_evaluation/Avesta/character05/1067_17.png',
'../data/omniglot/images_evaluation/Manipuri/character15/1333_01.png',
'../data/omniglot/images_evaluation/Syriac_(Serto)/character14/1493_07.png',
'../data/omniglot/images_evaluation/Mongolian/character30/1388_20.png',
'../data/omniglot/images_evaluation/Keble/character07/1252_12.png'])
('test train loader: ', ['../data/omniglot/images_evaluation/Keble/character02/1247_11.png',
'../data/omniglot/images_evaluation/ULOG/character14/1611_15.png',
'../data/omniglot/images_evaluation/Tibetan/character05/1560_19.png',
'../data/omniglot/images_evaluation/Kannada/character27/1231_08.png',
'../data/omniglot/images_evaluation/Kannada/character05/1209_20.png',
'../data/omniglot/images_evaluation/Syriac_(Serto)/character14/1493_07.png',
'../data/omniglot/images_evaluation/Aurek-Besh/character03/1039_18.png',
'../data/omniglot/images_evaluation/Avesta/character05/1067_17.png',
'../data/omniglot/images_evaluation/Glagolitic/character14/1128_04.png', '../data/omniglot/images_evaluation/Malayalam/character05/1276_16.png', '../data/omniglot/images_evaluation/Mongolian/character30/1388_20.png', '../data/omniglot/images_evaluation/Manipuri/character15/1333_01.png', '../data/omniglot/images_evaluation/Manipuri/character05/1323_07.png', '../data/omniglot/images_evaluation/Aurek-Besh/character13/1049_15.png', '../data/omniglot/images_evaluation/Sylheti/character05/1507_12.png', '../data/omniglot/images_evaluation/Kannada/character01/1205_20.png', '../data/omniglot/images_evaluation/Ge_ez/character06/1094_07.png', '../data/omniglot/images_evaluation/Keble/character07/1252_12.png', '../data/omniglot/images_evaluation/Keble/character16/1261_09.png', '../data/omniglot/images_evaluation/Angelic/character06/0970_19.png'])
('test train loader: ',
Thanks for your nice implementation. But I meet this mistake when I sh train-omniglot-20way-1shot.sh ,is there any dimension dismatch? I don't know how to fix it. It's much appreciated if anyone can help me.
Traceback (most recent call last):
File "maml.py", line 230, in
main()
File "/usr/local/lib/python3.6/dist-packages/click/core.py", line 829, in call
return self.main(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/click/core.py", line 782, in main
rv = self.invoke(ctx)
File "/usr/local/lib/python3.6/dist-packages/click/core.py", line 1066, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/usr/local/lib/python3.6/dist-packages/click/core.py", line 610, in invoke
return callback(*args, **kwargs)
File "maml.py", line 227, in main
learner.train(exp)
File "maml.py", line 151, in train
mt_loss, mt_acc, mv_loss, mv_acc = self.test()
File "maml.py", line 113, in test
tloss, tacc = evaluate(test_net, train_loader)
File "/home/src-plot/score.py", line 30, in evaluate
loss += l.data.cpu().numpy()[0]
IndexError: too many indices for array
Hi,
Thanks for sharing the code. I have questions about the implementation for inner loop:
Is there any reason for the special case of i == 0
? Can we just use fast_weights
for i == 0
?
Thanks!
Hi Kate Rakelly,
Thank you for sharing the MAML source code.
Here I have a few questions:
1 Since you have obtained the meta_grads in the function “forward()”, why not directly utilize them to update the parameters of self.net. (param - learning rategradient) (Here, it seems that you want to achieve that by defining a “meta_update()” function. And you make the annotation: “We use a dummy forward / backward pass to get the correct grads into self.net”. for this command: “loss, out = forward_pass(self.net, in_, target)”
I’m confused why you compute the loss using val data with respect to original “\theta”.)
2. In the “meta_update()” function, you use this (register_hook) to change the gradients. But the functions are(opt.zero_grad; loss.backward) also used to compute the gradients.
Do these functions(opt.zero_grad; loss.backward) overwrite the previous gradients (changed by using the register_hook function)?
I’m looking forward to your reply.
Do you have the result on miniImageNet?
Given that we are using Batch Normalization layers here, shouldn't we be calling model.eval() before getting accuracy numbers on the test set?
In this line: https://github.com/katerakelly/pytorch-maml/blob/master/src/inner_loop.py#L76
why don't use grads = torch.autograd.grad(loss, fast_weights.values())
? but self.paramters()
which is the original parameters before any inner-update step.
In the algorithm 2 mentioned in this paper,we have computed adapted parameters θ' in step 7,and I think we have updated the model for task T_i.
What does the meta-update mean?(in step 8)
Hi, Amazing repo and implementation :)
Please can you confirm what license this code is released under?
Net_helper file missing?
I have encountered the problem "IndexError: too many indices for array". Is it related to the dataset I use? I just download this file (https://github.com/brendenlake/omniglot/blob/master/python/images_evaluation.zip) and put it under data. The log is as follows:
./train-omniglot-5way-1shot.sh
tee: ../logs/maml-omniglot-5way-1shot-TEST: No such file or directory
exp maml-omniglot-5way-1shot-TEST
dataset omniglot
num_cls 5
num_inst 1
batch 1
m_batch 32
num_updates 15000
num_inner_updates 5
lr 1e-1
meta_lr 1e-3
gpu 0
Setting GPU to 0
init weights
init weights
init weights
Traceback (most recent call last):
File "maml.py", line 230, in <module>
main()
File "/home/maple/programs/miniconda2/lib/python2.7/site-packages/click/core.py", line 722, in __call__
return self.main(*args, **kwargs)
File "/home/maple/programs/miniconda2/lib/python2.7/site-packages/click/core.py", line 697, in main
rv = self.invoke(ctx)
File "/home/maple/programs/miniconda2/lib/python2.7/site-packages/click/core.py", line 895, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/home/maple/programs/miniconda2/lib/python2.7/site-packages/click/core.py", line 535, in invoke
return callback(*args, **kwargs)
File "maml.py", line 227, in main
learner.train(exp)
File "maml.py", line 151, in train
mt_loss, mt_acc, mv_loss, mv_acc = self.test()
File "maml.py", line 113, in test
tloss, tacc = evaluate(test_net, train_loader)
File "/home/maple/pytorch-maml/src/score.py", line 30, in evaluate
loss += l.data.cpu().numpy()[0]
IndexError: too many indices for array
Hi, thanks for the code.
I am trying test the code on the mini-imagenet dataset. However, my result is very bad, much worse than the reported ones. I wonder if you have tested the implementation on the dataset, and how about your results. If possible, could you possible upload your scripts of experimenting on the mini-imagenet dataset. Your help is much appreciated.
In the original paper, the authors claimed that MAML needs second gradient and Hessian-vector products. Could you explain how do you implement this or Pytorch just do this automatically? Thanks!
referenced dragen1860/MAML-Pytorch#18
Thanks for your helpful codes.
I want to change the origin net's weights using tasks-average gradients on every task's query set. then using opt.step to uodate it.
Inspired by the optim.GSD source code,
Can your code:
`
hooks = []
for (k,v) in self.net.named_parameters():
def get_closure():
key = k
def replace_grad(grad):
return gradients[key]
return replace_grad
hooks.append(v.register_hook(get_closure()))
self.opt.zero_grad()
loss.backward()
self.opt.step()
for h in hooks:
h.remove()
`
be replaced by:
for (k,v), (k,g) in zip(self.net.named_parameters(), gradients): v.grad.data = g.data self.opt.step() self.opt.zero_grad()
or just by:
for (k,v), (k,g) in zip(self.net.named_parameters(), gradients): v.data.add_(-meta_lr, g.data)
Thanks for your time!
Hi Kate, thanks for the Pytorch code of MAML!
I have two questions(in which I suspect a bug?) on your implementation.
Line 10, Algorithm2 from the original paper indicates that meta_update is performed using each D'_i. To do so with your code, I think the function meta_update need access to every task sampled, since each task contains its D'_i in your implementation.
Line 172 in 75907ac
Line 10 also states that meta-loss is calculated with adapted parameters.
Line 71 in 75907ac
Am I missing something?
Thanks for your good implementation of MAML, however, I think that maybe use state_dict() and load_stat_dict() is much faster than modifying the weights (in omniglot_net.py 43), can I first deepcopy the net parameters(state_dict()) and use the fast weights (also use a optimizer to update), then load the origin parameters back to update the meta learner? Thanks.
Thanks for the code. I meet the following problem when trying to run your code. I used pytorch 0.2.0, python 3.5.
tee: ../logs/maml-omniglot-5way-1shot-TEST: No such file or directory
exp maml-omniglot-5way-1shot-TEST
dataset omniglot
num_cls 5
num_inst 1
batch 1
m_batch 32
num_updates 15000
num_inner_updates 5
lr 1e-1
meta_lr 1e-3
gpu 0
Setting GPU to 0
init weights
init weights
init weights
> /zdata/users/kaili/code/pytorch-maml/src/maml.py(102)test()
-> for _ in range(10):
(Pdb) c
-------------------------
Meta train: 0.1178786426782608 1.0
Meta val: 1.139253568649292 0.7
-------------------------
inner step 0
inner step 1
Traceback (most recent call last):
File "maml.py", line 234, in <module>
main()
File "/home/mli/kaili/anaconda2/envs/kai_py35_torch02/lib/python3.5/site-packages/click/core.py", line 722, in __call__
return self.main(*args, **kwargs)
File "/home/mli/kaili/anaconda2/envs/kai_py35_torch02/lib/python3.5/site-packages/click/core.py", line 697, in main
rv = self.invoke(ctx)
File "/home/mli/kaili/anaconda2/envs/kai_py35_torch02/lib/python3.5/site-packages/click/core.py", line 895, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/home/mli/kaili/anaconda2/envs/kai_py35_torch02/lib/python3.5/site-packages/click/core.py", line 535, in invoke
return callback(*args, **kwargs)
File "maml.py", line 231, in main
learner.train(exp)
File "maml.py", line 166, in train
metrics, g = self.fast_net.forward(task)
File "/zdata/users/kaili/code/pytorch-maml/src/inner_loop.py", line 61, in forward
loss, _ = self.forward_pass(in_, target, fast_weights)
File "/zdata/users/kaili/code/pytorch-maml/src/inner_loop.py", line 43, in forward_pass
out = self.net_forward(input_var, weights)
File "/zdata/users/kaili/code/pytorch-maml/src/inner_loop.py", line 36, in net_forward
return super(InnerLoop, self).forward(x, weights)
File "/zdata/users/kaili/code/pytorch-maml/src/omniglot_net.py", line 49, in forward
x = batchnorm(x, weight = weights['features.bn1.weight'], bias = weights['features.bn1.bias'], momentum=1)
File "/zdata/users/kaili/code/pytorch-maml/src/layers.py", line 31, in batchnorm
running_mean = torch.zeros(np.prod(np.array(input.data.size()[1]))).cuda()
TypeError: torch.zeros received an invalid combination of arguments - got (numpy.int64), but expected one of:
* (int ... size)
didn't match because some of the arguments have invalid types: (!numpy.int64!)
* (torch.Size size)
didn't match because some of the arguments have invalid types: (!numpy.int64!)
Any clue to solve this? Thanks.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.