huggingface / candle

Minimalist ML framework for Rust
Apache License 2.0
14.25k stars 799 forks source link

Bidirectional LSTM? #1040

Open clarkmcc opened 9 months ago

clarkmcc commented 9 months ago

I'm in the process of porting some Tensorflow models over to Candle and I have a bidirectional LSTM layer in my model. Is this something that Candle can/will support in the future?

tf.keras.layers.Bidirectional(
    tf.keras.layers.LSTM(
        self.parameters.lstm_units,
        return_sequences=True)
)
LaurentMazare commented 9 months ago

We could add support for this but it seems pretty easy to do it in the user code with the current version, e.g. with something like:

    let inp_sequence = [3f32, 1., 4., 1., 5., 9., 2.];

    // Create one lstm for the forward pass and one for the backward pass
    let lstm = candle_nn::lstm(2, 3, Default::default(), vb.pp("forward"))?;
    let lstm_rev = candle_nn::lstm(2, 3, Default::default(), vb.pp("backward"))?;

    // Apply the forward lstm and collect the results in states
    let mut states = vec![lstm.zero_state(1)?];
    for &inp in inp_sequence.iter() {
        let inp = Tensor::new(&[[inp, inp * 0.5]], cpu)?;
        let state = lstm.step(&inp, &states.last().unwrap())?;
        states.push(state)
    }

    // Apply the backward lstm and collect the results in states_rv
    let mut states_rev = vec![lstm.zero_state(1)?];
    for &inp in inp_sequence.iter().rev() {
        let inp = Tensor::new(&[[inp, inp * 0.5]], cpu)?;
        let state = lstm_rev.step(&inp, &states_rev.last().unwrap())?;
        states_rev.push(state)
    }

    // Merge the results together
    let states = states
        .into_iter()
        .zip(states_rev.into_iter().rev())
        .collect::<Vec<_>>();
clarkmcc commented 9 months ago

Thanks, I can follow your example for now! I mainly suggested it because I don't know enough about the under-the-hood mechanics to re-implement it on my own.