sbrunk / storch

GPU accelerated deep learning and numeric computing for Scala 3.
https://storch.dev
Apache License 2.0
116 stars 7 forks source link

Implement "pico" GPT example #51

Open hmf opened 1 year ago

hmf commented 1 year ago

Response to request in issue https://github.com/sbrunk/storch/issues/44.

Attempt to rewrite the "pico" example from Karpathy's "Let's build GPT: from scratch, in code, spelled out" in storch.

sbrunk commented 10 months ago

@hmf I'm trying to reproduce your results but it's diverging much faster in my case. I get down to a train loss of 2.3 but then it starts to go up again. At some point, the losses even go NaN.

[info] step 8000: train loss 2.3154821, val loss 2.3707209, mem 942,0 MiB @ 00 00:08:13.783, mean 00 00:00:00.061
...
[info] step 40500: train loss NaN, val loss NaN, mem 944,2 MiB @ 00 00:41:43.711, mean 00 00:00:00.061

Here's how I ran it: Took the branch, disabled weight init, ran with a learning rate of 1e-4.

// modules.foreach(init_weights)
// ...

train(model, 1e-4, 67000)

Any idea what could be different?

hmf commented 10 months ago

@sbrunk I am trying to rerun the test to confirm all is Ok. Unfortunately I have had to recreate the dev container. I have also merged your latest changes from main. I am now running with initialization off.

As soon as I get some results I will report back.

Any idea what could be different?

At this time, it only occurs to me that the OS libraries (including NVIDIA's stuff) my be different. I am assuming your are using Linux Ubuntu. Below is a list of the setup.

storch_os.txt

We could also try setting up a fixed random number seed for replication.

hmf commented 10 months ago

@sbrunk I have rerun with Pytorch 2.1.2. and get:

[info] step 41500: train loss 1.6008276, val loss 1.7987113, @ 00 01:25:21.304, mean 00 00:00:00.123

Here is the full output: tmp.txt

I have added:

 torch.manualSeed(1337)

and am running this again. We can then test and compare with this seed.

I can't reproduce the GPU issue at the moment, but I could only try on an RTX 4090 so far, which is ADA architecture, while 3090 is Ampere.

I have noticed that your compute time per iteration is about half of mine (0.123 vs 0.061). Nice 8-)

EDIT: this run resulted in an abrupt divergence with a resulting NAN. Can you check that you get the same output? Here is the output I get:

tmp.txt

sbrunk commented 10 months ago

@hmf I did a run on d8d75b7b773972191dd0a48e374f9531bd2407d2 and it's much closer to your result than before: train-gpt.v2.txt.

There are still numeric differences but I guess that could be due to slightly different hardware/driver.

hmf commented 10 months ago

@sbrunk I am somewhat skeptical of the results. I noticed that the first loss is indeed the same value (baring precision errors). For that reason I would expect the next values to be the same (the data is the same and should be loaded in the same order). This also does not bode well for unit tests - something that can be added that can be used to confirm your hypothesis.

Having said this, I am not satisfied with the results. I am considering a ViT implementation that has SOTA performance that can be compared. This case is easier to test than NLP because pre processing should be simpler. Just a thought.

sbrunk commented 10 months ago

@sbrunk I am somewhat skeptical of the results. I noticed that the first loss is indeed the same value (baring precision errors). For that reason I would expect the next values to be the same (the data is the same and should be loaded in the same order). This also does not bode well for unit tests - something that can be added that can be used to confirm your hypothesis.

That's true. It seems like I get reproducible results on the same hardware/setup though when I run it twice. I'll have access to yet another GPU type next week and I'll try it there too for comparison.

Having said this, I am not satisfied with the results. I am considering a ViT implementation that has SOTA performance that can be compared. This case is easier to test than NLP because pre processing should be simpler. Just a thought.

That's a great idea. If you'll give it a try, let me know if I can help you in any way.