kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.26k stars 890 forks source link

Running on Colab TPU only gives random words and nonsensical outputs #216

Closed JohnnyRacer closed 2 years ago

JohnnyRacer commented 2 years ago

Cannot get the model to generate anything like the demo showed. The model would only generative non cohesive gibberish. There were no errors presented at runtime which I thought was very puzzling since the text output was so bad. The notebook was modified slightly to allow for easier installation of the dependencies. Could this be some strange OOM error as Colab usually gives me TPU-v2 which I've heard has insufficient RAM to run this. Any help would be much appreciated.

Link to the notebook with outputs:

https://colab.research.google.com/drive/1GFF56OtZlSoRFaZF35BBwoNgCfzjE-4m#scrollTo=nvlAK6RbCJYg

vfbd commented 2 years ago

You commented out the read_ckpt_lowmem line that actually loads the model checkpoint into memory, so it's using a randomly initialized model that will generate gibberish.

JohnnyRacer commented 2 years ago

Thanks for the quick reply. This was exactly the issue.