mlverse / torch

R Interface to Torch
https://torch.mlverse.org
Other
483 stars 66 forks source link

torch_split() does not work #1159

Closed MaximilianPi closed 1 month ago

MaximilianPi commented 1 month ago

Hi @dfalbel

as in pytorch, the torch split should split a tensor along a specified dimension in N splits, right?:

> A = torch_ones(10, 2)
> A$split(dim = 1L, split_size = 10L)
[[1]]
torch_tensor
 1  1
 1  1
 1  1
 1  1
 1  1
 1  1
 1  1
 1  1
 1  1
 1  1
[ CPUFloatType{10,2} ]

The output should be a list of 10 elements. Am I doing something wrong or is this a bug?

dfalbel commented 1 month ago

split_size actually tells torch how many elements there should be in each split, in your case 10 (which is also the size of the dimension you are splitting)

iff you do A$split(dim = 1, split_size = 2), you should get:

[[1]]
torch_tensor
 1  1
 1  1
[ CPUFloatType{2,2} ]

[[2]]
torch_tensor
 1  1
 1  1
[ CPUFloatType{2,2} ]

[[3]]
torch_tensor
 1  1
 1  1
[ CPUFloatType{2,2} ]

[[4]]
torch_tensor
 1  1
 1  1
[ CPUFloatType{2,2} ]

[[5]]
torch_tensor
 1  1
 1  1
[ CPUFloatType{2,2} ]
MaximilianPi commented 1 month ago

Ah, thanks! My fault