ctallec / world-models

Reimplementation of World-Models (Ha and Schmidhuber 2018) in pytorch
MIT License
568 stars 131 forks source link

MDRNN losses extremely low due to numerical instability? #39

Open parthjaggi opened 3 years ago

parthjaggi commented 3 years ago

MDRNN training and GMM losses decrease abruptly to very low values, even with gradient clipping. Was this observed in the originally tested repo, or is this result of recent PyTorch versions. Issue persists with higher precision PyTorch configuration as well.

Epoch 0: 2912it [00:18, 158.02it/s, loss=-7490883896016783802368.000000 bce=  0.022669 gmm=-7724973763404514721792.000000 mse=  0.000000]                                                

Epoch 0: 100%|██████████████████████████████| 1936/1936 [00:12<00:00, 157.29it/s, loss=-13451842733942434168832.000000 bce=  0.000828 gmm=-13872212352094781308928.000000 mse=  0.000000]

Epoch 1: 2912it [00:18, 157.59it/s, loss=-16901104332652949798912.000000 bce=  0.000793 gmm=-17429263292277607890944.000000 mse=  0.000000]                                              

Epoch 1: 100%|██████████████████████████████| 1936/1936 [00:12<00:00, 156.85it/s, loss=-19335289690015750160384.000000 bce=  0.000749 gmm=-19939516790304420134912.000000 mse=  0.000000]

Epoch 2: 2912it [00:18, 157.39it/s, loss=-20089711310459944042496.000000 bce=  0.000734 gmm=-20717514125435083948032.000000 mse=  0.000000]                                              

Epoch 2: 100%|███████████████████████████████| 1936/1936 [01:09<00:00, 27.85it/s, loss=-20316329081654105604096.000000 bce=  0.000709 gmm=-20951213785059046719488.000000 mse=  0.000000]