Open clarkmcc opened 9 months ago
We could add support for this but it seems pretty easy to do it in the user code with the current version, e.g. with something like:
let inp_sequence = [3f32, 1., 4., 1., 5., 9., 2.];
// Create one lstm for the forward pass and one for the backward pass
let lstm = candle_nn::lstm(2, 3, Default::default(), vb.pp("forward"))?;
let lstm_rev = candle_nn::lstm(2, 3, Default::default(), vb.pp("backward"))?;
// Apply the forward lstm and collect the results in states
let mut states = vec![lstm.zero_state(1)?];
for &inp in inp_sequence.iter() {
let inp = Tensor::new(&[[inp, inp * 0.5]], cpu)?;
let state = lstm.step(&inp, &states.last().unwrap())?;
states.push(state)
}
// Apply the backward lstm and collect the results in states_rv
let mut states_rev = vec![lstm.zero_state(1)?];
for &inp in inp_sequence.iter().rev() {
let inp = Tensor::new(&[[inp, inp * 0.5]], cpu)?;
let state = lstm_rev.step(&inp, &states_rev.last().unwrap())?;
states_rev.push(state)
}
// Merge the results together
let states = states
.into_iter()
.zip(states_rev.into_iter().rev())
.collect::<Vec<_>>();
Thanks, I can follow your example for now! I mainly suggested it because I don't know enough about the under-the-hood mechanics to re-implement it on my own.
I'm in the process of porting some Tensorflow models over to Candle and I have a bidirectional LSTM layer in my model. Is this something that Candle can/will support in the future?