Closed hadaev8 closed 3 years ago
Hi,
This is indeed a complicated issue to solve. Is training using the recurrent model very slow?
The batch vs recurrent models are completely compatible in terms of weights so it might be possible to create a utility function that replaces the parameters of one with the other's. This way you can use either the batch model or the recurrent model and the gradients will be accumulating to the same parameters. Something like the following:
from fast_transformers.recurrent import make_recurrent_mirror
batch_transformer = builder.ge()
recurrent_transformer = make_recurrent_mirror(batch_transformer)
for batch in dataloader:
optimizer.zero_grad()
teacher_forcing(batch_transformer, batch)
professor_forcing(recurrent_transformer, batch)
optimizer.step()
Let me know what you think.
Angelos
Is training using the recurrent model very slow
Yes, it should be slow, still, I think I can achieve better model quality with this.
Im also using my own type of layer (so I can use rezero and another activation function, or memory attention, or something else).
So, i will need also my own alteration for rnn like layers. Bit unconvinient.
Well, i tried it like this
self.transformer_decoder = TransformerDecoder([TransformerDecoderLayer(
self_attention=AttentionLayer(FullAttention(), hidden, nhead),
cross_attention=AttentionLayer(FullAttention(), hidden, nhead), d_model=hidden) for _ in range(1)])
self.transformer_decoder_rnn = RecurrentTransformerDecoder([RecurrentTransformerDecoderLayer(
self_attention=RecurrentAttentionLayer(RecurrentFullAttention(), hidden, nhead),
cross_attention=RecurrentAttentionLayer(RecurrentCrossFullAttention(), hidden, nhead), d_model=hidden) for _ in range(1)])
for name, param in self.transformer_decoder_rnn.named_parameters():
param = getattr(self.transformer_decoder, name)
But its not really good at handling moduledict.
Maybe i miss simpler solution?
You would need a custom setattr
function that can handle nested parameters. The following simple recursive solution should work fine and I have pushed it to master in fast_transformers.utils
.
def make_mirror(src_module, dst_module):
"""Sets the parameters of src_module to dst_module so that they share the
same parameters.
Most noteable usecase is to make a recurrent transformer mirror of a batch
transformer for fast inference.
Arguments
---------
src_module: Module to take the parameters from
dst_module: Module to set the parameters to
Returns
-------
None, it changes dst_module in place
"""
def setattr_recursive(mod, key, value):
key, *next_key = key.split(".", maxsplit=1)
if not next_key:
setattr(mod, key, value)
else:
setattr_recursive(getattr(mod, key), next_key[0], value)
for name, param in src_module.named_parameters():
setattr_recursive(dst_module, name, param)
Let me know if you encounter any problems.
@angeloskath Seems like it works perfectly, thanks.
@angeloskath Should you advise, does such sampling increase memory usage? Seems like training now uses around x8 more memory. I wonder if it is my mistake or its ok?
It increases memory significantly because it has to save all the intermediate states in order to compute gradients. When in batch mode, our kernel computes the states on the fly which requires approximately query_dimensions
times lower memory.
Obviously this is only for the attention part so it does not mean 32 or 64 times higher memory overall but ~8 seems about right.
@angeloskath Anything I can do about it? Checkpointing would help?
@angeloskath How faster recurrent sampling should be for 1k steps? I got around 20% gain.
Checkpointing would not work out of the box right now (see #21) but when it gets fixed you could perhaps try it.
Recurrent sampling is significantly faster for inference. To be honest I have not used it to train with teacher forcing. You get 20% gain wrt to what? Running the batch model 1k times to produce 1k outputs? Otherwise, how would you do teacher forcing without using the recurrent model?
@angeloskath I just tested recurrent sampling for inteference, one sample per batch.
This version executed in 4 secs
with torch.no_grad():
max_len = 1000
mels = []
frame = torch.zeros((1, hparams.n_mel_channels), device=memory.device)
state = None
for i in range(max_len):
frame = model.prenet(frame)
frame = model.pos_decoder(frame, i)
frame, state = model.transformer_decoder_rnn(frame, memory, memory_length_mask=encoder_len_mask, state=state)
frame = model.proj(frame)
mels.append(frame)
mels = torch.stack(mels, axis=1)
While this one in 5 secs
with torch.no_grad():
max_len = 1000
trg_tensor = torch.zeros(1, 1, hparams.n_mel_channels).to(device)
for i in range(max_len):
decoder_mask = TriangularCausalMask(trg_tensor.size(1), device=device)
decoder_len_mask = LengthMask(trg_tensor.new_full((trg_tensor.shape[0],), trg_tensor.shape[1], dtype=torch.int64))
output = model.proj(model.transformer_decoder(model.pos_decoder(model.prenet(trg_tensor)), memory, x_mask=decoder_mask, x_length_mask=decoder_len_mask, memory_length_mask=encoder_len_mask))
out_token = output[:,-1:]
trg_tensor = torch.cat([trg_tensor, out_token], axis=1)
For the train I tryig to do it like this. I have teacher forcing outputs and self run outputs. Zero pad by sample lens. And then apply discriminator model on it. This, in theory, should improve stability (and maybe other qualities) of the model on interference, but with batch size 4 I can't go anywhere. 😢 (max len is maximum length of samples in batch or x.size(1))
def decoder_self_run(self, memory, memory_length_mask, max_len):
mels = []
frame = torch.zeros(
(memory.size(0), self.n_mel_channels), device=memory.device)
state = None
for i in range(max_len):
frame = self.prenet(frame)
frame = self.pos_decoder(frame, i)
frame, state = self.transformer_decoder_rnn(
frame, memory, memory_length_mask=memory_length_mask, state=state)
frame = self.proj(frame)
mels.append(frame)
return torch.stack(mels, axis=1)
Also, my epoch time goes from 3mins to 6 hours.
Hi again,
My suggestion to you would be the following. Instead of performing the forward backward pass with the recurrent model, I suggest to generate an output from the recurrent model and then use it to train with the batch model. This way you do an extra forward pass but the memory is going to be significantly reduced and the forward pass can be properly parallelized.
Your epoch time would be simply the added inference time to generate the batch to train on. Moreover since it can be generated directly on the GPU it should be pipelined so it should be properly optimized.
Cheers, Angelos
@angeloskath Sorry, i dont understand what you mean.
I want to train the model in the mod as it work on interference. It should produce similar to the target tensor only with encoder memory.
Like here https://arxiv.org/pdf/1610.09038.pdf https://arxiv.org/pdf/1904.04775.pdf
I'm training the parallel teacher forcing on the target tensor and recurrent self-run tensor and compare it with teacher forcing outputs. Like this.
teacher_forcing_out, recurrent_out = model(inputs)
loss_mse = mse(teacher_forcing_out, y)
gan_loss = gan(teacher_forcing_out, recurrent_out) # instead of real and fake as in default gan i compare parallel and recurrent
bla bla bla
Do you mean doing it like this?
teacher_forcing_out = model(inputs)
loss_mse = mse(teacher_forcing_out, y)
loss.backward()
opt.step()
recurrent_out = model(inputs)
gan_loss = gan(teacher_forcing_out, recurrent_out)
bla bla bla
I still have to run backward on self-run outputs, and i think i need gradients on teacher_forcing_out tensor for the training model. Okay, maybe I dont really need gradients on teacher_forcing_out. But I dont understand how i can train self run mod without having backward on self run outputs?
Ah, while I was writing this, I probably understand what did you mean. I should generate recurrent outputs with no grad scope and feed it as targets in teacher forcing mod. And use gan to compare tensors of teacher forcing on target and teacher forcing on model self recurrent outputs, right?
😁 yep exactly. This way you will do one extra recurrent forward but it should be your best option.
@angeloskath Thanks, I will try, lets see how it develops.
@hadaev8 Shall I assume that we found the best way to train on the networks' own outputs and consider the issue closed? Did you experience any further problems with it?
@hadaev8 No, not really. One epoch takes 63 minutes (vs around 4 for one parallel pass). Gan discriminator should be super fast. So I guess its recurrence calculating. But I did not check exactly.
Unfortunately, seems like this approach doesn't work in general. I used trained model weights (so it should give something meaningful on self run pass) and new optimizer + gan discriminator. This is the difference between teacher forcing and self-run. And encoder-decoder attention failed for some reason. I gonna try some other train setups before giving up with this.
Also thing for consideration, here they talk about a new checkpointing approach for training recurrent transformers. I'm too weak to implement, but maybe you will find it interesting enough.😁 https://openreview.net/forum?id=_adSMszz_g9
In general, lstm tts model (tacotron2) outperforms transformer one. It is not robust on interference and fails quickly. Seems like the decoder just overfits on teacher forcing inputs. Ofc a lot of hyperparameters and tricks unchecked. So I guess, I will make it work once.
I am experimenting a bit with this scenario and indeed the recurrent forward pass significantly under utilizes the GPU and is slow. For instance the recurrent forward pass with a batch size of 32 and a batch size of 8 takes almost exactly the same time in the GPU.
So, one solution would be to generate sequences every n batches. They will be a bit stale (the samples) but your epoch time should go down to approximately 10 minutes if you generate every 30-50 batches.
Let me know what you think.
@angeloskath Honestly, not sure if one batch per 50 would affect training. Intuitively, it seems like just noise. Once per 50 batches net will get "strange inputs" instead of the normal ones. Produce bad output, get a penalty from discriminator, and will continue to live a normal life for the next 50 batches. Its sounds like lookahead approach.
Guess, try never hurt.
But I think checkpointing makes more sense because model would know how this "strange" output produced and how to fix it. I tried vanilla pytorch checkpointing on linear attention encdoer for image and it seems to work as intended (x2 batch size and x2 step time). I only need to make a wrapper for masks.
I will run expiriment with this idea eg reshape my target (batch, len, feature) to (batch, len/factor, feature*factor). Without linear projection sample dim to transformer dim I should get x2 more memory (x2 batch size, faster eposh). Also, less sample len should help with future recurrent expirement.
With the above tricks, I was able to achieve 80mins 16 batch size epoch. Far from desirable, but better than the previous attempt.
@angeloskath should you advise why i dont have much speadup with recurrent sampling?
I assume that the sequence length is not large enough or the batch size is small or both. You can certainly micro benchmark this outside of your main pipeline to check what is best in your particular situation. The following code benchmarks per sample forward pass time.
import torch
from fast_transformers.builders import TransformerEncoderBuilder, \
RecurrentEncoderBuilder
from fast_transformers.masking import TriangularCausalMask
@torch.no_grad()
def measure_batch(transformer, batch_size, sequence_length):
x = torch.randn(batch_size, sequence_length, 8*64).cuda()
m = TriangularCausalMask(sequence_length, device="cuda")
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
transformer(x, attn_mask=m)
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / batch_size
@torch.no_grad()
def measure_recurrent(transformer, batch_size, sequence_length):
x = torch.randn(batch_size, 8*64).cuda()
start = torch.cuda.Event(enable_timing=True)
end = torch.cuda.Event(enable_timing=True)
start.record()
state = None
for i in range(sequence_length):
x, state = transformer(x, state)
end.record()
torch.cuda.synchronize()
return start.elapsed_time(end) / batch_size
t1 = TransformerEncoderBuilder.from_kwargs(
n_layers=4,
n_heads=8,
query_dimensions=64,
attention_type="causal-linear"
).get()
t2 = TransformerEncoderBuilder.from_kwargs(
n_layers=4,
n_heads=8,
query_dimensions=64,
attention_type="full"
).get()
t3 = RecurrentEncoderBuilder.from_kwargs(
n_layers=4,
n_heads=8,
query_dimensions=64,
attention_type="causal-linear"
).get()
t4 = RecurrentEncoderBuilder.from_kwargs(
n_layers=4,
n_heads=8,
query_dimensions=64,
attention_type="full"
).get()
t1.cuda()
t2.cuda()
t3.cuda()
t4.cuda()
print("Batch Causal-Linear", measure_batch(t1, 8, 2000))
print("Batch Full", measure_batch(t2, 8, 2000))
print("Recurrent Causal-Linear", measure_recurrent(t3, 128, 2000))
print("Recurrent Full", measure_recurrent(t4, 128, 2000))
On my RTX 2060 super I get the following outcome:
Batch Causal-Linear 19.364452362060547
Batch Full 30.340200424194336
Recurrent Causal-Linear 35.718772888183594
Recurrent Full 374.2028503417969
You could probably get even bigger benefits if you increase the batch size more but the bottom line is for a batch size large enough to get the GPU to actually do some work then causal linear is >10x faster than full softmax any day of the week. Regarding the performance for the batch forward, it is indeed even better, as expected. There is no way that you will ever be able to reach the performance of the batch forward version, however if you wanted to do recurrent generation using the batch version then you would need to call it sequence_length times which in this case is 2000! So the recurrent linear is around 1000x faster than the batch linear and >10x faster than the recurrent full.
Btw, if the configuration is similar to the above, then you can generate 128 samples to use for the next 16 batches. at the cost of less than a forward/backward batch pass. So the generated samples won't be that stale and are probably going to be useful for training.
Cheers, Angelos
@angeloskath Seems like linear attention use much less memory even in the recurrent mod. Seems like checkpointing doesn't improve memory usage. Seems like, for full utilization of linear recurrence advantage, I need a batch size like 512+. Which seems to be impossible to achieve.
I will run training with linear decoder and report results, guess it should handle 32 batch size. This is my notebook with a close to real model setup. Kind of unexpected on the train I can fit at a maximum 52 batch size, while on inference 4096 is ok. https://colab.research.google.com/drive/1K7XaUTADfaq4XGeyx7tNffPExuJdHfXr?usp=sharing
@angeloskath Should you advise why I getting oom error on the same batch size in the test notebook. With and without checkpointing. But with my real model checkpointing work as intended (i can train with x2 batch size)? Linear attention in decoder unlocks 16 batch size and x4 faster train (compared to 6 batch size). Still, model produce garbage on self run.
@angeloskath Why full attention need x4 memory compared to linear?
Hi @hadaev8,
It is expected as full attention memory requirements scale O(N^2) wrt the sequence length and linear scales O(N). So 4x less memory is give a specific input, for a larger input the difference should be larger.
As an aside, I will close this issue as it is not about recurrent sampling any more. Feel free to open another issue if you want.
Cheers, Angelos
@angeloskath Should be any difference in memory between recurrent and parallel pass? In my tests parallel pass linear and full attention capped at the same batch size (for seq lens around 500).
In general, I want to have teacher forcing pas and self-generated (free-running generative) pass aka professor forcing.
For now, looks like I need to merge FullAttention RecurrentFullAttention RecurrentCrossFullAttention into one class. And use it with flags like recurrent = true And the same for layers and encoder/decoder class. Seems inconvenient. Am I right? Or here is a better way?