deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.13k stars 655 forks source link

NDArray.set(NDArray index, Number value) failing with int64 index array on gpu #1773

Closed demq closed 2 years ago

demq commented 2 years ago

Calling NDArray.set(NDArray index, Number value) with an index being an int64 array on a gpu with PyTorch engine

NDManager manager = ctx.getNDManager();
NDArray start_logits = list.get(0);
long[] bad_tokens_mask = new long[128];
NDArray nd_bad_tokens_mask = manager.create(bad_tokens_mask);
start_logits.set(nd_bad_tokens_mask, -10000.);

fails with:

Exception in thread "main" ai.djl.translate.TranslateException: ai.djl.engine.EngineException: expected mask dtype to be Bool but got Long Caused by: ai.djl.engine.EngineException: expected mask dtype to be Bool but got Long at ai.djl.pytorch.jni.PyTorchLibrary.torchMaskedPut(Native Method) at ai.djl.pytorch.jni.JniUtils.booleanMaskSet(JniUtils.java:416) at ai.djl.pytorch.engine.PtNDArrayIndexer.set(PtNDArrayIndexer.java:82) at ai.djl.ndarray.index.NDArrayIndexer.set(NDArrayIndexer.java:157) at ai.djl.ndarray.NDArray.set(NDArray.java:469) at ai.djl.ndarray.NDArray.set(NDArray.java:490)

The CUDA PyTorch implementation requires a boolean. The fix is to create index as an NDArray of type DataType.BOOLEAN: boolean[] bad_tokens_mask = new boolean[tokenTypes.size()];

This behavior is not documented, and the type check /translation to a boolean index should be best done in the NDArray.set() method itself.

lanking520 commented 2 years ago

@KexinFeng Can you help to take a look

KexinFeng commented 2 years ago

@demq From here, it looks that default void set(NDArray index, Number value) only takes boolean index. I have added the datatype check.

For your case, have you tried setting with NDIndex? start_logits.set(new NDIndex("{}", nd_bad_tokens_mask), -10000.); Like this:

// set by index array
original = manager.arange(1, 10).reshape(3, 3);
NDArray index = manager.create(new long[] {0, 1}, new Shape(2));
original.set(new NDIndex("{}, :{}", index, 2), 666);
expected =
        manager.create(new int[] {666, 666, 3, 666, 666, 6, 7, 8, 9}, new Shape(3, 3));
Assert.assertEquals(original, expected);

This is in version 0.18.0, which is just released.

demq commented 2 years ago

I have pulled the latest snapshot for 0.18.0, and the example that you have created works fine.

When I tries to implement the same thing in QATranslator::processOutput(TranslatorContext ctx, NDList list)

        NDArray index = manager.create(new boolean[128], new Shape(128));
        startLogits.set(new NDIndex("{}", index), 999.f);

I get

Caused by: ai.djl.engine.EngineException: Index put requires the source and destination dtypes match, got Float for the destination and Double for the source.
    at ai.djl.pytorch.jni.PyTorchLibrary.torchIndexAdvPut(Native Method)
    at ai.djl.pytorch.jni.JniUtils.indexAdvPut(JniUtils.java:473)
    at ai.djl.pytorch.engine.PtNDArrayIndexer.set(PtNDArrayIndexer.java:85)
    at ai.djl.ndarray.NDArray.set(NDArray.java:470)
    at ai.djl.ndarray.NDArray.set(NDArray.java:491)

If I use

startLogits.set( index, 999.f);

all works fine.

Does the NDArray.set(NDIndex ind , Number num) somehow cast num to double ?

KexinFeng commented 2 years ago

It's good to hear that the new feature of set with NDIndex solves your issue!

About the second question, the descrepency between the two ways of calling set is again due to the update of the feature. Basically, startLogits.set( index, 999.f); is still using the old feature, which calls ai.djl.pytorch.jni.PyTorchLibrary.torchMaskedPut. We are updating it to the new feature too. See this PR. This way, the behaviour will be aligned with PyTorch engine, which requires users to cast num to double.