register buffers for norm_mean and norm_std so they end up on the proper GPUs (.cuda() always puts them on cuda:0 but the tensors they are interacting with aren't necessarily on that GPU when using multiple GPUs)
switch normalization so we normalize the labels, not the logits
std_dev is set to 1 not 1e-8 for targets where the stddev is 0 (aka all labels are the same) to minimize chance of overflows