microsoft / knossos-ksc

Compiler with automatic differentiation
Other
45 stars 10 forks source link

Added 1D broadcasting add, div, and erf #914

Closed ryotatomioka closed 3 years ago

ryotatomioka commented 3 years ago

~Also fixed the existing aten::add definition.~

Fixes #912

cgravill commented 3 years ago

Some extra context on the integer add in aten:

import torch
from math import sqrt, tanh, erf, exp
def vgelu_pytorch(x: torch.Tensor):
    return 0.5 * x * (1 + torch.erf(x / sqrt(2)))

torch.jit.script(vgelu_pytorch).graph

in the extracted TorchScript note %4 as an int:

graph(%x.1 : Tensor):
  %22 : float = prim::Constant[value=1.4142135623730951]()
  %1 : float = prim::Constant[value=0.5]() # <ipython-input-2-e1a969ca3175>:2:11
  %4 : int = prim::Constant[value=1]() # <ipython-input-2-e1a969ca3175>:2:22
  %3 : Tensor = aten::mul(%x.1, %1) # <string>:3:9
  %8 : Tensor = aten::div(%x.1, %22) # <ipython-input-2-e1a969ca3175>:2:36
  %9 : Tensor = aten::erf(%8) # <ipython-input-2-e1a969ca3175>:2:26
  %11 : Tensor = aten::add(%9, %4, %4) # <string>:5:9
  %12 : Tensor = aten::mul(%3, %11) # <ipython-input-2-e1a969ca3175>:2:11
  return (%12)

so C++ is going to merrily implicitly cast the int to a float in many circumstances, but I'm not sure if the same is true throughout Knossos. Any thoughts @toelli-msft ?

I'm not sure why Python is sending through a 1 as int rather than 1.0 float.

acl33 commented 3 years ago

I'm not sure why Python is sending through a 1 as int rather than 1.0 float.

Does it help if you change

def vgelu_pytorch(x: torch.Tensor):
    return 0.5 * x * (1 + torch.erf(x / sqrt(2)))

into

def vgelu_pytorch(x: torch.Tensor):
    return 0.5 * x * (1.0 + torch.erf(x / sqrt(2)))

?

awf commented 3 years ago

I'm not sure why Python is sending through a 1 as int rather than 1.0 float.

Seems correct to me. it's up to the optimizer (i.e. us) to optimize that away if helpful.

cgravill commented 3 years ago

I'm not sure why Python is sending through a 1 as int rather than 1.0 float.

Does it help if you change

def vgelu_pytorch(x: torch.Tensor):
    return 0.5 * x * (1 + torch.erf(x / sqrt(2)))

into

def vgelu_pytorch(x: torch.Tensor):
    return 0.5 * x * (1.0 + torch.erf(x / sqrt(2)))

?

Still passing an int 1 for alpha

def vgelu_pytorch(x: torch.Tensor):
    return 0.5 * x * (1 + torch.erf(x / sqrt(2)))
​
def vgelu_pytorch_floaty(x: torch.Tensor):
    return 0.5 * x * (1.0 + torch.erf(x / sqrt(2)))
​
print(torch.jit.script(vgelu_pytorch).graph)
​
print(torch.jit.script(vgelu_pytorch_floaty).graph)

graph(%x.1 : Tensor): %22 : float = prim::Constant[value=1.4142135623730951]() %1 : float = prim::Constant[value=0.5]() # :2:11 %4 : int = prim::Constant[value=1]() # :2:22 %3 : Tensor = aten::mul(%x.1, %1) # :3:9 %8 : Tensor = aten::div(%x.1, %22) # :2:36 %9 : Tensor = aten::erf(%8) # :2:26 %11 : Tensor = aten::add(%9, %4, %4) # :5:9 %12 : Tensor = aten::mul(%3, %11) # :2:11 return (%12)

graph(%x.1 : Tensor): %10 : int = prim::Constant[value=1]() %22 : float = prim::Constant[value=1.4142135623730951]() %1 : float = prim::Constant[value=0.5]() # :5:11 %4 : float = prim::Constant[value=1.]() # :5:22 %3 : Tensor = aten::mul(%x.1, %1) # :3:9 %8 : Tensor = aten::div(%x.1, %22) # :5:38 %9 : Tensor = aten::erf(%8) # :5:28 %11 : Tensor = aten::add(%9, %4, %10) # :5:9 %12 : Tensor = aten::mul(%3, %11) # :5:11 return (%12)

cgravill commented 3 years ago

I'm not sure why Python is sending through a 1 as int rather than 1.0 float.

Seems correct to me. it's up to the optimizer (i.e. us) to optimize that away if helpful.

Agreed it's up to us to an optimiser to change if desired. I'm still curious if it was a deliberate decision e.g. to detect this case easily by testing against an int 1, or if that's being coerced anyway.

ryotatomioka commented 3 years ago

@cgravill Agreed and I have searched around for the place this decision alpha : int is made but with no success.

awf commented 3 years ago

Top level: I believe the new signature add(Tensor, Float, int) is valid and expected in aten, so this PR is good.

The aten signature is add(Tensor, Scalar, Scalar), and ints are convertible to scalars

https://github.com/pytorch/pytorch/blob/5503a4ac6e0728b74e0945935740c20c89bf2016/c10/core/Scalar.h#L18

@nunoplopes will know for sure

nunoplopes commented 3 years ago

Top level: I believe the new signature add(Tensor, Float, int) is valid and expected in aten, so this PR is good.

The aten signature is add(Tensor, Scalar, Scalar), and ints are convertible to scalars

https://github.com/pytorch/pytorch/blob/5503a4ac6e0728b74e0945935740c20c89bf2016/c10/core/Scalar.h#L18

@nunoplopes will know for sure

add(Tensor, Scalar, Scalar) is a valid overload, yes. And scalar can be any of bool, long, double, complexdouble.