kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.29k stars 892 forks source link

read_ckpt getting killed (OOM?) #181

Closed gaycomputers closed 2 years ago

gaycomputers commented 2 years ago

I suspect that my attempt to run with the params:

params = {
    "layers": 28,
    "d_model": 4096,
    "n_heads": 16,
    "n_vocab": 50400,
    "norm": "layernorm",
    "pe": "rotary",
    "pe_rotary_dims": 64,
    #"early_cast": True,
    "seq": 2048,
    "cores_per_replica": 1,
    "per_replica_batch": 1,
}

is getting killed after returning

read from disk/gcs in 86.9368s [[0.00210571 7.82013e-05 0.0025177 -0.00157166 -0.000255585 -0.000128746 0.00170898 0.000278473 0.00128174 0.00185394 0.000953674 0.00100708 0.000333786 0.00145721 4.24385e-05 -5.36442e-05]] [[-0.0585938 -0.0427246 -0.0134277 -0.0698242 -0.0446777 -0.0654297 -0.0334473 -0.0620117 -0.0456543 -0.0415039 -0.0458984 -0.0393066 -0.0246582 -0.0449219 -0.0554199 -0.041748]] Killed

because it is running out of memory, I am trying to use CPU, I only have 32 gb of ram on this machine, is it possible to run?

I have another machine with 128 if absolutely needed.

gaycomputers commented 2 years ago

If I drop the layers a bunch (14) I get a new error:

  1.01562 0.996094 1 0.984375 0.980469 1.00781 1.01562 0.996094]]
Traceback (most recent call last):
  File "main.py", line 116, in <module>
    network.state = read_ckpt(
  File "/usr/local/lib/python3.8/dist-packages/mesh_transformer/checkpoint.py", line 164, in read_ckpt
    unsharded = _unshard(shards, old_flattened)
  File "/usr/local/lib/python3.8/dist-packages/mesh_transformer/checkpoint.py", line 158, in _unshard
    x = reshard(x, old.shape)
  File "/usr/local/lib/python3.8/dist-packages/mesh_transformer/checkpoint.py", line 120, in reshard
    if x.shape[0] * x.shape[2] == old_shape[2]:
IndexError: tuple index out of range
gaycomputers commented 2 years ago

My guess is that n_layers on gpt-j is 28 so obviously unsharding wouldn't work, however, is there a way to get gpt-j to work in 32gbs?

gaycomputers commented 2 years ago

I ended up switching computers and sure enough it peaked around ~40 gb in setup then settled to <~15

Not sure theres anyway around this atm, would appreciate if someone who better understood the mesh transformer could clarify for others

kingoflolz commented 2 years ago

Memory efficiency or running on cpu is not a goal for this library, you can look into the huggingface implementation which might be more efficient.