mlverse / torch

R Interface to Torch
https://torch.mlverse.org
Other
490 stars 64 forks source link

torch_arange() inconsistent with PyTorch #1138

Closed lawremi closed 4 months ago

lawremi commented 7 months ago

There are at least a couple inconsistencies with torch.arange().

In PyTorch:

  1. Passing a start value without an end value generates the range [0, start). That is convenient in the same way that seq(x) is convenient.
  2. Passing all integer values for start, end and step (which should default to 1L in R, not 1) yields an int64 tensor, not a float tensor, unless overridden by dtype. In principle, R torch could be even smarter, and infer that arguments are meant to be integer (like seq() does), but having it work for formal integers would be a great start.

In other words, it would be nice if these were TRUE:

identical(torch_arange(5L), torch_arange(0L, 4L))
identical(torch_arange(0L, 4L), torch_arange(0, 4, dtype = torch_int64()))