Comments (7)
The updates of FedProx is defined here: https://github.com/litian96/FedProx/blob/master/flearn/optimizer/pgd.py#L29-L36
where var is the current local model you are optimizing for, and vstar is the current global model which is the starting point of solving for the local subproblem.
from fedprox.
Thank you, It makes sense now!
from fedprox.
Hi Tian,
Thanks for previous reply, but there is one place I still don't figure out. The paper shows objective function is optimizing Fk(w) and ||w - w^t|| simultaneously, but the code shows that we are calculating loss and ||w - w^t|| seperately. I found issue #19 #18 and #6 are quite useful, but not exactly what I need.
Take mnist task as an example, the code first calculate the cross entropy loss in
grads_and_vars = optimizer.compute_gradients(loss), this part does not consider ||w - w^t||.
By calling var_update = state_ops.assign_sub(var, lr_t*(grad + mu_t*(var-vstar))) in
train_op = optimizer.apply_gradients(grads_and_vars, global_step=tf.train.get_global_step()) we consider the proximal term.
def create_model(self, optimizer):
"""Model function for Logistic Regression."""
features = tf.placeholder(tf.float32, shape=[None, 784], name='features')
labels = tf.placeholder(tf.int64, shape=[None,], name='labels')
logits = tf.layers.dense(inputs=features, units=self.num_classes, kernel_regularizer=tf.contrib.layers.l2_regularizer(0.001))
predictions = {
"classes": tf.argmax(input=logits, axis=1),
"probabilities": tf.nn.softmax(logits, name="softmax_tensor")
}
loss = tf.losses.sparse_softmax_cross_entropy(labels=labels, logits=logits)
# set_trace()
grads_and_vars = optimizer.compute_gradients(loss)
grads, _ = zip(*grads_and_vars)
train_op = optimizer.apply_gradients(grads_and_vars, global_step=tf.train.get_global_step())
eval_metric_ops = tf.count_nonzero(tf.equal(labels, predictions["classes"]))
return features, labels, train_op, grads, eval_metric_ops, loss
Therefore, it seems that the code is calculating the gradient from Fk's perspective, and then subtract a proximal part to get the final gradient. I also debuged the code to check who has called pgd's _apply_dense(), it shows that this method is only used by optimizer.apply_gradient(), instead of compute_gradient().
Thank you in advance :)
from fedprox.
You are exactly right that we first need to obtain the gradients of \nabla F_k (w_local) + \mu (w_local - w_global), and then apply the gradients to w_local.
What is your question?
from fedprox.
My question is compute_gradient() only calculates \nabla F_k(w_local). How can we find argmin_w h_k without considering \nabla (w_local - w_global)?
I know we considered \mu (w_local - w_global) in apply_gradient(), but this proximal term not necessary leads to argmin h_k.
from fedprox.
The gradient of h_k is \nabla F_k (w_local) + \mu (w_local - w_global), which is what we apply to update w_local (specifically, grad + mu_t * (var-vstar)
).
from fedprox.
Ohhhhh, I see, really appreciate for the clarification :)
from fedprox.
Related Issues (20)
- Whether the difference calculation need to consider the proportion of the sample HOT 1
- The FEMNIST data generation HOT 1
- Tensorflow installation HOT 1
- what is the exact role of PerturbedGradientDescent in FedProx HOT 1
- Obtain \nabla h_k(w_t, w_t) in FedProx HOT 3
- License? HOT 2
- ModuleNotFoundError: No module named 'FedML' HOT 1
- Should the global model replace the client model? HOT 2
- about personalized FL
- Why can't plot_fig2 after bash and showing no such a file? HOT 5
- AttributeError: 'Server' object has no attribute 'client_model' HOT 2
- PGGD HOT 1
- Dynamic μ HOT 1
- problem about setup HOT 5
- All clients are sharing the same underlying learner.
- Where is the gamma in the code implemetion? HOT 1
- Does this project have a pytorch version?
- python version
- The Google Drive link to MNIST dataset is expired. HOT 1
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from fedprox.