Open pfeatherstone opened 1 year ago
@pfeatherstone ahh, yea, i can look into that
care to share what you are seeing on your dataset with this approach?
I haven't been able to train my models yet just with normal transformers, using larger context lengths (my weird TTS + STT system). CTC loss isn't converging at all. So haven't attempted a proper run with RMT architecture in the STT model. But setting it up with RMT while debugging the other one. I will let you know if i find success. I'm worried that training 2 transformers in tandem simply doesn't work for reasons. Either because of stupidly slow convergence, too lower batch size, or other reasons... Don't know. I've been looking at shifted tokens, scale_norm and other tricks to help with convergence. But i'm not getting any luck. I'm tempted to try RWKV as they claim really fast convergence. Either way, I'm going to need something like RMT in the end so i can have a well defined streaming architecture on the STT side.
oh got it, makes sense
Gave this a go, it turns out that torch.jit.trace()
doesn't accept None
in example_inputs
. So we cannot trace with mems
not None and expect to work when None, or vice versa. My workaround is to pass mems=torch.zeros(B,num_memory_tokens,dim)
in the first pass. Which means you're attending to self.read_memory_emb
ONLY in the first pass. Don't know if that's allowed.
It would be great if the above worked.