Open hawkinsp opened 10 months ago
+1 to this issue. What does the devices=[1,64,4]<=[256] last_tile_dim_replicate
syntax even mean?
I also got a similar error message but I am not sure what it is about,
E0215 01:27:54.605426 2560490 spmd_partitioner.cc:589] [spmd] Involuntary full rematerialization. The compiler was not able to go from sharding {devices=[1,2,1]<=[2]} to {maximal device=0} without doing a full rematerialization of the tensor. You probably want to enrich the sharding annotations to prevent this from happening.
@hawkinsp any idea what's happening here 🤔
For me, I get this when using things like jax.debug.print
on sharded arrays.
@GallagherCommaJack wrote in https://github.com/google/jax/issues/18591:
I often get error messages that look like this:
it's really hard to debug this though, as the error tells me nothing about where in my code it's getting triggered.