google / flax

Flax is a neural network library for JAX that is designed for flexibility.
https://flax.readthedocs.io
Apache License 2.0
6.12k stars 645 forks source link

PixelCNN++ example flakey on TPUs, probably due to bfloat16 #458

Closed j-towns closed 2 years ago

j-towns commented 4 years ago

The PixelCNN++ should reach a test loss value of 2.92, we're finding that that doesn't happen consistently. @marcvanzee is in the process of doing more rigorous testing, running training many times with different random seeds on different versions of the model, with and without JAX omni-staging enabled. The data we have so far (see below) is limited, but we expect to follow up soon with more detail, and hopefully identify whether this problem is real or just bad luck. For now we're focusing on GPU, will look at TPU after.

Training run with

python train.py --batch_size=320 --num_epochs=1800 

Sheet with run statistics: https://docs.google.com/spreadsheets/d/1IDtAKUTHr6MTynKvl5szsepb9_ta-arZU-3yvyj6KQs/edit?usp=sharing

Update by @marcvanzee (Oct 29, 2020)

The performance on GPU is as expected now on Linen. Jamie is investigating TPUs now and was still experiencing some unexpected behavior, so I will assign the issue to him.

Update by @j-towns (Feb 1, 2020)

I think the discrepency between GPU and TPU is likely caused by bfloat16 vs float32 in conv ops. We should test this hypothesis by running on TPU with precision=lax.Precision.HIGHEST, and potentially add a command line argument to allow the user to choose which precision setting to use.

avital commented 4 years ago

@marcvanzee suggested that this may simply be due to variance. Marc can run the old and new versions 3 times each to test this hypothesis.

marcvanzee commented 4 years ago

Given that @j-towns already completed one Linen run and two non-Linen runs, I will do three more runs so that we have three runs each. I expect each run to take about one day, so if all experiments go well I expect to report results early next week. Depending on those results we can see what kind of experiments we want to do next to narrow down the problem.

marcvanzee commented 4 years ago

Update: I did three runs but I wasn't using different random seeds. @jheek suggested to try this, so I will re-run a few more experiments in the coming week, and report back once I have the numbers.

marcvanzee commented 4 years ago

After doing a few more runs in October, it seems the PixelCNN++ example in Linen is now performing just as good as the old one. I've added all runs' statistics in the sheet I've linked in this issue description.

To be honest, I am not entirely sure what was causing the problem... maybe it was some underlying bug in JAX related to omnistaging that is fixed now, but given that the model is now working as expected I am closing this issue.

marcvanzee commented 4 years ago

On second thought, I think @j-towns was telling me that he was still experimenting with TPUs and he was seeing some strange behavior there, so given that he is debugging that now, I will reopen the issue, rename it and assign it to Jamie.

marcvanzee commented 4 years ago

From @j-towns in personal communication: "i've closed it for now because i basically don't have time to run experiments on TPU. I guess if it's working fine on GPU then any bug that still exists is unlikely to be in Flax but more likely in JAX or XLA. so i'm not too worried about it."

j-towns commented 3 years ago

Based on some other generative modelling work which I've been doing on TPU lately, it seems the precision parameter to layers like Conv makes a small but noticable difference to training stability and to test performance. It might be worth finding out whether this affects PixelCNN++ and perhaps adding a command line argument to enable higher precisions.

marcvanzee commented 2 years ago

Closing this since we've decided to drop the PixelCNN++ example. @j-towns please let us know when you move it to a personal repo, then we will make sure to link to it from here!