google / fedjax

FedJAX is a JAX-based open source library for Federated Learning simulations that emphasizes ease-of-use in research.
Apache License 2.0
254 stars 41 forks source link

Support for gldv2 and inaturalist datasets #241

Open marcociccone opened 2 years ago

marcociccone commented 2 years ago

I think it would be great to port these datasets from tff to fedjax. I would be happy to make the effort and contribute to the library, but I need a bit of support from the fedjax team 🙂

By looking at the tff codebase (gldv2, inaturalist) it looks that load_data_from_cache function creates a tfrecords file for each client.

The only concrete classes that I see are SQLiteFederatedData and InMemoryFederatedData, but I don't think they are meant for this use case. What would be the best way to map the clients into a FederatedDataset? We could replicate something like FilePerUserClientData.

Thanks!

stheertha commented 2 years ago

Thanks for reaching out. Adding these datasets sounds interesting and happy to support you.

FedJAX uses sqlite files instead of tfrecord files. One way to add support for these datasets would be as follows:

  1. Create a script to download the datasets from TFF, postprocess the dataset by the scripts you mentioned.
  2. Use SQLiteFederatedDataBuilder to build a sqlite dataset file. Note that SQLiteFederatedDataBuilder takes as input client_ids and their corresponding client examples. Hence as long as we are able to read and parse the client examples from TFF's tfrecords, it can be used in conjunction with SQLiteFederatedDataBuilder.
  3. Cache the sqlite file for further use.

Some of the utilities in https://github.com/google/fedjax/pull/216 might be useful.

Please let us know if you have any questions.

marcociccone commented 2 years ago

Hi @stheertha and thanks for the support! I'll have a look at the SQLite stuff and get back to you.

I went a different way for the moment (just to play around) by creating a TFRecordFederatedDataset class that extends InMemoryFederatedData. Basically, I just read the TFRecords (one per client) and map them to client_data.Examples just to see if that would work.

I don't think that's the way to go, mainly because images have different shapes and I can't create the numpy objects this way. I'm still looking into it, but do you think that would be an issue with the SQLiteFederatedData too?

kho commented 2 years ago

Hi @marcociccone! I am not very familiar with these two datasets. By "images have different shapes" do you mean images in these datasets are not already transformed into a uniform height/width? How about images belonging to the same client? Can they be different in height/width too?

In JAX in general, we need to keep the possible input shapes to a small set to avoid repetitive XLA compilations (each unique input shape configuration will require one XLA compilation), so padding or some other types of transformation is needed to ensure uniform input shapes.

The main problem with deciding what to do with images in different shapes is first deciding how models consume them, so that we can choose an appropriate storage format (i.e. either padding or resizing). I am not very knowledgeable with image models. How do they deal with a batch of images that are in different shapes? Base on my limited understanding of a Conv layer, won't different input shapes lead to different output shapes after a Conv layer? What will an output layer do in that case?

marcociccone commented 2 years ago

Hi @kho! By looking at images in the tfrecord of a randomly sampled client, I see that images have different height/width. Images should be then randomly cropped to 299x299 (inaturalist) or 224x224(gldv2) before being batched together and consumed by the neural network (see the ECCV2020 paper proposing these two datasets sec 6.5 for more details).

This is a standard data augmentation practice when dealing with image datasets to increase the variability of the dataset. See also this input data pipeline for imagenet as a reference.

Do you think that doing something like that would be possible with the current fedjax data pipeline?

stheertha commented 2 years ago

Thanks for the clarification. This is supported by FedJAX pipeline and can be done in two ways:

Option 1: When converting tfrecords to SQLiteFederatedDataBuilder, do some offline processing to make all images in the same shape. For example, this can be done via zero padding. This will ensure everything can be stored in the numpy format. Then during training and evaluation, use BatchPreprocessor to do the preprocessing (e.g., random cropping) as required for both train and test datases.

Option 2: When converting tfrecords to SQLiteFederatedDataBuilder, do entire preprocessing offline including random cropping offline and then store the processed image, all of which have the shape (299, 299). This way the stored dataset is already preprocessed and it can be just used directly without additional preprocessing.

marcociccone commented 2 years ago

Thanks for your answer! I think option 1 is the way to go to ensure enough data variability (multiple crops for each image). However, gldv2 image sizes span from 300 to 800 pixels, and some clients have up to 1K images so zero-padding to the max shape (800x800) and storing the np.array in memory would require around 14gb. Also I should check that heavily padded images aren't mostly empty when cropped.

I still need to check the codebase carefully but what if we create a TfDataClientDataset class that iterates over the tf.data object rather than the np.array? This would be more efficient in terms of memory and allow us to take advantage of the tf.data input pipeline.

kho commented 2 years ago

Sorry about the confusion, I didn't know the datasets were this big (should have read the READMEs more carefully). Could you help me run some quick stats on gldv2? That will help me figure out if putting everything inside a SQLite database is feasible.

Regarding your proposal of wrapping tf.data, there is a very significant overhead in iterator creation in tf.data.Dataset, which become problematic in federated learning since we need to create many dataset iterators during a single experiment. However the calculus might just be different if an individual client is big enough. The stats above will also help in evaluating that.

I also have one question about how people outside Google usually work with such big datasets. Are the files actually stored on local disks, some NFS volume, or some other distributed file system (e.g. GCS or S3)?