google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
5.84k stars 620 forks source link

More memory consume compared with Pytorch #3766

Open Sun-Xiaohui opened 4 months ago

Sun-Xiaohui commented 4 months ago

Hello, flax team. When I tried to transfer Pytorch's models to the flax framework, I find flax will consume more memory than Pytorch's. For example, a ResNet50 model in Pytorch will consume 4G GPU memory, while it rises to 6G in flax. I wonder what causes the difference in memory consumption between Pytorch and Flax? Or What can I do to reduce memory usage in Flax? Thanks!

Sun-Xiaohui commented 3 months ago

@chiamp Hello, Is it working as expected, or did I make a mistake? Thanks.

chiamp commented 2 months ago

Hi @Sun-Xiaohui, how are you transferring over the model? Could you provide an example code snippet?