mfajnberg / tensorevo

0 stars 1 forks source link

Define trait aliases combining `Tensor` with operator overloading #5

Closed daniil-berg closed 11 months ago

daniil-berg commented 1 year ago

Background

To overload basic mathematical operators like + or *, Rust has special traits in its standard library, such as Add and Mul.

Motivation

Having access to those operators when dealing with types that have the Tensor trait would be very convenient and make a lot of code much more readable.

ToDo

The obvious way forward would be to try and add those operator traits as trait bounds to Tensor (and of course implement them accordingly for NDTensor).

The way forward is to rely on trait aliases that combine the basic tensor interface with those operator traits that make sense for our purposes. We would then have to implement all of those operator traits for NDTensor.

daniil-berg commented 11 months ago

Upon further investigation, it seems we need to approach this a bit differently.

Arithmetic operations primarily for borrowed Tensor types

First of all, the arithmetic operations should mainly be available for references to types implementing the Tensor trait. It should be possible read the data contained in tensors to do those operations without taking ownership of it. In other words:

fn do_things<T: Tensor>(a: &T, b: &T) -> T {
    let c = a + b;
    ...
    return c;
}

Whether or not it is useful to have those operations available on the (owned) Tensor types themselves is another question. The fact of the matter is that having a: T and b: T (as opposed to a: &T and b: &T) and doing a + b will consume the data contained in both.

Therefore we should first make it possible to use the operators with references to Tensor types.

Defining bounds for the reference type is hard

It is easy to enforce bounds on the trait itself via supertraits:

use std::ops::Add;

trait Tensor: Add<Output = Self> {}

But this is not what we want (primarily). How can we declare bounds on the reference to a Tensor type?

One way is by using Higher-Rank Trait Bounds (HRTB):

trait Tensor
where for<'a> &'a Self: Add<Output = Self>
{}

Essentially saying that for any given lifetime 'a the type implementing Tensor must implement Add<...> for its reference (&Self).

The problem is that this trait bound of Tensor is (for some reason) not automatically implied, when the compiler reads something like this:

fn do_stuff<T: Tensor>(x: &T) {}

It will complain and tell us that we need to amend the where for<'a> &'a T: Add bound.

This is annoying and redundant and there is the already accepted RFC 2089 to fix this. But it has still (after 6 years) not been implemented.

The consequence would be that we would have to add those where for ... bounds _everywhere we use the Tensor bound... And for every std::ops trait (Mul, Sub, etc.). This would look absolutely horrendous and is not an option IMO.

Trait aliases

Fortunately, it seems that there is another feature that can help us achieve what we want: Trait aliases.

As opposed to implied traits, this is not just an accepted RFC, but actually fully implemented, just not yet fully documented and stabilized. Consequently, it is only available in Rust's nightly channel behind the feature flag trait_alias.

But they seem to save us here, by allowing the following:

// Notice the equals sign as opposed to the colon:
trait Tensor = Clone + Debug + ... + where for<'a> &'a Self: Add<Output = Self>

As opposed to a sub-trait (i.e. trait with bounds) of Clone and Debug and so on, this is just an alias for the combination. It cannot be implemented itself, but we can use it as a bound for type variables like before:

fn do_stuff<T: Tensor>(x: &T) {}

To be able to pass a variable of some type &T into do_stuff, that type must implement Clone, Debug etc. and its reference must implement Add.

Unlike with the super-trait notation, the compiler will not force us to repeat any of the trait bounds here.

Proposed interface change

We rename our base trait from Tensor to OwnedTensor (for example).

We define a bunch of trait aliases to use throughout the code wherever those traits are needed. For example:

trait OwnedTensor: Clone + ... {
    fn dot(&self, rhs: &Self) -> Self;

    ...
}

trait TensorAdd = OwnedTensor + where for<'a> &'a Self: Add<Output = Self>;
trait TensorSub = OwnedTensor + where for<'a> &'a Self: Sub<Output = Self>;
...
trait TensorArithmetic = TensorAdd + TensorSub + ...;

trait Tensor = OwnedTensor + TensorArithmetic + ...;

This will allow us more fine-grained control over what types are accepted in what context. For example, generally speaking, it is a good idea to make parameter types of functions as broad as possible, i.e. if all a function depends on is a OwnedTensor type, whose reference can be added to another tensor, it is unnecessary to enforce the entire Tensor interface for that parameter. Instead TensorAdd should be sufficient.

The more complex types like Individual that use most or all of the tensor capabilities can still use Tensor (which combines all of them) as the overarching trait bound for everything it deals with.

The user will only be forced to implement things like Add etc. for his own tensor-like type in so far as he intends to use those of our functions that depend on those traits.

By the way, the same probably makes sense for Deserialize/Serialize. It is probably not necessary to have those as supertraits of OwnedTensor since a lot of things can be done with those without needing (de-)serialization.