sbrunk / storch

GPU accelerated deep learning and numeric computing for Scala 3.
https://storch.dev
Apache License 2.0
113 stars 7 forks source link

Implement update/setter for Tensor #53

Closed sbrunk closed 1 year ago

sbrunk commented 1 year ago

Implement setter with indexing for tensors, analogous to the index/apply method for getting values/slices out.

val t = torch.zeros(Seq(2, 2))
// set first row to ones
t(Seq(0)) = 1 // syntactic sugar for t.update(Seq(0), 1)

I'm not super happy that we have to wrap the assignment indices in a seq, but I haven't found a way to get varargs working with the update method, as varags need to come last. Suggestions for improving the syntax appreciated.

See also https://contributors.scala-lang.org/t/allow-update-method-accept-varargs-of-its-first-parameter/5151

sbrunk commented 1 year ago

We might be able to improve the syntax in the future, for instance with tuples or by using our own assignment operator like t(0,1) := 2. But let's merge what we have now and then iterate.