facebookresearch / ParlAI

A framework for training and evaluating AI models on a variety of openly available dialogue datasets.
https://parl.ai
MIT License
10.48k stars 2.1k forks source link

Recommended implementation of multi-turn training #2404

Closed jsedoc closed 1 year ago

jsedoc commented 4 years ago

In order to incorporate a model like HRED into parlai, we need to get triplets (). I've made a few hack with a custom loader to do this, but I was wondering if there is a better/recommended way.

Maybe this has already been done and I missed it somewhere in the documentation. If so, I apologize.

stephenroller commented 4 years ago

TorchAgent's History keeps track of all the turns:

https://github.com/facebookresearch/ParlAI/blob/72c304fa7cac16ed19d8bc75a017f17c8073dd2f/parlai/core/torch_agent.py#L280-L301

So you can use self.history.history_vecs to get the tokenized prior utterances, and self.history.history_strings to get the prior utterances as strings. You can override _set_text_vec to do this: https://github.com/facebookresearch/ParlAI/blob/72c304fa7cac16ed19d8bc75a017f17c8073dd2f/parlai/core/torch_agent.py#L1229

Caching the LSTM hidden state is a little more difficult, since the dialogue state may not be what the model actually witnessed. If you do the above, and you're dissatisfied with inference-time performance, we can talk about how'd you cache this.

jsedoc commented 4 years ago

Yes, indeed, I've finally gotten to the caching problem. I am currently just copying this to a tensor in a circular buffer style manner. Is there a better way?

stephenroller commented 4 years ago

No, that seems like a clever way to do it. I've never trained an HRED model, so I don't know how important the cache invalidation is at training time.

A lot of people have been requesting HRED in ParlAI, so I would very much welcome a PR providing an implementation which we could make one of our main agents.

github-actions[bot] commented 4 years ago

This issue has not had activity in 30 days. Marking as stale.