Cambridge-ICCS / FTorch

A library for directly calling PyTorch ML models from Fortran.
https://cambridge-iccs.github.io/FTorch/
MIT License
48 stars 11 forks source link

Adjust the API in anticipation of future overloading #143

Closed jatkinson1000 closed 4 days ago

jatkinson1000 commented 5 days ago

Change the API to make assignment of data to a torch_tensor a subroutine.

This is in anticipation of the AutoGrad wotk led by @jwallwork23 in #139 This work may take a while to complete, but we need to stablise the API in advance of the summer school and as we move towards a first release and JOSS.

This makes the changes to the API that are required for this work without changing anything underneath.

jatkinson1000 commented 5 days ago

One thought @jwallwork23 and @TomMelt:

After this most things will be subroutines to be called from Fortran.

Do we want to update the module loading to be:

call torch_module_load(torch_model, filename, device_type)

rather than:

torch_model = torch_module_load(filename, device_type)

?

And do we want to make it torch_model_load() rather than torch_module_load? I have had a couple of comments in the past about people misunderstanding what is meant by 'module' in this context (A Neural Net subclassed from PyTorch's nn.Module).

dorchard commented 5 days ago

I like the naming of torch_model_load - it might indeed be more clear.

jwallwork23 commented 5 days ago

Thanks very much for this @jatkinson1000! I was going to suggest doing this separately, too.

Do we want to update the module loading to be:

call torch_module_load(torch_model, filename, device_type)

rather than:

torch_model = torch_module_load(filename, device_type)

Once the assignment operator is overloaded for tensors, any function that returns a tensor will implicitly do an additional copy, so it's probably best to turn any functions that return tensors into subroutines.

And do we want to make it torch_model_load() rather than torch_module_load?

I agree with @dorchard that torch_model_load is clearer.

jatkinson1000 commented 4 days ago

The latest commits refer to the model as a 'model' rather than a module (note that it still has to be referred to as a module when binding to C++ as this is what the Torch API uses).

It also makes the module loading a subroutine rather than a function. I have left the get device as a function as this feels more appropriate in this instance.

jatkinson1000 commented 4 days ago

Thanks @dorchard I have addressed those in a rebase to keep history clean. Ready for a re-review I think.