google-research / federated

A collection of Google research projects related to Federated Learning and Federated Analytics.
Apache License 2.0
674 stars 191 forks source link

StackOverflow NWP centralized dataset consumes > 64 GB of RAM #29

Closed alshedivat closed 3 years ago

alshedivat commented 3 years ago

My machine runs out of RAM when trying to run a centralized baseline on StackOverflow NWP. The following code reproduces the issue (the process gets killed after RAM is overflown on a machine with 64 GB of RAM):

>>> from utils.datasets import stackoverflow_word_prediction
>>> datasets = stackoverflow_word_prediction.get_centralized_datasets(vocab_size=10000, max_sequence_length=20)

I believe the high memory usage comes from this line, which calls create_tf_dataset_from_all_clients(), which in turn creates the centralized dataset from tensor slices, so the whole dataset is kept in memory.

Is it possible to somehow create centralized SO NWP dataset without keeping everything in RAM?

wennanzhu commented 3 years ago

Hi alshedivat,

A possible solution suggested by @ZacharyGarrett: You can first select a subset of clients and generate the dataset on them, e.g.:

subset_client_data = client_data.from_clients_and_fn( client_ids=np.random.choice(raw_train.client_ids, num_validation_examples), create_tf_dataset_for_client_fn= client_data.create_tf_dataset_for_client_fn ).create_tf_dataset_from_all_clients()

This issue might also be related to the TFF and TF versions. Would you share which versions are you using? Also, are you using the code in the head of this repository or some earlier commit?

alshedivat commented 3 years ago

@wennanzhu, thanks for the code snippet! Let me try and see if it can (partially) solve my problem. In the end, though, I'd still like to be able to train a centralized model on the full StackOverflow dataset, not just on a subset of clients.

I'm using the following versions of this repo, TFF, and TF:

By the way, just checked, getting the same RAM overflow even when trying to run a federated (not centralized!) experiment from within optimization/:

$ bazel run main:federated_trainer -- --task=stackoverflow_nwp --total_rounds=100 --client_optimizer=sgd --client_learning_rate=0.1 --client_batch_size=20 --server_optimizer=sgd --server_learning_rate=1.0 --clients_per_round=10 --client_epochs_per_round=1 --experiment_name=sonwp_fedavg_experiment
wennanzhu commented 3 years ago

Hi @alshedivat , we have reproduced this issue and it seems to be a bug. Thanks for reporting it! We will let you know when it's fixed.

ZacharyGarrett commented 3 years ago

Added a test in TFF to reproduce the slowness in https://github.com/tensorflow/federated/pull/1261

zcharles8 commented 3 years ago

Hi all. This issue seems to occur due to a non-performant implementation of tff.simulation.datasets.ClientData.create_tf_dataset_from_all_clients. While a fix at that level is something I'm working on, I submitted a hot fix for the SqlClientData and FilePerUserClientData (which are the types of ClientData used by the pre-canned datasets in tff.simulation), see commit https://github.com/tensorflow/federated/commit/de97c52ca835b5557c0c1c9dbd768f742a2739f5.

Please let me know if this does not resolve the issue. I'll leave this open in the meantime. For posterity's sake: This turned out to be a bug on TFF, not federated_research, but I think this is a useful place for discussions regardless.

alshedivat commented 3 years ago

@zcharles8, thanks for the fix! I confirm that I both centralized and federated training on StackOverflow no longer run into the memory issues after upgrading TFF to tensorflow-federated-0.18.0.dev20210407. The RAM consumption is now stable, between around 3-5GB.

zcharles8 commented 3 years ago

Thanks @alshedivat. I'll close this bug for now, but please re-open if something seems amiss.