tysam-code / hlb-CIFAR10

Train to 94% on CIFAR-10 in <6.3 seconds on a single A100. Or ~95.79% in ~110 seconds (or less!)
https://twitter.com/hi_tysam
Apache License 2.0
1.23k stars 75 forks source link

Out of memory with 5GB VRAM #2

Open 99991 opened 1 year ago

99991 commented 1 year ago
--------------------------------------------------------------------------------------------------------
|  epoch  |  train_loss  |  val_loss  |  train_acc  |  val_acc  |  ema_val_acc  |  total_time_seconds  |
--------------------------------------------------------------------------------------------------------
Traceback (most recent call last):
  File "main.py", line 621, in <module>
    main()
  File "main.py", line 540, in main
    for epoch_step, (inputs, targets) in enumerate(get_batches(data, key='train', batchsize=batchsize)):
  File "main.py", line 428, in get_batches
    images = batch_crop(data_dict[key]['images'], 32) # TODO: hardcoded image size for now?
  File "main.py", line 390, in batch_crop
    cropped_batch = torch.masked_select(inputs, crop_mask_batch).view(inputs.shape[0], inputs.shape[1], crop_size, crop_size)
torch.cuda.OutOfMemoryError: CUDA out of memory. Tried to allocate 4.58 GiB (GPU 0; 5.81 GiB total capacity; 835.23 MiB already allocated; 2.35 GiB free; 1.24 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

I have not looked at the code too closely, but it might be possible to shave off a few MB when preparing batches.

Thank you for this comment by the way.

https://github.com/tysam-code/hlb-CIFAR10/blob/132829f191c00e71a178c3c995ab6c0302ec66e5/main.py#L523

I totally forgot to add torch.cuda.synchronize(), but it is finally fixed https://github.com/99991/cifar10-fast-simple Fortunately, it did not make much of a difference. I now get 14.3 seconds with my code vs 15.7 seconds with your code. Perhaps there is something during batch preparation which makes a difference?

tysam-code commented 1 year ago

Hi hi hiya Thomas! :D <3 <3 <3 <3 :D 🐧 🐧

Thanks so much for reaching out. I fan-reacted a bit when I saw that, so cool this little project got on your radar (and I'm a bit curious how, too, the cross section of word of mouth that I've done is pretty small so far! At least, so I think!). I really appreciated your project and used it as a reference as I was building my own code, especially at the beginning.

Thanks for catching the model OOM problem, I'd missed that (A100s due to using Colab Pro, and I'd sort of left that check by the wayside). Having this model be performant on small GPUs is of course really important. I will try to take a look at that, as this after all is trying to make things as accessible as possible. I think there is some room for memory optimization too, probably as you noted in the data prep section, I was pretty darn fast and loose there. I do store some higher precision gradients at the end but that's pretty small so I'm not too worried about that.

As far as the timing, now that was a serious rabbit hole and I am more confused on the other side out than going in. On colab, even on the terminal, I'm getting ~18.1 seconds on my code and ~17.2 seconds on yours. No t bad for the flexibility on this end, but clearly there's room for improvement, .9 seconds is a lot. However, this is on an A100. And that's what's confusing to me, as you seem to be getting much faster results both ways. I think with your code you were reporting faster results on an A100. This is with the default Colab PyTorch installs and the SXM4 model. Something is resulting in a drastically faster speed on your end, so maybe it's something like the pytorch/cuda versions installed. I've attached a pipdeptree dump just in case for public reference: colab_deps.txt

As for the relative speed difference, I went ahead and pulled out the pytorch profiler for both code files. The code is pretty simple and I'm attaching the version that I hacked into your train.py file for reference (in case you ever want to use it -- it's quirky but pretty darn good for what it does):
train_fast_simple_cifar10_with_profiling.py.txt

I couldn't find anything upon immediate inspection (I probably would with a more detailed look, it's just a bit of a slog), but I did notice something interesting when looking at the dumps:

Screen Shot 2023-01-11 at 8 32 46 PM

So if you look, the "repeat" and batchnorm operations take up almost the same amount of time as two full-on conv2d operations! I'm assuming the gaps are just the time it takes for the Python interpreter, this is not scripted after all. In any case -- this is quite a lot, and it looks like those repeats -- from the GhostBatchNorm (i.e. GhostNorm) -- are pretty big impediments to training speed. Here's something worse to add to it all:

Screen Shot 2023-01-08 at 6 57 07 PM

So this is not good at all! Not only are we not using our tensor cores (though in a different screen -- we do have a great GPU usage percentile!), but we're spending most of our compute time (compared to other individual operations) switching just between nchw and nwch. This is major no bueno!

So now we're on a bit of a rabbit trail for this thread (which still definitely needs to be addressed but may take a while), but take a look at this. If we use pytorch's experimental 'channels last' data format:

Screen Shot 2023-01-11 at 8 38 38 PM Screen Shot 2023-01-11 at 8 38 53 PM

We should be able to go much faster. Unfortunately, the GhostNorm requires a '.contiguous()' call due to this which sort of nerfs all of the potential speed gains one would get. So, based upon the earlier 'repeat' calls probably being a large waste of compute for their gains (hypothetically), if we just remove that call during training, and add a bit of noise to try to compensate for the loss of special regularization that only GhostNorm can seem to provide...:

Screen Shot 2023-01-11 at 8 40 46 PM

To:

Screen Shot 2023-01-11 at 8 44 55 PM

Wow. A huge accuracy drop, but also the speed difference is pretty darn massive. Without any noise we get even faster (if practically even more nigh-unusable from a 94% accuracy standpoint):

Screen Shot 2023-01-08 at 8 07 22 PM

From 19.1 seconds (on this setup) to 12.78 is a ~33% reduction in training time which is pretty darn massive. I think of course we have to re-fill the gap by trying to replicate a similar kind of regularization, which honestly should be feasible to simulate with the sample mean and sample variance rules given some running statistics of the batches over time (something a conveniently-fused kernel should already provide). Of course, working that back up is the hard part but I think we could probably leverage that to great effect, and on your end in your code and hardware it might actually be possible to break the 10 seconds mark (!!!!) were we able to get near ~19.2 -> ~12.78 seconds on this particular version of the solution. Depending upon your interest in diving/delving deep into cracking this nut vs maintaining the software.

In any case, massive rabbit trail but it came up in looking at the more minor (but just as important if we're going to get to ~<2 seconds training time in ~2 years or so) speed gains that you were talking about. Thank you so much for reporting that, it's much appreciated and hopefully I'm able to mirror back something useful to you in turn with some of the stuff here. Happy to help with any/all of this, this is one part passion project and one part living resume for the meantime. I'm happy to help with any tangents (pytorch profiler, etc) that you might have.

Oh, and I'm kicking the tires on getting an hlb-Pile (or some such similar LLM training dealio) up and going at some point. Right now the rough goal is whatever the best baseline can train to in ~30 minutes or so, and then set that as the target val loss and speed up from there. If you want to dork around on that together, let me know!

Again, thanks for reaching out and I know this was a, er, rather exuberant and long response! I really appreciate the feedback again -- means more than I could articulate. I'll keep taking a look at that speed stuff to see what we can come up with, and thanks for the lead that led to these (potential) first massive speed gains! :D

tysam-code commented 1 year ago

Howly Freaking Carp I figured it out.

Batchnorm momentum = 0.4, no noise at all (!!!), every_n_steps 5 -> 1, ema_epochs 2 -> 4.

Screen Shot 2023-01-11 at 9 03 55 PM

I think this is because the batch-to-batch momentum being loosened creates noise just like GhostNorm did, but when being distilled over to the EMA, which is just what we use for validation anyways (!!!!). At least, that was the half-formed intuition that I was following when also just twiddling knobs semi-randomly on instinct.

Then also we get to use the nchw/nhwc-friendly kernels since we're not having to do any weirdness between the convolution and the batchnorm. Sometimes this opens the door to fused kernels too, a huge jump in speed (!!!!! :O).

As you can see, I am very excited right now. If you (or anyone reading this) wants to reproduce these results, then the code used to achieve this attached (doctored to be a .txt due to github weirdness....).

cifar_10_new_single_gpu_world_record_12.91_seconds.py.txt (edit: unstable/inconsistent, see next comment/post)

I'm in shock that this worked. How did this work? Why does this work? It's never this easy. What. The. Freaking. Heck.

Now I definitely owe you a non-water beverage of your choice. Thanks again. Dang. What the heck. What. Thanks again. What.

What. How. Why. I need to just send this comment or else I'm going to keep expressing my confusion into it. ~33%. Dear freaking goodness.

tysam-code commented 1 year ago

A more stable version that satisfies the 94% mark consistently (still more tuning to be done with genetic search, but we're clearing the bar now. New single GPU world record now for sure! :D <3)

A number of different params twiddled around here to get it to work more consistently, but it works! :D <3 <3 <3 <3 :D 🐧🐧

runs: 25 mean: 0.940364 standard deviation: 0.00111762

Source: cifar_10_new_single_gpu_world_record_12.91_seconds_more_stable.py.txt

99991 commented 1 year ago

Thanks so much for reaching out. I fan-reacted a bit when I saw that

No need, I am just some guy 😄

so cool this little project got on your radar (and I'm a bit curious how, too, the cross section of word of mouth that I've done is pretty small so far! At least, so I think!)

Your project has more stars than mine, so clearly it is more popular! I think I found it on /r/machinelearning if I remember correctly.

maybe it's something like the pytorch/cuda versions

That is quite possible. Currently, our server uses CUDA 11.2 and torch==1.11.0+cu113, which was the only version that did not segfault when the server was installed. Unfortunately, I do not have root access, so I can not test other versions.

Source: cifar_10_new_single_gpu_world_record_12.91_seconds_more_stable.py.txt

I get an accuracy over 94 % for 14 of 25 runs with an average runtime of 13.28 seconds! Amazing job! :tada: :penguin:

Chillee commented 1 year ago

So if you look, the "repeat" and batchnorm operations take up almost the same amount of time as two full-on conv2d operations! I'm assuming the gaps are just the time it takes for the Python interpreter, this is not scripted after all. In any case -- this is quite a lot, and it looks like those repeats -- from the GhostBatchNorm (i.e. GhostNorm) -- are pretty big impediments to training speed. Here's something worse to add to it all:

When looking at the profiler, you need to be looking at the "GPU" stream and not the "CPU" stream. GPU operations are asynchronous, so how long the CPU components are spending is irrelevant as long as the GPU is "running ahead" of the CPU.

tysam-code commented 1 year ago

Oh, gotcha, many thanks for sharing this! I really appreciate it. :) :D And many thanks for your interest and in checking the repo out! It's still surreal and I'm just on autopilot mode at this point (and working on more speed refinements just to 'unwind' a bit!).

Also...Oh no! you caught my embarrassing mistake. I was hoping nobody would notice, I realized that I think maybe only a week or two ago. And I have used that profiler off and on to much success for a year or two. Quite embarrassing indeed, but I would much rather you highlight it. I always have lots of room to grow, and really appreciate it. :D

tysam-code commented 1 year ago

A brief update on this issue, I made some progress earlier in the last update but not enough I think to post something alone here -- I think the main memory issue now is just in that index_select in the dataloaders. The whitening operation should be gucci now that we've chunked the eigenvalue/eigenvector caclulation process and just average them at the end. It seems to work well enough, at least.

But that index_select seems to be a bear when I lock at...I think 5 GB and 6 GB of memory using the 'set_memory_fraction' command for debugging purposes. So not quite enough to close this yet, and I don't have any good clean solutions on the mind as of yet... ;'(((((

bonlime commented 1 year ago

hey! found your project and your results are beyond imagination! Haven't thought we can be so fast on cifar already

looking at your profiling above it seems you may benefit from jit.scripting the model or when the torch 2.0 becomes stable enough from the torch.compile, to avoid recomputing the graph everytime. or maybe you have already experimented with it?

tysam-code commented 1 year ago

Hello @bonlime! Great to have feedback, thank you for the compliment. I watch a lot of the cloning numbers and such with curiosity wondering what people are doing with it and/or think about it (and what the hang-ups are, lol). So thank you so much! It's pretty crazy, and I think we got some more juice to squeeze out of the performance lemon of this network! :D

As far as scripting goes -- yes, in this most recent update you can actually JIT without major issue so long as you turn the EMA off, which is okay for short runs as there's actually a bug in how the Lookahead optimizer (via the EMA) currently works. I get a pretty decent boost in performance, just with net = torch.jit.script(net) right after the net gets created via the make_net() function! :D

There's a lot of slowness from kernel launch overheads right now, and of course the little elementwise operations add up far too quickly for comfort. I'm looking forward to torch.compile once 2.0 is out (unfortunately Colab does not seem to properly support it yet, but once they upgrade some of their GPU drivers and such, hopefully we'll be in business for that! :D)

One downside is that JIT adds a lot of complexity and rigidity to the code, and torch.compile is likely to increase the bug surface by a good bit as well. So as this is in a few senses a researcher's super-rapid-development toolkit first, and world record second (oddly enough), I'm going to try to delay that as much a possible, since that means that every decision that I make after that point will somewhat require one of those two operations in the loop. If need be, we can sort of save it until the final stretch or until absolutely necessary, just to keep the pace of development high.

I truly believe that the equation 1./(time_to_develop_a_new_feature * time_to_test_a_new_feature) is a really good loss function to optimize in research. Even though the space is getting narrower here, I'm able to find more possibilities more quickly because the code runs much faster, which offsets some of the difficulty of development a bit. It honestly sorta reminds me of https://www.youtube.com/watch?v=ffUnNaQTfZE, though there's clearly some limiting factor here, lol.

That is really funny you ask though about the jit, I just tried it tonight after seeing George Hotz working on using this as a benchmark network for the tinygrad library that he's working on. You should try it! ;D If you have any questions or need any help, feel free to ping/DM me on twitter at twitter.com/hi_tysam

Also, I should update this thread proper on the memory side -- we've bumped our minimum from 6 to 6.5 GB, I'm assuming because of the batchsize increase, but I think that's still reasonable. There's probably still some good ways to clean that up a bit! Maybe I'll post some more profiling stuff here sooner/later :D

bonlime commented 1 year ago

thanks for such a large response. I clearly see your concerns about torch.compile introducing some new bugs, but if you make it a flag, then it would be pretty easy to turn on/off and make sure it doesn't break anything. at least this is the way I would implement it. also I would say jit.script doesn't bring that much complexity, it just requires to type annotate everything which would only make the code better and is great in any case.

another suggestion: it seems your currently using a slower version of optimizer. you could try passing foreach=True to merge all kernel launches. it would increase the memory requirement a little bit, but imo this is worth it

tysam-code commented 1 year ago

Oh gosh, I should take a look at the foreach option. Thank you for saying that! Last time I turned it on, I got a huge accuracy drop and I wasn't sure why, but the speed improvements were pretty crazy. I think I'll probably take a look at this next. Fingers crossed!

Thanks for bringing that up, I'd played with it once upon a time, and then it faded into the sands of time. Lots of great thoughts, much appreciated! :D