bytedeco / javacpp-presets

The missing Java distribution of native C++ libraries
Other
2.66k stars 740 forks source link

[pytorch] Can't get byte data from torch.ScalarType.Byte tensor #1321

Closed jxtps closed 11 months ago

jxtps commented 1 year ago

I'm using fairly large images and want to avoid the overhead of using floats. However, when I try to read byte data from a torch.ScalarType.Byte tensor I get a weird exception.

Minimal code example to reproduce:

package misc;

import org.bytedeco.javacpp.BytePointer;
import org.bytedeco.javacpp.FloatPointer;
import org.bytedeco.pytorch.Tensor;
import org.bytedeco.pytorch.global.torch;

import java.nio.FloatBuffer;

public class PytorchUint8TestBench {
    public static void main(String[] args) {
        long[] shape = new long[]{1, 3, 20, 20};
        float[] floats = new float[3 * 20 * 20];
        FloatPointer p = new FloatPointer(FloatBuffer.wrap(floats));
        Tensor floatT = org.bytedeco.pytorch.global.torch.from_blob(p, shape);
        Tensor byteT = floatT.to(torch.ScalarType.Byte);
        System.out.println("byteT.dtype().isScalarType(torch.ScalarType.Byte): " + byteT.dtype().isScalarType(torch.ScalarType.Byte));
        BytePointer byteP = byteT.data_ptr_byte(); // <-- Exception here
        byte[] bytes = new byte[3 * 20 * 20];
        byteP.get(bytes);
    }
}

This produces:

byteT.dtype().isScalarType(torch.ScalarType.Byte): true
Exception in thread "main" java.lang.RuntimeException: expected scalar type Char but found Byte
Exception raised from data_ptr at C:\build\aten\src\ATen\core\TensorMethods.cpp:18 (most recent call first):
00007FFAAF449CD200007FFAAF449C70 c10.dll!c10::Error::Error [<unknown file> @ <unknown line number>]
00007FFAAF44975E00007FFAAF449710 c10.dll!c10::detail::torchCheckFail [<unknown file> @ <unknown line number>]
00007FF8804BB95D00007FF8804BB8E0 torch_cpu.dll!at::TensorBase::data_ptr<signed char> [<unknown file> @ <unknown line number>]
00007FFA6E330B5700007FFA6E330AE0 jnitorch.dll!Java_org_bytedeco_pytorch_TensorBase_data_1ptr_1byte [<unknown file> @ <unknown line number>]
0000000003A217B0 <unknown symbol address> !<unknown symbol> [<unknown file> @ <unknown line number>]

    at org.bytedeco.pytorch.TensorBase.data_ptr_byte(Native Method)
    at misc.PytorchUint8TestBench.main(PytorchUint8TestBench.java:18)

Reading the exception it looks like Java_org_bytedeco_pytorch_TensorBase_data_1ptr_1byte is for some reason trying to access TensorBase::data_ptr<signed char> when it should probably be TensorBase::data_ptr<byte>?

Version: 1.10.2 & 1.12.1 on windows. Haven't tried 1.13.1 as snapshots are tricky in my setup.

jxtps commented 1 year ago

Workaround:

        Pointer pp = byteT.data_ptr(); 
        BytePointer byteP = new BytePointer(pp);

I checked that I got reasonable data back and it seems to work. That also allows viewing the data as e.g. ints, which can be useful when working with images.

saudet commented 1 year ago

There are no types named "byte" in C++. I wonder what it's expecting here. @HGuillemet Would you know?

jxtps commented 1 year ago

https://github.com/bytedeco/javacpp-presets/blob/master/pytorch/src/gen/java/org/bytedeco/pytorch/TensorBase.java#L297 reads:

public native @Name("data_ptr<int8_t>") BytePointer data_ptr_byte();

What is int8_t defined as?

HGuillemet commented 1 year ago

Apparently, for pytorch scalar type BYTE is unsigned byte and CHAR is signed byte, so we need to change the Info in the presets to:

               .put(new Info("at::TensorBase::data_ptr<uint8_t>").javaNames("data_ptr_byte"))
               .put(new Info("at::TensorBase::data_ptr<int8_t>").javaNames("data_ptr_char"))
...
               .put(new Info("at::Tensor::item<uint8_t>").javaNames("item_byte"))
               .put(new Info("at::Tensor::item<int8_t>").javaNames("item_char"))

BTW, we could add something for half floats, too, but since there is no java type for them and since we would need to return a ShortPointer, we could as well leave as is and have the users call the raw data_ptr as @jxtps did above.

HGuillemet commented 1 year ago

Or we could use the terms of dtype: int8 and uint8.

Also it would be nicer if item_byte returned a java int (or short) equal to the unsigned byte.

saudet commented 1 year ago

Sounds all good! Please open a pull request :)

saudet commented 1 year ago

Or we could use the terms of dtype: int8 and uint8.

But if you want to do something like that, please be consistent with the C++ API, not the Python API.

jxtps commented 1 year ago

Also it would be nicer if item_byte returned a java int (or short) equal to the unsigned byte.

No, no, please don't! That explodes the memory consumption, which would (at least for my use-case) completely undo the rationale for using bytes in the first place.

A ShortPointer for FP16 sounds good, I guess it'll be a while until JEP 401 lands.

HGuillemet commented 1 year ago

item_byte is the function returning a single value, when the tensor has 0 dimension. data_ptr_int8 and data_ptr_uint8 will both return byte pointers of course.

saudet commented 1 year ago

Or maybe it's OK the way it is now. Java doesn't have types for unsigned bytes, half floats, and what not, so in those cases it makes sense to leave users deal with the raw Pointer themselves. There are a few indexers to deal with those in JavaCPP though, and I've already added something for them in AbstractTensor: https://github.com/bytedeco/javacpp-presets/blob/master/pytorch/src/main/java/org/bytedeco/pytorch/AbstractTensor.java#L98 @jxtps Are you having any problems when calling Tensor.createIndexer()?

jxtps commented 1 year ago
    byte[] bytes = new byte[128];
    bytes[0] = 127;
    Tensor t = Tensor.create(bytes);
    UByteIndexer i = t.createIndexer();
    int a = i.get(0);
    System.out.println("a: " + a);

produces the expected a: 127, so looks ok?

HGuillemet commented 12 months ago

Is this issue solved and can be closed ?

jxtps commented 12 months ago

The original exception still happens on "org.bytedeco" % "pytorch-platform" % "2.0.1-1.5.9":

java.lang.RuntimeException: expected scalar type Char but found Byte
Exception raised from data_ptr at C:\actions-runner\_work\pytorch\pytorch\builder\windows\pytorch\build\build\aten\src\ATen\core\TensorMethods.cpp:20 (most recent call first):
00007FFF39AAD24200007FFF39AAD1E0 c10.dll!c10::Error::Error [<unknown file> @ <unknown line number>]
00007FFF39AACE8A00007FFF39AACE30 c10.dll!c10::detail::torchCheckFail [<unknown file> @ <unknown line number>]
00007FFE0C4B80E300007FFE0C4B8000 torch_cpu.dll!at::TensorBase::data_ptr<signed char> [<unknown file> @ <unknown line number>]
00007FFF17D3DC1700007FFF17D3DBA0 jnitorch.dll!Java_org_bytedeco_pytorch_TensorBase_data_1ptr_1byte [<unknown file> @ <unknown line number>]
0000000002BD17B0 <unknown symbol address> !<unknown symbol> [<unknown file> @ <unknown line number>]

We should probably either fix it, or remove the TensorBase.data_ptr_byte()?

HGuillemet commented 12 months ago

It has already been renamed to data_ptr_char() since PR #1360

saudet commented 12 months ago

It has already been renamed to data_ptr_char() since PR #1360

Why did you rename that? Please don't break backward compatibility just for fun

saudet commented 12 months ago

I assumed the plan was to add getters for uint8_t as well, that hasn't been done.

HGuillemet commented 12 months ago

Right, we can add back data_ptr_byte for data_ptr<uint8_t> and item_byte for item<uint8_t>

saudet commented 12 months ago

Yes, at least that would kind of make sense.

HGuillemet commented 12 months ago

A detail to be decided: should item_byte (used when the tensor type is uint8_t) return a byte (signed), short or int ? I'd go for int

jxtps commented 12 months ago

int sounds great. I trust the reintroduced data_ptr_byte won't crash?

HGuillemet commented 12 months ago

Sure, as long as your tensor is ScalarType.Byte. Concerning item<*>, in fact * is the type we want the result cast into, while in data_ptr<*> * must match the type of the tensor. We can do t.item_int() even if t contains floats or bytes. So no need for a item_byte.

saudet commented 12 months ago

int item_byte() sounds fine though.

HGuillemet commented 12 months ago

It would be the same as int item_int(). item<X> is a macro to item().toX where item() returns a Scalar of the same type as the tensor.