tracel-ai / burn

Burn is a new comprehensive dynamic Deep Learning Framework built using Rust with extreme flexibility, compute efficiency and portability as its primary goals.
https://burn.dev
Apache License 2.0
8.66k stars 430 forks source link

GRU: Gated Recurrent Unit #314

Closed nathanielsimard closed 1 year ago

nathanielsimard commented 1 year ago

Feature description

GRU Module

Feature motivation

Still useful in some applications.

agelas commented 1 year ago

@nathanielsimard @antimora You can assign this issue to me- this should be relatively straightforward since GRUs are a simplification of LSTMs and I can pound out this issue when I have time.

Do you want this in a new folder or can I just put this inside the lstm folder? The GateController can be re-used for GRUs, and like I said before GRUs are technically a simplified variation of an LSTM, so I don't think users would be too surprised to find this under the lstm module.

antimora commented 1 year ago

@agelas I assigned it to you per your request. @nathanielsimard can comment on the design.

nathanielsimard commented 1 year ago

@agelas I think you can rename the folder rnn and put the file inside. We might also change the visibility of the module:

pub mod rnn;

instead of:

mod rnn;
pub use rnn::*;

So that it's similar to how transformer and convolution layers are exported.

agelas commented 1 year ago

@nathanielsimard I renamed the folder rnn. As for mod.rs, it looks like the transformer and convolution layers are exported using the latter method, i.e. transformer's mod.rs looks like:

mod decoder;
mod encoder;
mod pwff;

pub use decoder::*;
pub use encoder::*;
pub use pwff::*;
nathanielsimard commented 1 year ago

I think we can re-export all RNN implementations to avoid long imports. Similar to conv and transformer modules.

burn-core/src/nn/rnn/mod.rs

mod lstm;
mod gru;

pub use lstm::*;
pub use gru::*;

So user can import it like.

use burn::nn::rnn::Gru
antimora commented 1 year ago

Closed via https://github.com/burn-rs/burn/pull/393