Dear author,
Challenge and solution
This repository has implemented Tensor Parallel, which facilitates the system by distributing the computation workload evenly to each node, achieving nearly linear acceleration in terms of inference
time. However, the communication workload is not distributed. In other words, the transfer time will increase with the number of workers. This situation can be solved by changing all-reduce
to ring-all-reduce,
which distributes the transfer data
workload to every worker.
Let me briefly introduce the concept of all-reduce
and ring-all-reduce
.
All-reduce
This master-worker
architecture has currently been implemented.
It can be changed into ring-workers.
Ring-all-reduce
The ring-all-reduce
algorithm is divided into two stages.
stage 1
First, Distribute P
workers on a ring and divide each worker's data into P
parts. In your case, the hidden_dim
is divided into P
parts.
Next, look at the k
-th worker, who will send the k
-th data to the next worker and receive the k-1
-st data from the previous worker.
Afterwards, the worker will integrate the received k-1
-st data with their own k-1
st data, and then send the integrated data to the next worker.
After P-1
cycles, each worker will include a copy of the final integration result.
stage 2
In the second stage, each worker sends the integrated part to the next worker. Workers can update the corresponding part of its data after receiving the data. After P-1
cycles, each worker will include a full copy of the final integration result. This result is the same as All-Reduce
.
Assuming that each worker's data is a vector of length hidden_dim
= h,
the amount of data sent or received by each worker is 2(P-1)*h/P
, almost independent of the number of workers P.
When P=1, the transfer data is 0;
When P=2, the transfer data is h
;
When P=4, the transfer data is 1.5*h
, less than current 3*h
;
When P=8, the transfer data is 1.75*h
, much less than current 7*h
;
When p->β, the transfer data is 2*h
, of course less than βh
Summary
Ring AllReduce can avoid the problem of the master needing to handle the amount of O(h*P)
data in the master-worker
architecture, which can become a network bottleneck when the number of devices increases to 8 or more.
Best Regards.
For your reference: Optimization of Collective Communication Operations in MPICH .pdf