deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.02k stars 640 forks source link

can't support tensor array input #1391

Closed yjmwolf closed 2 years ago

yjmwolf commented 2 years ago

I find djl can't support input format like that:

forward(torch.ESMM self, Tensor x_cate, Tensor[] x_seq, Tensor x_numeric)

I can't implement with NDList for the input with Tensor x_cate, Tensor[] x_seq, Tensor x_numeric. NDList only can implement with Collection, this case should be (NDArray,NDList,NDArray) but can't find the implement. please give a suggestion or enhancement. Thank you very much

frankfliu commented 2 years ago

PyTorch internally use IValue, this is PyTorch specific, we didn't expose IValue at DJL API level. If your model has a nested structure or non-tensor data type, then you have to use IValue directly, see: https://github.com/deepjavalibrary/djl/blob/master/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/IValueTest.java#L181-L193

For your model, you don't actually need to using IValue. If you provide name for NDArray, DJL can convert to IValue: see: https://github.com/deepjavalibrary/djl/blob/master/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValueUtils.java#L76

NDList list = new NDList();
NDArray x_cate = NDManager.create(...);
list.add(x_cate);

for (int i.=0; i < N; ++i) {
   NDArray array = NDManager.create(...);
   array.setName("x_seq[]");
   list.add(array);
}

list.add(x_numberic);
yjmwolf commented 2 years ago

PyTorch internally use IValue, this is PyTorch specific, we didn't expose IValue at DJL API level. If your model has a nested structure or non-tensor data type, then you have to use IValue directly, see: https://github.com/deepjavalibrary/djl/blob/master/engines/pytorch/pytorch-engine/src/test/java/ai/djl/pytorch/integration/IValueTest.java#L181-L193

For your model, you don't actually need to using IValue. If you provide name for NDArray, DJL can convert to IValue: see: https://github.com/deepjavalibrary/djl/blob/master/engines/pytorch/pytorch-engine/src/main/java/ai/djl/pytorch/jni/IValueUtils.java#L76

NDList list = new NDList();
NDArray x_cate = NDManager.create(...);
list.add(x_cate);

for (int i.=0; i < N; ++i) {
   NDArray array = NDManager.create(...);
   array.setName("x_seq[]");
   list.add(array);
}

list.add(x_numberic);

Thanks for reply, it works!