pytorch / vision

Datasets, Transforms and Models specific to Computer Vision
https://pytorch.org/vision
BSD 3-Clause "New" or "Revised" License
16.33k stars 6.97k forks source link

[RFC] Make `transforms.functional` methods differential w.r.t. their parameters #5157

Open vfdev-5 opened 2 years ago

vfdev-5 commented 2 years ago

This is an RFC that continues the discussion #5000 by @ain-soph and PRs: #4995 and #5110 on updating functional tensor methods from F.* to accept learnable parameters (tensors with requires_grad=True) and propagating the gradient.

For the motivation and the context, please see https://github.com/pytorch/vision/issues/5000

Proposal

Torchvision transformations can work on PIL images and torch Tensors and accept scalars, list of scalars as parameters. For example,

x = torch.rand(1, 3, 32, 32)
alpha = 45
center = [1, 2]
out = F.rotate(x, alpha, interpolation=InterpolationMode.BILINEAR, center=center)
# out is tensor

The proposal is to be able to learn parameters like alpha and center using gradients descent:

x = torch.rand(1, 3, 32, 32)
- alpha = 45
+ alpha = torch.tensor(45.0, requires_grad=True)
- center = [1, 2]
+ center = torch.tensor([1.0, 2.0], requires_grad=True)]
out = F.rotate(x, alpha, interpolation=InterpolationMode.BILINEAR, center=center)
# out is tensor that requires grad
assert out.requires_grad

# parameters can have grads:
out.mean().backward()  # some dummy criterion
assert alpha.grad is not None
assert center.grad is not None

and also keep previous API (no BC breaking changes).

Implementation

In terms of API, it would require updates like:

def rotate(
    img: Tensor,
-   angle: float,
+   angle: Union[float, int, Tensor],
    interpolation: InterpolationMode = InterpolationMode.NEAREST,
    expand: bool = False,
-   center: Optional[List[int]] = None,
+   center: Optional[Union[List[int], Tuple[int, int], Tensor]] = None,
    fill: Optional[List[float]] = None,
    resample: Optional[int] = None,
) -> Tensor:

Note: we need to keep transforms torch jit scriptable and thus we can also be limited by what is supported by torch jit script (simply adding Union[float, Tensor] does not always work and can break compatibility with the stable version).

In terms of implementation, we have to ensure that:

Example with affine and rotate ops : #5110

Transforms to update

Please comment here if I'm missing any op that we could add into the list.

cc @vfdev-5 @datumbox

ain-soph commented 2 years ago

I think gaussian_blur kernel_size, posterize and solarize are not differentiable in mathematics. Maybe we can just ignore them?