Open davoclavo opened 1 year ago
Quick update:
Adding .clone()
seems to fix the issue, I wonder what other operations might require the same fix. Will do a PR soon.
I also wonder what are the implications on calling .clone()
in terms of memory usage or any other computing factor.
--- a/core/src/main/scala/torch/ops/IndexingSlicingJoiningOps.scala
+++ b/core/src/main/scala/torch/ops/IndexingSlicingJoiningOps.scala
@@ -1014,7 +1014,7 @@ private[torch] trait IndexingSlicingJoiningOps {
case i: Int => torchNative.split(input.native, i.toLong, dim.toLong)
case s: Seq[Int] => torchNative.split(input.native, s.map(_.toLong).toArray, dim.toLong)
}
- (0L until result.size()).map(i => Tensor(result.get(i)))
+ (0L until result.size()).map(i => Tensor(result.get(i)).clone())
}
/** Returns a tensor with all specified dimensions of `input` of size 1 removed.
It might have to do something with the fact that split returns a view.
https://pytorch.org/docs/stable/generated/torch.split.html:
Splits the tensor into chunks. Each chunk is a view of the original tensor.
It's just a guess for now, but it would explain why clone()
makes a difference.
I ran into this while implementing tensor printing, which needs to convert tensor values to buffers, and crashed on non-contiguous values, as the memory layout of views can sometimes be non-contiguous.
In this case the view should be contiguous, so it's not exactly the same issue, but it could still be related to being a view.
Interestingly, your example works on my machine (I tried in ammonite too):
object Split extends App {
val data = torch.arange(0L, 1_000_000L)
val Seq(a, b) = torch.split(data, 600_000)
println(a)
println(b)
val x = a
println(x)
}
Perhaps we also need to understand why the Python impl calls split_with_sizes
in certain cases. We might need to do something similar.
def split(self, split_size, dim=0):
r"""See :func:`torch.split`"""
if has_torch_function_unary(self):
return handle_torch_function(
Tensor.split, (self,), self, split_size, dim=dim
)
if isinstance(split_size, Tensor):
try:
split_size = int(split_size)
except ValueError:
pass
if isinstance(split_size, (int, torch.SymInt)):
return torch._VF.split(self, split_size, dim) # type: ignore[attr-defined]
else:
return torch._VF.split_with_sizes(self, split_size, dim)
Hello!
I stumbled upon a fatal error while using
torch.split
+ reassignment of a tensor - not sure how to even start debugging this, but I am documenting it here in case someone knows how to investigate this further.Here is a way to replicate the error.
Trying with other variations.. any operation done after any of the portions of
tensor.split
causes this panic, evena + 1