allenai / allennlp

An open-source NLP research library, built on PyTorch.
http://www.allennlp.org
Apache License 2.0
11.74k stars 2.24k forks source link

How to obtain the final hidden states for all layers of a seq2vec encoder? #2411

Closed bryant1410 closed 4 years ago

bryant1410 commented 5 years ago

This is something I think it cannot be done with AllenNLP currently, so it can turn into a feature request. Or maybe it's just that I don't know how to do it. It's about getting the final hidden states for all layers of a seqvec encoder.

I'm trying to implement a model such as the one from TGIF-QA. The thing is that, if I encode the video with a 2-layer LSTM with a Seq2VecEncoder (such as PytorchSeq2VecWrapper), I don't know how can I obtain all hidden states instead of just the one from the last layer. I need them to then pass them as initial hidden states of another seq2vec encoder. The architecture is the following:

TGIF-QA architecture

This neither can be done with a seq2seq encoder such as PytorchSeq2SeqWrapper because it returns the sequence of the last layer.

From my understanding, this cannot be done with AllenNLP out-of-the-box. I'd be happy to send a patch, but what could be the best way to do it? I believe this should be provided as a seq2vec model (cause conceptually it's what it is) but I'm reluctant to modify PytorchSeq2VecWrapper cause it would be a breaking change. Maybe adding a flag to it, to enable this behavior? Another class?

bryant1410 commented 5 years ago

Also (maybe it should be another issue?), if using LSTM particularly (instead of other RNN types), I'd want to get the final cell state of each layer (not hidden state) and input it as the initial cell state of the following one. What would be the best way to introduce this change?

matt-gardner commented 5 years ago

We can't have this functionality in the Seq2VecEncoder and Seq2SeqEncoder APIs, because not all encoders that match the API have layers like you're describing. That's the whole point of the abstraction - anything that conforms to the abstraction is interchangeable as far as the model is concerned.

It sounds like you should just be using a pytorch LSTM or whatever encoder you want directly, instead of trying to use our abstractions. It's just fine to not use our abstractions when they don't fit what you want to do - just use whatever pieces work for you.

bryant1410 commented 5 years ago

We can't have this functionality in the Seq2VecEncoder and Seq2SeqEncoder APIs, because not all encoders that match the API have layers like you're describing. That's the whole point of the abstraction - anything that conforms to the abstraction is interchangeable as far as the model is concerned.

Oh, don't all implementations have a concept of "return the final hidden state for each layer"? Could you please point out which ones?

I agree that what you said applies to what I said after, the thing about obtaining the final cell state of the LSTMs. Although it could be seen as part of the "final hidden state".

It sounds like you should just be using a pytorch LSTM or whatever encoder you want directly, instead of trying to use our abstractions. It's just fine to not use our abstractions when they don't fit what you want to do - just use whatever pieces work for you.

Yeah, I agree that AllenNLP cannot support every corner case and that it's okay not to use its abstractions. But isn't it a bit common to want to chain RNN encoders in this way? That's why I think it should be provided. Maybe it's not common, I don't know.

Thanks!

matt-gardner commented 5 years ago

Not all of these encoders even have layers, or hidden states - the CnnEncoder and the BagOfEmbeddings encoder, for instance.

It's not that "AllenNLP doesn't support this use case" - it does, because you can just use a regular LSTM. It's that "this use case doesn't fit the Seq2VecEncoder or the Seq2SeqEncoder abstractions". That's fine. The abstractions are useful precisely because they do just one thing; an abstraction that tries to do everything quickly becomes meaningless. It's possible that there's another abstraction hidden here that might be useful to add, but at a glance it looks like only RNNs fit this abstraction, and that abstraction already exists: torch.nn.modules.RNNBase.

Maybe your issue is that you can't build a pytorch LSTM from_params, so taking all of its options in a constructor is annoying? Yes, I can see that that would be a problem. I think the solution there is to provide some kind of wrapper that adds from_params to a pytorch LSTM, though, not to add some new abstraction or modify one of our existing ones.

bryant1410 commented 5 years ago

It's not that "AllenNLP doesn't support this use case" - it does, because you can just use a regular LSTM. It's that "this use case doesn't fit the Seq2VecEncoder or the Seq2SeqEncoder abstractions". That's fine. The abstractions are useful precisely because they do just one thing; an abstraction that tries to do everything quickly becomes meaningless. It's possible that there's another abstraction hidden here that might be useful to add, but at a glance it looks like only RNNs fit this abstraction, and that abstraction already exists: torch.nn.modules.RNNBase.

Then, do you think it's worth it to add a function to RNNBase to obtain the final hidden state for each layer (and maybe also cell state?)? Or is it too application-specific for AllenNLP?

Maybe your issue is that you can't build a pytorch LSTM from_params, so taking all of its options in a constructor is annoying? Yes, I can see that that would be a problem. I think the solution there is to provide some kind of wrapper that adds from_params to a pytorch LSTM, though, not to add some new abstraction or modify one of our existing ones.

I haven't run into this yet (haven't even tried). Are you suggesting to add a PyTorch LSTM wrapper to AllenNLP, or just to my project?

matt-gardner commented 5 years ago

I don't know enough of what exactly you're doing to suggest specific solutions, other than to say that it looks like the already-existing LSTM will do what you want, and you shouldn't need anything else.

If you just need a function to get the final hidden state from an encoder, you can use this: https://github.com/allenai/allennlp/blob/c1aace7e3e07ecb6d4e7f806b2157f29a71cfab9/allennlp/nn/util.py#L172-L188

If that doesn't work because you have multiple layers, you can collapse the layers and the batch dimension, or split the layers and call this on each one, or something similar.

If none of this solves your issue, you're going to need to be more specific about what exactly you're looking for.

bryant1410 commented 5 years ago

Ok, thank you @matt-gardner. I'll work on this and get back to this issue. So I'll leave it open meanwhile if you don't mind.

bryant1410 commented 5 years ago

@matt-gardner could you give me feedback on the following patch regarding this? https://github.com/allenai/allennlp/commit/1d9464b11dc7c597a7a13b4cb36cfa54e55eb2d7

This would allow to chain RNNs given the last hidden state of every layer, including LSTMs (that return state and memory).

DeNeutoy commented 4 years ago

Closing, the PR ran into some issues and it's unclear if this is at all a common usecase, I think this would be better outside of the core library