Git Product home page Git Product logo

pytorch-maml's Introduction

Hi there 👋

pytorch-maml's People

Contributors

katerakelly avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

pytorch-maml's Issues

batch_cutoff not working for split = 'train'

Hi Kate,

Under split='train' in inner loop, the data loader always output the same data no matter under what batch_cutoff. I know that this is the setting in the MAML paper, but I am wondering is there a way to make batch_cutoff work? Many thanks!

get_data_loader function gets same train samples

Hi kelly, I read the code and find that the get_data_loader function get the same batch samples every iteration, is that you want? (sorry, I'm not familiar with the maml algorithm.)

below is the samples got from the get_data_loader funtion.

('test train loader: ', ['../data/omniglot/images_evaluation/Kannada/character01/1205_20.png',
'../data/omniglot/images_evaluation/Manipuri/character05/1323_07.png',
'../data/omniglot/images_evaluation/Angelic/character06/0970_19.png',
'../data/omniglot/images_evaluation/ULOG/character14/1611_15.png',
'../data/omniglot/images_evaluation/Keble/character02/1247_11.png',
'../data/omniglot/images_evaluation/Kannada/character05/1209_20.png',
'../data/omniglot/images_evaluation/Sylheti/character05/1507_12.png',
'../data/omniglot/images_evaluation/Manipuri/character15/1333_01.png',
'../data/omniglot/images_evaluation/Kannada/character27/1231_08.png',
'../data/omniglot/images_evaluation/Mongolian/character30/1388_20.png',
'../data/omniglot/images_evaluation/Keble/character07/1252_12.png',
'../data/omniglot/images_evaluation/Tibetan/character05/1560_19.png',
'../data/omniglot/images_evaluation/Glagolitic/character14/1128_04.png',
'../data/omniglot/images_evaluation/Ge_ez/character06/1094_07.png',
'../data/omniglot/images_evaluation/Avesta/character05/1067_17.png',
'../data/omniglot/images_evaluation/Malayalam/character05/1276_16.png',
'../data/omniglot/images_evaluation/Aurek-Besh/character03/1039_18.png',
'../data/omniglot/images_evaluation/Syriac_(Serto)/character14/1493_07.png',
'../data/omniglot/images_evaluation/Keble/character16/1261_09.png',
'../data/omniglot/images_evaluation/Aurek-Besh/character13/1049_15.png'])
('test train loader: ', ['../data/omniglot/images_evaluation/Syriac_(Serto)/character14/1493_07.png',
'../data/omniglot/images_evaluation/Sylheti/character05/1507_12.png',
'../data/omniglot/images_evaluation/Keble/character16/1261_09.png',
'../data/omniglot/images_evaluation/Ge_ez/character06/1094_07.png',
'../data/omniglot/images_evaluation/Tibetan/character05/1560_19.png',
'../data/omniglot/images_evaluation/Manipuri/character05/1323_07.png',
'../data/omniglot/images_evaluation/Kannada/character05/1209_20.png',
'../data/omniglot/images_evaluation/Manipuri/character15/1333_01.png',
'../data/omniglot/images_evaluation/Avesta/character05/1067_17.png',
'../data/omniglot/images_evaluation/ULOG/character14/1611_15.png',
'../data/omniglot/images_evaluation/Glagolitic/character14/1128_04.png',
'../data/omniglot/images_evaluation/Malayalam/character05/1276_16.png',
'../data/omniglot/images_evaluation/Aurek-Besh/character13/1049_15.png',
'../data/omniglot/images_evaluation/Kannada/character01/1205_20.png',
'../data/omniglot/images_evaluation/Keble/character02/1247_11.png',
'../data/omniglot/images_evaluation/Angelic/character06/0970_19.png',
'../data/omniglot/images_evaluation/Aurek-Besh/character03/1039_18.png',
'../data/omniglot/images_evaluation/Mongolian/character30/1388_20.png',
'../data/omniglot/images_evaluation/Kannada/character27/1231_08.png',
'../data/omniglot/images_evaluation/Keble/character07/1252_12.png'])
('test train loader: ', ['../data/omniglot/images_evaluation/Kannada/character05/1209_20.png',
'../data/omniglot/images_evaluation/Malayalam/character05/1276_16.png',
'../data/omniglot/images_evaluation/Manipuri/character05/1323_07.png',
'../data/omniglot/images_evaluation/Tibetan/character05/1560_19.png',
'../data/omniglot/images_evaluation/Kannada/character27/1231_08.png',
'../data/omniglot/images_evaluation/Ge_ez/character06/1094_07.png',
'../data/omniglot/images_evaluation/Keble/character16/1261_09.png',
'../data/omniglot/images_evaluation/Kannada/character01/1205_20.png',
'../data/omniglot/images_evaluation/Glagolitic/character14/1128_04.png',
'../data/omniglot/images_evaluation/Angelic/character06/0970_19.png',
'../data/omniglot/images_evaluation/Aurek-Besh/character03/1039_18.png',
'../data/omniglot/images_evaluation/ULOG/character14/1611_15.png',
'../data/omniglot/images_evaluation/Aurek-Besh/character13/1049_15.png',
'../data/omniglot/images_evaluation/Keble/character02/1247_11.png',
'../data/omniglot/images_evaluation/Sylheti/character05/1507_12.png',
'../data/omniglot/images_evaluation/Avesta/character05/1067_17.png',
'../data/omniglot/images_evaluation/Manipuri/character15/1333_01.png',
'../data/omniglot/images_evaluation/Syriac_(Serto)/character14/1493_07.png',
'../data/omniglot/images_evaluation/Mongolian/character30/1388_20.png',
'../data/omniglot/images_evaluation/Keble/character07/1252_12.png'])
('test train loader: ', ['../data/omniglot/images_evaluation/Keble/character02/1247_11.png',
'../data/omniglot/images_evaluation/ULOG/character14/1611_15.png',
'../data/omniglot/images_evaluation/Tibetan/character05/1560_19.png',
'../data/omniglot/images_evaluation/Kannada/character27/1231_08.png',
'../data/omniglot/images_evaluation/Kannada/character05/1209_20.png',
'../data/omniglot/images_evaluation/Syriac_(Serto)/character14/1493_07.png',
'../data/omniglot/images_evaluation/Aurek-Besh/character03/1039_18.png',
'../data/omniglot/images_evaluation/Avesta/character05/1067_17.png',
'../data/omniglot/images_evaluation/Glagolitic/character14/1128_04.png', '../data/omniglot/images_evaluation/Malayalam/character05/1276_16.png', '../data/omniglot/images_evaluation/Mongolian/character30/1388_20.png', '../data/omniglot/images_evaluation/Manipuri/character15/1333_01.png', '../data/omniglot/images_evaluation/Manipuri/character05/1323_07.png', '../data/omniglot/images_evaluation/Aurek-Besh/character13/1049_15.png', '../data/omniglot/images_evaluation/Sylheti/character05/1507_12.png', '../data/omniglot/images_evaluation/Kannada/character01/1205_20.png', '../data/omniglot/images_evaluation/Ge_ez/character06/1094_07.png', '../data/omniglot/images_evaluation/Keble/character07/1252_12.png', '../data/omniglot/images_evaluation/Keble/character16/1261_09.png', '../data/omniglot/images_evaluation/Angelic/character06/0970_19.png'])
('test train loader: ',

Maybe dimension mistake

Thanks for your nice implementation. But I meet this mistake when I sh train-omniglot-20way-1shot.sh ,is there any dimension dismatch? I don't know how to fix it. It's much appreciated if anyone can help me.
Traceback (most recent call last):
File "maml.py", line 230, in
main()
File "/usr/local/lib/python3.6/dist-packages/click/core.py", line 829, in call
return self.main(*args, **kwargs)
File "/usr/local/lib/python3.6/dist-packages/click/core.py", line 782, in main
rv = self.invoke(ctx)
File "/usr/local/lib/python3.6/dist-packages/click/core.py", line 1066, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/usr/local/lib/python3.6/dist-packages/click/core.py", line 610, in invoke
return callback(*args, **kwargs)
File "maml.py", line 227, in main
learner.train(exp)
File "maml.py", line 151, in train
mt_loss, mt_acc, mv_loss, mv_acc = self.test()
File "maml.py", line 113, in test
tloss, tacc = evaluate(test_net, train_loader)
File "/home/src-plot/score.py", line 30, in evaluate
loss += l.data.cpu().numpy()[0]
IndexError: too many indices for array

meta_update function

Hi Kate Rakelly,
Thank you for sharing the MAML source code.
Here I have a few questions:
1  Since you have obtained the meta_grads in the function “forward()”, why not directly utilize them to update the parameters of self.net. (param - learning rategradient)  (Here, it seems that you want to achieve that by defining a “meta_update()” function. And you make the annotation: “We use a dummy forward / backward pass to get the correct grads into self.net”. for this command: “loss, out = forward_pass(self.net, in_, target)”
I’m confused why you compute the loss using val data with respect to original “\theta”.)
2. In the “meta_update(
)” function, you use this (register_hook) to change the gradients. But the functions are(opt.zero_grad;  loss.backward) also used to compute the gradients.
Do these functions(opt.zero_grad;  loss.backward) overwrite the previous gradients (changed by using the register_hook function)?
I’m looking forward to your reply.

Model.eval()?

Given that we are using Batch Normalization layers here, shouldn't we be calling model.eval() before getting accuracy numbers on the test set?

License?

Hi, Amazing repo and implementation :)

Please can you confirm what license this code is released under?

How to solve "IndexError: too many indices for array"?

I have encountered the problem "IndexError: too many indices for array". Is it related to the dataset I use? I just download this file (https://github.com/brendenlake/omniglot/blob/master/python/images_evaluation.zip) and put it under data. The log is as follows:

./train-omniglot-5way-1shot.sh
tee: ../logs/maml-omniglot-5way-1shot-TEST: No such file or directory
exp maml-omniglot-5way-1shot-TEST
dataset omniglot
num_cls 5
num_inst 1
batch 1
m_batch 32
num_updates 15000
num_inner_updates 5
lr 1e-1
meta_lr 1e-3
gpu 0
Setting GPU to 0
init weights
init weights
init weights
Traceback (most recent call last):
  File "maml.py", line 230, in <module>
    main()
  File "/home/maple/programs/miniconda2/lib/python2.7/site-packages/click/core.py", line 722, in __call__
    return self.main(*args, **kwargs)
  File "/home/maple/programs/miniconda2/lib/python2.7/site-packages/click/core.py", line 697, in main
    rv = self.invoke(ctx)
  File "/home/maple/programs/miniconda2/lib/python2.7/site-packages/click/core.py", line 895, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/home/maple/programs/miniconda2/lib/python2.7/site-packages/click/core.py", line 535, in invoke
    return callback(*args, **kwargs)
  File "maml.py", line 227, in main
    learner.train(exp)
  File "maml.py", line 151, in train
    mt_loss, mt_acc, mv_loss, mv_acc = self.test()
  File "maml.py", line 113, in test
    tloss, tacc = evaluate(test_net, train_loader)
  File "/home/maple/pytorch-maml/src/score.py", line 30, in evaluate
    loss += l.data.cpu().numpy()[0]
IndexError: too many indices for array

result on mini-imagenet dataset

Hi, thanks for the code.

I am trying test the code on the mini-imagenet dataset. However, my result is very bad, much worse than the reported ones. I wonder if you have tested the implementation on the dataset, and how about your results. If possible, could you possible upload your scripts of experimenting on the mini-imagenet dataset. Your help is much appreciated.

Hessian-vector products

In the original paper, the authors claimed that MAML needs second gradient and Hessian-vector products. Could you explain how do you implement this or Pytorch just do this automatically? Thanks!
referenced dragen1860/MAML-Pytorch#18

why need use hook?

Thanks for your helpful codes.

I want to change the origin net's weights using tasks-average gradients on every task's query set. then using opt.step to uodate it.
Inspired by the optim.GSD source code,
Can your code:
`
hooks = []
for (k,v) in self.net.named_parameters():
def get_closure():
key = k
def replace_grad(grad):
return gradients[key]
return replace_grad
hooks.append(v.register_hook(get_closure()))

Compute grads for current step, replace with summed gradients as defined by hook

self.opt.zero_grad()
loss.backward()

Update the net parameters with the accumulated gradient according to optimizer

self.opt.step()

Remove the hooks before next training phase

for h in hooks:
h.remove()
`

be replaced by:
for (k,v), (k,g) in zip(self.net.named_parameters(), gradients): v.grad.data = g.data self.opt.step() self.opt.zero_grad()

or just by:
for (k,v), (k,g) in zip(self.net.named_parameters(), gradients): v.data.add_(-meta_lr, g.data)

Thanks for your time!

meta_update with a single task and meta-loss calculated with current weight?

Hi Kate, thanks for the Pytorch code of MAML!

I have two questions(in which I suspect a bug?) on your implementation.

image
Line 10, Algorithm2 from the original paper indicates that meta_update is performed using each D'_i. To do so with your code, I think the function meta_update need access to every task sampled, since each task contains its D'_i in your implementation.

self.meta_update(task, grads)

However, it seems that you perform meta_update with a single task, resulting in using only one D'_i of a specific task.

Line 10 also states that meta-loss is calculated with adapted parameters.

loss, out = forward_pass(self.net, in_, target)

You seem to have calculated meta-loss with self.net, which I think is "original parameters"(\theta_i) in stead of adapted parameters.

Am I missing something?

About the Model parameters updating in OmniglotNet Class

Thanks for your good implementation of MAML, however, I think that maybe use state_dict() and load_stat_dict() is much faster than modifying the weights (in omniglot_net.py 43), can I first deepcopy the net parameters(state_dict()) and use the fast weights (also use a optimizer to update), then load the origin parameters back to update the meta learner? Thanks.

errors when loading weights

Thanks for the code. I meet the following problem when trying to run your code. I used pytorch 0.2.0, python 3.5.

tee: ../logs/maml-omniglot-5way-1shot-TEST: No such file or directory
exp maml-omniglot-5way-1shot-TEST
dataset omniglot
num_cls 5
num_inst 1
batch 1
m_batch 32
num_updates 15000
num_inner_updates 5
lr 1e-1
meta_lr 1e-3
gpu 0
Setting GPU to 0
init weights
init weights
init weights
> /zdata/users/kaili/code/pytorch-maml/src/maml.py(102)test()
-> for _ in range(10):
(Pdb) c
-------------------------
Meta train: 0.1178786426782608 1.0
Meta val: 1.139253568649292 0.7
-------------------------
inner step 0
inner step 1
Traceback (most recent call last):
  File "maml.py", line 234, in <module>
    main()
  File "/home/mli/kaili/anaconda2/envs/kai_py35_torch02/lib/python3.5/site-packages/click/core.py", line 722, in __call__
    return self.main(*args, **kwargs)
  File "/home/mli/kaili/anaconda2/envs/kai_py35_torch02/lib/python3.5/site-packages/click/core.py", line 697, in main
    rv = self.invoke(ctx)
  File "/home/mli/kaili/anaconda2/envs/kai_py35_torch02/lib/python3.5/site-packages/click/core.py", line 895, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/home/mli/kaili/anaconda2/envs/kai_py35_torch02/lib/python3.5/site-packages/click/core.py", line 535, in invoke
    return callback(*args, **kwargs)
  File "maml.py", line 231, in main
    learner.train(exp)
  File "maml.py", line 166, in train
    metrics, g = self.fast_net.forward(task)
  File "/zdata/users/kaili/code/pytorch-maml/src/inner_loop.py", line 61, in forward
    loss, _ = self.forward_pass(in_, target, fast_weights)
  File "/zdata/users/kaili/code/pytorch-maml/src/inner_loop.py", line 43, in forward_pass
    out = self.net_forward(input_var, weights)
  File "/zdata/users/kaili/code/pytorch-maml/src/inner_loop.py", line 36, in net_forward
    return super(InnerLoop, self).forward(x, weights)
  File "/zdata/users/kaili/code/pytorch-maml/src/omniglot_net.py", line 49, in forward
    x = batchnorm(x, weight = weights['features.bn1.weight'], bias = weights['features.bn1.bias'], momentum=1)
  File "/zdata/users/kaili/code/pytorch-maml/src/layers.py", line 31, in batchnorm
    running_mean = torch.zeros(np.prod(np.array(input.data.size()[1]))).cuda()
TypeError: torch.zeros received an invalid combination of arguments - got (numpy.int64), but expected one of:
 * (int ... size)
      didn't match because some of the arguments have invalid types: (!numpy.int64!)
 * (torch.Size size)
      didn't match because some of the arguments have invalid types: (!numpy.int64!)

Any clue to solve this? Thanks.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.