Closed MaximilianPi closed 1 month ago
split_size
actually tells torch how many elements there should be in each split, in your case 10 (which is also the size of the dimension you are splitting)
iff you do A$split(dim = 1, split_size = 2)
, you should get:
[[1]]
torch_tensor
1 1
1 1
[ CPUFloatType{2,2} ]
[[2]]
torch_tensor
1 1
1 1
[ CPUFloatType{2,2} ]
[[3]]
torch_tensor
1 1
1 1
[ CPUFloatType{2,2} ]
[[4]]
torch_tensor
1 1
1 1
[ CPUFloatType{2,2} ]
[[5]]
torch_tensor
1 1
1 1
[ CPUFloatType{2,2} ]
Ah, thanks! My fault
Hi @dfalbel
as in pytorch, the torch split should split a tensor along a specified dimension in N splits, right?:
The output should be a list of 10 elements. Am I doing something wrong or is this a bug?