Open Joshuaalbert opened 9 hours ago
Currently, performance for a single solint (6s, 40chan) without bright source prediction is :
Total run time: 1203.54s
Initialisation time: 610.28s
Compilation time: 59.41s
Run time: 530.97s
roughly 100x slower than real time. I think we can and must do better. We require real-time. Some part of this must be related to non-optimal JAX compilation. Combine with isoloated performance tests to see what we should be getting.
Using a single kernel takes ages to compile and I'm not sure there is much performance gain, as well the sharding machinery is getting stuck I think. Using >500GB for sharded predict over T/C despite using scan_dims={'T', 'C'}. I think this can be simplified and the expense of breaking up the computation into streaming actor layers, rather than a single JAX kernel.
The transition is relatively simple. For each SFM core step we refactor the
get_state
to provide the initial state, then for each layer we spin upN[layer]
idenical actors which feed off a single queue, which uses object refs to pass data. The__reduce__
function is necessary for pytrees to pass efficiently.Thus as we execute the DAG each step has a supervisor actor that is responsible for dishing out work. Serial steps, e.g. simulating dish, can have single worker actors responsible for the work. Thus, the size of each layer is controllable. Only horizontal layers will take advantage of scale.