shindavid / AlphaZeroArcade

6 stars 0 forks source link

Stochastic weight averaging #28

Open shindavid opened 1 year ago

shindavid commented 1 year ago

The KataGo paper has the following:

Every roughly 250,000 training samples, a snapshot of the weights is saved, and every four snapshots, a new candidate neural net is produced by taking an exponential moving average of snapshots with decay = 0.75 (averaging four snapshots of lookback)

Implement this and do experiments that demonstrate the value of this. Ask David Wu about this if experiments are inconclusive.

shindavid commented 10 months ago

I experimented a bit with SWA, using an EMA with a constant learning rate rather than the exact KataGo methodology. My work is in branch swa.

On a per-generation basis, this did not make learning worse, but it wasn't clearly better. Measured by overall runtime, however, it was clearly worse. Each generation took significantly longer because I had to do another pass through the dataset on each generation to recalibrate the batch normalization layers of the network (via torch.optim.swa_utils.update_bn()). Without the update_bn() call, network quality was quite clearly worse.

I think I can avoid the separate update_bn() call by maintaining batch normalization layer stats "online" while doing the main forward pass. However, the lack of clear improvement even on a per-generation basis discouraged me from going further down this path.

It's quite possible that either:

  1. I have a conceptual misunderstanding
  2. I have a bug
  3. Experimenting further with parameters or non-constant learning rate schedules will work