Open marcociccone opened 2 years ago
Thanks for reaching out. Adding these datasets sounds interesting and happy to support you.
FedJAX uses sqlite files instead of tfrecord files. One way to add support for these datasets would be as follows:
Some of the utilities in https://github.com/google/fedjax/pull/216 might be useful.
Please let us know if you have any questions.
Hi @stheertha and thanks for the support! I'll have a look at the SQLite stuff and get back to you.
I went a different way for the moment (just to play around) by creating a TFRecordFederatedDataset
class that extends InMemoryFederatedData
. Basically, I just read the TFRecords (one per client) and map them to client_data.Examples
just to see if that would work.
I don't think that's the way to go, mainly because images have different shapes and I can't create the numpy objects this way. I'm still looking into it, but do you think that would be an issue with the SQLiteFederatedData
too?
Hi @marcociccone! I am not very familiar with these two datasets. By "images have different shapes" do you mean images in these datasets are not already transformed into a uniform height/width? How about images belonging to the same client? Can they be different in height/width too?
In JAX in general, we need to keep the possible input shapes to a small set to avoid repetitive XLA compilations (each unique input shape configuration will require one XLA compilation), so padding or some other types of transformation is needed to ensure uniform input shapes.
The main problem with deciding what to do with images in different shapes is first deciding how models consume them, so that we can choose an appropriate storage format (i.e. either padding or resizing). I am not very knowledgeable with image models. How do they deal with a batch of images that are in different shapes? Base on my limited understanding of a Conv layer, won't different input shapes lead to different output shapes after a Conv layer? What will an output layer do in that case?
Hi @kho! By looking at images in the tfrecord of a randomly sampled client, I see that images have different height/width. Images should be then randomly cropped to 299x299 (inaturalist) or 224x224(gldv2) before being batched together and consumed by the neural network (see the ECCV2020 paper proposing these two datasets sec 6.5 for more details).
This is a standard data augmentation practice when dealing with image datasets to increase the variability of the dataset. See also this input data pipeline for imagenet as a reference.
Do you think that doing something like that would be possible with the current fedjax data pipeline?
Thanks for the clarification. This is supported by FedJAX pipeline and can be done in two ways:
Option 1: When converting tfrecords to SQLiteFederatedDataBuilder
, do some offline processing to make all images in the same shape. For example, this can be done via zero padding. This will ensure everything can be stored in the numpy format. Then during training and evaluation, use BatchPreprocessor
to do the preprocessing (e.g., random cropping) as required for both train and test datases.
Option 2: When converting tfrecords to SQLiteFederatedDataBuilder
, do entire preprocessing offline including random cropping offline and then store the processed image, all of which have the shape (299, 299). This way the stored dataset is already preprocessed and it can be just used directly without additional preprocessing.
Thanks for your answer! I think option 1 is the way to go to ensure enough data variability (multiple crops for each image). However, gldv2 image sizes span from 300 to 800 pixels, and some clients have up to 1K images so zero-padding to the max shape (800x800) and storing the np.array in memory would require around 14gb. Also I should check that heavily padded images aren't mostly empty when cropped.
I still need to check the codebase carefully but what if we create a TfDataClientDataset
class that iterates over the tf.data
object rather than the np.array
?
This would be more efficient in terms of memory and allow us to take advantage of the tf.data
input pipeline.
Sorry about the confusion, I didn't know the datasets were this big (should have read the READMEs more carefully). Could you help me run some quick stats on gldv2? That will help me figure out if putting everything inside a SQLite database is feasible.
Regarding your proposal of wrapping tf.data
, there is a very significant overhead in iterator creation in tf.data.Dataset
, which become problematic in federated learning since we need to create many dataset iterators during a single experiment. However the calculus might just be different if an individual client is big enough. The stats above will also help in evaluating that.
I also have one question about how people outside Google usually work with such big datasets. Are the files actually stored on local disks, some NFS volume, or some other distributed file system (e.g. GCS or S3)?
I think it would be great to port these datasets from tff to fedjax. I would be happy to make the effort and contribute to the library, but I need a bit of support from the fedjax team 🙂
By looking at the tff codebase (gldv2, inaturalist) it looks that load_data_from_cache function creates a tfrecords file for each client.
The only concrete classes that I see are
SQLiteFederatedData
andInMemoryFederatedData
, but I don't think they are meant for this use case. What would be the best way to map the clients into aFederatedDataset
? We could replicate something like FilePerUserClientData.Thanks!