This is an RFC that continues the discussion #5000 by @ain-soph and PRs: #4995 and #5110 on updating functional tensor methods from F.* to accept learnable parameters (tensors with requires_grad=True) and propagating the gradient.
Torchvision transformations can work on PIL images and torch Tensors and accept scalars, list of scalars as parameters. For example,
x = torch.rand(1, 3, 32, 32)
alpha = 45
center = [1, 2]
out = F.rotate(x, alpha, interpolation=InterpolationMode.BILINEAR, center=center)
# out is tensor
The proposal is to be able to learn parameters like alpha and center using gradients descent:
x = torch.rand(1, 3, 32, 32)
- alpha = 45
+ alpha = torch.tensor(45.0, requires_grad=True)
- center = [1, 2]
+ center = torch.tensor([1.0, 2.0], requires_grad=True)]
out = F.rotate(x, alpha, interpolation=InterpolationMode.BILINEAR, center=center)
# out is tensor that requires grad
assert out.requires_grad
# parameters can have grads:
out.mean().backward() # some dummy criterion
assert alpha.grad is not None
assert center.grad is not None
and also keep previous API (no BC breaking changes).
Note: we need to keep transforms torch jit scriptable and thus we can also be limited by what is supported by torch jit script (simply adding Union[float, Tensor] does not always work and can break compatibility with the stable version).
In terms of implementation, we have to ensure that:
methods with updated parameters still support all previous data types
methods are torch jit scriptable
methods verify that input image is float tensor (no grad propagation otherwise)
methods propagate grads for tensor inputs <=> all internal ops for tensor branch are propagating grads
only floating parameters can accept values as Tensors
for example, rotate with learnable floating angle
IMO, we can't make output (integer) size learnable in resize op (please fix me if there is a way)
certain integer parameters can be promoted to float, e.g. translate in affine
This is an RFC that continues the discussion #5000 by @ain-soph and PRs: #4995 and #5110 on updating functional tensor methods from
F.*
to accept learnable parameters (tensors withrequires_grad=True
) and propagating the gradient.For the motivation and the context, please see https://github.com/pytorch/vision/issues/5000
Proposal
Torchvision transformations can work on PIL images and torch Tensors and accept scalars, list of scalars as parameters. For example,
The proposal is to be able to learn parameters like
alpha
andcenter
using gradients descent:and also keep previous API (no BC breaking changes).
Implementation
In terms of API, it would require updates like:
Note: we need to keep transforms torch jit scriptable and thus we can also be limited by what is supported by torch jit script (simply adding
Union[float, Tensor]
does not always work and can break compatibility with the stable version).In terms of implementation, we have to ensure that:
Example with affine and rotate ops : #5110
Transforms to update
Please comment here if I'm missing any op that we could add into the list.
cc @vfdev-5 @datumbox