Closed ryotatomioka closed 3 years ago
Some extra context on the integer add in aten:
import torch
from math import sqrt, tanh, erf, exp
def vgelu_pytorch(x: torch.Tensor):
return 0.5 * x * (1 + torch.erf(x / sqrt(2)))
torch.jit.script(vgelu_pytorch).graph
in the extracted TorchScript note %4 as an int:
graph(%x.1 : Tensor):
%22 : float = prim::Constant[value=1.4142135623730951]()
%1 : float = prim::Constant[value=0.5]() # <ipython-input-2-e1a969ca3175>:2:11
%4 : int = prim::Constant[value=1]() # <ipython-input-2-e1a969ca3175>:2:22
%3 : Tensor = aten::mul(%x.1, %1) # <string>:3:9
%8 : Tensor = aten::div(%x.1, %22) # <ipython-input-2-e1a969ca3175>:2:36
%9 : Tensor = aten::erf(%8) # <ipython-input-2-e1a969ca3175>:2:26
%11 : Tensor = aten::add(%9, %4, %4) # <string>:5:9
%12 : Tensor = aten::mul(%3, %11) # <ipython-input-2-e1a969ca3175>:2:11
return (%12)
so C++ is going to merrily implicitly cast the int to a float in many circumstances, but I'm not sure if the same is true throughout Knossos. Any thoughts @toelli-msft ?
I'm not sure why Python is sending through a 1 as int rather than 1.0 float.
I'm not sure why Python is sending through a 1 as int rather than 1.0 float.
Does it help if you change
def vgelu_pytorch(x: torch.Tensor):
return 0.5 * x * (1 + torch.erf(x / sqrt(2)))
into
def vgelu_pytorch(x: torch.Tensor):
return 0.5 * x * (1.0 + torch.erf(x / sqrt(2)))
?
I'm not sure why Python is sending through a 1 as int rather than 1.0 float.
Seems correct to me. it's up to the optimizer (i.e. us) to optimize that away if helpful.
I'm not sure why Python is sending through a 1 as int rather than 1.0 float.
Does it help if you change
def vgelu_pytorch(x: torch.Tensor): return 0.5 * x * (1 + torch.erf(x / sqrt(2)))
into
def vgelu_pytorch(x: torch.Tensor): return 0.5 * x * (1.0 + torch.erf(x / sqrt(2)))
?
Still passing an int 1 for alpha
def vgelu_pytorch(x: torch.Tensor):
return 0.5 * x * (1 + torch.erf(x / sqrt(2)))
def vgelu_pytorch_floaty(x: torch.Tensor):
return 0.5 * x * (1.0 + torch.erf(x / sqrt(2)))
print(torch.jit.script(vgelu_pytorch).graph)
print(torch.jit.script(vgelu_pytorch_floaty).graph)
graph(%x.1 : Tensor):
%22 : float = prim::Constant[value=1.4142135623730951]()
%1 : float = prim::Constant[value=0.5]() #
graph(%x.1 : Tensor):
%10 : int = prim::Constant[value=1]()
%22 : float = prim::Constant[value=1.4142135623730951]()
%1 : float = prim::Constant[value=0.5]() #
I'm not sure why Python is sending through a 1 as int rather than 1.0 float.
Seems correct to me. it's up to the optimizer (i.e. us) to optimize that away if helpful.
Agreed it's up to us to an optimiser to change if desired. I'm still curious if it was a deliberate decision e.g. to detect this case easily by testing against an int 1, or if that's being coerced anyway.
@cgravill Agreed and I have searched around for the place this decision alpha : int
is made but with no success.
Top level: I believe the new signature add(Tensor, Float, int)
is valid and expected in aten, so this PR is good.
The aten signature is add(Tensor, Scalar, Scalar)
, and ints are convertible to scalars
@nunoplopes will know for sure
Top level: I believe the new signature
add(Tensor, Float, int)
is valid and expected in aten, so this PR is good.The aten signature is
add(Tensor, Scalar, Scalar)
, and ints are convertible to scalars@nunoplopes will know for sure
add(Tensor, Scalar, Scalar)
is a valid overload, yes. And scalar can be any of bool, long, double, complexdouble.
~Also fixed the existing aten::add definition.~
Fixes #912