Time For configure functions and sharding them (ms) : 2012.7429962158203
Action : Sharding Passed Parameters
Model Contain 1.100048384 Billion Parameters
0%| | 0/12000 [00:00<?, ?it/s]jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/***/research/EasyDeL/1.py", line 101, in <module>
output = trainer.train(flax.core.FrozenDict({"params": params}))
File "/home/***/research/EasyDeL/.venv/lib/python3.10/site-packages/EasyDel/trainer/causal_language_model_trainer.py", line 708, in train
sharded_state, loss, accuracy = self.sharded_train_step_fn(
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 21.90G of 7.48G hbm. Exceeded hbm capacity by 14.42G.
Where you are running these codes?
They provide you with tpuv2 and you are getting out of memory errors
You can use kaggle for free TPUv3 and that's much more powerful
To Reproduce Use a TPU v2_8 with vm architecture