Closed bruhwiler closed 1 year ago
Which part of rslaser code in particular? Propagation?
Yes, propagation. If you want to be specific, then let's consider propagation types 'n0n2_lct' and 'gain_calc'.
I'm thinking at a high level about the choice between PyTorch vs Numba vs some other technology.
@jeinstei let me know about https://www.taichi-lang.org which is an alternative to numba.
@jeinstei -- do you have time to consider Rob's request in the previous comment?
I think they all are options....if the code is structured to decouple the kernels from the front-end, we could probably experiment on a couple different options even. For one-offs, I've been proposing arrayfire
, and cupy
and rapids
are sorta a standard non-language approach. Those are my two cents. And quick and easy to implement generally.
Note that we are only considering single node / single GPU for this exercise. Multiple GPUs and multiple nodes are out of scope.
The first step should be to create a test script that invokes the n0n2_lct propagator, with large values for nx and ny @gurhar1133 -- please coordinate with @k-wolfinger to create the test script.
The second step should be to profile the test script on a single CPU.
Then we can do comparison runs on a GPU via the technologies suggested in previous comments.
To be clear, I'm not trying to get to production code -- that will wait until the Phase 2b, if it gets funded. What I want is a short-term effort to reasonably justify the selection of a technology to be used in the future.
Performance/speed is an important criterion, of course. However, a technology that requires minimal changes to the serial source code would also be attractive.
@bruhwiler @k-wolfinger Do we have any "reference simulations" or configurations that we can agree upon to use for profiling? I'm thinking ones that would highlight specific simulation issues, such as:
The fundamental calculation is a laser pulse propagating through a Ti:Sapphire crystal with thermal effects that lead to weak focusing, followed by a drift, where the pulse converges slightly.
The laser pulse can be represented by N slices, where N=10 is a good default, but we can reasonably increase this to 100 or larger. Each laser pulse slice is represented by a wavefront, consisting of Re(Ex), Re(Ey), Im(Ex) and Im(Ey) on a 2D Cartesian mesh. The crystal is represented by M slices, where again M=10 is a good default, but we can reasonably increase this to 100 or larger.
The Cartesian wavefront meshes can be 32 x 32 as a default, but we can imagine that higher resolution simulations will be important in the future, with mesh sizes of 1024 x 1024. When propagating each wavefront through each crystal slice, the linear canonical transform (LCT) developed by Boaz/Dan/Ilya is used, which includes 1 or more FFTs.
In terms of varying the complexity and intensity of the floating point operations, there are two primary axes:
I'm going to be "that guy" and ask -- do we have specific commands to run for benchmarking yet? Maybe one config for each of the 4 levels of complexity?
I haven't actually run rslaser yet, so I might be missing something about the practical usage
We have been mostly exercising the library via notebooks. There are some tests that exercise the library from Python.
@k-wolfinger has agreed to create a plain Python example that executes the scenario described above. It will be clear from her code how to vary the mesh dimensions and the number of slices.
@gurhar1133 will then create a few examples from Kathryn's code and put them in the repo, so that he can begin doing some profiling on the CPU, and we will have concrete code to consider for the comparative analysis of the various GPU technologies.
For the n0n2_lct
propagator there are several calls to the SRW code:
srwlib.SRWLWfr
once and srwutil.calc_int_from_wfr
four timesGiven that our algorithms are still dependent on SRW, we cannot proceed with this activity as originally planned. We do plan to remove this dependence on SRW in the future.
Don't think about any code that invokes SRW. This issue is just about Python.