deepjavalibrary / djl

An Engine-Agnostic Deep Learning Framework in Java
https://djl.ai
Apache License 2.0
4.14k stars 660 forks source link

How can I reorder tensors in place like `x[:,:,:,[2,0,1]]` In PyTorch? #2647

Closed i10416 closed 1 year ago

i10416 commented 1 year ago

Given x of shape (1, 10, 10, 3), I can reorder( and replace in place) tensors in the specific axis by, for example, x[:,:,:,[2,0,1]] In PyTorch. How can I achieve the same result using DJL NDArray API?

batch_size = 1
height = 10
width = 10
channel = 3
x = torch.randn((batch_size, height, width, channel))
x
tensor([[[[ 1.2450,  0.2395, -1.4496],
          [-1.4505,  0.8022, -0.8087],
          [-0.8357, -0.5123,  1.1846],
          [-1.1332,  0.0763,  1.2089],
          [-1.0103, -2.2320,  0.1810],
          [ 0.7712,  0.5609, -0.2574],
          [ 0.3336,  0.4204, -0.4664],
          [ 1.8834,  0.3339, -1.4987],
          [-1.5052,  0.1414,  2.9350],
          [ 0.3335,  0.3214,  1.6047]],

         [[-0.3892,  0.4478, -0.4097],
          [ 1.2167, -1.5380, -0.1554],
          [-2.2246,  0.2458, -0.3464],
          [-1.2612,  0.4891, -1.4027],
          [ 1.6989, -0.1904, -1.4988],
          [ 1.2409,  0.8922,  1.4012],
....
x[:,:,:,[2,0,1]]
tensor([[[[-1.4496,  1.2450,  0.2395],
          [-0.8087, -1.4505,  0.8022],
          [ 1.1846, -0.8357, -0.5123],
          [ 1.2089, -1.1332,  0.0763],
          [ 0.1810, -1.0103, -2.2320],
          [-0.2574,  0.7712,  0.5609],
          [-0.4664,  0.3336,  0.4204],
          [-1.4987,  1.8834,  0.3339],
          [ 2.9350, -1.5052,  0.1414],
          [ 1.6047,  0.3335,  0.3214]],

         [[-0.4097, -0.3892,  0.4478],
          [-0.1554,  1.2167, -1.5380],
          [-0.3464, -2.2246,  0.2458],
          [-1.4027, -1.2612,  0.4891],
          [-1.4988,  1.6989, -0.1904],
          [ 1.4012,  1.2409,  0.8922],
          [ 0.7958,  0.1829,  0.7539],
          [-0.1230,  0.8494, -1.2449],

For now, I use the following code.

//> using scala "3.3.0"
//> using dep "ai.djl:api:0.22.1"
//> using dep "ai.djl:basicdataset:0.22.1"
//> using dep "org.slf4j:slf4j-simple:2.0.7"
//> using dep "ai.djl.pytorch:pytorch-engine:0.22.1"

import ai.djl.*
import ai.djl.ndarray.*
import ai.djl.ndarray.types.*
import ai.djl.ndarray.index.NDIndex

import scala.util.chaining.*
import scala.jdk.CollectionConverters.*

val mg = NDManager.newBaseManager()
val sample = mg.randomNormal(new Shape(1, 10, 10, 3))

val a0 = sample.get(NDIndex(":,:,:,0"))
val a0clone = mg.create(a0.getShape())
a0.copyTo(a0clone)
val a1 = sample.get(NDIndex(":,:,:,1"))
val a1clone = mg.create(a1.getShape())
a1.copyTo(a1clone)
val a2 = sample.get(NDIndex(":,:,:,2"))
val a2clone = mg.create(a2.getShape())
a2.copyTo(a2clone)

sample.set(NDIndex(":,:,:,0"), a2clone)
sample.set(NDIndex(":,:,:,1"), a0clone)
sample.set(NDIndex(":,:,:,2"), a1clone)
KexinFeng commented 1 year ago

We already have full support of pytorch indexing: #1719 and #1755 . See the demos therein.

To your use case, it can be easily done like the following:

val mg = NDManager.newBaseManager()
val sample = mg.randomNormal(new Shape(1, 10, 10, 3))
val indexArray = mg.create(new int {0, 2, 1});

val newArray = sample.get(new NDIndex(":, :, :, {}"), indexArray));
sample.set(new NDIndex(":, :, :, :"), newArray) // if needed
i10416 commented 1 year ago

Ah, I didn't know those PRs. Thanks a lot!

i10416 commented 1 year ago

By the way, is there a correspondence table of PyTorch tensor API vs djl tensor API similar to that of numpy vs breeze(see https://github.com/scalanlp/breeze/wiki/Linear-Algebra-Cheat-Sheet)?

And if there isn't, is it helpful to write such comparison as a Wiki or a document?

zachgk commented 1 year ago

@i10416 We don't have a corresponding table like that. I could see it being useful though! If you are interested in writing one, you can put it in a markdown document PR and we can add it to our docs near http://docs.djl.ai/master/engines/pytorch/index.html