Git Product home page Git Product logo

Comments (7)

litian96 avatar litian96 commented on June 16, 2024 1

The updates of FedProx is defined here: https://github.com/litian96/FedProx/blob/master/flearn/optimizer/pgd.py#L29-L36
where var is the current local model you are optimizing for, and vstar is the current global model which is the starting point of solving for the local subproblem.

from fedprox.

tianboh avatar tianboh commented on June 16, 2024

Thank you, It makes sense now!

from fedprox.

tianboh avatar tianboh commented on June 16, 2024

Hi Tian,

Thanks for previous reply, but there is one place I still don't figure out. The paper shows objective function is optimizing Fk(w) and ||w - w^t|| simultaneously, but the code shows that we are calculating loss and ||w - w^t|| seperately. I found issue #19 #18 and #6 are quite useful, but not exactly what I need.

Take mnist task as an example, the code first calculate the cross entropy loss in
grads_and_vars = optimizer.compute_gradients(loss), this part does not consider ||w - w^t||.
By calling var_update = state_ops.assign_sub(var, lr_t*(grad + mu_t*(var-vstar))) in
train_op = optimizer.apply_gradients(grads_and_vars, global_step=tf.train.get_global_step()) we consider the proximal term.

  def create_model(self, optimizer):
      """Model function for Logistic Regression."""
      features = tf.placeholder(tf.float32, shape=[None, 784], name='features')
      labels = tf.placeholder(tf.int64, shape=[None,], name='labels')
      logits = tf.layers.dense(inputs=features, units=self.num_classes, kernel_regularizer=tf.contrib.layers.l2_regularizer(0.001))
      predictions = {
          "classes": tf.argmax(input=logits, axis=1),
          "probabilities": tf.nn.softmax(logits, name="softmax_tensor")
          }
      loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
      # set_trace()
      grads_and_vars = optimizer.compute_gradients(loss)
      grads, _ = zip(*grads_and_vars)
      train_op = optimizer.apply_gradients(grads_and_vars, global_step=tf.train.get_global_step())
      eval_metric_ops = tf.count_nonzero(tf.equal(labels, predictions["classes"]))
      return features, labels, train_op, grads, eval_metric_ops, loss

Therefore, it seems that the code is calculating the gradient from Fk's perspective, and then subtract a proximal part to get the final gradient. I also debuged the code to check who has called pgd's _apply_dense(), it shows that this method is only used by optimizer.apply_gradient(), instead of compute_gradient().

Thank you in advance :)

from fedprox.

litian96 avatar litian96 commented on June 16, 2024

You are exactly right that we first need to obtain the gradients of \nabla F_k (w_local) + \mu (w_local - w_global), and then apply the gradients to w_local.
What is your question?

from fedprox.

tianboh avatar tianboh commented on June 16, 2024

My question is compute_gradient() only calculates \nabla F_k(w_local). How can we find argmin_w h_k without considering \nabla (w_local - w_global)?

I know we considered \mu (w_local - w_global) in apply_gradient(), but this proximal term not necessary leads to argmin h_k.

from fedprox.

litian96 avatar litian96 commented on June 16, 2024

The gradient of h_k is \nabla F_k (w_local) + \mu (w_local - w_global), which is what we apply to update w_local (specifically, grad + mu_t * (var-vstar) ).

from fedprox.

tianboh avatar tianboh commented on June 16, 2024

Ohhhhh, I see, really appreciate for the clarification :)

from fedprox.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.