Open pwithams opened 2 months ago
I checked the HLO when using dls=jnp.ones(shape=(10000, 3))
but it does indeed look like some very large tensors are being generated by your program (1 x 10000 x 71 x 75 x 71 x3 ~= 40GB)
ENTRY main.152 {
constant.27 = f32[] constant(1)
broadcast.28 = f32[1,71]{1,0} broadcast(constant.27), dimensions={}
iota.29 = s32[71]{0} iota(), iota_dimension=0
...
constant.15 = f32[3,3]{1,0} constant({ { 1, 0, 0 }, { 0, 1, 0 }, { 0, 0, -1 } })
dot.95 = f32[1,71,75,71,3]{4,3,2,1,0} dot(scatter.89, constant.15), lhs_contracting_dims={4}, rhs_contracting_dims={0}
reshape.96 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} reshape(dot.95)
broadcast.97 = f32[1,1,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.96), dimensions={0,1,2,3,4,5}
reshape.98 = f32[1,71,75,71,3]{4,3,2,1,0} reshape(broadcast.97)
broadcast.99 = f32[1,10000,71,75,71,3]{5,4,3,2,1,0} broadcast(reshape.98), dimensions={0,2,3,4,5}
subtract.100 = f32[1,10000,71,75,71,3]{5,4,3,2,1,0} subtract(broadcast.99, broadcast.17)
multiply.132 = f32[1,10000,71,75,71,3]{5,4,3,2,1,0} multiply(subtract.100, subtract.100)
divide.133 = f32[1,10000,71,75,71,3]{5,4,3,2,1,0} divide(multiply.132, broadcast.4)
reduce.138 = f32[1,10000,71,75,71]{4,3,2,1,0} reduce(divide.133, constant.25), dimensions={5}, to_apply=region_3.134
multiply.139 = f32[1,10000,71,75,71]{4,3,2,1,0} multiply(reduce.138, broadcast.2)
exponential.140 = f32[1,10000,71,75,71]{4,3,2,1,0} exponential(multiply.139)
add.141 = f32[1,10000,71,75,71]{4,3,2,1,0} add(exponential.131, exponential.140)
multiply.146 = f32[1,10000,71,75,71]{4,3,2,1,0} multiply(broadcast.145, add.141)
ROOT reduce.151 = f32[10000,71]{1,0} reduce(multiply.146, constant.25), dimensions={0,2,3}, to_apply=region_4.147
}
After commenting out the two lines containing exp these large tensors are not materialized:
...
constant.12 = f32[] constant(1)
reduce.27 = f32[1,71]{1,0} reduce(broadcast.6, constant.12), dimensions={2}, to_apply=region_0.23
constant.1 = f32[] constant(15.7496099)
broadcast.2 = f32[1,71]{1,0} broadcast(constant.1), dimensions={}
multiply.28 = f32[1,71]{1,0} multiply(reduce.27, broadcast.2)
reshape.29 = f32[1,1,71]{2,1,0} reshape(multiply.28)
broadcast.34 = f32[1,1,71]{2,1,0} broadcast(reshape.29), dimensions={0,1,2}
reshape.35 = f32[1,71]{1,0} reshape(broadcast.34)
broadcast.36 = f32[1,71,71]{2,1,0} broadcast(reshape.35), dimensions={0,2}
divide.37 = f32[1,71,71]{2,1,0} divide(broadcast.33, broadcast.36)
broadcast.38 = f32[1,10000,71,75,71]{4,3,2,1,0} broadcast(divide.37), dimensions={0,2,4}
constant.11 = f32[] constant(0)
ROOT reduce.43 = f32[10000,71]{1,0} reduce(broadcast.38, constant.11), dimensions={0,2,3}, to_apply=region_1.39
}
I'm not sure why thus code runs on Jax <0.4.14... it's possible there's some optimizations being done differently. You can inspect the compiled code yourself using:
run.lower().compiler_ir(dialect='hlo').as_hlo_text()
(for >=0.4.30)
jax.xla_computation(run)().as_hlo_text()
(for <0.4.30)
Thanks for the response. I'm starting to think it is some change in openxla or lower that is responsible rather than jax itself. A few questions:
func.lower().as_text()
and run.lower().compiler_ir(dialect='hlo').as_hlo_text()
?Does this seem like a bug or just an old edge case not working anymore do you think? When using dls=jnp.ones(shape=(1590, 3))
the program ran successfully and pprof reported ~500kB of memory usage, but increasing to dls=jnp.ones(shape=(1600, 3))
fails trying to allocate ~5GB, which seems like strange behavior.
Description
Overview
The script below works when using an NVIDIA GPU with Jax version 0.4.14, but after upgrading to 0.4.31 (and trying a few other versions in between) it is triggering the following error:
E0910 20:24:00.097739 38257 pjrt_stream_executor_client.cc:3067] Execution of replica 0 failed: RESOURCE_EXHAUSTED: Out of memory while trying to allocate X bytes
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate X bytes.
where the value of
X
ranges from ~5GB (e.g. 4843897104) to 20GB+ depending on the shape of thedls
variable (set to 3540 in the script below).jax<=0.4.14 - no error jax>0.4.14 - error
Not sure if this is a bug or if there is some code/syntax in the example below that is no longer supported in versions > 0.4.14 that is responsible for this behavior.
Allocation vs. pprof usage
The GPU has 6GB of memory and after some trial and error it appears that setting the
dls
variable to a shape of 1590 succeeds and uses only ~500kB of memory according to pprof (following https://jax.readthedocs.io/en/latest/device_memory_profiling.html), but a shape of 1600 gives an error trying to allocate ~5GB. If pprof is in fact showing GPU memory usage this could suggest memory is being allocated but not used.jnp.exp removal
Trial and error also showed that removing the
jnp.exp
calls inside the functionm
seem to resolve the issue. For example, the script below withdls
shape set to 10000 fails trying to allocate 30GB, but removing thejnp.exp
calls succeeds and shows as using only ~2MB by pprof.Script
System info (python version, jaxlib version, accelerator, etc.)
Pip versions:
Output of
jax.print_environment_info()
, it is running inside a container based onnvidia/cuda:12.3.2-base-ubuntu22.04
:Pip versions of latest version that does not show the error (v0.4.14):