google-research / big_vision

Official codebase used to develop Vision Transformer, SigLIP, MLP-Mixer, LiT and more.
Apache License 2.0
2.04k stars 140 forks source link

Contrastive Input Pipeline #75

Closed philippe-eecs closed 8 months ago

philippe-eecs commented 8 months ago

Hi, this is an amazing codebase for big-vision tasks!

I wanted to re-implement MOCO, but I was unsure how to modify the input pipeline + data augmentation to allow for applying independent random augmentations to the same image. Is there a simple way to do this?

My current implementation just applies the augmentations after the batch of images leaves the input pipeline (without any augmentation), but this requires me to write new data augmentation functions in Jax, which isn't ideal. Do you have any ideas on how to integrate this into the input pipeline?

Any help would be appreciated and let me know if you need more information.

akolesnikoff commented 8 months ago

I think the easiest solution is to hack input_pipeline a bit and duplicate images in the input pipeline. flat_map operation can be used for this, see the relevant question on stack overflow: https://stackoverflow.com/questions/61754089/how-to-repeat-a-tf-data-dataset

philippe-eecs commented 8 months ago

def duplicate_element(x): return tf.data.Dataset.from_tensors(x).repeat(2)

This worked. Thanks for the idea!

Test on cifar10: first 64 "samples"

image