dotnet / TorchSharp

A .NET library that provides access to the library that powers PyTorch.
MIT License
1.41k stars 183 forks source link

Fixing Tensor.backward's function signature #1376

Closed JamesG9802 closed 1 month ago

JamesG9802 commented 2 months ago

Fixes #692

TLDR: Tensor.backward has a different parameter order compared to PyTorch and also swaps retain_graph and create_graph in its internal function call.

See https://pytorch.org/docs/stable/generated/torch.Tensor.backward.html for backward's function signature: Tensor.backward(gradient=None, retain_graph=None, create_graph=False, inputs=None)

The current TorchSharp version's function signature is: Tensor.backward(grad_tensors=null, create_graph=false, retain_graph=false, inputs=null)

Note the difference between the ordering of retain_graph and create_graph. Tensor.backward is just a wrapper to torch.autograd.backward which has a function signature of: autograd.backward(tensors, grad_tensors=null, retain_graph=null, create_graph=false, inputs=null)

This means calling Tensor.backward(retain_graph: true) in TorchSharp is actually Tensor.backward(create_graph:true) in PyTorch. Same thing for Tensor.backward(create_graph: true) actually being Tensor.backward(retain_graph:true).

The proposed fix is breaking and would change the Tensor.backward function signature to match PyTorch. However, nobody noticed for like 2 years anyway and imo retain_graph should actually mean retain_graph (and same for create_graph) 🙂.

JamesG9802 commented 1 month ago

Please add an entry in the RELEASENOTES.md file, and note that this is a breaking change (if you are passing the parameters by position rather than name).

No problem, I added the change to the release notes to NuGet Version 0.103.1, though I am not sure if that's the correct place to put it.

Let me know if there are any problems and I'll fix it,

JamesG9802 commented 1 month ago

@dotnet-policy-service agree