Git Product home page Git Product logo

dense's Introduction

DENSE: Data-Free One-Shot Federated Learning

This repository is the official PyTorch implementation of:

DENSE: Data-Free One-Shot Federated Learning (NeurIPS 2022).

image

Requirements

  • This codebase is written for python3 (used python 3.6 while implementing).
  • We use Pytorch version of 1.8.2, 11.4 CUDA version.
  • To install necessary python packages,

pip install -r requirements.txt

  • The logs are uploaded to the wandb server. If you do not have a wandb account, just install and use as offline mode.
pip install wandb
wandb off

How to Run Codes?

Local Training

At first, we divide the data into $m$ clients, train the model on each client until it converges and send it to the central server. The sample code is as follows:

python3 loop_df_fl.py --type=pretrain  --iid=0 --lr=0.01 --model=cnn \
--dataset=cifar10 --beta=0.5 --seed=1 --num_users=5 --local_ep=400

Here is an explanation of some parameters,

  • --dataset: name of the datasets (e.g., mnist, cifar10, cifar100 or cinic10).
  • --num_users: the number of total clients (e.g., 5, 20, 100).
  • --batch_size: the size of batch to be used for local training. (default: 256)
  • --iid: IID or non-IID partition strategy (0 for non-IID).
  • --beta: concentration parameter beta for latent Dirichlet Allocation (default: beta=0.5).
  • --model: model architecture to be used (e.g., cnn, resnet, or vit).
  • --epochs: the number of total communication rounds. (default: 200)
  • --frac: fraction of clients to be ramdonly sampled at each round (default: 1)
  • --local_ep: the number of local epochs (default: 400).
  • --lr: the initial learning rate for local training (default: 0.01)
  • --momentum: the momentum for SGD (default: 0.9).
  • --seed: random seed
  • --adv: scaling factor for adv loss
  • --bn: scaling factor for BN regularization
  • --lr_g: learning rate for the generator

Note that the same random seed must be fixed for fair comparison. Because different random seeds mean that the data distribution on each client is different. Therefore, we should use several random seeds in the experiments. For args.seed=1. Here is an example for --seed=1 (cifar10, 5 clients, $beta$=0.5),

Data statistics: {client 0: {0: 156, 1: 709, 2: 301, 3: 2629, 4: 20, 5: 651, 6: 915, 7: 113, 8: 180, 9: 2133}, \
client 1: {0: 1771, 1: 2695, 2: 1251, 3: 1407, 4: 665, 5: 314, 6: 1419, 7: 3469}, \
client 2: {0: 236, 1: 15, 2: 1715, 3: 76, 4: 1304, 5: 34, 6: 1773, 7: 75, 8: 3289, 9: 2360}, \
client 3: {0: 2809, 1: 575, 2: 157, 3: 853, 4: 2555, 5: 2557, 6: 203, 7: 1213}, \
client 4: {0: 28, 1: 1006, 2: 1576, 3: 35, 4: 456, 5: 1444, 6: 690, 7: 130, 8: 1531, 9: 507}}

The sample learning curves for local training:

image

The accuracy for model ensemble (teacher) and the accuracy after FedAvg:

For each client, Accuracy: 55 / 55 / 59 / 60 / 62
FedAvg Accuracy: 30.9300
Ensemble Accuracy: 71.8100

Global Distillation

Then we use the ensemble of local models for KD. Here is the sample code,

python loop_df_fl.py --type=kd_train --iid=0 --epochs=200 --lr=0.005 --batch_size 256 --synthesis_batch_size=256 \
--g_steps 30 --lr_g 1e-3 --bn 1.0 --oh 1.0 --T 20 --save_dir=run/cifar10 --other=cifar10 --model=cnn --dataset=cifar10 \
--adv=1 --beta=0.5 --seed=1

image

Note that the parameters are not well-designed in the experiments, intuitively the accuracy of KD should be close to the performance of the model ensemble.

The synthetic data for one batch (CIFAR10):

image

Citing this work

@inproceedings{zhangdense,
  title={DENSE: Data-Free One-Shot Federated Learning},
  author={Zhang, Jie and Chen, Chen and Li, Bo and Lyu, Lingjuan and Wu, Shuang and Ding, Shouhong and Shen, Chunhua and Wu, Chao},
  booktitle={Advances in Neural Information Processing Systems}
}

dense's People

Contributors

zj-jayzhang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

dense's Issues

Codes of implementing other baselines

Could you please provide your codes of conducting other baseline methods? That would be very helpful for others to follow this interesting work.

And it seems that current codes have some bugs, like line

acc, test_loss = test(model, test_loader)
, the test_loader is not given.

Synthetic Data Generated

Can you please provide the location/command to generate the synthetic data generated that has been used in the paper for visualization and comparison with the original dataset?

command used for experiments in Table 1

Dear authors,

Could you please provide the commands to relaunch the code to generate the results in your Table 1 for MNIST, FMNIST, CIFAR10?

Thanks so much,
Best,
Chuan Xu

Concerns Regarding Dependencies in requirements.txt

Hello, I came across some issues while reproducing your project. Firstly, I noticed a problem with the dependency version on line 21 of your requirements.txt file; it appears to be nonexistent on PyPI. Additionally, I encountered compatibility issues with Python 3.6 for the dependencies on lines 20 and 22.

For the issue on line 21, you can check the link at https://pypi.org/project/mkl-random/#history.

Regarding the issues on lines 20 and 22, you can check the links at https://pypi.org/project/mkl-fft/1.3.0/#files and https://pypi.org/project/mkl-service/2.3.0/#files.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.