Git Product home page Git Product logo

tensorflow-maml's Introduction

Reproduction of MAML using TensorFlow 2.0.

This reproduction is highly influenced by the pytorch reproduction by Adrien Lucas Effot available at Paper repro: Deep Metalearning using “MAML” and “Reptile”.

MAML | Neural Net

alt-text-1 alt-text-2

MAML paper

https://arxiv.org/abs/1703.03400

Model-Agnostic Meta-Learning for Fast Adaptation of Deep Networks Chelsea Finn, Pieter Abbeel, Sergey Levine

We propose an algorithm for meta-learning that is model-agnostic, in the sense that it is compatible with any model trained with gradient descent and applicable to a variety of different learning problems, including classification, regression, and reinforcement learning. The goal of meta-learning is to train a model on a variety of learning tasks, such that it can solve new learning tasks using only a small number of training samples. In our approach, the parameters of the model are explicitly trained such that a small number of gradient steps with a small amount of training data from a new task will produce good generalization performance on that task. In effect, our method trains the model to be easy to fine-tune. We demonstrate that this approach leads to state-of-the-art performance on two few-shot image classification benchmarks, produces good results on few-shot regression, and accelerates fine-tuning for policy gradient reinforcement learning with neural network policies.


image.png

tensorflow-maml's People

Contributors

hereismari avatar roschly avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

tensorflow-maml's Issues

I want change the model to train image classification.

Here we go.I have some problems.
First, I saw every time for model foward, we using same input "x".I can't understand, why not update copy-model by input x1, and update maml-model by input x2?Is the code same as paper?Maybe I am wrong.
Second, if we train image classification, we need to make data set.The data set consist of 50 picture, and every 10 picture from same class.So we have one task.However, I don't know how to define batch. Somebody tell me I can binding 4 task that means your train 4 task at the same time.I even don't know how to input my network, how it train.

Save weight

Thanks for sharing your code.
I have a question, how to save the MAML weight? I tried save the MAML object using pickle but didn't work

How to apply several SGD steps within the ineer loop?

Hi @mari-linhares , thanks for the repo!
We are building on your code to implement a bit more general version of MAML that includes a batch of tasks within the inner loop and several steps of gradient descent wrt the parameters of each task. However, we are stuck in how to add several steps of SGD within your code using tensorflow 2.0. Do you have any idea of how to do that?

There might be error in the train_maml function.

The loss used for calculating the gradient to perform the meta update is only from one task. However, it should be the sum of all sampled tasks according to the original paper. Please look at step 8 in the code (below). The test_loss is inside the loop for i, t in enumerate(random.sample(dataset, len(dataset))), indicating the test_loss is only for one task.

# Step 2: instead of checking for convergence, we train for a number
# of epochs
for _ in range(epochs):
    total_loss = 0
    losses = []
    start = time.time()
    # Step 3 and 4
    for i, t in enumerate(random.sample(dataset, len(dataset))):
        **x, y = np_to_tensor(t.batch())**
        model.forward(x)  # run forward pass to initialize weights
        with tf.GradientTape() as test_tape:
            # test_tape.watch(model.trainable_variables)
            # Step 5
            with tf.GradientTape() as train_tape:
                train_loss, _ = compute_loss(model, x, y)
            # Step 6
            gradients = train_tape.gradient(train_loss, model.trainable_variables)
            k = 0
            model_copy = copy_model(model, x)
            for j in range(len(model_copy.layers)):
                model_copy.layers[j].kernel = tf.subtract(model.layers[j].kernel,
                            tf.multiply(lr_inner, gradients[k]))
                model_copy.layers[j].bias = tf.subtract(model.layers[j].bias,
                            tf.multiply(lr_inner, gradients[k+1]))
                k += 2
            # Step 8
            **test_loss**, logits = compute_loss(model_copy, **x, y**)
        # Step 8
        gradients = test_tape.gradient(**test_loss**, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients, model.trainable_variables))

when calculating test loss

Hello thanx form the code it helped alot!

i have some questions for calculating test loss in MAML

when your code calculates test loss in train_maml() for specific sine function in inner loop

are you using same data to calculate both training and test loss?
shouldn't you sample new test data from same sine function to calculate test error like in the paper

thanx

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.