dotnet / TorchSharp

A .NET library that provides access to the library that powers PyTorch.
MIT License
1.38k stars 179 forks source link

Ranges when condition gives a zero length tensor #1187

Closed Sprinzl closed 10 months ago

Sprinzl commented 10 months ago

Ok this one is a nice one: You can do something like this: where h is an tensor and target is a double...

h[h > target] = 1

... but this only works if the resulting vector is bigger than 0.

Is where any help or alternative, because thsi is a very ellegant solution. thanks in advance Michael

NiklasGustafsson commented 10 months ago

Hmmm. It works as I had expected it to do.

Here's my repro:

        [Fact]
        public void Validate_1187()
        {
            var target = 1.5f;
            using var h = torch.rand(10);
            using var expected = h.clone();

            // This should do nothing, including blow up.
            h[h > target] = 1.0f;

            Assert.Equal(expected, h);
        }

What was your expected result?

Sprinzl commented 10 months ago

Ok Niklas got the solution: h[h > target] = h[h > target] .copy_(1); The problem arises when the h[h > target] has zero length or is null. (Think......)

Thanks Niklas

The goal of the hole script us to get the first non-zero of the row and when hot-encode it. Whats for labeling for stock market data. I have an LLM Encoder with next step prediction self-subervised training. I want fine-tune as follows profit - neutral - stop loss. Something like this: from: 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 1 1 2 0 0 0 0 0 0 0 0 0 1 1 1 2 0 0 0 0 0 0 0 0 0 0 1 1 2 0 0 0 0 0 0 0 0 0 0 1 1 2 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 2 2 1 0 0 0 0 2 2 0 0 0 2 2 2 1 1 0 0 0 2 2 0 0 0 2 2 2 2 2 1 0 0 0 0 0 0 0 0 0 2 2 0 2 2 to: 1 0 0 0 1 0 0 1 0 0 1 0 0 1 0 1 0 0 0 0 1 0 0 1 0 0 1 0 0 1 I use Rhino3d as visualization for tensors.....

Candles
NiklasGustafsson commented 10 months ago

Is the behavior different from PyTorch? If the assignment is supposed to do what your workaround does, can you give me a simple example in Python that illustrates it?

Sprinzl commented 10 months ago

It is a custom made code. All c#. So i cannot say. I have also a runtime-compiler plugin in the CAD-System which is compiling a vs-project and serializes it before. (So you do not to need to restart the programm - something what if you attach an dll in VS-Code does not work) So where can be a series of problems. I changed from
h[h > target] = 1.0f; to h[h > target] = h[h > target].copy_(1); All tensors in cuda. Where is the spagetti-monster code for what task:

private Tensor LabelEncoder(Tensor bar, double target, int steps) 

{ //Risk-Reward 1:1 //bar Sequence -- Open|H|L|C|WAP|Volume|Ratio|YearMonthDay|HourMinutesSeconds Func<Tensor, int, int,Tensor> hotselect = (x, dim, label) => { var max = torch.max(x != 0, dim: dim).indexes; max = max + arange(x.numel() / x.size(dim), dtype: ScalarType.Int64, device: x.device) x.size(dim); max = torch.nn.functional.one_hot(x.flatten()[max].to(ScalarType.Int64), label); var _n0 = max[.., 0]; var _n1 = max[.., 1]; max[.., 0] = _n1; max[.., 1] = n0; return max; }; int high = 1; int low = 2; int close = 3; var v = bar.clone(); v = v.unsqueeze(0); v = v.unfold(dimension: 1, size: steps, step: 1)[0, .., .., ..].permute(0, 2, 1); //[1269674x15x9] var h = (v[.., 1.., high].permute(1, 0) / v[.., 0, close]).permute(1, 0) - 1.0; h[h > target] = h[h > target].copy(1); h[h != 1] = h[h != 1].copy_(0); var l = ((v[.., 1.., low].permute(1, 0) / v[.., 0, close]).permute(1, 0) - 1.0) -1.0; l[l > target] = l[l > target].copy( 1); l[l != 1] = l[l != 1].copy(0); var mix = torch.zeroslike(h, device: bar.device); mix[(l == 1)] = mix[(l == 1)].copy(1); mix[(h == 1)] = mix[(h == 1)].copy_(2); var sum = mix.sum(1); var idx = hotselect(mix, 1, 3); var extend = torch.zeros(new long[] { steps - 1, 3 }, dtype: idx.dtype, device: idx.device); extend[.., 1] = 1.0; idx = torch.cat(new List() { idx, extend }, dim: 0L); return idx; }

Sprinzl commented 10 months ago

I cannot reproduce the error... no idea. Sorry. The mistake was some where else.