alexeedm / pytorch-fortran

Pytorch bindings for Fortran
MIT License
85 stars 12 forks source link
deep-learning fortran pytorch

Pytorch Fortran bindings

The goal of this code is to provide Fortran HPC codes with a simple way to use Pytorch deep learning framework. We want Fortran developers to take advantage of rich and optimized Torch ecosystem from within their existing codes. The code is very much work-in-progress right now and any feedback or bug reports are welcome.

Features

Building

To assist with the build, we provide the Docker and HPCCM recipe for the container with all the necessary dependencies installed, see container

You'll need to mount a folder with the cloned repository into the container, cd into this folder from the running container and execute ./make_nvhpc.sh, ./make_gcc.sh or ./make_intel.sh depending on the compiler you want to use.

To enable the GPU support, you'll need the NVIDIA HPC SDK build. GNU compiler is ramping up its OpenACC implementation, and soon may also be supported. Changing the compiler is possible by modifying CMAKE_Fortran_COMPILER cmake flag. Note that we are still working on testing different compilers, so issues are possible.

Examples

examples folder contains three samples:

API

Keep in mind that order of the array dimensions is different in Fortran and C/Pytorch. Fortran's contiguous dimension is the first one, while in Pytorch the contiguous dimension is the last one. Therefore, in order for the Fortran input to match the Pytorch expectation, the order of the Fortran input array dimensions must be the inverse of the Pytorch input tensor. The library will take care of correctly matching the dimensions in this case without any data movement.

We are working on documenting the full API. Please refer to the examples for more details. The bindings are provided through the following Fortran classes:

Class torch_tensor

This class represents a light-weight Pytorch representation of a Fortran array. It does not own the data and only keeps the respective pointer. Supported arrays of ranks up to 7 and datatypes real32, real64, int32, int64. Members:

Class torch_tensor_wrap

This class wraps a few tensors or scalars that can be passed as input into Pytorch models. Arrays and scalars must be of types real32, real64, int32 or int64. Members:

Class torch_module

This class represents the traced Pytorch model, typically a result of torch.jit.trace or torch.jit.script call from your Python script. This class in not thread-safe. For multi-threaded inference either create a threaded Pytorch model, or use a torch_module instance per thread (the latter could be less efficient). Members:

Class torch_pymodule

This class represents the Pytorch Python script and required the interpreter to be called. Only one torch_pymodule can be opened at a time due to the Python interpreter limitation. Overheads calling this class are higher than with torch_module, but contrary to the torch_module%train one can now train their Pytorch model with any optimizer, dropouts, etc. The intended usage of this class is to run online training with a complex pipeline that cannot be expressed as TorchScript. Members:

Changelog

v0.4

v0.3