Closed demq closed 2 years ago
@KexinFeng Can you help to take a look
@demq
From here, it looks that default void set(NDArray index, Number value)
only takes boolean index. I have added the datatype check.
For your case, have you tried setting with NDIndex?
start_logits.set(new NDIndex("{}", nd_bad_tokens_mask), -10000.);
Like this:
// set by index array
original = manager.arange(1, 10).reshape(3, 3);
NDArray index = manager.create(new long[] {0, 1}, new Shape(2));
original.set(new NDIndex("{}, :{}", index, 2), 666);
expected =
manager.create(new int[] {666, 666, 3, 666, 666, 6, 7, 8, 9}, new Shape(3, 3));
Assert.assertEquals(original, expected);
This is in version 0.18.0, which is just released.
I have pulled the latest snapshot for 0.18.0, and the example that you have created works fine.
When I tries to implement the same thing in QATranslator::processOutput(TranslatorContext ctx, NDList list)
NDArray index = manager.create(new boolean[128], new Shape(128));
startLogits.set(new NDIndex("{}", index), 999.f);
I get
Caused by: ai.djl.engine.EngineException: Index put requires the source and destination dtypes match, got Float for the destination and Double for the source.
at ai.djl.pytorch.jni.PyTorchLibrary.torchIndexAdvPut(Native Method)
at ai.djl.pytorch.jni.JniUtils.indexAdvPut(JniUtils.java:473)
at ai.djl.pytorch.engine.PtNDArrayIndexer.set(PtNDArrayIndexer.java:85)
at ai.djl.ndarray.NDArray.set(NDArray.java:470)
at ai.djl.ndarray.NDArray.set(NDArray.java:491)
If I use
startLogits.set( index, 999.f);
all works fine.
Does the NDArray.set(NDIndex ind , Number num) somehow cast num to double ?
It's good to hear that the new feature of set
with NDIndex solves your issue!
About the second question, the descrepency between the two ways of calling set
is again due to the update of the feature. Basically, startLogits.set( index, 999.f);
is still using the old feature, which calls ai.djl.pytorch.jni.PyTorchLibrary.torchMaskedPut
. We are updating it to the new feature too. See this PR. This way, the behaviour will be aligned with PyTorch engine, which requires users to cast num to double.
Calling
NDArray.set(NDArray index, Number value)
with an index being an int64 array on a gpu with PyTorch enginefails with:
The CUDA PyTorch implementation requires a boolean. The fix is to create index as an NDArray of type DataType.BOOLEAN:
boolean[] bad_tokens_mask = new boolean[tokenTypes.size()];
This behavior is not documented, and the type check /translation to a boolean index should be best done in the
NDArray.set()
method itself.