A couple of months ago we (Kevin Scaman, Francis Bach, Yin Tat Lee, Laurent Massoulie and myself) uploaded a new paper on distributed convex optimization. We came up with a pretty clean picture for the optimal oracle complexity of this setting, which I will describe below. I should note that there are hundreds of papers on this topic, but the point of the post is to show our simple cute calculations and not to survey the immense literature on distributed optimization, see the paper itself for a number of pointers to other recent works.
Distributed optimization setting
Let
be an undirected graph on
vertices (
) and with diameter
. We will think of the nodes in
as the computing units. To each vertex
there is an associated convex function
. For machine learning applications one can think that each computing unit has access to a “private” dataset, and
represents the fit of the model corresponding to
on this dataset (say measured on least squares loss, or logistic loss for example). The goal will be to find in a distributed way the optimal “consensus” point:
The distributed processing protocol is as follows: asynchronously/in parallel, each node
can (i) compute a (local) gradient
in time
, and (ii) communicate a vector in
to its neighbors in
in time
. We denote by
the local model (essentially its guess for
) of node
at time
. We aim to characterize the smallest time
such that one can guarantee that all nodes
satisfy
where
.
We focus on the case where
is
-smooth and
-strongly convex (
is the condition number), which is arguably the most challenging case since one expects linear convergence (i.e., the scaling of
in
should be
) which a priori makes the interaction of optimization error and communication error potentially delicate (one key finding is that in fact it is not delicate!). Also, having in mind applications outside of large-scale machine learning (such as “federated” learning), we will make no assumptions about the functions at different vertices relate to each other.
A trivial answer
Recall that Nesterov’s accelerated gradient descent solves the serial problem in time
. Trivially one can distribute a step of Nesterov’s accelerated gradient descent in time
(simply designate a master node at the beginning, and everybody sends its local gradient to the master node in time
). Thus we arrive at the upper bound
using a trivial (centralized) algorithm. We now show (slightly informally, see the paper for proper definitions) that this in fact optimal!
First let us recall the lower bound proof in the serial case (see for example Theorem 3.15 here). The idea is to introduce the function
where
is the Laplacian of the path graph on
, or in other words
First it is easy to see that this function is indeed
-smooth and
-strongly convex. The key point is that, for any algorithm starting at
and such that each iteration stays in the linear span of the previously computed gradients (a very natural assumption) then
In words one can say that each gradient calculation “discovers” a new edge of the path graph involved in the definition of
. Concluding the serial proof is then just a matter of brute force calculations.
Now let us move to the distributed setting, and consider two vertices
and
that realize the diameter of
. The idea goes as follows: let
(respectively
) be the Laplacian of even edges of the path graph on
(respectively the odd edges), that is
Now define
,
, and
for any
. The key observation is that node
does not “know” about the even edges until it receives a message from
and vice versa. Thus it fairly easy to show that in this case one has:
which effectively amounts to a slowdown by a factor
compared to the serial case and proves the lower bound
.
Not so fast!
One can say that the algorithm proposed above defeats a bit the purpose of the distributed setting. Indeed the centralized communication protocol it relies on is not robust to various real-life issues such as machine failures, time-varying graphs, edges with different latency, etc. An elegant and practical solution is to restrict communication to be gossip-like. That is local computations have now to be communicated via matrix multiplication with a walk matrix
which we define as satisfying the following three conditions: (i)
, (ii)
, and (iii)
. Let us briefly discuss these conditions: (i) simply means that if
represents real values stored at the vertices, then
can be calculated with a distributed communication protocol; (ii) says that if there is consensus (that is all vertices hold the same value) then no communication occurs with this matrix multiplication; and (iii) will turn out to be natural in a just a moment for our algorithm based on duality. A prominent example of a walk matrix would be the (normalized) Laplacian of
We denote by
the inverse condition number of
on
(that is the ratio of the smallest non-zero eigvenvalue of
to its largest eigenvalue), also known as the spectral gap of
when
is the Laplacian. Notice that
naturally controls the number of gossip steps to reach consensus, in the sense that gossip steps corresponds to gradient descent steps on
, which will converge in
steps. Doing an “accelerated gossip” (also known as Chebyshev gossiping) one could thus hope to essentially replace the diameter
by
. Notice that this is hopeful thinking because in the centralized model
steps gets you to an exact consensus, while in the gossip model one only reaches an
-approximate consensus and errors might compound. In fact with a bit of graph theory one can immediately see that simply replacing
by
is too good to be true: there are graphs (namely expanders) where
is of order
while
is of order of a constant, and thus an upper bound of the form (say)
would violate our previous lower bound by a factor
.
To save the day we will make extra assumptions, namely that each local function
has condition number
and that in addition to computing local gradient the vertices can also compute local gradients of the Fenchel dual functions
. The latter assumption can be removed at the expense of extra logarithmic factors but we will ignore this point here (see the paper for some hints as well as further discussion on this point). For the former assumption we note that the lower bound proof given above completely breaks under this assumption. However one can save the construction for some specific graphs (finding the correct generalization to arbitrary graphs is one of our open problems). For example imagine a line graph, and cluster the vertices into three groups, the first third, the middle, and the last third. Then one could distribute the even part of the Laplacian on
in the first group, and the odd part on the last group, as well as distribute the Euclidean norm evenly among all vertices. This construction verifies that each vertex function has condition number
and furthermore the rest of the argument still goes through. Interestingly in this case one also has
and thus this proves that for the line graph one has
for gossip algorithms. We will now show a matching upper bound (which holds for arbitrary graphs).
Dual algorithm
For
(which we think of as a set of column vectors, one for each vertex
), denote
for the
column and let
. We are interested in minimizing
under the constraint that all columns are equal, which can be written as
. By definition of the Fenchel dual
and a simple change of variable one has:
Next observe that gradient ascent on
can be written as
and with the notation
this is simply
. Crucially
exactly corresponds to gossiping the local conjugate gradients (which are also the local models)
. In other words we only have to understand the condition number of the function
. The beauty of all of this is that this condition number is precisely
(i.e. it naturally combines the condition number of the vertex functions with the “condition number” of the graph). Thus by accelerating gradient ascent we arrive at a time complexity of
(recall that a gossip step takes time
). We call the corresponding algorithm SSDA (Single-Step Dual Accelerated). One can improve it slightly in the case of low communication cost by doing multiple rounds of communication between two gradient computations (essentially replacing
by
). We call the corresponding algorithm MSDA (Multi-Step Dual Accelerated) and its attains the optimal (in the worst case over graphs) complexity of
.