wangkuiyi / gotorch

A Go idiomatic binding to the C++ core of PyTorch
MIT License
321 stars 35 forks source link

`NewFunctional` only accepts functions with `func(torch.Tensor) torch.Tensor` type #117

Open QiJune opened 4 years ago

QiJune commented 4 years ago

NewFunctional only accepts functions with func(torch.Tensor) torch.Tensor type.

We could write the following codes:

nn.NewFunctional(torch.Tanh)

However, LeakyRelu takes two input parameters.

func LeakyRelu(t Tensor, negativeSlope float64) Tensor {
    return t.LeakyRelu(negativeSlope)
}

We could not write the following codes directly.

nn.NewFunctional(torch.LeakyRelu(0.2))

Maybe we should borrow more features from the functional programming language, like currying in Haskell.

torch.LeakyRelu(0.2) will return a function with func(torch.Tensor) torch.Tensor type. Then, it will work well with NewFunctional.

There is also a project maxsz/curry which provides a way to support currying in Go.

shendiaomo commented 4 years ago

This needs more discussion. At the moment we can use a lambda for this purpose:

nn.NewFunctional(
func (in torch.Tensor) torch.Tensor {
    return torch.LeakyRelu(in, 0.2)
})

With the help of Go+ lambda syntax sugar, this will be clear to write.