microsoft / mup

maximal update parametrization (µP)
https://arxiv.org/abs/2203.03466
MIT License
1.24k stars 88 forks source link

fix: dtype for newer torch versions #33

Closed zanussbaum closed 1 year ago

zanussbaum commented 1 year ago

Since torch 1.10, torch.mean will try to infer the dtype if none is supplied. However, if you use it with an input that is of type long (i.e. a tensor of token ids), we get the following error

RuntimeError: mean(): could not infer output dtype. Input dtype must be either a floating point or complex dtype. Got: Long

zanussbaum commented 1 year ago

@microsoft-github-policy-service agree