ryanccarelli / ssljax

self-supervised learning in jax
GNU General Public License v2.0
8 stars 0 forks source link

Generalising Augmentation Distributions #76

Open aranku opened 2 years ago

aranku commented 2 years ago

Right now, the probability of applying a distribution is a property of the Augmentation itself i.e.

class Augmentation:
    """
    An augmentation applies a function to data.

    Args:
        prob (float): Probability that augmentation will be executed.
    """

    def __init__(self, prob=1.0):
        assert isinstance(prob, float), "prob must be of type float"
        self.prob = prob

This variable defines the probability of applying an augmentation or not. This allows for pipelines like:

RandomFlip(0.5) → GaussianBlur(0.5) → ....

However, this framework is not flexible enough to allow for distributions over multiple augmentations:

sample one ( RandomFlip(0.5), GaussianBlur(0.5) ) → ....

This issue proposes refactoring Augmentation to be a deterministic transform with its probability being the property of an AugmentationDistribution class

class AugmentationDistribution:
    """
    A categorical distribution of augmentations to apply to input

    Args:
        probs (List[float]): List of probabilities
        augs (List[Augmentation]): List of augmentations to be sampled
    """
    def __init__(self, probs, augs):
        self.probs = jax.numpy.array(probs)
        self.augs = augs

This will allow us to express the distribution over multiple augmentations mentioned above via:

AugmentationDistribution([0.5,0.5],[Flip,GaussianBlur])

A proof of concept for the above refactor is shown below

import jax

key = jax.random.PRNGKey(0)
key, subkey = jax.random.split(key)

class Augmentation:
    def __init__(self, add):
        self.add = add

    def __call__(self, x):
        return self.add+x

class AugmentationDistribution:
    def __init__(self, probs, augs):
        self.probs = jax.numpy.array(probs)
        self.augs = augs
    def apply(self, rng, x):
        key, subkey = jax.random.split(rng)
        sampledIndex = jax.random.choice(subkey,len(self.augs),p=self.probs)
        x = jax.lax.switch(sampledIndex,self.augs,x)
        return x, key

identity = Augmentation(0.0)
addOne = Augmentation(1.0)
addHalf = Augmentation(0.5)

pipeline = [(AugmentationDistribution([0.2,0.8],[addOne,addHalf])),
            (AugmentationDistribution([0.2,0.8],[addOne,identity]))]

@jax.jit
def applyAugPipeline(rng, x):
    for augDist in pipeline:
        x,rng = augDist.apply(rng,x)
    return x

print(applyAugPipeline(key,0))