dotnet / TorchSharp

A .NET library that provides access to the library that powers PyTorch.
MIT License
1.4k stars 182 forks source link

slice bug #1399

Closed whuanle closed 2 weeks ago

whuanle commented 2 weeks ago

TorchSharp can't slice like Pytorch.

Python:

arr = torch.asarray([[1,2,3],[4,5,6]])

print(arr[0,2])
print(arr[1,2])

result:

tensor(3)
tensor(6)

C#:

var arr = torch.from_array(new[,] { { 1, 2, 3 }, { 4, 5, 6 } });

var a = arr[0, 2];
a.print(style: TensorStringStyle.Numpy);
var b = arr[1, 2];
b.print(style: TensorStringStyle.Numpy);

result:

Image

But if you get a layer, it's normal. Image

yueyinqiu commented 2 weeks ago

It seems to be the problem of print instead of slice:

Image

yueyinqiu commented 2 weeks ago

Maybe related to #1250

yueyinqiu commented 2 weeks ago

I have found why the problem was not discovered at that time. The unit test Validate_1250 uses torch.zeros(0) to create the tensor. However this is a 1d tensor rather than a scalar (although there is no element inside it).

The unit test should be look like this instead:

        [Fact]
        public void Validate_1250()
        {
            var scalar = torch.zeros(Array.Empty<long>());
            Assert.Equal("0", scalar.npstr());
            Assert.Equal("[], type = Float32, device = cpu, value = 0", scalar.cstr());
        }

By the way, the expected value for cstr is set to the current result. But it's a bit strange, if we compare those three below:

using TorchSharp;

var zero = torch.tensor(0f);
var one = torch.tensor(new float[] { 0, 0 });
var two = torch.tensor(new float[,] { { 0, 0 }, { 0, 0 } });

Console.WriteLine(zero.cstr());
// [], type = Float32, device = cpu, value = 0

Console.WriteLine(one.cstr());
// [2], type = Float32, device = cpu, value = float [] {0f, 0f}

Console.WriteLine(two.cstr());
/*
[2x2], type = Float32, device = cpu, value =
float [,] {
 {0f, 0f},
 {0f, 0f}
}
*/

Maybe I would expect it to be [], type = Float32, device = cpu, value = float 0f.

If we want to follow you should be able to copy the output and paste it into your code, maybe it should be float (0f) to meet the pattern of new xxxxxx.