pabloferz / DLPack.jl

Julia interface for dlpack
MIT License
48 stars 4 forks source link

Check the contiguous when wrapping tensor to julia array? #29

Open songjhaha opened 2 years ago

songjhaha commented 2 years ago

A small example:

julia> v = rand(3, 2)
3×2 Matrix{Float64}:
 0.071368    0.486031
 0.00750569  0.53865
 0.416978    0.316323

julia> sub_v = @views v[2:3,:] # which is a StridedArray
2×2 view(::Matrix{Float64}, 2:3, :) with eltype Float64:
 0.00750569  0.53865
 0.416978    0.316323

julia> py_sub_v = DLPack.share(sub_v, torch.from_dlpack)
Python Tensor:
tensor([[0.0075, 0.4170],
        [0.5387, 0.3163]], dtype=torch.float64)

julia> DLPack.wrap(py_sub_v, torch.to_dlpack) # get wrong data because py_sub_v is not C or F contiguous
2×2 Matrix{Float64}:
 0.00750569  0.486031
 0.416978    0.53865

julia> py_sub_v.numpy().flags
Python flagsobj:
  C_CONTIGUOUS : False
  F_CONTIGUOUS : False
  OWNDATA : False
  WRITEABLE : True
  ALIGNED : True
  WRITEBACKIFCOPY : False
  UPDATEIFCOPY : False

Although this situation is not so common in normal usage.

songjhaha commented 2 years ago

@pabloferz could you have a look at this?

gdalle commented 1 week ago

@pabloferz I'm currently wondering how to combine JAX with DifferentiationInterface and your package seems like the way to go! Should I always implement contiguity checks before every transfer or is there an easier solution to this issue?

pabloferz commented 1 week ago

@gdalle DLPack does not impede handling non contiguous arrays (but they need to be strided). That said, unfortunately JAX does check for that (at least it did some time ago, not sure that has changed) and errors if that's not the case, so for JAX in particular I believe you need to make sure arrays are contiguous before transferring.