google / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
29.09k stars 2.66k forks source link

An uninformative error on TPU #20812

Open sash-a opened 2 months ago

sash-a commented 2 months ago

Description

Hi there, getting a strange and uninformative error when running on TPU that I don't get when running locally, was hoping someone here could explain the error :thinking:

jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: layout minor_to_major field contains 0 elements, but shape is rank 4: {}; shape: element_type:
PRED dimensions: 10 dimensions: 64 dimensions: 5 dimensions: 5 layout { minor_to_major: 3 minor_to_major: 2 minor_to_major: 0 minor_to_major: 1 tail_padding_alignment_in_elements: 1 } is_dynamic_dimension: false is_dynamic_dimension: false is_dynamic_dimension: false is_dynamic_dimension: false

Any help with what this error could mean would be appreciated. I've restarted the TPU and tried different JAX versions, but all seem to give the same error

System info (python version, jaxlib version, accelerator, etc.)

TPU v4-8 python 3.10 JAX 0.4.24, 0.4.25 and 0.4.26

superbobry commented 2 months ago

Is there a way for us to reproduce this error?

sash-a commented 2 months ago

Unfortunately I don't have a minimal example, I've only encountered it in a large code base. Any idea what the error could be pointing me to? The traceback just ends where I call into a pmapped function