Joshuaalbert / DSA2000-Cal

DSA-2000 Calibration and Forward Modelling
https://www.deepsynoptic.org/overview
MIT License
1 stars 1 forks source link

SFM to ray #129

Open Joshuaalbert opened 9 hours ago

Joshuaalbert commented 9 hours ago

Using a single kernel takes ages to compile and I'm not sure there is much performance gain, as well the sharding machinery is getting stuck I think. Using >500GB for sharded predict over T/C despite using scan_dims={'T', 'C'}. I think this can be simplified and the expense of breaking up the computation into streaming actor layers, rather than a single JAX kernel.

The transition is relatively simple. For each SFM core step we refactor the get_state to provide the initial state, then for each layer we spin up N[layer] idenical actors which feed off a single queue, which uses object refs to pass data. The __reduce__ function is necessary for pytrees to pass efficiently.

Thus as we execute the DAG each step has a supervisor actor that is responsible for dishing out work. Serial steps, e.g. simulating dish, can have single worker actors responsible for the work. Thus, the size of each layer is controllable. Only horizontal layers will take advantage of scale.

Joshuaalbert commented 8 hours ago

Currently, performance for a single solint (6s, 40chan) without bright source prediction is :

Total run time: 1203.54s
Initialisation time: 610.28s
Compilation time: 59.41s
Run time: 530.97s

roughly 100x slower than real time. I think we can and must do better. We require real-time. Some part of this must be related to non-optimal JAX compilation. Combine with isoloated performance tests to see what we should be getting.