tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
9.07k stars 450 forks source link

RFC: Adding Forward Trait #667

Closed ArvidHammarlund closed 1 year ago

ArvidHammarlund commented 1 year ago

RFC: Adding Forward Trait

To make all modules of the Burn crate implement a Forward trait.

Feature motivation

This would allow a more functional and thus rusty implementation of the forward pass of user defined modules, akin to the following:


let seq: [&dyn Forward<B, D>; 4] = [
      &self.fc1,
      &self.dropout
      &self.activation
      &self.fc2
];

let res = seq.into_iter().fold(input.clone(), |res, x| x.forward(res));
res + input

A downside would be that it imposes a more sequential framework and hence more restrictive, however, I believe this would induce more modular code since the parts which e.g. uses a residual connection between several layers or require reshaping could be refactored into its own module, since such actions can easily be done at / between the beginning and end of the forward pass without being hindered by the suggested framework.

In the end, I think the sequential nature and the increased modularisation would make the forward passes of the burn ecosystem more concise and quicker to understand for others.

(Optional) Suggest a Solution

To prevent breaking changes the Forward trait could just call the existing forward method until next major release.

antimora commented 1 year ago

Just thinking out loud:

Just a side note: We need to check if a module can work with dyn, which requires the module to be object-safe.

Question: How is this different from the Module Trait defined here, which is also used by every module? link

Also, some forward methods take multiple inputs, and even if the input is single, the types are different (e.g., Tensor<B, 3>, Tensor<B, 2>). So to make it work universally, there should be two additional traits (ForwardArgs and ForwardOutputs), and the users need to implement the conversion.

ArvidHammarlund commented 1 year ago

Regarding the Question: The Module trait has several methods that return Self, so as of my understanding, it wouldn't be possible to make it object safe, and creating a separate Trait for just the forward method would prevent forcing all Modules to have a forward trait; I'm not sure if that is already the case, though.

Regarding the side note: I think a trait can be made object safe even though it returns objects that have generic types, like Tensor<_>, if the generic types are set on the Trait signature rather than the signature of the method; A downside to this is a that the iterator chain can only handle Tensors of a fixed number of dimensions.

Additional thoughts: There could possibly be a performance drop by adopting dynamic dispatch, since it prevents compiler optimisation - though, I have no knowledge on how applicable that is to Burn code, and the impact of GPU architecture on said issue? - but, I don't see the downsides of letting some users opt in to it, if they so wish, so as to enable subjectively ( it might just be me who prefer the functional style ) more succinct code.

ArvidHammarlund commented 1 year ago

Also, some forward methods take multiple inputs, and even if the input is single, the types are different (e.g., Tensor<B, 3>, Tensor<B, 2>). So to make it work universally, there should be two additional traits (ForwardArgs and ForwardOutputs), and the users need to implement the conversion.

I think you could make a separation between those Modules which are base building blocks , like Linear, GELU, LayerNorm, and versions of different NN frameworks, like Transformer or LTSM, when it comes to forward methods that have more than 1 input argument; A Forward Trait would mainly be adopted for the base building block and the thus make the frameworks easier to understand and create on your own. (Embeddings would remain problematic, though, as they require Int rather than Float)

Regarding Tensors with different number of dimensions, my initial thought was that there would be one iterator chain for say Convolution layer and one for the a possible FC so that this wouldn't be a problem - or at least not as frequent so as to allow the user to chose one main dimension count for the majority of layers, and if one layer would require something different the changing of dimensions could be confined in that layers forward pass only, ie convert input immediately and return back to original dimension before returning.

Another possible solution would be to have Tensors of high dimension count with several ones of them being at count 1 if low dimensional Tensors are required; there might be an overhead to this though, and could prevent Type system to catch bugs and mistakes.

nathanielsimard commented 1 year ago

I don't think we should add a Forward trait to Burn. I have thought deeply about it, and this will not be flexible enough to enable all network architectures. Modules can have multiple forward passes with any arguments and do anything, really. That's the point of dynamic graphs.

I acknowledge that we may want to provide a way of implementing sequential modules similar to Keras: https://keras.io/guides/sequential_model/. I don't think we should force the implementation of a trait on all neural network building blocks to provide this API. This is mostly about removing verbosity, so we could use codegen (macros, attributes) to perform such tasks.

#[burn::sequential(input = "Tensor<B, 4>", output = "Tensor<B, 4>")]
pub struct MySequentialModule<B: Backend> {
    fc1: nn::Linear<B>,
    dropout: nn::Dropout<B>,
    fc2: nn::Linear<B>,
    relu: nn::ReLU,
}

This is just a rough idea of an API that would generate the correct forward method. However, I'm still unsure if this is a desirable API since it would create different ways of creating modules in the library, potentially making it harder to understand and harder to debug.

You are free to add any trait/codegen system on top of Burn modules in your crates. By testing abstractions, we could potentially adopt one that helps without losing flexibility.

ArvidHammarlund commented 1 year ago

I wasn't thinking about procedural macros, but that seems indeed to have potential for higher level of abstractions.

As a side note, I've come up with a work around for those who read this and want a sequentialish forward pass:

let x = self.linear.iter().fold(input, |res, e| {
    let res = e.forward(res);
    let res = self.dropout.forward(res);
    self.activation.forward(res)
});