Since torch 1.10, torch.mean will try to infer the dtype if none is supplied. However, if you use it with an input that is of type long (i.e. a tensor of token ids), we get the following error
RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Long
Since torch 1.10, torch.mean will try to infer the dtype if none is supplied. However, if you use it with an input that is of type long (i.e. a tensor of token ids), we get the following error
RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Long