FluxML / DataAugmentation.jl

Flexible data augmentation library for machine and deep learning
https://fluxml.ai/DataAugmentation.jl/dev/
MIT License
41 stars 18 forks source link

Optimize multiple crops #99

Open paulnovo opened 2 weeks ago

paulnovo commented 2 weeks ago

Update CroppedProjectiveTranform to optimize projective transforms followed by multiple crops, not just one crop. For instance:

Rotate(10) |> CenterCrop((100, 100)) |> RandomCrop((50, 50))

Is now optimized to warp only into the 50x50 region, instead of to the 100x100 region.

One caveat of this implementation is that crops that enlarge the region after a smaller crop will be different due to extrapolation of the smaller crop being skipped. ie this will give different results now:

Rotate(10) |> CenterCrop((50, 50)) |> RandomCrop((100, 100))

I wonder if enlarging crops should be disallowed anyways, since offsetcropbounds doesn't appear to be written with it in mind (ie sz>bounds), unless I am missing something. Thoughts?

PR Checklist

CarloLucibello commented 4 days ago

I wonder if enlarging crops should be disallowed anyways, since offsetcropbounds doesn't appear to be written with it in mind (ie sz>bounds), unless I am missing something.

How do other frameworks like torchvision handle enlarging the crops?

paulnovo commented 3 days ago

Tochvision throws a ValueError if the crop is larger than the input image

In [1]: from torchvision.transforms import v2
   ...: from torchvision.io import read_image
   ...: 
   ...: img = read_image( 'astronaut.jpg')
   ...: print(f"{img.shape = }")
img.shape = torch.Size([3, 512, 512])

In [2]: transform = v2.RandomCrop(size=(100, 100))
   ...: out = transform(img)

In [3]: print(f"{out.shape = }")
out.shape = torch.Size([3, 100, 100])

In [4]: transform = v2.RandomCrop(size=(1000, 1000))
   ...: out = transform(img)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[4], line 2
      1 transform = v2.RandomCrop(size=(1000, 1000))
----> 2 out = transform(img)

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1553, in Module._wrapped_call_impl(self, *args, **kwargs)
   1551     return self._compiled_call_impl(*args, **kwargs)  # type: ignore[misc]
   1552 else:
-> 1553     return self._call_impl(*args, **kwargs)

File /usr/local/lib/python3.11/dist-packages/torch/nn/modules/module.py:1562, in Module._call_impl(self, *args, **kwargs)
   1557 # If we don't have any hooks, we want to skip the rest of the logic in
   1558 # this function, and just call forward.
   1559 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1560         or _global_backward_pre_hooks or _global_backward_hooks
   1561         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1562     return forward_call(*args, **kwargs)
   1564 try:
   1565     result = None

File /usr/local/lib/python3.11/dist-packages/torchvision/transforms/v2/_transform.py:46, in Transform.forward(self, *inputs)
     43 self._check_inputs(flat_inputs)
     45 needs_transform_list = self._needs_transform_list(flat_inputs)
---> 46 params = self._get_params(
     47     [inpt for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list) if needs_transform]
     48 )
     50 flat_outputs = [
     51     self._transform(inpt, params) if needs_transform else inpt
     52     for (inpt, needs_transform) in zip(flat_inputs, needs_transform_list)
     53 ]
     55 return tree_unflatten(flat_outputs, spec)

File /usr/local/lib/python3.11/dist-packages/torchvision/transforms/v2/_geometry.py:870, in RandomCrop._get_params(self, flat_inputs)
    867         padded_width += 2 * diff
    869 if padded_height < cropped_height or padded_width < cropped_width:
--> 870     raise ValueError(
    871         f"Required crop size {(cropped_height, cropped_width)} is larger than "
    872         f"{'padded ' if self.padding is not None else ''}input image size {(padded_height, padded_width)}."
    873     )
    875 # We need a different order here than we have in self.padding since this padding will be parsed again in `F.pad`
    876 padding = [pad_left, pad_top, pad_right, pad_bottom]

ValueError: Required crop size (1000, 1000) is larger than input image size (512, 512).

You can pass pad_if_needed=True to the crop, but that isn't the default behavior

CarloLucibello commented 2 days ago

we could error as well, but I'll leave you the decision to choose the most sensible behavior