Hello, flax team. When I tried to transfer Pytorch's models to the flax framework, I find flax will consume more memory than Pytorch's. For example, a ResNet50 model in Pytorch will consume 4G GPU memory, while it rises to 6G in flax. I wonder what causes the difference in memory consumption between Pytorch and Flax? Or What can I do to reduce memory usage in Flax? Thanks!
Hello, flax team. When I tried to transfer Pytorch's models to the flax framework, I find flax will consume more memory than Pytorch's. For example, a ResNet50 model in Pytorch will consume 4G GPU memory, while it rises to 6G in flax. I wonder what causes the difference in memory consumption between Pytorch and Flax? Or What can I do to reduce memory usage in Flax? Thanks!