LaurentMazare / tch-rs

Rust bindings for the C++ api of PyTorch.
Apache License 2.0
4.28k stars 340 forks source link

Scalar multiplication in element-wise Tensor methods #133

Open danieldk opened 4 years ago

danieldk commented 4 years ago

I am implementing the AdamW optimizer for a project. The implementation can benefit from in-place computations, in particular add_, addcmul_, and addcdiv_. The tensor, multiplied tensors, and divided tensors need to be scaled before addition. Unfortunately, the corresponding Tensor methods do not take a scalar, instead the default (in the C++ API) of 1.0 is used. So, instead, some temporaries are needed.

It would be nice if the 'scaling scalar' was also exposed as an argument.

tiberiusferreira commented 4 years ago

Looks like you need to change the default params, like me in https://github.com/LaurentMazare/tch-rs/issues/132 , is that right? If that is so, take a look there to see why they are not exposed.

danieldk commented 4 years ago

Looks like you need to change the default params, like me in #132 , is that right? If that is so, take a look there to see why they are not exposed.

Indeed, it's a default parameter in C++. But the reasoning is sensible. It seems that the slowdown in my case is ~14% (comparing regular Adam to AdamW with fewer in-place operations) on an Quadro RTX 5000. Not dramatic, but it looked like somewhat low-hanging fruit ;).

LaurentMazare commented 4 years ago

Handling default parameters is a bit messy indeed as detailed in the other issue. It would be easy to expose manually some of these but of course that would be a bit on the hacky side. I imagine that the 14% figure is just for the optimizer book-keeping which doesn't include gradient computation, is this right ? (I would imagine gradient computation to be far slower for most use cases but maybe that's not the case)

danieldk commented 4 years ago

I imagine that the 14% figure is just for the optimizer book-keeping which doesn't include gradient computation, is this right ? (I would imagine gradient computation to be far slower for most use cases but maybe that's not the case)

Indeed, I just use Tensor::backward to compute gradients.