Right now, the probability of applying a distribution is a property of the Augmentation itself i.e.
class Augmentation:
"""
An augmentation applies a function to data.
Args:
prob (float): Probability that augmentation will be executed.
"""
def __init__(self, prob=1.0):
assert isinstance(prob, float), "prob must be of type float"
self.prob = prob
This variable defines the probability of applying an augmentation or not. This allows for pipelines like:
RandomFlip(0.5) → GaussianBlur(0.5) → ....
However, this framework is not flexible enough to allow for distributions over multiple augmentations:
sample one ( RandomFlip(0.5), GaussianBlur(0.5) ) → ....
This issue proposes refactoring Augmentation to be a deterministic transform with its probability being the property of an AugmentationDistribution class
class AugmentationDistribution:
"""
A categorical distribution of augmentations to apply to input
Args:
probs (List[float]): List of probabilities
augs (List[Augmentation]): List of augmentations to be sampled
"""
def __init__(self, probs, augs):
self.probs = jax.numpy.array(probs)
self.augs = augs
This will allow us to express the distribution over multiple augmentations mentioned above via:
Right now, the probability of applying a distribution is a property of the Augmentation itself i.e.
This variable defines the probability of applying an augmentation or not. This allows for pipelines like:
However, this framework is not flexible enough to allow for distributions over multiple augmentations:
This issue proposes refactoring Augmentation to be a deterministic transform with its probability being the property of an AugmentationDistribution class
This will allow us to express the distribution over multiple augmentations mentioned above via:
A proof of concept for the above refactor is shown below