Cambridge-ICCS / FTorch

A library for directly calling PyTorch ML models from Fortran.
https://cambridge-iccs.github.io/FTorch/
MIT License
64 stars 14 forks source link

Implement batching #157

Open jwallwork23 opened 2 months ago

jwallwork23 commented 2 months ago

When we run examples/1_SimpleNet/simplenet.py, the final thing that's executed is effectively

a = [0.0, 1.0, 2.0, 3.0, 4.0]
model(torch.Tensor(a))

This would also work with batching e.g.,

a = [0.0, 1.0, 2.0, 3.0, 4.0]
model(torch.Tensor([a, a]))

We should enable calling batching in this way on the FTorch side, too.

jwallwork23 commented 1 month ago

To investigate: is Torch underneath smart enough to do this, or will we have to loop?

jatkinson1000 commented 1 month ago

It would still be good to implement this on the FTorch side, but worth noting that batching can be incorporated into the pytorch side (to then take arbitrary sized (in one dimension) Fortran arrays) with a little thought and care.

This is what was done for MiMA here: https://github.com/DataWaveProject/MiMA-machine-learning/blob/ML/src/shared/pytorch/arch_davenet.py Though it is not the easiest to follow.

jwallwork23 commented 1 month ago

Might still be worth comparing performance between implementing this on the Fortran side vs Torch.