USGS-R / river-dl

Deep learning model for predicting environmental variables on river systems
Creative Commons Zero v1.0 Universal
21 stars 15 forks source link

Proposed change to weight initialization #139

Closed SimonTopp closed 2 years ago

SimonTopp commented 2 years ago

Currently we initialize with a normal distribution centered at 0 and with a std of 0.02.

https://github.com/USGS-R/river-dl/blob/2b068418159fe6fcbcf9c79271872a8d45091345/river_dl/RGCN.py#L36-L39

In experimenting with the GW loss function, I found that this initialization limits the amounts that these weights can update through the training process (orange'ish distributions below where x is weight value, y is epoch, and z is density distribution). This limited updating of the weights can lead to vanishing gradients. I tried updating our initialization to the Xavier Initialization, and found that it leads to a larger distribution in the weights and more variance in the weights over time (blue plots below).

image

Similarly, we see that overall with the new initialization, we see a consistent increase in model performance throughout training (blue = new initialization, orange = old initialization).

image

What do you all think? Based on this it makes sense to me to update the default initialization scheme.

jzwart commented 2 years ago

seems like a good improvement to me. Also pretty figures 😍

jsadler2 commented 2 years ago

wow. awesome. so my takeaway is that a more diverse set of starting weights results in a more accurate model. is that what your interpretation is?

also - is the plot above the training loss or the testing loss?

jsadler2 commented 2 years ago

also also - I agree with Jake. very nice figs. are those ggplot?

jdiaz4302 commented 2 years ago

I also think it'd be a good idea to swap to this. I wouldn't expect any issues because it is a standard alternative that is appearing to solve the problem that it was designed to solve

jdiaz4302 commented 2 years ago

Somewhat of an aside, but one thing I notice from your epoch loss vs epoch curves is that they don't have the prominent plateau that is stereotypical of these curves. There could definitely be good reasons for this (e.g., early stopping to avoid overfitting like you mentioned in #135 or you've experimented and seen that the plateau generally occurs at 100 epochs so you're avoiding further runtime), but if that was simply overlooked, there could be some further performance gains to see 😃

SimonTopp commented 2 years ago

Sounds good. I'll plan to incorporate this into my next PR, and @jdiaz4302, you're totally right. We typically train for >100 epochs, but I don't think we've done a good job and going in and actually looking at when our training plateaus. Definitely an area where we should be more aware and something that I think #135 will make easy to take into account.

SimonTopp commented 2 years ago

Changed in #141