Equim-chan / Mortal

🚀🀄️ A fast and strong AI for riichi mahjong, powered by Rust and deep reinforcement learning.
https://mortal.ekyu.moe
GNU Affero General Public License v3.0
929 stars 118 forks source link

Add data validation for torch.distributions.Normal() #15

Closed hyskylord closed 2 years ago

hyskylord commented 2 years ago

This PR resolve occasional ValueError when training offline.

wongsingfo commented 2 years ago

Hi hyskylord, I am curious about why this error happens. There seems to be some root cause elsewhere (e.g. bugs in libriichi library or invalid training data). It would be helpful if you can provide more details on how to reproduce the error.

hyskylord commented 2 years ago

@wongsingfo I think it is because of special training data but I cannot provide reproducible examples (it only happens around every 100k batch). I don't know why I get mu=NaN sometimes but it is quite reasonable to expect logsig.exp() has some zeros if logsig has some sufficiently small terms (less than -1000 or so)

Equim-chan commented 2 years ago

NaN is caused by some numerical instability issue which is not uncommon in deep learning, especially when AMP is on. I prefer letting the error get raised as is, because sometimes it implies the model is actually malformed (weights become NaN themselves) and some rollback is required.

As for the second term, logsig.exp() collapsing to 0, it is better to just add some constant eps value (eg. 1e-6) to it to make it always stay positive. In fact such mechanism is already implemented in Mortal v2, but not ported to v1 (this repo) yet.