Closed mj023 closed 2 days ago
Great explanation! From my perspective, you won't have to go into such detail -- it is not that we are talking about a library that is in use across thousands of weird environments here. Everyone can update their environments without problems.
This PR changes the dependencies for the pixi environments to require JAX>= 0.4.34 and jaxlib >= 0.4.34. The reason for this is, that #77 created memory allocation issues when using older JAX and jaxlib versions. For larger models this made them impossible to solve with limited GPU memory.
Problem description With #77 the solve_model() function returned by LCM is jitted by default. After this change, one would run into problems when solving a model on the GPU. Jax would throw an error, because the program tried to save huge arrays, with the same dimensions as the the whole state-choice-space, into the GPU memory. For larger models these arrays could be multiple TB big. In the past high memory usage has never been a problem and considering the algorithm used for LCM, there should be no reason to save these arrays.
Reasons for Memory Allocation Issues When jitting a function with JAX a computation graph will be created, that graph will then be passed to a compiler for further optimization. The Memory Allocation Issues probably stem from this optimization step. It is possible to visualize the computation graph before and after the compiler optimization. Below you can see the optimized computation graph with older Jax and Jaxlib versions. For some reason, the compiler splits the fusion into two parts, instead of creating one big fusion, with a reduce operator as the root. The arrays get passed as parameters from one fusion to the other and therefore are saved in the GPU memory.