google / trax

Trax — Deep Learning with Clear Code and Speed
Apache License 2.0
8.1k stars 816 forks source link

[BUG] Incorrect readings of crossentrophy loss in Trax 1.3.7 #1328

Open qiwen98 opened 3 years ago

qiwen98 commented 3 years ago

Description

The crossentrophy loss in Trax 1.3.7 gave strange reading. Is this normal? 1

The same code was running in Trax 1.3.6 with this result. 2

The loss layer and evaluation metrics used are both tl.CrossEntropyLoss().

JEF1056 commented 3 years ago

Looks like tl.CrossEntropyLoss() is depreciated: https://trax-ml.readthedocs.io/en/latest/trax.layers.html#trax.layers.metrics.CrossEntropyLoss However, it looks like tl.WeightedCategoryCrossEntropy() might also have a memory leak (at least on TPU) so hold off on switching to trax 1.3.7 for now?