SciML / OperatorLearning.jl

No need to train, he's a smooth operator
https://operatorlearning.sciml.ai/dev
MIT License
43 stars 8 forks source link

Proper handling of batches #13

Closed pzimbrod closed 2 years ago

pzimbrod commented 3 years ago

As for now, the constructor for the Fourier Layer specifically requests the sample size of the problem:

https://github.com/pzimbrod/NeuralOperator.jl/blob/e3facfa214a74d82db4b9d4c5c35dc8d10eeaa7b/src/FourierLayer.jl#L58-L86

This is highly undesirable since it introduces a huge amount of training parameters and slows down training considerably. It also limits the applicability of the trained model.

A re-write is needed to overcome this. For now, this is a workaround in order to do the matrix multiplication properly:

https://github.com/pzimbrod/NeuralOperator.jl/blob/e3facfa214a74d82db4b9d4c5c35dc8d10eeaa7b/src/FourierLayer.jl#L100-L101 https://github.com/pzimbrod/NeuralOperator.jl/blob/e3facfa214a74d82db4b9d4c5c35dc8d10eeaa7b/src/FourierLayer.jl#L109-L111

pzimbrod commented 3 years ago

Flux.jl itself solves this by overloading the function with x as an Array and does some reshaping:

(a::Dense)(x::AbstractArray) = 
  reshape(a(reshape(x, size(x,1), :)), :, size(x)[2:end]...)

In this form, that's not possible here. Maybe find a workaround to discern higher-dimensional arrays from the desired input.

pzimbrod commented 3 years ago

NNlib's batched_mul! might be an alternative as well as the @tullio macro. The former can handle batches (i.e. an additional dim in one of the arrays). However, Tullio.jl produces much faster results at least with Arrays containing real Floats. Need to evaluate this with complex numbers.

Besides, you can take care of the bias addition with the dotwise sum operator/broadcasting .+.

pzimbrod commented 2 years ago

Using batched_mul seems to do the trick along with simply broadcasting the bias addition via .+.

Though, we also need to broadcast the Fourier Tranforms now when working with pre-allocated plans. Currently, NNlib's batched_mul can't handle FFTW Plans. Overloading the function could help though - or stick with regular FFT for now.