kuprel / min-dalle

min(DALL·E) is a fast, minimal port of DALL·E Mini to PyTorch
MIT License
3.48k stars 257 forks source link

High memory usage of torch compared to flax #11

Closed phhusson closed 2 years ago

phhusson commented 2 years ago

Hi,

Thanks for this project, you gave me hope to give my AMD GPU some fun :-) (torch is much easier to get to work on it than XLA)

That being said, it doesn't work yet, because the torch variant eats much more RAM than the flax one:

/usr/bin/time -v python image_from_text.py --text='a comfy chair that looks like an avocado' --torch --seed=4
        Maximum resident set size (kbytes): 21.096.396

/usr/bin/time -v python image_from_text.py --text='a comfy chair that looks like an avocado' --seed=4
        Maximum resident set size (kbytes): 4.152.744

Would you have some pointers as to how to optimize that?

haydn-jones commented 2 years ago

IIRC the torch version does not turn off gradient tracking, throwing @torch.no_grad() over the gen function should help a lot.

phhusson commented 2 years ago

Maximum resident set size (kbytes): 4013488

Yup, that was it. Thanks. I can infer both mega and-non mega on my AMD RX 6700 XT 12GB RAM. Mega infers in 45s, non-mega in 16s.

My local patch for GPU inference is attached (rename .csv to .patch)

min-dalle.csv

kuprel commented 2 years ago

I added torch.set_grad_enabled(False) and it seems to have significantly reduced memory usage with torch