Open marcelo-alvarez opened 1 year ago
jax_cone.py
PCM = per cell memory | Line No. | PCM (in bytes) increment | PCM (in bytes) total | Total memory (MB) | Increment (MB) | Code |
---|---|---|---|---|---|---|
73 | - | - | 96 | 96 | skymap = np.zeros((npix,)) |
|
106 | 4 | 4 | 1728 | 1824 | grid_sx = read_displacement(sxfile) |
|
107 | 4 | 8 | 1728 | 3552 | grid_sy = read_displacement(syfile) |
|
108 | 4 | 12 | 1728 | 5280 | grid_sz = read_displacement(szfile) |
|
117 | 12 | 24 | 5184 | 10464 | grid_qx, grid_qy, grid_qz = lagrange_mesh(xaxis, yaxis, zaxis, translation, lattice_size_in_Mpc) |
|
121 | 4 | 28 | 1728 | 12192 | lagrange_grid = jax.vmap(comoving_q, in_axes=(0, 0, 0, None, None), out_axes=0)(grid_qx, grid_qy, grid_qz, translation, lattice_size_in_Mpc) |
|
124 | 4 | 32 | 1728 | 13920 | redshift_grid = jax.vmap(cosmo_wsp.comoving_distance2z)(lagrange_grid) |
|
128 | 8 | 40 | 3456 | 17376 | ipix_grid = jhp.vec2pix(nside, grid_qz, grid_qy, grid_qx) |
|
132 | 4 | 44 | 1728 | 19104 | kernel_sphere = jnp.where((lagrange_grid >= chimin) & (lagrange_grid <= chimax), jax.vmap(lensing_kernel_F)(lagrange_grid, redshift_grid), 0.) |
|
137 | -12 | 32 | -5184 | 13920 | del kernel_sphere, ipix_grid |
|
141 | 4 | 36 | 1728 | 15648 | jax.vmap(cosmo_wsp.growth_factor_D)(redshift_grid) |
|
146 | 4 | 40 | 1728 | 17376 | grid_Xx = jax.vmap(euclid_i, in_axes=(0, 0, 0, None, None), out_axes=0)(grid_qx, grid_sx, growth_grid, lattice_size_in_Mpc, translation[0]) |
|
147 | -4 | 36 | -1728 | 15648 | del grid_qx |
|
151 | 4 | 40 | 1728 | 17376 | grid_Xy = jax.vmap(euclid_i, in_axes=(0, 0, 0, None, None), out_axes=0)(grid_qy, grid_sy, growth_grid, lattice_size_in_Mpc, translation[1]) |
|
142 | -4 | 36 | -1728 | 15648 | del grid_qy |
|
156 | 4 | 40 | 1728 | 17376 | grid_Xz = jax.vmap(euclid_i, in_axes=(0, 0, 0, None, None), out_axes=0)(grid_qz, grid_sz, growth_grid, lattice_size_in_Mpc, translation[2]) |
|
157 | -8 | 32 | -3456 | 13920 | del grid_qz, growth_grid |
|
163 | 8 | 40 | 3456 | 17376 | ipix_grid = jhp.vec2pix(nside, grid_Xz, grid_Xy, grid_Xx) |
|
164 | -12 | 28 | -5184 | 12192 | del grid_Xx, grid_Xy, grid_Xz |
|
168 | 4 | 32 | 1728 | 13920 | kernel_sphere = jnp.where((lagrange_grid >= chimin) & (lagrange_grid <= chimax), jax.vmap(lensing_kernel_F)(lagrange_grid, redshift_grid), 0.) |
|
168 | -8 | 24 | -3456 | 10464 | del lagrange_grid, redshift_grid |
|
174 | -12 | 12 | -5184 | 5184 | del ipix_grid, kernel_sphere |
|
178 | -12 | 12 | -5184 | 96 | if store_displacements: del grid_sx, grid_sy, grid_sz |
memory profiler
Filename: jax_cone.py
Line # Mem usage Increment Occurrences Line Contents
=============================================================
41 280.0 MiB 280.0 MiB 1 @profile(stream=memfile)
42 def light_cone():
43 # Lattice spacing (a_latt in Websky parlance) in Mpc
44 280.0 MiB 0.0 MiB 1 lattice_size_in_Mpc = L_box / grid_nside # in Mpc; 7700 Mpc box length for websky 6144 cube
45
46 # comoving distance to last scattering in Mpc
47 280.0 MiB 0.0 MiB 1 comov_lastscatter = comov_lastscatter_Gpc * (cons.giga / cons.mega) # in Mpc
48
49 # minimum and maximum radii of projection
50 315.2 MiB 35.2 MiB 1 chimin = cosmo_wsp.comoving_distance(zmin)
51 315.2 MiB 0.0 MiB 1 chimax = cosmo_wsp.comoving_distance(zmax)
52
53 315.2 MiB 0.0 MiB 1 print("chimin, chimax: ",chimin,chimax)
54 # NSIDE of HEALPix map
55 315.2 MiB 0.0 MiB 1 nside = 1024
56 315.2 MiB 0.0 MiB 1 npix = hp.nside2npix(nside)
57 315.2 MiB 0.0 MiB 1 solidang_pix = 4*np.pi / npix
58
59 # Effectively \Delta chi, comoving distance interval spacing for LoS integral
60 315.2 MiB 0.0 MiB 1 geometric_factor = lattice_size_in_Mpc**3. / solidang_pix
61
62
63 # Setup axes for the slab grid
64 316.6 MiB 1.4 MiB 1 xaxis = jnp.arange(0, grid_nside, dtype=jnp.int16)
65 316.6 MiB 0.0 MiB 1 yaxis = jnp.arange(0, grid_nside, dtype=jnp.int16)
66 316.6 MiB 0.0 MiB 1 zaxis = jnp.arange(0, grid_nside, dtype=jnp.int16)
67
68 # Setup meshgrid for the slab
69 # grid_qx, grid_qy, grid_qz = jnp.meshgrid(xaxis, yaxis, zaxis, indexing='ij') # 6 : 6
70
71 # del xaxis, yaxis, zaxis
72
73 316.6 MiB 0.0 MiB 1 skymap = np.zeros((npix,))
74 316.6 MiB 0.0 MiB 1 shift_param = grid_nside
75 316.6 MiB 0.0 MiB 1 origin_shift = [(0,0,0)]#, (-shift_param,0,0), (0,-shift_param,0), (-shift_param,-shift_param,0),
76 #(0,0,-shift_param), (-shift_param,0,-shift_param), (0,-shift_param,-shift_param), (-shift_param,-shift_param,-shift_param)]
77 316.6 MiB 0.0 MiB 1 t2 = time() ; print("Initial setup took", t2-t1, "s ")
78
79 # Lagrangian comoving distance grid for the slab
80 5500.6 MiB 0.0 MiB 2 @partial(jax.jit, static_argnames=['trans_vec', 'Dgrid_in_Mpc'])
81 316.6 MiB 0.0 MiB 1 def lagrange_mesh(x_axis, y_axis, z_axis, trans_vec, Dgrid_in_Mpc):
82 5500.6 MiB 0.0 MiB 1 return jnp.meshgrid( jnp.float32((x_axis + 0.5 + trans_vec[0]) * Dgrid_in_Mpc), jnp.float32((y_axis + 0.5 + trans_vec[1]) * Dgrid_in_Mpc), jnp.float32((z_axis + 0.5 + trans_vec[2]) * Dgrid_in_Mpc), indexing='ij')
83
84 10686.5 MiB 0.0 MiB 2 @partial(jax.jit, static_argnames=['trans_vec', 'Dgrid_in_Mpc'])
85 316.6 MiB 0.0 MiB 1 def comoving_q(x_i, y_i, z_i, trans_vec, Dgrid_in_Mpc):
86 10686.5 MiB 0.0 MiB 1 return jnp.sqrt(x_i**2. + y_i**2. + z_i**2.).astype(jnp.float32)
87 # return jnp.sqrt((x_i + 0.5 + trans_vec[0])**2. + (y_i + 0.5 + trans_vec[1])**2. + (z_i + 0.5 + trans_vec[2])**2.) * Dgrid_in_Mpc
88
89 17423.0 MiB 0.0 MiB 2 @partial(jax.jit, static_argnames=['Dgrid_in_Mpc', 'trans'])
90 316.6 MiB 0.0 MiB 1 def euclid_i(q_i, s_i, growth_i, Dgrid_in_Mpc, trans):
91 17423.0 MiB 0.0 MiB 1 return (q_i + growth_i * s_i).astype(jnp.float32)
92 # return q_i * Dgrid_in_Mpc + growth_i * s_i + 0.5 + trans*Dgrid_in_Mpc
93
94 18906.5 MiB 1299.1 MiB 2 @jax.jit
95 316.6 MiB 0.0 MiB 1 def lensing_kernel_F(comov_q_i, redshift_i):
96 18906.5 MiB 0.0 MiB 1 return (geometric_factor * (3./2.) * cosmo_wsp.params['Omega_m'] * (cosmo_wsp.params['h'] * 100. * cons.kilo / cons.c )**2. * (1 + redshift_i) * (1. - (comov_q_i/comov_lastscatter)) / comov_q_i).astype(jnp.float32)
97
98 3772.6 MiB 0.0 MiB 4 def read_displacement(filename):
99 5500.6 MiB 5184.0 MiB 3 return jnp.asarray(np.fromfile(filename, count=grid_nside * grid_nside * grid_nside, dtype=jnp.float32).reshape((grid_nside, grid_nside, grid_nside)), dtype=jnp.float32)
100
101 316.6 MiB 0.0 MiB 1 t3 = time() ; print("Jit compilation took", t3-t2, "s ")
102
103
104 316.6 MiB 0.0 MiB 1 store_displacements=True
105 316.6 MiB 0.0 MiB 1 if store_displacements:
106 2044.6 MiB 0.0 MiB 1 grid_sx = read_displacement(sxfile) # 4 : 10
107 3772.6 MiB 0.0 MiB 1 grid_sy = read_displacement(syfile) # 4 : 14
108 5500.6 MiB 0.0 MiB 1 grid_sz = read_displacement(szfile) # 4 : 18
109
110 5500.6 MiB 0.0 MiB 1 t4 = time() ; print("I/O took", t4-t3, "s ")
111
112 7055.2 MiB 0.0 MiB 2 for translation in origin_shift:
113
114 5500.6 MiB 0.0 MiB 1 t4 = time()
115 5500.6 MiB 0.0 MiB 1 print(translation)
116
117 10686.4 MiB 5185.8 MiB 1 grid_qx, grid_qy, grid_qz = lagrange_mesh(xaxis, yaxis, zaxis, translation, lattice_size_in_Mpc)
118
119 10686.4 MiB 0.0 MiB 1 t5 = time() ; print("Largrangian meshgrid took", t5 - t4, "s ")
120
121 12417.5 MiB 1731.1 MiB 1 lagrange_grid = jax.vmap(comoving_q, in_axes=(0, 0, 0, None, None), out_axes=0)(grid_qx, grid_qy, grid_qz, translation, lattice_size_in_Mpc) # 4 : 22
122
123 12417.5 MiB 0.0 MiB 1 t6 = time() ; print("Lagrangian comoving distance grid took", t6 - t5, "s ")
124 14149.2 MiB 1731.6 MiB 1 redshift_grid = jax.vmap(cosmo_wsp.comoving_distance2z)(lagrange_grid) # 4 : 26
125
126 14149.2 MiB 0.0 MiB 1 t7 = time() ; print("Redshift took", t7-t6, "s ")
127 # Compute healpix pixel grid from Lagrangian x, y, z values
128 17607.3 MiB 3458.2 MiB 1 ipix_grid = jhp.vec2pix(nside, grid_qz, grid_qy, grid_qx) # 8 : 38
129
130 17607.3 MiB 0.0 MiB 1 t8 = time() ; print("HPX pixel grid (Lagrangian) took", t8-t7, "s ")
131
132 20634.8 MiB 1728.4 MiB 1 kernel_sphere = jnp.where((lagrange_grid >= chimin) & (lagrange_grid <= chimax), jax.vmap(lensing_kernel_F)(lagrange_grid, redshift_grid), 0.) # 4 : 42
133
134 20634.8 MiB 0.0 MiB 1 t9 = time() ; print("Kernel grid (Lagrangian) took", t9-t8, "s ")
135
136 20879.0 MiB 244.2 MiB 1 skymap += np.asarray(jnp.histogram(ipix_grid, bins=npix, range=(-0.5,npix-0.5), weights=-kernel_sphere, density=False)[0])
137 15695.0 MiB -5184.0 MiB 1 del kernel_sphere, ipix_grid # -20 : 22
138
139 15695.0 MiB 0.0 MiB 1 t10 = time() ; print("Project to healpix (Lagrangian) took", t10-t9, "s ")
140
141 17423.0 MiB 1728.0 MiB 1 growth_grid = jax.vmap(cosmo_wsp.growth_factor_D)(redshift_grid) # 4 : 30
142
143 17423.0 MiB 0.0 MiB 1 t11 = time() ; print("Growth took", t11-t10, "s ")
144
145 17423.0 MiB 0.0 MiB 1 if not store_displacements: grid_sx = read_displacement(sxfile)
146 19151.2 MiB 1728.2 MiB 1 grid_Xx = jax.vmap(euclid_i, in_axes=(0, 0, 0, None, None), out_axes=0)(grid_qx, grid_sx, growth_grid, lattice_size_in_Mpc, translation[0]) # 4 : 34
147 17423.2 MiB -1728.0 MiB 1 del grid_qx
148 17423.2 MiB 0.0 MiB 1 if not store_displacements: del grid_sx
149
150 17423.2 MiB 0.0 MiB 1 if not store_displacements: grid_sy = read_displacement(syfile)
151 19151.2 MiB 1728.0 MiB 1 grid_Xy = jax.vmap(euclid_i, in_axes=(0, 0, 0, None, None), out_axes=0)(grid_qy, grid_sy, growth_grid, lattice_size_in_Mpc, translation[1]) # 4 : 38
152 17423.2 MiB -1728.0 MiB 1 del grid_qy
153 17423.2 MiB 0.0 MiB 1 if not store_displacements: del grid_sy
154
155 17423.2 MiB 0.0 MiB 1 if not store_displacements: grid_sz = read_displacement(szfile)
156 19151.2 MiB 1728.0 MiB 1 grid_Xz = jax.vmap(euclid_i, in_axes=(0, 0, 0, None, None), out_axes=0)(grid_qz, grid_sz, growth_grid, lattice_size_in_Mpc, translation[2]) # 4 : 42
157 15695.2 MiB -3456.0 MiB 1 del grid_qz, growth_grid
158 15695.2 MiB 0.0 MiB 1 if not store_displacements: del grid_sz
159
160 15695.2 MiB 0.0 MiB 1 t12 = time() ; print("Displacements took", t12-t11, "s ")
161
162 # Compute healpix pixel grid from Euclidean x, y, z values
163 19151.2 MiB 3456.0 MiB 1 ipix_grid = jhp.vec2pix(nside, grid_Xz, grid_Xy, grid_Xx) # 8 : 50
164 13967.2 MiB -5184.0 MiB 1 del grid_Xx, grid_Xy, grid_Xz # -12 : 38
165
166 13967.2 MiB 0.0 MiB 1 t13 = time() ; print("HPX pixel grid (Eulerian) took", t13-t12, "s ")
167
168 15695.2 MiB 1728.0 MiB 1 kernel_sphere = jnp.where((lagrange_grid >= chimin) & (lagrange_grid <= chimax), jax.vmap(lensing_kernel_F)(lagrange_grid, redshift_grid), 0.) # 4 : 42
169 12239.2 MiB -3456.0 MiB 1 del lagrange_grid, redshift_grid
170
171 12239.2 MiB 0.0 MiB 1 t14 = time() ; print("Kernel grid (Eulerian) took", t14-t13, "s ")
172
173 12239.2 MiB 0.0 MiB 1 skymap += np.asarray(jnp.histogram(ipix_grid, bins=npix, range=(-0.5,npix-0.5), weights=kernel_sphere, density=False)[0])
174 7055.2 MiB -5184.0 MiB 1 del ipix_grid, kernel_sphere # -12 : 30
175
176 7055.2 MiB 0.0 MiB 1 t15 = time() ; print("Project to healpix (Eulerian) took", t15-t14, "s ")
177
178 7055.2 MiB 0.0 MiB 1 if store_displacements: del grid_sx, grid_sy, grid_sz
179
180 7055.2 MiB 0.0 MiB 1 return skymap
The memory profile of line-by-line allocated memory for jax_cone.py
matches expected allocated memory within 2 GB which comes from some function (jit compiled) overhead.
Runtime memory usage by JAX code:
Runtime memory usage by Numpy-Healpy code:
The added memory usage comes only from the JAX XLA compiler preallocation of memory. Only two JAX based functions seem to be a cause for this overhead: jax-healpix vec2pix
and jax-numpy histogram
. JAX preallocates 90% memory at the start of the code but does not deallocate the memory, instead reusing them as needed. This adds a significant memory overhead which is about twice the expected memory usage. We may consider XLA environment variables to change this behaviour.
We need to profile line-by-line memory usage of the light_cone codes (both MPI and JAX) and ensure that the memory usage matches to calculation. When running the MPI implementation for the full 6144 cube on Perlmutter, there was definite memory leak. We need to check for memory spikes during function calls and ensure the memory management is reasonable for future additions like 2LPT and multiple kernels.
(transferred from lptmap created by @1cosmologist: marcelo-alvarez/lptmap#5)