Closed daniil-berg closed 11 months ago
Upon further investigation, it seems we need to approach this a bit differently.
Tensor
typesFirst of all, the arithmetic operations should mainly be available for references to types implementing the Tensor
trait. It should be possible read the data contained in tensors to do those operations without taking ownership of it. In other words:
fn do_things<T: Tensor>(a: &T, b: &T) -> T {
let c = a + b;
...
return c;
}
Whether or not it is useful to have those operations available on the (owned) Tensor
types themselves is another question. The fact of the matter is that having a: T
and b: T
(as opposed to a: &T
and b: &T
) and doing a + b
will consume the data contained in both.
Therefore we should first make it possible to use the operators with references to Tensor
types.
It is easy to enforce bounds on the trait itself via supertraits:
use std::ops::Add;
trait Tensor: Add<Output = Self> {}
But this is not what we want (primarily). How can we declare bounds on the reference to a Tensor
type?
One way is by using Higher-Rank Trait Bounds (HRTB):
trait Tensor
where for<'a> &'a Self: Add<Output = Self>
{}
Essentially saying that for any given lifetime 'a
the type implementing Tensor
must implement Add<...>
for its reference (&Self
).
The problem is that this trait bound of Tensor
is (for some reason) not automatically implied, when the compiler reads something like this:
fn do_stuff<T: Tensor>(x: &T) {}
It will complain and tell us that we need to amend the where for<'a> &'a T: Add
bound.
This is annoying and redundant and there is the already accepted RFC 2089 to fix this. But it has still (after 6 years) not been implemented.
The consequence would be that we would have to add those where for ...
bounds _everywhere we use the Tensor
bound... And for every std::ops
trait (Mul
, Sub
, etc.). This would look absolutely horrendous and is not an option IMO.
Fortunately, it seems that there is another feature that can help us achieve what we want: Trait aliases.
As opposed to implied traits, this is not just an accepted RFC, but actually fully implemented, just not yet fully documented and stabilized. Consequently, it is only available in Rust's nightly channel behind the feature flag trait_alias
.
But they seem to save us here, by allowing the following:
// Notice the equals sign as opposed to the colon:
trait Tensor = Clone + Debug + ... + where for<'a> &'a Self: Add<Output = Self>
As opposed to a sub-trait (i.e. trait with bounds) of Clone
and Debug
and so on, this is just an alias for the combination. It cannot be implemented itself, but we can use it as a bound for type variables like before:
fn do_stuff<T: Tensor>(x: &T) {}
To be able to pass a variable of some type &T
into do_stuff
, that type must implement Clone
, Debug
etc. and its reference must implement Add
.
Unlike with the super-trait notation, the compiler will not force us to repeat any of the trait bounds here.
We rename our base trait from Tensor
to OwnedTensor
(for example).
We define a bunch of trait aliases to use throughout the code wherever those traits are needed. For example:
trait OwnedTensor: Clone + ... {
fn dot(&self, rhs: &Self) -> Self;
...
}
trait TensorAdd = OwnedTensor + where for<'a> &'a Self: Add<Output = Self>;
trait TensorSub = OwnedTensor + where for<'a> &'a Self: Sub<Output = Self>;
...
trait TensorArithmetic = TensorAdd + TensorSub + ...;
trait Tensor = OwnedTensor + TensorArithmetic + ...;
This will allow us more fine-grained control over what types are accepted in what context. For example, generally speaking, it is a good idea to make parameter types of functions as broad as possible, i.e. if all a function depends on is a OwnedTensor
type, whose reference can be added to another tensor, it is unnecessary to enforce the entire Tensor
interface for that parameter. Instead TensorAdd
should be sufficient.
The more complex types like Individual
that use most or all of the tensor capabilities can still use Tensor
(which combines all of them) as the overarching trait bound for everything it deals with.
The user will only be forced to implement things like Add
etc. for his own tensor-like type in so far as he intends to use those of our functions that depend on those traits.
By the way, the same probably makes sense for Deserialize
/Serialize
. It is probably not necessary to have those as supertraits of OwnedTensor
since a lot of things can be done with those without needing (de-)serialization.
Background
To overload basic mathematical operators like
+
or*
, Rust has special traits in its standard library, such asAdd
andMul
.Motivation
Having access to those operators when dealing with types that have the
Tensor
trait would be very convenient and make a lot of code much more readable.ToDo
The obvious way forward would be to try and add those operator traits as trait bounds toTensor
(and of course implement them accordingly forNDTensor
).The way forward is to rely on trait aliases that combine the basic tensor interface with those operator traits that make sense for our purposes. We would then have to implement all of those operator traits for
NDTensor
.