exgalsky / xgfield

Generation of mocks from a field representation of LSS on the observer's past light cone.
GNU General Public License v3.0
0 stars 0 forks source link

Profile memory usage and track memory leaks #3

Open marcelo-alvarez opened 1 year ago

marcelo-alvarez commented 1 year ago

We need to profile line-by-line memory usage of the light_cone codes (both MPI and JAX) and ensure that the memory usage matches to calculation. When running the MPI implementation for the full 6144 cube on Perlmutter, there was definite memory leak. We need to check for memory spikes during function calls and ensure the memory management is reasonable for future additions like 2LPT and multiple kernels.

(transferred from lptmap created by @1cosmologist: marcelo-alvarez/lptmap#5)

marcelo-alvarez commented 1 year ago

Expected memory profile for jax_cone.py

PCM = per cell memory Line No. PCM (in bytes) increment PCM (in bytes) total Total memory (MB) Increment (MB) Code
73 - - 96 96 skymap = np.zeros((npix,))
106 4 4 1728 1824 grid_sx = read_displacement(sxfile)
107 4 8 1728 3552 grid_sy = read_displacement(syfile)
108 4 12 1728 5280 grid_sz = read_displacement(szfile)
117 12 24 5184 10464 grid_qx, grid_qy, grid_qz = lagrange_mesh(xaxis, yaxis, zaxis, translation, lattice_size_in_Mpc)
121 4 28 1728 12192 lagrange_grid = jax.vmap(comoving_q, in_axes=(0, 0, 0, None, None), out_axes=0)(grid_qx, grid_qy, grid_qz, translation, lattice_size_in_Mpc)
124 4 32 1728 13920 redshift_grid = jax.vmap(cosmo_wsp.comoving_distance2z)(lagrange_grid)
128 8 40 3456 17376 ipix_grid = jhp.vec2pix(nside, grid_qz, grid_qy, grid_qx)
132 4 44 1728 19104 kernel_sphere = jnp.where((lagrange_grid >= chimin) & (lagrange_grid <= chimax), jax.vmap(lensing_kernel_F)(lagrange_grid, redshift_grid), 0.)
137 -12 32 -5184 13920 del kernel_sphere, ipix_grid
141 4 36 1728 15648 jax.vmap(cosmo_wsp.growth_factor_D)(redshift_grid)
146 4 40 1728 17376 grid_Xx = jax.vmap(euclid_i, in_axes=(0, 0, 0, None, None), out_axes=0)(grid_qx, grid_sx, growth_grid, lattice_size_in_Mpc, translation[0])
147 -4 36 -1728 15648 del grid_qx
151 4 40 1728 17376 grid_Xy = jax.vmap(euclid_i, in_axes=(0, 0, 0, None, None), out_axes=0)(grid_qy, grid_sy, growth_grid, lattice_size_in_Mpc, translation[1])
142 -4 36 -1728 15648 del grid_qy
156 4 40 1728 17376 grid_Xz = jax.vmap(euclid_i, in_axes=(0, 0, 0, None, None), out_axes=0)(grid_qz, grid_sz, growth_grid, lattice_size_in_Mpc, translation[2])
157 -8 32 -3456 13920 del grid_qz, growth_grid
163 8 40 3456 17376 ipix_grid = jhp.vec2pix(nside, grid_Xz, grid_Xy, grid_Xx)
164 -12 28 -5184 12192 del grid_Xx, grid_Xy, grid_Xz
168 4 32 1728 13920 kernel_sphere = jnp.where((lagrange_grid >= chimin) & (lagrange_grid <= chimax), jax.vmap(lensing_kernel_F)(lagrange_grid, redshift_grid), 0.)
168 -8 24 -3456 10464 del lagrange_grid, redshift_grid
174 -12 12 -5184 5184 del ipix_grid, kernel_sphere
178 -12 12 -5184 96 if store_displacements: del grid_sx, grid_sy, grid_sz
marcelo-alvarez commented 1 year ago

Line-by-line memory usage from memory profiler

Filename: jax_cone.py

Line #    Mem usage    Increment  Occurrences   Line Contents
=============================================================
    41    280.0 MiB    280.0 MiB           1   @profile(stream=memfile)
    42                                         def light_cone():
    43                                             # Lattice spacing (a_latt in Websky parlance) in Mpc
    44    280.0 MiB      0.0 MiB           1       lattice_size_in_Mpc = L_box / grid_nside  # in Mpc; 7700 Mpc box length for websky 6144 cube  
    45                                         
    46                                             # comoving distance to last scattering in Mpc
    47    280.0 MiB      0.0 MiB           1       comov_lastscatter = comov_lastscatter_Gpc * (cons.giga / cons.mega) # in Mpc
    48                                         
    49                                             # minimum and maximum radii of projection
    50    315.2 MiB     35.2 MiB           1       chimin = cosmo_wsp.comoving_distance(zmin)
    51    315.2 MiB      0.0 MiB           1       chimax = cosmo_wsp.comoving_distance(zmax)
    52                                         
    53    315.2 MiB      0.0 MiB           1       print("chimin, chimax: ",chimin,chimax)
    54                                             # NSIDE of HEALPix map 
    55    315.2 MiB      0.0 MiB           1       nside = 1024
    56    315.2 MiB      0.0 MiB           1       npix = hp.nside2npix(nside)
    57    315.2 MiB      0.0 MiB           1       solidang_pix = 4*np.pi / npix
    58                                         
    59                                             # Effectively \Delta chi, comoving distance interval spacing for LoS integral
    60    315.2 MiB      0.0 MiB           1       geometric_factor = lattice_size_in_Mpc**3. / solidang_pix
    61                                         
    62                                         
    63                                             # Setup axes for the slab grid
    64    316.6 MiB      1.4 MiB           1       xaxis = jnp.arange(0, grid_nside, dtype=jnp.int16)
    65    316.6 MiB      0.0 MiB           1       yaxis = jnp.arange(0, grid_nside, dtype=jnp.int16)
    66    316.6 MiB      0.0 MiB           1       zaxis = jnp.arange(0, grid_nside, dtype=jnp.int16)
    67                                         
    68                                             # Setup meshgrid for the slab 
    69                                             # grid_qx, grid_qy, grid_qz = jnp.meshgrid(xaxis, yaxis, zaxis, indexing='ij')        # 6 : 6
    70                                         
    71                                             # del xaxis, yaxis, zaxis
    72                                         
    73    316.6 MiB      0.0 MiB           1       skymap = np.zeros((npix,))    
    74    316.6 MiB      0.0 MiB           1       shift_param = grid_nside
    75    316.6 MiB      0.0 MiB           1       origin_shift = [(0,0,0)]#, (-shift_param,0,0), (0,-shift_param,0), (-shift_param,-shift_param,0),
    76                                                             #(0,0,-shift_param), (-shift_param,0,-shift_param), (0,-shift_param,-shift_param), (-shift_param,-shift_param,-shift_param)]
    77    316.6 MiB      0.0 MiB           1       t2 = time() ; print("Initial setup took", t2-t1, "s ")
    78                                         
    79                                             # Lagrangian comoving distance grid for the slab
    80   5500.6 MiB      0.0 MiB           2       @partial(jax.jit, static_argnames=['trans_vec', 'Dgrid_in_Mpc'])
    81    316.6 MiB      0.0 MiB           1       def lagrange_mesh(x_axis, y_axis, z_axis, trans_vec, Dgrid_in_Mpc):
    82   5500.6 MiB      0.0 MiB           1           return jnp.meshgrid( jnp.float32((x_axis + 0.5 + trans_vec[0]) * Dgrid_in_Mpc), jnp.float32((y_axis + 0.5 + trans_vec[1]) * Dgrid_in_Mpc), jnp.float32((z_axis + 0.5 + trans_vec[2]) * Dgrid_in_Mpc), indexing='ij')
    83                                         
    84  10686.5 MiB      0.0 MiB           2       @partial(jax.jit, static_argnames=['trans_vec', 'Dgrid_in_Mpc'])
    85    316.6 MiB      0.0 MiB           1       def comoving_q(x_i, y_i, z_i, trans_vec, Dgrid_in_Mpc):
    86  10686.5 MiB      0.0 MiB           1           return jnp.sqrt(x_i**2. + y_i**2. + z_i**2.).astype(jnp.float32)
    87                                                 # return jnp.sqrt((x_i + 0.5 + trans_vec[0])**2. + (y_i + 0.5 + trans_vec[1])**2. + (z_i + 0.5 + trans_vec[2])**2.) * Dgrid_in_Mpc
    88                                         
    89  17423.0 MiB      0.0 MiB           2       @partial(jax.jit, static_argnames=['Dgrid_in_Mpc', 'trans'])
    90    316.6 MiB      0.0 MiB           1       def euclid_i(q_i, s_i, growth_i, Dgrid_in_Mpc, trans):
    91  17423.0 MiB      0.0 MiB           1           return (q_i + growth_i * s_i).astype(jnp.float32)
    92                                                 # return q_i * Dgrid_in_Mpc + growth_i * s_i + 0.5 + trans*Dgrid_in_Mpc
    93                                         
    94  18906.5 MiB   1299.1 MiB           2       @jax.jit
    95    316.6 MiB      0.0 MiB           1       def lensing_kernel_F(comov_q_i, redshift_i):
    96  18906.5 MiB      0.0 MiB           1           return (geometric_factor * (3./2.) * cosmo_wsp.params['Omega_m'] * (cosmo_wsp.params['h'] * 100. * cons.kilo / cons.c )**2. * (1 + redshift_i) * (1. - (comov_q_i/comov_lastscatter)) / comov_q_i).astype(jnp.float32)
    97                                         
    98   3772.6 MiB      0.0 MiB           4       def read_displacement(filename):
    99   5500.6 MiB   5184.0 MiB           3           return jnp.asarray(np.fromfile(filename, count=grid_nside * grid_nside * grid_nside, dtype=jnp.float32).reshape((grid_nside, grid_nside, grid_nside)), dtype=jnp.float32)
   100                                         
   101    316.6 MiB      0.0 MiB           1       t3 = time() ; print("Jit compilation took", t3-t2, "s ")
   102                                         
   103                                         
   104    316.6 MiB      0.0 MiB           1       store_displacements=True
   105    316.6 MiB      0.0 MiB           1       if store_displacements:
   106   2044.6 MiB      0.0 MiB           1           grid_sx = read_displacement(sxfile)     # 4 : 10
   107   3772.6 MiB      0.0 MiB           1           grid_sy = read_displacement(syfile)     # 4 : 14
   108   5500.6 MiB      0.0 MiB           1           grid_sz = read_displacement(szfile)     # 4 : 18
   109                                         
   110   5500.6 MiB      0.0 MiB           1       t4 = time() ; print("I/O took", t4-t3, "s ")
   111                                         
   112   7055.2 MiB      0.0 MiB           2       for translation in origin_shift:
   113                                         
   114   5500.6 MiB      0.0 MiB           1           t4 = time()
   115   5500.6 MiB      0.0 MiB           1           print(translation)
   116                                         
   117  10686.4 MiB   5185.8 MiB           1           grid_qx, grid_qy, grid_qz = lagrange_mesh(xaxis, yaxis, zaxis, translation, lattice_size_in_Mpc)
   118                                         
   119  10686.4 MiB      0.0 MiB           1           t5 = time() ; print("Largrangian meshgrid took", t5 - t4, "s ")
   120                                         
   121  12417.5 MiB   1731.1 MiB           1           lagrange_grid = jax.vmap(comoving_q, in_axes=(0, 0, 0, None, None), out_axes=0)(grid_qx, grid_qy, grid_qz, translation, lattice_size_in_Mpc)    # 4 : 22
   122                                         
   123  12417.5 MiB      0.0 MiB           1           t6 = time() ; print("Lagrangian comoving distance grid took", t6 - t5, "s ")
   124  14149.2 MiB   1731.6 MiB           1           redshift_grid = jax.vmap(cosmo_wsp.comoving_distance2z)(lagrange_grid)      # 4 : 26
   125                                         
   126  14149.2 MiB      0.0 MiB           1           t7 = time() ; print("Redshift took", t7-t6, "s ")
   127                                                 # Compute healpix pixel grid from Lagrangian x, y, z values
   128  17607.3 MiB   3458.2 MiB           1           ipix_grid = jhp.vec2pix(nside, grid_qz, grid_qy, grid_qx)    # 8 : 38
   129                                         
   130  17607.3 MiB      0.0 MiB           1           t8 = time() ; print("HPX pixel grid (Lagrangian) took", t8-t7, "s ")
   131                                         
   132  20634.8 MiB   1728.4 MiB           1           kernel_sphere = jnp.where((lagrange_grid >= chimin) & (lagrange_grid <= chimax), jax.vmap(lensing_kernel_F)(lagrange_grid, redshift_grid), 0.)      # 4 : 42
   133                                         
   134  20634.8 MiB      0.0 MiB           1           t9 = time() ; print("Kernel grid (Lagrangian) took", t9-t8, "s ")
   135                                         
   136  20879.0 MiB    244.2 MiB           1           skymap += np.asarray(jnp.histogram(ipix_grid, bins=npix, range=(-0.5,npix-0.5), weights=-kernel_sphere, density=False)[0])      
   137  15695.0 MiB  -5184.0 MiB           1           del kernel_sphere, ipix_grid         # -20 : 22
   138                                         
   139  15695.0 MiB      0.0 MiB           1           t10 = time() ; print("Project to healpix (Lagrangian) took", t10-t9, "s ")
   140                                         
   141  17423.0 MiB   1728.0 MiB           1           growth_grid = jax.vmap(cosmo_wsp.growth_factor_D)(redshift_grid)        #   4 : 30
   142                                         
   143  17423.0 MiB      0.0 MiB           1           t11 = time() ; print("Growth took", t11-t10, "s ")
   144                                         
   145  17423.0 MiB      0.0 MiB           1           if not store_displacements: grid_sx = read_displacement(sxfile)
   146  19151.2 MiB   1728.2 MiB           1           grid_Xx = jax.vmap(euclid_i, in_axes=(0, 0, 0, None, None), out_axes=0)(grid_qx, grid_sx, growth_grid, lattice_size_in_Mpc, translation[0])     # 4 : 34
   147  17423.2 MiB  -1728.0 MiB           1           del grid_qx
   148  17423.2 MiB      0.0 MiB           1           if not store_displacements: del grid_sx
   149                                         
   150  17423.2 MiB      0.0 MiB           1           if not store_displacements: grid_sy = read_displacement(syfile)
   151  19151.2 MiB   1728.0 MiB           1           grid_Xy = jax.vmap(euclid_i, in_axes=(0, 0, 0, None, None), out_axes=0)(grid_qy, grid_sy, growth_grid, lattice_size_in_Mpc, translation[1])     # 4 : 38
   152  17423.2 MiB  -1728.0 MiB           1           del grid_qy
   153  17423.2 MiB      0.0 MiB           1           if not store_displacements: del grid_sy
   154                                         
   155  17423.2 MiB      0.0 MiB           1           if not store_displacements: grid_sz = read_displacement(szfile)
   156  19151.2 MiB   1728.0 MiB           1           grid_Xz = jax.vmap(euclid_i, in_axes=(0, 0, 0, None, None), out_axes=0)(grid_qz, grid_sz, growth_grid, lattice_size_in_Mpc, translation[2])     # 4 : 42
   157  15695.2 MiB  -3456.0 MiB           1           del grid_qz, growth_grid
   158  15695.2 MiB      0.0 MiB           1           if not store_displacements: del grid_sz
   159                                         
   160  15695.2 MiB      0.0 MiB           1           t12 = time() ; print("Displacements took", t12-t11, "s ")
   161                                         
   162                                                 # Compute healpix pixel grid from Euclidean x, y, z values
   163  19151.2 MiB   3456.0 MiB           1           ipix_grid = jhp.vec2pix(nside, grid_Xz, grid_Xy, grid_Xx)   # 8 : 50 
   164  13967.2 MiB  -5184.0 MiB           1           del grid_Xx, grid_Xy, grid_Xz               # -12 : 38
   165                                         
   166  13967.2 MiB      0.0 MiB           1           t13 = time() ; print("HPX pixel grid (Eulerian) took", t13-t12, "s ")
   167                                         
   168  15695.2 MiB   1728.0 MiB           1           kernel_sphere = jnp.where((lagrange_grid >= chimin) & (lagrange_grid <= chimax), jax.vmap(lensing_kernel_F)(lagrange_grid, redshift_grid), 0.)      # 4 : 42
   169  12239.2 MiB  -3456.0 MiB           1           del lagrange_grid, redshift_grid
   170                                         
   171  12239.2 MiB      0.0 MiB           1           t14 = time() ; print("Kernel grid (Eulerian) took", t14-t13, "s ")
   172                                         
   173  12239.2 MiB      0.0 MiB           1           skymap += np.asarray(jnp.histogram(ipix_grid, bins=npix, range=(-0.5,npix-0.5), weights=kernel_sphere, density=False)[0])
   174   7055.2 MiB  -5184.0 MiB           1           del ipix_grid, kernel_sphere           # -12 : 30
   175                                         
   176   7055.2 MiB      0.0 MiB           1           t15 = time() ; print("Project to healpix (Eulerian) took", t15-t14, "s ")
   177                                         
   178   7055.2 MiB      0.0 MiB           1       if store_displacements: del grid_sx, grid_sy, grid_sz
   179                                         
   180   7055.2 MiB      0.0 MiB           1       return skymap
marcelo-alvarez commented 1 year ago

The memory profile of line-by-line allocated memory for jax_cone.py matches expected allocated memory within 2 GB which comes from some function (jit compiled) overhead.

marcelo-alvarez commented 1 year ago

Runtime memory usage by JAX code: jax_cone

Runtime memory usage by Numpy-Healpy code: py_cone

The added memory usage comes only from the JAX XLA compiler preallocation of memory. Only two JAX based functions seem to be a cause for this overhead: jax-healpix vec2pix and jax-numpy histogram. JAX preallocates 90% memory at the start of the code but does not deallocate the memory, instead reusing them as needed. This adds a significant memory overhead which is about twice the expected memory usage. We may consider XLA environment variables to change this behaviour.