Git Product home page Git Product logo

matcha's Introduction

MATCHA: Communication-Efficient Decentralized SGD

Code to reproduce the experiments reported in this paper:

Jianyu Wang, Anit Kumar Sahu, Zhouyi Yang, Gauri Joshi, Soummya Kar, "MATCHA: Speeding Up Decentralized SGD via Matching Decomposition Sampling," arxiv preprint 2019.

A short version has been abridged in FL-NeurIPS'19 and received the Distinguished Student Paper Award.

This repo contains the implementations of MATCHA and D-PSGD for any arbitrary node topologies. You can also use it to develop other decentralized training methods. Please cite this paper if you use this code for your research/projects.

Dependencies and Setup

The code runs on Python 3.5 with PyTorch 1.0.0 and torchvision 0.2.1. The peer-to-peer communication among workers is achieved by MPI4Py sendrecv function.

Training examples

Here is an example on how to use MATCHA to train a neural network.

import util
from graph_manager import FixedProcessor, MatchaProcessor
from communicator import decenCommunicator, ChocoCommunicator, centralizedCommunicator

# Define the base node topology by giving the graph ID
# There are six pre-defined graphs in utils.py
base_graph = util.select_graph(args.graphid)

# Preprocess the base topology: 1) decompose it into matchings; 
#                               2) get activation probabities for matchings;
#                               3) compute the mixing weight;
#                               4) generate activation flags for each iteration
# All these information is stored in GP
GP = MatchaProcessor(base_graph, 
                     commBudget = args.budget,
                     rank = rank,
                     size = size,
                     iterations = args.epoch * num_batches,
                     issubgraph = True)

# Define the communicator
communicator = decenCommunicator(rank, size, GP)

# Start training
for batch_id, (data, label) in enumerate(data_loader):
    # same as serial training
    output = model(data) # forward
    loss = criterion(output, label)
    loss.backward() # backward
    optimizer.step() # gradient step
    optimizer.zero_grad()

    # additional line to average local models at workers
    communicator.communicate(model)

In order to use D-PSGD, we just need to change MatchaProcessor to FixedProcessor. Similarly, in order to use ChocoSGD, we can change MatchaProcessor to FixedProcessor and decenCommunicator to ChocoCommunicator. If one wants to run fully synchronous SGD, then centralizedCommunicator can be used and there is no need to define the graph processor.

In addition, before training starts, we need to initialize MPI processes on each worker machine as follows:

from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
size = comm.Get_size()

The script can be run using the following command:

mpirun --hostfile c8 -np 8 python train_mpi.py

Citation

@article{wang2019matcha,
  title={{MATCHA}: Speeding Up Decentralized {SGD} via Matching Decomposition Sampling},
  author={Wang, Jianyu and Sahu, Anit Kumar and Yang, Zhouyi and Joshi, Gauri and Kar, Soummya},
  journal={arXiv preprint arXiv:1905.09435},
  year={2019}
}

matcha's People

Contributors

jywa avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.