Closed hxssgaa closed 4 months ago
use 1,-1,1,1 that's the best sharding case or write custom sharding methods and use FSDP on every layer that's easier
Sorry for the delay. vmem OOM error is likely due to using large chunk size. I have tried v4-64 with chunk size 512, it worked well with fast computation and overlapped communication. Regarding sharding, you can refer to https://github.com/LargeWorldModel/LWM/blob/main/docs/sharding.md
Feel free to reopen this issue if there are questions.
Hi,
I tried to run your script on Cloud TPU v4-64, but failed with following error:
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space vmem. Used 59.79M of 16.00M vmem. Exceeded vmem capacity by 43.79M.
I tried mesh dim of 1,1,1,32 and 1,1,4,8 all failed.
Any suggestion what caused the error? Thanks.