jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.4k stars 2.79k forks source link

bfloat16/float32 memory requirements seem off #3302

Open btanner opened 4 years ago

btanner commented 4 years ago

I was getting out of memory errors in my project in a place where some pmapped(vmapped(vmapped(code)) was doing a conversion of a large array from bfloat16 -> float32. Code below is for a different specific error message, but the weirdness here feels like it might be related:

Running in Colab on a fresh TPUv2 runtime: Setup:

import jax.numpy as jnp
import numpy as np

Ok Case

# Ok
np_points = np.ones((2**31,), np.float64)
jax_32_points = jnp.asarray(np_points, jnp.float32)

Allocation Error when using np.float32

Restart runtime

# Not ok, can't allocate the jax array. This seems weird, since it works with float64.
np_points = np.ones((2**31,), np.float32)
jax_32_points = jnp.asarray(np_points, jnp.float32)
google3/third_party/py/jax/numpy/lax_numpy.py in asarray(a, dtype, order)
   2143 def asarray(a, dtype=None, order=None):
   2144   lax._check_user_dtype_supported(dtype, "asarray")
-> 2145   return array(a, dtype=dtype, copy=False, order=order)
   2146 
   2147 

google3/third_party/py/jax/numpy/lax_numpy.py in array(object, dtype, copy, order, ndmin)
   2113       out = lax.convert_element_type(object, dtype)
   2114     else:
-> 2115       out = device_put(object)
   2116   elif isscalar(object):
   2117     out = lax.reshape(object, ())

google3/third_party/py/jax/api.py in device_put(x, device)
   1601     A copy of ``x`` that resides on ``device``.
   1602   """
-> 1603   return tree_map(lambda y: xla.device_put_p.bind(y, device=device), x)
   1604 
   1605 

google3/third_party/py/jax/tree_util.py in tree_map(f, tree)
    159   """
    160   leaves, treedef = pytree.flatten(tree)
--> 161   return treedef.unflatten(map(f, leaves))
    162 
    163 def tree_multimap(f, tree, *rest):

google3/third_party/py/jax/api.py in <lambda>(y)
   1601     A copy of ``x`` that resides on ``device``.
   1602   """
-> 1603   return tree_map(lambda y: xla.device_put_p.bind(y, device=device), x)
   1604 
   1605 

google3/third_party/py/jax/core.py in bind(self, *args, **kwargs)
    209     top_trace = find_top_trace(args)
    210     if top_trace is None:
--> 211       return self.impl(*args, **kwargs)
    212 
    213     tracers = map(top_trace.full_raise, args)

google3/third_party/py/jax/interpreters/xla.py in _device_put_impl(x, device)
   1075         f"Argument '{x}' of type {type(x)} is not a valid JAX type") from err
   1076   handler = aval_to_result_handler(device, a)  # type: ignore[arg-type]
-> 1077   return handler(device_put(x, device))
   1078 
   1079 device_put_p = core.Primitive('device_put')

google3/third_party/py/jax/interpreters/xla.py in device_put(x, device)
    112   x = canonicalize_dtype(x)
    113   try:
--> 114     return device_put_handlers[type(x)](x, device)
    115   except KeyError as err:
    116     raise TypeError(f"No device_put handler for type: {type(x)}") from err

google3/third_party/py/jax/interpreters/xla.py in _device_put_array(x, device)
    118 def _device_put_array(x, device: Optional[Device]):
    119   backend = xb.get_device_backend(device)
--> 120   return backend.buffer_from_pyval(x, device)
    121 
    122 def _device_put_scalar(x, device):

RuntimeError: Resource exhausted: Failed to allocate request for 8.00GiB (8589934592B) on device ordinal 0

Crash/Dump when using bfloat16

Restart runtime

# Not ok, crashes with same specs whether np.float64 or np.float32.
np_points = np.ones((2**31,), np.float32)
jax_b16_points = jnp.asarray(np_points, jnp.bfloat16)
----> 4 jax_b16_points = jnp.asarray(np_points, jnp.bfloat16)

google3/third_party/py/jax/numpy/lax_numpy.py in asarray(a, dtype, order)
   2143 def asarray(a, dtype=None, order=None):
   2144   lax._check_user_dtype_supported(dtype, "asarray")
-> 2145   return array(a, dtype=dtype, copy=False, order=order)
   2146 
   2147 

google3/third_party/py/jax/numpy/lax_numpy.py in array(object, dtype, copy, order, ndmin)
   2111   if isinstance(object, ndarray):
   2112     if dtype and _dtype(object) != dtypes.canonicalize_dtype(dtype):
-> 2113       out = lax.convert_element_type(object, dtype)
   2114     else:
   2115       out = device_put(object)

google3/third_party/py/jax/lax/lax.py in convert_element_type(operand, new_dtype)
    383     warnings.warn(msg, onp.ComplexWarning, stacklevel=2)
    384   return convert_element_type_p.bind(
--> 385       operand, new_dtype=new_dtype, old_dtype=old_dtype)
    386 
    387 def bitcast_convert_type(operand: Array, new_dtype: DType) -> Array:

google3/third_party/py/jax/core.py in bind(self, *args, **kwargs)
    209     top_trace = find_top_trace(args)
    210     if top_trace is None:
--> 211       return self.impl(*args, **kwargs)
    212 
    213     tracers = map(top_trace.full_raise, args)

google3/third_party/py/jax/interpreters/xla.py in apply_primitive(prim, *args, **params)
    215 def apply_primitive(prim, *args, **params):
    216   """Impl rule that compiles and runs a single primitive 'prim' using XLA."""
--> 217   compiled_fun = xla_primitive_callable(prim, *map(arg_spec, args), **params)
    218   return compiled_fun(*args)
    219 

google3/third_party/py/jax/interpreters/xla.py in xla_primitive_callable(prim, *arg_specs, **params)
    252       device_assignment=device and (device.id,))
    253   options.parameter_is_tupled_arguments = tuple_args
--> 254   compiled = backend.compile(built_c, compile_options=options)
    255   if nreps == 1:
    256     return partial(_execute_compiled_primitive, prim, compiled, handle_result)

RuntimeError: Resource exhausted: Ran out of memory in memory space hbm. Used 8.00G of 7.48G hbm. Exceeded hbm capacity by 530.01M.

Total hbm usage >= 8.52G:
    reserved        530.00M 
    program            8.0K 
    arguments         8.00G (100.0% utilization)

Output size 4.00G (100.0% utilization); shares 0B with arguments.

Program hbm requirement 8.0K:
    reserved           4.0K
    global             4.0K

  Largest program allocations in hbm:

  1. Size: 4.0K
     XLA label: profiler
     Allocation type: reserved
     ==========================

  2. Size: 4.0K
     Shape: u32[8,128]{1,0}
     Unpadded size: 4.0K
     XLA label: constant literal
     Allocation type: global
     ==========================
mattjj commented 4 years ago

Thanks for the question!

Would you be able to expand that 5 frames from the stack trace, and provide the full trace? Seeing deeper stacks can be helpful in figuring out what's going on!

btanner commented 4 years ago

Sure thing! Updated.

BlueskyFR commented 2 years ago

@mattjj any thoughts on this?