haoliuhl / ringattention

Transformers with Arbitrarily Large Context
Apache License 2.0
630 stars 50 forks source link

vmem OOM on TPU #11

Closed hxssgaa closed 4 months ago

hxssgaa commented 10 months ago

Hi,

I tried to run your script on Cloud TPU v4-64, but failed with following error:

jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space vmem. Used 59.79M of 16.00M vmem. Exceeded vmem capacity by 43.79M.

I tried mesh dim of 1,1,1,32 and 1,1,4,8 all failed.

Any suggestion what caused the error? Thanks.

erfanzar commented 10 months ago

use 1,-1,1,1 that's the best sharding case or write custom sharding methods and use FSDP on every layer that's easier

haoliuhl commented 4 months ago

Sorry for the delay. vmem OOM error is likely due to using large chunk size. I have tried v4-64 with chunk size 512, it worked well with fast computation and overlapped communication. Regarding sharding, you can refer to https://github.com/LargeWorldModel/LWM/blob/main/docs/sharding.md

Feel free to reopen this issue if there are questions.