Federated Learning: Machine Learning on Decentralized data
https://www.youtube.com/watch?v=89BGjQYA0uE
1. Decentralized data
Edge devices
can data live at the edge?
On-device inference offers
- improved latency
- works offline
- better battery life
- privacy advantages
But what about analytics? How could the model learn on the data on the edge devices?
⇒ federated learning
Gboard: mobile keyboard
Gboard machine learning
Models are essential for:
- tap typing
- gesture typing
- auto-corrections
- predictions
- voice to text
- and more…
all on-deice for latency and reliability
Gboard data: more, better, decentralized
on-device cache of local interactions: touch points, typed text, context, and more.
Used exclusively for federated learning and computation ⇒
Federated Learning: Collaborative Machine Learning without Centralized Training Data
Privacy Principles - like ephemeral reports
-
Only-in-aggregation: engineer may only access combined device reports
-
Ephemeral reports: never persist per-device reports
-
Focused collection: devices report only what is needed for this computation
-
Don’t memorize individuals’ data ⇒ Differential Privacy: statistical science of learning common patterns in a dataset without memorizing individual examples
Differential privacy complements federated learning: use noise to obscure an individual’s impact on the learned model.
Privacy Technology
-
On-device datasets ⇒ keep raw data local. Expire old data. Encrypt at rest.
-
Federated aggregation ⇒ combine reports from multiple devices.
-
Secure aggregation: compute a (vector) sum of encrypted device reports; e.g. Gboard word count
it’s a practical protocol with
- security guarantees
- communication efficiency
- dropout tolerance
-
Federated model averaging: many steps of gradient descent on each device
-
Differentially Private model averaging
- devices “clip” their updates if they are too large
- server adds noise when combining updates
2. Federated computation: Map-reduce for the next generation
Federated computation: only-in-aggregation
on device dataset
Federated computation: Federated aggregation
-
The engineer set an threshold on the server
-
The threshold will be broadcast to the available devices
-
The threshold is then compared to the local temperature data, the value will be 1 or 0 depends on whether it’s larger than the threshold
-
Then the 0s and 1s will be aggerated again (in this case, federated mean)
-
Repeat: each round involves different devices
Federated computation challenges
What’s different from datacenter MapReduce
- limited communication
- intermittent compute node availability
- intermittent data availability
- massively distributed ⇒ e.g. sample magnitudes: 1B devices, 10M available, 1000 selected, 50 dropout
- privacy preserving aggregation
Round completion rate by hours (US)
- rounds complete faster when more devices available
- device availability changes over the course of a day ⇒ dynamic data availability
Federated computation: secure aggregation, focused collection
Gboard: word counts
relative typing frequency of common word
How to compute the frequency of the word ⇒ Focused collection
-
count the word in device_data:
for word in device_data: if word in ["hello", "world"]: device_update[word] += 1
-
broad cast the counts into local device and run locally
-
combine the counts on the local device and sum them on the server
-
on the server, the engineer is able to see the combined results of all the available devices
-
repeat
3. Federated learning
FL: machine learning on decentralized data
Privacy Technologies - like federated model averaging
Model engineer workflow
-
Train & evaluate on cloud data / proxy data (similar to the device data)
-
Main training loop is on the decentralized data
-
Final model validation steps
-
deploy model to devices for on-device inference
Federated model averaging
FL vs. in-datacenter learning
federated computation plus:
- unbalance data
- self-correlated (non-IID) data
- variable data availability
When does it apply?
most appropriate when:
- on-device data is more relevant than server-side proxy data
- on-device data is privacy sensitive or large
- labels can be inferred naturally from user interaction
Gboard: language modeling
- predict the next word based on typed text so far
- powers the predictions strip
Decentralized data represents better what user really type.
Federated Learning for Mobile Keyboard Prediction
Applied Federated Learning: Improving Google Keyboard Query Suggestions
Federated Learning Of Out-Of-Vocabulary Words
Towards Federated Learning at Scale: System Design
4. TensorFlow Federated
experiment with federated technologies in simulation
What’s in the box
Federated learning (FL) API
- Implementations of federated training/evaluation
- Can be applied to existing TF models/data
Federated Core (FC) API
- Allows for expressing new federated algorithms
Local runtime for simulations
Federated computation in TFF
-
federated “op”
-
FederatedType
READINGS_TYPE = tff.FederatedType(tf.float32, tff.CLIENTS) # An abstract specification of a simple distributed system @tff.federated_computation(READINGS_TYPE) def get_avarage_temperature(sensor_readings): return tff.federated_mean(sensor_readings)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-WZZXh7Mu-1622214443329)(Federated%20Learning%20Machine%20Learning%20on%20Decentraliz%20165da61b123d4cf8893b06ab85abe307/Untitled%2012.png)]
- tff.federated_broadcast: broadcast the threshold to the devices
- tff.federated_map: get 1 / 0 depends on whether it has surpass the threshold or not
- tff.federated_mean: to get the results aggregated back to the server
THRESHOLD_TYPE = tff.FederatedType(tf.float32, tff.SERVER, all_equal=True @tff.federated_computation(READINGS_TYPE, THRESHOLD_TYPE) def get_fraction_over_threshold(readings, threshold): @tff.tf_computation(tf.float32, tf.float32) def _is_over_as_float(val, threshold): return tf.to_float(val > threshold) return tff.federated_average(tff.federated_map(_is_over_as_float, [readings, tff.federated_broadcast(threshold)]))
Federated learning and corgis