openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.59k stars 405 forks source link

Better error message for "Involuntary full rematerialization" #7148

Open hawkinsp opened 10 months ago

hawkinsp commented 10 months ago

@GallagherCommaJack wrote in https://github.com/google/jax/issues/18591:

I often get error messages that look like this:

[spmd] Involuntary full rematerialization. The compiler was not able to go from sharding {devices=[256,1]<=[256]} to {devices=[1,64,4]<=[256] last_tile_dim_replicate} without doing a full rematerialization of the tensor. You probably want to enrich the sharding annotations to prevent this from happening.

it's really hard to debug this though, as the error tells me nothing about where in my code it's getting triggered.

kvablack commented 9 months ago

+1 to this issue. What does the devices=[1,64,4]<=[256] last_tile_dim_replicate syntax even mean?

ShivamPR21 commented 7 months ago

I also got a similar error message but I am not sure what it is about,

E0215 01:27:54.605426 2560490 spmd_partitioner.cc:589] [spmd] Involuntary full rematerialization. The compiler was not able to go from sharding {devices=[1,2,1]<=[2]} to {maximal device=0} without doing a full rematerialization of the tensor. You probably want to enrich the sharding annotations to prevent this from happening.

@hawkinsp any idea what's happening here 🤔

Joshuaalbert commented 1 month ago

For me, I get this when using things like jax.debug.print on sharded arrays.