Closed roflmaostc closed 4 months ago
Also in Python this seems to make a difference:
In [12]: xc = torch.rand(128, 128, 128).cuda()
In [13]: xc.shape
Out[13]: torch.Size([128, 128, 128])
In [14]: xc.dtype
Out[14]: torch.float32
In [15]: %%timeit
...: torch.fft.fftn(xc, dim=(0,1))
...: torch.cuda.synchronize()
...:
...:
471 µs ± 201 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
In [16]:
In [16]: %%timeit
...: torch.fft.fftn(xc, dim=(1,2))
...: torch.cuda.synchronize()
...:
...:
295 µs ± 277 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Not sure why torch is faster than Julia though. (opened an issue at CUDA.jl)
Thanks Felix, this is a great catch! It is also sometimes slower in JAX in the order that we do things:
chromatix arrangement fft2 time for (128, 128, 128): 0.00011737400200217962 +/- 6.774095103994204e-05
inner arrangement fft2 time for (128, 128, 128): 9.083172772079706e-05 +/- 1.7062060562748385e-06
chromatix arrangement fft2 time for (128, 256, 256): 0.00021060981089249253 +/- 2.100722751556226e-05
inner arrangement fft2 time for (128, 256, 256): 0.00021538520231842995 +/- 2.5950248370836355e-06
chromatix arrangement fft2 time for (128, 512, 512): 0.0005954901920631528 +/- 3.403699545818871e-05
inner arrangement fft2 time for (128, 512, 512): 0.0006192663800902665 +/- 2.50930329870112e-06
chromatix arrangement fft2 time for (128, 1024, 1024): 0.002230075397528708 +/- 0.00036002158224289005
inner arrangement fft2 time for (128, 1024, 1024): 0.002260779996868223 +/- 8.880622902105098e-06
chromatix arrangement fft2 time for (128, 2048, 2048): 0.009496355836745352 +/- 0.0018868240660353255
inner arrangement fft2 time for (128, 2048, 2048): 0.009488527581561356 +/- 1.5524316672778253e-05
All times are in seconds averaged over 10 runs on an H100.
This layout is a choice we made at the beginning of this project (when we didn't have the vectorial dimension) so that our layout would match the default neural network image layout in JAX/TF of (batch height width channels)
.
I think that the difference becomes much less significant for larger size and also flip flops between which order is faster, plus we would likely have to look at what happens if we combine the simulations with neural networks to see how that would affect overall timings. So I'm not sure if we want to change anything actually.
I'm also not sure about putting the vectorial dimensions on the outside, because that would change how the matrix multiplications operate as well. But we should definitely investigate (tagging @GJBoth).
Let me try and explain your observations, and why it’s probably not an issue for us. I don’t have access to a computer right now and am pretty jet lagged, but I think I’m correct. Would love to see this code for Jax though!
All this stuff has to do with memory layout - python is row major so stuff is stored along the first axes. That’s why using (1,2) as axes is slower than (2,3) - using (2,3) will load memory faster because it’s contiguous. This is why python puts the dimensional axes along the rows instead of columns.
So when there’s multiple wavelengths and polarization indeed this might be slower, but vice versa operations broadcasting over the polarization and wavelength axes will be slower. For example, all our polarization stuff will take a hit. Ideal memory layout depends on your specific use case, and we choose this axes layout to stay in line with other python approaches.
As for why it’s probably not an issue for us and the difference will be minor (but do double check): we have the batch axis in front, and this axis is usually much much bigger than the amount of wavelengths.
I’m also wondering if the Jax compiler doesn’t compile this away.
I'm surprised that jax does not follow the standard convention.
Yes, the observations are fine and consistent (in Julia just reversed).
Aren't the polarization operations usually just elementwise? Isn't the FFT the bottleneck, at the end of the day?
I guess a profiler would help to test some examples...
The vectorial operations are matrix multiplies at each location, which means we would want those dimensions to be contiguous.
Also there is not really a standard convention (assuming you are talking about the layout of (batch height width channels)
); this layout as far as I know depends on what the faster layout was depending on the available CUDA kernels at the time. So PyTorch chose a different layout than TF and both stick with their own conventions.
And as you can see our current layout can be faster even for FFT depending on the size.
To echo @diptodip, I’m nearly 100% sure both Jax and PyTorch use [H, W, C] as the CNNs operate over the channel axis.
And for the polarization, in free space yes but in samples they often become matrix multiplication as you get mixing.
PyTorch actually uses channels first as the default but in both you have the option to choose.
But in general you would recommend a 5D or 5D scheme?
For my Julia package I'm tempted to use (x,y,z, polarization,wavelength,batch)
.
We've found keeping all those dimensions to be useful (except z), and that order seems fine to me for Julia (though I don't know what happens with other Julia libraries or neural networks/if that's a concern for you).
Closing for now, feel free to reopen if you have more questions.
Hi,
for our Julia package WaveOpticsPropagation.jl I'm trying to change our dimension layout a bit.
It seems like the (x,y) dimensions are not the most inner ones: https://chromatix.readthedocs.io/en/latest/101/#creating-a-field
Doesn't that create a bottleneck because usually the FFTs are calculated along x,y?
For example, in Julia this makes a factor of 4 difference:
I was wondering how I should do the layout in Julia. Because of the different memory layout, I would do
(x,y,z, polarizations, wavelength, batch)
.CC: @RainerHeintzmann