chromatix-team / chromatix

Differentiable wave optics using JAX! Documentation can be found at https://chromatix.readthedocs.io
MIT License
73 stars 13 forks source link

Dimension arrangements #142

Closed roflmaostc closed 4 months ago

roflmaostc commented 4 months ago

Hi,

for our Julia package WaveOpticsPropagation.jl I'm trying to change our dimension layout a bit.

It seems like the (x,y) dimensions are not the most inner ones: https://chromatix.readthedocs.io/en/latest/101/#creating-a-field

Doesn't that create a bottleneck because usually the FFTs are calculated along x,y?

For example, in Julia this makes a factor of 4 difference:

julia> using CUDA,  CUDA.CUFFT, BenchmarkTools

julia> xc = CUDA.rand((128, 128, 128)...);

julia> p = plan_fft(xc, (2,3));

julia> @benchmark CUDA.@sync $p * $x
ERROR: UndefVarError: `x` not defined
Stacktrace:
 [1] top-level scope
   @ ~/.julia/packages/BenchmarkTools/QNsku/src/execution.jl:496

julia> @benchmark CUDA.@sync $p * $xc
BenchmarkTools.Trial: 2681 samples with 1 evaluation.
 Range (min … max):  1.792 ms …  25.102 ms  ┊ GC (min … max): 0.00% … 96.38%
 Time  (median):     1.849 ms               ┊ GC (median):    0.00%
 Time  (mean ± σ):   1.859 ms ± 449.835 μs  ┊ GC (mean ± σ):  0.69% ±  3.01%

                        ▁▄▂▂▃▆▆▅▆█▆▇▅▇▄▅█▃▃▂▃▁                 
  ▂▁▁▂▁▃▂▁▂▂▃▃▃▃▄▄▄▅▅▆██████████████████████████▆▅▅▅▄▄▃▃▃▃▃▃▂ ▅
  1.79 ms         Histogram: frequency by time        1.89 ms <

 Memory estimate: 3.86 KiB, allocs estimate: 157.

julia> p = plan_fft(xc, (1,2));

julia> @benchmark CUDA.@sync $p * $xc
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  414.595 μs …  25.045 ms  ┊ GC (min … max): 0.00% … 97.07%
 Time  (median):     419.495 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   425.350 μs ± 248.092 μs  ┊ GC (mean ± σ):  1.00% ±  2.90%

       ▂▆▇█▇▅▃▂▁▁▁                                               
  ▂▂▃▄▇████████████▇▅▄▃▃▃▃▂▂▂▂▂▂▃▃▃▃▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▂▁▂▁▂▁▂▂▁▂ ▃
  415 μs           Histogram: frequency by time          444 μs <

 Memory estimate: 4.09 KiB, allocs estimate: 161.

I was wondering how I should do the layout in Julia. Because of the different memory layout, I would do (x,y,z, polarizations, wavelength, batch).

CC: @RainerHeintzmann

roflmaostc commented 4 months ago

Also in Python this seems to make a difference:

In [12]: xc = torch.rand(128, 128, 128).cuda()

In [13]: xc.shape
Out[13]: torch.Size([128, 128, 128])

In [14]: xc.dtype
Out[14]: torch.float32

In [15]: %%timeit
    ...: torch.fft.fftn(xc, dim=(0,1))
    ...: torch.cuda.synchronize()
    ...: 
    ...: 
471 µs ± 201 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

In [16]: 

In [16]: %%timeit
    ...: torch.fft.fftn(xc, dim=(1,2))
    ...: torch.cuda.synchronize()
    ...: 
    ...: 
295 µs ± 277 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)

Not sure why torch is faster than Julia though. (opened an issue at CUDA.jl)

diptodip commented 4 months ago

Thanks Felix, this is a great catch! It is also sometimes slower in JAX in the order that we do things:

chromatix arrangement fft2 time for (128, 128, 128): 0.00011737400200217962 +/- 6.774095103994204e-05
inner arrangement fft2 time for (128, 128, 128): 9.083172772079706e-05 +/- 1.7062060562748385e-06
chromatix arrangement fft2 time for (128, 256, 256): 0.00021060981089249253 +/- 2.100722751556226e-05
inner arrangement fft2 time for (128, 256, 256): 0.00021538520231842995 +/- 2.5950248370836355e-06
chromatix arrangement fft2 time for (128, 512, 512): 0.0005954901920631528 +/- 3.403699545818871e-05
inner arrangement fft2 time for (128, 512, 512): 0.0006192663800902665 +/- 2.50930329870112e-06
chromatix arrangement fft2 time for (128, 1024, 1024): 0.002230075397528708 +/- 0.00036002158224289005
inner arrangement fft2 time for (128, 1024, 1024): 0.002260779996868223 +/- 8.880622902105098e-06
chromatix arrangement fft2 time for (128, 2048, 2048): 0.009496355836745352 +/- 0.0018868240660353255
inner arrangement fft2 time for (128, 2048, 2048): 0.009488527581561356 +/- 1.5524316672778253e-05

All times are in seconds averaged over 10 runs on an H100.

This layout is a choice we made at the beginning of this project (when we didn't have the vectorial dimension) so that our layout would match the default neural network image layout in JAX/TF of (batch height width channels).

I think that the difference becomes much less significant for larger size and also flip flops between which order is faster, plus we would likely have to look at what happens if we combine the simulations with neural networks to see how that would affect overall timings. So I'm not sure if we want to change anything actually.

I'm also not sure about putting the vectorial dimensions on the outside, because that would change how the matrix multiplications operate as well. But we should definitely investigate (tagging @GJBoth).

GJBoth commented 4 months ago

Let me try and explain your observations, and why it’s probably not an issue for us. I don’t have access to a computer right now and am pretty jet lagged, but I think I’m correct. Would love to see this code for Jax though!

All this stuff has to do with memory layout - python is row major so stuff is stored along the first axes. That’s why using (1,2) as axes is slower than (2,3) - using (2,3) will load memory faster because it’s contiguous. This is why python puts the dimensional axes along the rows instead of columns.

So when there’s multiple wavelengths and polarization indeed this might be slower, but vice versa operations broadcasting over the polarization and wavelength axes will be slower. For example, all our polarization stuff will take a hit. Ideal memory layout depends on your specific use case, and we choose this axes layout to stay in line with other python approaches.

As for why it’s probably not an issue for us and the difference will be minor (but do double check): we have the batch axis in front, and this axis is usually much much bigger than the amount of wavelengths.

I’m also wondering if the Jax compiler doesn’t compile this away.

roflmaostc commented 4 months ago

I'm surprised that jax does not follow the standard convention.

Yes, the observations are fine and consistent (in Julia just reversed).

Aren't the polarization operations usually just elementwise? Isn't the FFT the bottleneck, at the end of the day?

I guess a profiler would help to test some examples...

diptodip commented 4 months ago

The vectorial operations are matrix multiplies at each location, which means we would want those dimensions to be contiguous.

Also there is not really a standard convention (assuming you are talking about the layout of (batch height width channels)); this layout as far as I know depends on what the faster layout was depending on the available CUDA kernels at the time. So PyTorch chose a different layout than TF and both stick with their own conventions.

diptodip commented 4 months ago

And as you can see our current layout can be faster even for FFT depending on the size.

GJBoth commented 4 months ago

To echo @diptodip, I’m nearly 100% sure both Jax and PyTorch use [H, W, C] as the CNNs operate over the channel axis.

And for the polarization, in free space yes but in samples they often become matrix multiplication as you get mixing.

diptodip commented 4 months ago

PyTorch actually uses channels first as the default but in both you have the option to choose.

roflmaostc commented 4 months ago

But in general you would recommend a 5D or 5D scheme?

For my Julia package I'm tempted to use (x,y,z, polarization,wavelength,batch).

diptodip commented 4 months ago

We've found keeping all those dimensions to be useful (except z), and that order seems fine to me for Julia (though I don't know what happens with other Julia libraries or neural networks/if that's a concern for you).

GJBoth commented 4 months ago

Closing for now, feel free to reopen if you have more questions.