What Does it Take to Train Deep Learning Models On-Device?

Photo by Fort George G. Meade Public Affairs Office

Over the past few weeks,a few different people have asked me about the state of model training on phones and embedded devices. The good news is that it’s definitely possible, I know of multiple examples of teams doing this successfully. The bad news is that our tools don’t yet make it easy.

Back in 2014 I released some code and a video showing how you could do simple transfer learning on a phone. At that point I was using a support vector machine to handle the on-device training of the final layer, but there was no fundamental reason I couldn’t have used back-propagation instead, it was just easier to use a technique I knew. I’m sure there are other examples even earlier than that too.

One area where on-device learning with neural networks has long been common is speech recognition. Even today, your phone will often ask you to say the wake word (for example “OK Google” or “Hey Siri”) a few times, so it can learn your pronunciation. Typically this won’t involve a complete retraining of the network using back propagation, but something that’s more like transfer learning on either the feature extraction process, or the final layer. Apple’s implementation uses a supplemental network to compare utterances to your references in what they term speaker space, so it’s pretty different than how we normally think of training. I still consider these cases valid examples of on-device learning, but often use the term “personalization” to capture the limited extent of the network changes.

There are situations where the kind of full-network back propagation familiar to deep learning practitioners building models in Python happens on devices too. A good published example is for the Google Keyboard application. The paper focuses on the novel aspects of sending anonymized model updates over the network, but those updates come from a traditional process of running forward and backward passes on the phone.

So, corporate teams are building applications that use learning on-device, but how can you do the same thing? The key thing to realize is that there are two different stages to training a model. The first is building the back-propagation machinery to update the gradients. The second is actually feeding examples through the forward pass, and then passing the gradient changes through the backward pass.

Most libraries, including TensorFlow, handle the automatic differentiation you need to build the back propagation operations in the Python layer. This code looks at the model you’ve defined in terms of inference operations (convolution layers feeding into activation functions, etc) and builds a complementary set of operations for the back propagation, including a loss function. The important part is that once the graph is built, at least in TensorFlow, you don’t need the Python code any more. Running the backward pass just means running the sub-graph of operations that was created by the differentiation stage.

In TensorFlow terms, this means after you’ve created the backward pass you can save off the GraphDef containing both the forward and backward operations. If you then load that in C++, either on-device, or on a server, theoretically you can also run gradient updates from labeled examples using Session::Run(). I say theoretically because we don’t have documentation on this process, and it involves a lot of unfamiliar steps that can be hard to get right in C++, like initializing variables, loading and saving checkpoints, and handling losses. This is a gap I’m hoping we can fill, but there are internal teams who’ve been able to past these issues, so at least there’s proof that it’s possible. I’m sure there must be examples of external teams doing similar work, so feel free to add comments or reply on Twitter if you do have code to share, in TensorFlow or any other framework, I’d love to share them!

Like this:

Like Loading…

Related