Closed visionscaper closed 4 years ago
I made a copy of my test notebook, made on Google Colab, publicly available so you can run it yourself: chatbot-training-test-pytorch-xla-10012020.ipynb. To run it, make a copy of it first.
The code of interest is in the forward
method of the EncoderRNN
.
Hey @visionscaper: thanks for reporting this issue! Let me take a look next week when I'm in the office?
Hi @mruberry, yeah sure! Looking forward to hear your feedback.
For your reference I also added the CUDA (GPU) version of the Colab notebook : chatbot-training-test-pytorch-gpu-10012020.ipynb
Sorry for not getting back to you sooner, @visionscaper, I got sick over the weekend :(
There's definitely an issue here. I'm seeing some different behavior, however:
pack_padded_sequence
throws an error: size mismatch, m1: [4000 x 2000], m2: [1000 x 1000]pack_padded_sequence
I start with comparable loss (10.109) and see no improvement on the second epoch. Training on the TPU is also MUCH slower than training with CUDA.Does that fit your experience? Also, can you advise on the second issue?
I'll have to take a closer look at your network to understand the performance issue.
Hi @mruberry,
No worries, I hope you feel better now. Thanks for looking in to the issue!
Yes, your observations match my experience. If you leave the TPU training run for a bit longer you will even notice that the loss is even increasing.
To use pack_padded_sequence
in the GPU notebook, (un)comment the code in the forward
method of the EncoderRNN
as follows:
# Pack padded batch of sequences for RNN module
packed = nn.utils.rnn.pack_padded_sequence(embedded, input_lengths, enforce_sorted=self.enforce_sorted)
# packed = nn.utils.rnn.pack_padded_sequence(embedded.cpu(), input_lengths.cpu(), enforce_sorted=self.enforce_sorted)
# # packed = packed.to(input_seq.device)
# packed = packed.cuda()
# Forward pass through GRU
self.gru.flatten_parameters()
outputs, hidden = self.gru(packed, hidden)
# outputs, hidden = self.gru(embedded, hidden)
# Unpack padding
outputs, _ = nn.utils.rnn.pad_packed_sequence(outputs, total_length=total_length)
# Sum bidirectional GRU outputs
outputs = outputs[:, :, :self.hidden_size] + outputs[:, :, self.hidden_size:]
Looking forward to new updates!
Thanks @visionscaper! Trying now.
Update: Cool. It appears to be working as expected.
I don't think you need the self.gru.flatten_parameters(), either?
It appears there is a bug when using pack_padded_sequence with PyTorch/XLA. Passing in a PyTorch/XLA tensor as the input causes a failure in the constructor that it shouldn't. I have confirmed your finding that on CUDA this works.
I'm going to try and replicate and fix this locally. I think this regressed, but I'm not sure what would have caused it do so. Let's continue to use this issue to track the problem.
There are additional performance issues with the network on TPU. I'm thinking we should produce a series of Colab notebooks, like the current samples https://github.com/pytorch/xla/tree/master/contrib/colab, demonstrating how to tweak networks to maximize performance on TPUs.
Hi @mruberry,
Thanks for looking in to this. I think there are three issues:
1) a bug using pack_padded_sequence
with PyTorch/XLA
, as verified by you
2) performance issue, in terms of computation duration per batch
3) performance issue, in terms of loss not decreasing, even increasing.
Agreed?
As mentioned earlier, I did use contrib/colab/resnet18-training-xrt-1-15.ipynb as a template, further I optimised the training as follows:
What else could I do?
One thing to note, though, is that I use the masked_select
method here to calculate the loss, also see the following line in the Colab notebooks:
loss = masked_loss(per_sample_loss, output_mask)
This is in the _evaluate_loss
method of the ChatbotTrainer
class.
In issue #1509, @dlibenzi mentions here that this method is in experimental mode. Would it help performance issue 3) (and maybe 2)) if we set the proposed environment variable?
export XLA_EXPERIMENTAL="nonzero:masked_select"
Last, keep in mind that I'm using two optimisers, one for the encoder and one for the decoder. What kind of effect does that have?
I have no cycles to follow the pad/unpad thing ATM.
WRT masked_select, what are you going to do with the masked loss? A sum? If you are going to do a reduce operation you should be fine setting the environment variable. Can you post here the metrics that you get, to have an idea of the hits we get into our hooks?
Hi @dlibenzi,
Thanks for replying. No worries, I really would like this to work, but I'm not in a big hurry.
loss
, the output of the masked_loss
method is used to backprop from; in a N
-worker (N
-core) situation this should be reduced (summed) in to a total loss over N
sub-batches.
Below I show the metrics (loss and batch time) for three situations:
1) set os.environ['XLA_EXPERIMENTAL'] = "nonzero:masked_select"
2) not set os.environ['XLA_EXPERIMENTAL'] = "nonzero:masked_select"
3) results on GPU
Some observations about the results below:
XLA_EXPERIMENTAL
is set,
Case 1
When I set the XLA_EXPERIMENTAL
environment variable, like so ...
os.environ['XLA_EXPERIMENTAL'] = "nonzero:masked_select"
... in a single TPU core configuration (so no reduction as in a multi-core situation), I get:
INFO: train-reference-chatbot::setup_datasets : Setup data sets (rank = 0)...
INFO: train-reference-chatbot::setup_datasets : Loading training set ...
INFO: train-reference-chatbot::load_sentence_pair_data :
Loading sentence pair file : /tmp/training-1578511965-cmdc-sentence-pairs-with-voc-max-len-40-min-word-occurance-3.pickle
DEBUG: train-reference-chatbot::setup_datasets : Number of sentence pairs in training set: 125689
INFO: model_data_generation.py::create_sentence_pairs_collate_fn : Using fixed sequence lengths of 40 tokens.
INFO: train-reference-chatbot::build_model : Building model ...
INFO: train-reference-chatbot::train_chatbot : Transfering model to xla:1 ...
INFO: train-reference-chatbot::train_chatbot : Building optimizers ...
INFO: train-reference-chatbot::train_chatbot : Prepare training ...
INFO: train-reference-chatbot::train_chatbot : Adding progress logger ...
DEBUG: TrainingManager::_assess_num_batches_per_epoch : Number of batches per epoch : 1256
INFO: train-reference-chatbot::train_chatbot : Start training ...
WARNING: utils.py::get_value_at : Key path duration.mean.batch not found in given data
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer/pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:1334: UserWarning: This overload of add_ is deprecated:
add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
add_(Tensor other, Number alpha)
/pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:1550: UserWarning: This overload of addcmul_ is deprecated:
addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
addcmul_(Tensor tensor1, Tensor tensor2, Number value)
/pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:1480: UserWarning: This overload of addcdiv_ is deprecated:
addcdiv_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
addcdiv_(Tensor tensor1, Tensor tensor2, Number value)
<decoder_optimizer!
Epoch 0/49 - ETA: [UNKNOWN] Batch 0/1255 Average batch training time [UNKNOWN]
Batch:
training: loss 2.953.
Moving average:
training: loss 2.953.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 6 days, 20:32:27 Batch 1/1255 Average batch training time 471989ms
Batch:
training: loss 3.155.
Moving average:
training: loss 3.054.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 4 days, 6:45:19 Batch 2/1255 Average batch training time 294991ms
Batch:
training: loss 2.853.
Moving average:
training: loss 2.987.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 2 days, 22:29:57 Batch 3/1255 Average batch training time 202551ms
Batch:
training: loss 2.838.
Moving average:
training: loss 2.950.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 2 days, 6:35:50 Batch 4/1255 Average batch training time 156988ms
Batch:
training: loss 3.003.
Moving average:
training: loss 2.960.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 1 day, 20:48:41 Batch 5/1255 Average batch training time 128953ms
Batch:
training: loss 3.127.
Moving average:
training: loss 2.988.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 1 day, 14:27:58 Batch 6/1255 Average batch training time 110782ms
Batch:
training: loss 2.922.
Moving average:
training: loss 2.979.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 1 day, 9:52:35 Batch 7/1255 Average batch training time 97642ms
Batch:
training: loss 2.972.
Moving average:
training: loss 2.978.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 1 day, 6:28:10 Batch 8/1255 Average batch training time 87892ms
Batch:
training: loss 2.976.
Moving average:
training: loss 2.978.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 1 day, 3:49:52 Batch 9/1255 Average batch training time 80346ms
Batch:
training: loss 3.057.
Moving average:
training: loss 2.986.
Case 2
When I comment out the setting the XLA_EXPERIMENTAL
environment variable ...
# os.environ['XLA_EXPERIMENTAL'] = "nonzero:masked_select"
... we get :
INFO: train-reference-chatbot::setup_datasets : Setup data sets (rank = 0)...
INFO: train-reference-chatbot::setup_datasets : Loading training set ...
INFO: train-reference-chatbot::load_sentence_pair_data :
Loading sentence pair file : /tmp/training-1578511965-cmdc-sentence-pairs-with-voc-max-len-40-min-word-occurance-3.pickle
DEBUG: train-reference-chatbot::setup_datasets : Number of sentence pairs in training set: 125689
INFO: model_data_generation.py::create_sentence_pairs_collate_fn : Using fixed sequence lengths of 40 tokens.
INFO: train-reference-chatbot::build_model : Building model ...
INFO: train-reference-chatbot::train_chatbot : Transfering model to xla:1 ...
INFO: train-reference-chatbot::train_chatbot : Building optimizers ...
INFO: train-reference-chatbot::train_chatbot : Prepare training ...
INFO: train-reference-chatbot::train_chatbot : Adding progress logger ...
DEBUG: TrainingManager::_assess_num_batches_per_epoch : Number of batches per epoch : 1256
INFO: train-reference-chatbot::train_chatbot : Start training ...
WARNING: utils.py::get_value_at : Key path duration.mean.batch not found in given data
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer/pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:1334: UserWarning: This overload of add_ is deprecated:
add_(Number alpha, Tensor other)
Consider using one of the following signatures instead:
add_(Tensor other, Number alpha)
/pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:1550: UserWarning: This overload of addcmul_ is deprecated:
addcmul_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
addcmul_(Tensor tensor1, Tensor tensor2, Number value)
/pytorch/torch/csrc/autograd/generated/python_variable_methods.cpp:1480: UserWarning: This overload of addcdiv_ is deprecated:
addcdiv_(Number value, Tensor tensor1, Tensor tensor2)
Consider using one of the following signatures instead:
addcdiv_(Tensor tensor1, Tensor tensor2, Number value)
<decoder_optimizer!
Epoch 0/49 - ETA: [UNKNOWN] Batch 0/1255 Average batch training time [UNKNOWN]
Batch:
training: loss 10.105.
Moving average:
training: loss 10.105.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 6 days, 20:24:04 Batch 1/1255 Average batch training time 471588ms
Batch:
training: loss 10.101.
Moving average:
training: loss 10.103.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 4 days, 6:55:15 Batch 2/1255 Average batch training time 295466ms
Batch:
training: loss 10.090.
Moving average:
training: loss 10.099.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 3 days, 10:40:47 Batch 3/1255 Average batch training time 237547ms
Batch:
training: loss 10.086.
Moving average:
training: loss 10.095.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 3 days, 0:19:53 Batch 4/1255 Average batch training time 207981ms
Batch:
training: loss 10.102.
Moving average:
training: loss 10.097.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 2 days, 18:06:27 Batch 5/1255 Average batch training time 190237ms
Batch:
training: loss 10.112.
Moving average:
training: loss 10.099.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 2 days, 14:08:32 Batch 6/1255 Average batch training time 178969ms
Batch:
training: loss 10.199.
Moving average:
training: loss 10.113.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 2 days, 11:16:47 Batch 7/1255 Average batch training time 170862ms
Batch:
training: loss 10.459.
Moving average:
training: loss 10.157.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 2 days, 9:02:35 Batch 8/1255 Average batch training time 164547ms
Batch:
training: loss 10.923.
Moving average:
training: loss 10.242.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 2 days, 7:24:31 Batch 9/1255 Average batch training time 159960ms
Batch:
training: loss 11.444.
Moving average:
training: loss 10.362.
Case 3 Using the GPU Colab note book we get:
INFO: train-reference-chatbot::setup_datasets : Setup data sets ...
INFO: train-reference-chatbot::setup_datasets : Loading training set ...
INFO: train-reference-chatbot::load_sentence_pair_data :
Loading sentence pair file : /tmp/training-1578511965-cmdc-sentence-pairs-with-voc-max-len-40-min-word-occurance-3.pickle
DEBUG: train-reference-chatbot::setup_datasets : Number of sentence pairs in training set: 125689
INFO: model_data_generation.py::create_sentence_pairs_collate_fn : Using fixed sequence lengths of 40 tokens.
INFO: train-reference-chatbot::build_model : Building model ...
INFO: train-reference-chatbot::train_chatbot : Transfering model to cuda ...
INFO: train-reference-chatbot::train_chatbot : Building optimizers ...
INFO: train-reference-chatbot::train_chatbot : Prepare training ...
INFO: train-reference-chatbot::train_chatbot : Adding progress logger ...
DEBUG: TrainingManager::_assess_num_batches_per_epoch : Number of batches per epoch : 1256
INFO: train-reference-chatbot::train_chatbot : Start training ...
WARNING: utils.py::get_value_at : Key path duration.mean.batch not found in given data
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: [UNKNOWN] Batch 0/1255 Average batch training time [UNKNOWN]
Batch:
training: loss 10.099.
Moving average:
training: loss 10.099.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 0:23:35 Batch 1/1255 Average batch training time 1127ms
Batch:
training: loss 10.064.
Moving average:
training: loss 10.082.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 0:20:47 Batch 2/1255 Average batch training time 994ms
Batch:
training: loss 10.015.
Moving average:
training: loss 10.059.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 0:19:54 Batch 3/1255 Average batch training time 953ms
Batch:
training: loss 9.942.
Moving average:
training: loss 10.030.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 0:19:23 Batch 4/1255 Average batch training time 928ms
Batch:
training: loss 9.798.
Moving average:
training: loss 9.984.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 0:19:04 Batch 5/1255 Average batch training time 914ms
Batch:
training: loss 9.584.
Moving average:
training: loss 9.917.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 0:18:55 Batch 6/1255 Average batch training time 907ms
Batch:
training: loss 9.442.
Moving average:
training: loss 9.849.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 0:18:43 Batch 7/1255 Average batch training time 899ms
Batch:
training: loss 9.240.
Moving average:
training: loss 9.773.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 0:18:34 Batch 8/1255 Average batch training time 892ms
Batch:
training: loss 8.961.
Moving average:
training: loss 9.683.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!<!<encoder_optimizer<decoder_optimizer!
Epoch 0/49 - ETA: 0:18:30 Batch 9/1255 Average batch training time 890ms
Batch:
training: loss 8.660.
Moving average:
training: loss 9.580.
So the wrong loss could be that the masked_select() result size() is the size of the padded tensor (artifact of the way XLA implements dynamic shapes).
Are the "masked selected" loss components positive?
If yes, can you make loss being a sum()
of the selected components?
If not positive, square+sum.
Can you print the metrics for 4..5 steps?
import torch_xla.debug.metrics as met
print(met.metrics_report())
Or, I think it might be possible to use debug_run.py
on Colab?
Hi @dlibenzi,
I will get back to your questions later this week. But because masked_select
might be an issue, I created a version of masked_loss
that does not use masked_select
, like so:
def masked_loss_tpu(per_sample_loss, mask, average_loss=True):
loss = per_sample_loss*mask
if average_loss:
num_samples = mask.sum()
return loss.sum()/num_samples
else:
return loss.sum()
You can find the code in the shared Colab Notebook.
When applying this, in the _evaluate_loss
method, as follows:
# loss = masked_loss(per_sample_loss, output_mask)
loss = masked_loss_tpu(per_sample_loss, output_mask)
... I observe the following:
I also applied masked_loss_tpu
in the CUDA version of the notebook; there are no issues at all, the loss decreases fast.
So my conclusions are now the following:
masked_select
, in masked_loss_tpu
there are no more variable sized tensors.pack_padded_sequence
bug, is that the loss doesn't decrease in the TPU version of the notebook.@dlibenzi, @mruberry, do you agree with these conclusions? What else could cause the loss not to decrease?
Could it be that, because the output_mask
is a BoolTensor
, an issue arises on the TPU?
I aim to provide logs of the met.metrics_report
and debug_run.py
by the end of the week.
Thanks for your time,
-- Freddy
EDIT : PS :
Some answers to your questions, @dlibenzi, that I can quickly answer:
Are the "masked selected" loss components positive?
Yes, they are. They are the negative log of softmax outputs.
If yes, can you make loss being a sum() of the selected components?
It was always a sum of the selected components, an average to be exact.
The issues with the average and the loss being smaller, is the fact that the average ends up fetching the padded size of the tensor, not the selected one. So the loss is smaller. In theory the version using masked_select should be as fast as the one with your manual loss.
Seeing the metrics would really help debug this.
A Bool mask tensor should be no problem on TPU.
Hi @dlibenzi,
The issues with the average and the loss being smaller, is the fact that the average ends up fetching the padded size of the tensor, not the selected one.
Got it
In theory the version using masked_select should be as fast as the one with your manual loss.
What I noticed is that when using torch_xla==nightly
it is not, but with xrt==1.15.0
it is as fast.
I will get the debug metrics.
Should be the contrary. There was no masked_select
lowering in 1.15. It'd go to pytorch/CPU.
But, it could look slower as at that point the initial 2..3 step might be recompiling a fully fused graph.
Metrics would shed light on this.
@visionscaper Update: we have confirmed and reproduced the padded sequence issue and are working on a fix. Thanks again for reporting this when you saw it!
@mruberry This great to hear, looking forward to the fix!
@dlibenzi I'm still a bit busy, so it will take longer before I can provide you the debugging information you requested. Still really want to solve the issue of the flat/increasing loss, so keep posted :)
@visionscraper
Sorry for the lengthy silence, while we understand some "fixes" we're actually trying to understand which fix to implement. One question you may be able to help us with: can you pack as part of your network's input pipeline that runs on CPU?
Hi @mruberry,
I'm currently very busy myself, so at the moment I can't do this experiment. But when I do have time again, I will, and will also provide @dlibenzi the debugging information he requested.
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
Hey guys, any update on this?
Hi @SeanNaren, later this week I plan to look in to the requests for information, of @mruberry and @dlibenzi, to resolve the issue(s).
Hi @dlibenzi,
Here is the debug metrics report you asked for. Some notes:
The version of the script that I ran with the debug metrics: https://colab.research.google.com/drive/1EnckRFcwlgazYaniKj5Cb-kYVosIvR4h
I set os.environ['XLA_EXPERIMENTAL'] = "nonzero:masked_select"
torch_xla==nightly
was used (torch-xla-0.8+742abdf
)
The use of pack_padded_sequence
was disabled (commented out)
The original masked_loss
method was used that makes use of masked_select
I set num_cores = 8
<encoder_optimizer<decoder_optimizer!
EEEEEEEE>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!>!>>><<>>>>>>>>>>>>>>>>>>>>!><>>>!>!><>><!!<<!<!!!!<encoder_optimizer<decoder_optimizer!
<encoder_optimizer<encoder_optimizer<encoder_optimizer<decoder_optimizer<decoder_optimizer<decoder_optimizer!!
!
!
!!<encoder_optimizer<decoder_optimizer!
!Epoch 0/49 - ETA: 0:15:02 Batch 4/156 Average batch training time 5897ms
Batch:
<encoder_optimizer<encoder_optimizertraining: loss 10.115.
<decoder_optimizer
Moving average:
<decoder_optimizer!
training: loss 10.110.
!
INFO:MetricDebugCallback:XLA METRIC DEBUG REPORT AFTER ITER 4 :
Metric: CompileTime
TotalSamples: 4
Accumulator: 01m24s786ms232.820us
ValueRate: 850ms866.619us / second
Rate: 0.0405731 / second
Percentiles: 1%=182ms438.501us; 5%=182ms438.501us; 10%=182ms438.501us; 20%=182ms438.501us; 50%=37s996ms562.953us; 80%=37s195ms497.482us; 90%=37s195ms497.482us; 95%=37s195ms497.482us; 99%=37s195ms497.482us
Metric: DeviceLockWait
TotalSamples: 10
Accumulator: 057.769us
ValueRate: 000.525us / second
Rate: 0.0908913 / second
Percentiles: 1%=003.851us; 5%=003.851us; 10%=003.881us; 20%=003.965us; 50%=005.213us; 80%=006.536us; 90%=013.439us; 95%=013.439us; 99%=013.439us
Metric: ExecuteTime
TotalSamples: 10
Accumulator: 02s192ms378.585us
ValueRate: 020ms963.435us / second
Rate: 0.0910583 / second
Percentiles: 1%=007ms454.432us; 5%=007ms454.432us; 10%=063ms783.770us; 20%=063ms990.141us; 50%=095ms253.534us; 80%=489ms412.396us; 90%=507ms302.893us; 95%=507ms302.893us; 99%=507ms302.893us
Metric: InboundData
TotalSamples: 5
Accumulator: 20.00B
ValueRate: 0.21B / second
Rate: 0.0520066 / second
Percentiles: 1%=4.00B; 5%=4.00B; 10%=4.00B; 20%=4.00B; 50%=4.00B; 80%=4.00B; 90%=4.00B; 95%=4.00B; 99%=4.00B
Metric: InputOutputAliasCount
TotalSamples: 3
Accumulator: 283.00
ValueRate: 4.62 / second
Rate: 0.0490036 / second
Percentiles: 1%=31.00; 5%=31.00; 10%=31.00; 20%=31.00; 50%=126.00; 80%=126.00; 90%=126.00; 95%=126.00; 99%=126.00
Metric: IrValueTensorToXlaData
TotalSamples: 41
Accumulator: 03s927ms218.080us
ValueRate: 027ms912.683us / second
Rate: 0.376952 / second
Percentiles: 1%=001ms178.381us; 5%=001ms437.651us; 10%=001ms482.974us; 20%=002ms715.400us; 50%=010ms790.589us; 80%=047ms999.016us; 90%=256ms886.358us; 95%=338ms546.556us; 99%=693ms898.982us
Metric: OutboundData
TotalSamples: 68
Accumulator: 359.21MB
ValueRate: 3.20MB / second
Rate: 0.60572 / second
Percentiles: 1%=4.00B; 5%=4.00B; 10%=4.00B; 20%=4.00B; 50%=11.72KB; 80%=11.44MB; 90%=11.44MB; 95%=22.89MB; 99%=93.22MB
Metric: ReleaseDataHandlesTime
TotalSamples: 29
Accumulator: 02s597ms026.593us
ValueRate: 015ms540.917us / second
Rate: 0.264045 / second
Percentiles: 1%=919.706us; 5%=001ms080.230us; 10%=001ms084.947us; 20%=001ms150.036us; 50%=003ms379.675us; 80%=104ms359.718us; 90%=260ms388.905us; 95%=318ms914.630us; 99%=352ms802.249us
Metric: TensorsGraphSize
TotalSamples: 10
Accumulator: 199958.00
ValueRate: 1821.68 / second
Rate: 0.0911033 / second
Percentiles: 1%=157.00; 5%=157.00; 10%=10923.00; 20%=10923.00; 50%=10923.00; 80%=36327.00; 90%=36327.00; 95%=36327.00; 99%=36327.00
Metric: TransferFromServerTime
TotalSamples: 5
Accumulator: 016ms569.720us
ValueRate: 161.946us / second
Rate: 0.0520066 / second
Percentiles: 1%=002ms785.641us; 5%=002ms785.641us; 10%=002ms785.641us; 20%=002ms931.456us; 50%=002ms045.930us; 80%=005ms052.955us; 90%=005ms052.955us; 95%=005ms052.955us; 99%=005ms052.955us
Metric: TransferToServerTime
TotalSamples: 68
Accumulator: 03s180ms267.558us
ValueRate: 028ms470.838us / second
Rate: 0.608759 / second
Percentiles: 1%=001ms167.563us; 5%=001ms430.797us; 10%=001ms456.100us; 20%=002ms714.290us; 50%=004ms490.703us; 80%=045ms532.906us; 90%=072ms201.776us; 95%=326ms679.986us; 99%=693ms831.711us
Metric: TransferToServerTransformTime
TotalSamples: 68
Accumulator: 384ms454.563us
ValueRate: 003ms424.585us / second
Rate: 0.60572 / second
Percentiles: 1%=049.410us; 5%=054.285us; 10%=062.781us; 20%=075.903us; 50%=303.726us; 80%=005ms067.250us; 90%=010ms602.037us; 95%=036ms033.393us; 99%=131ms702.529us
Counter: CachedCompile
Value: 6
Counter: CreateCompileHandles
Value: 4
Counter: CreateDataHandles
Value: 653
Counter: CreateXlaTensor
Value: 90904
Counter: DestroyDataHandles
Value: 501
Counter: DestroyXlaTensor
Value: 90746
Counter: MarkStep
Value: 5
Counter: ReleaseDataHandles
Value: 501
Counter: UncachedCompile
Value: 4
Counter: XRTAllocateFromTensor_Empty
Value: 45
Counter: XrtCompile_Empty
Value: 112
Counter: XrtExecuteChained_Empty
Value: 112
Counter: XrtExecute_Empty
Value: 112
Counter: XrtRead_Empty
Value: 112
Counter: XrtReleaseAllocationHandle_Empty
Value: 112
Counter: XrtReleaseCompileHandle_Empty
Value: 112
Counter: XrtSessionCount
Value: 9
Counter: XrtSubTuple_Empty
Value: 112
Counter: aten::_local_scalar_dense
Value: 5
Counter: xla::_softmax
Value: 400
Counter: xla::_softmax_backward_data
Value: 400
Counter: xla::_unsafe_view
Value: 200
Counter: xla::add
Value: 11825
Counter: xla::add_
Value: 4244
Counter: xla::addcdiv_
Value: 160
Counter: xla::addcmul_
Value: 160
Counter: xla::addmm
Value: 2800
Counter: xla::as_strided
Value: 441
Counter: xla::bernoulli_
Value: 405
Counter: xla::bmm
Value: 600
Counter: xla::cat
Value: 2610
Counter: xla::clone
Value: 2600
Counter: xla::copy_
Value: 671
Counter: xla::div
Value: 365
Counter: xla::div_
Value: 405
Counter: xla::embedding
Value: 205
Counter: xla::embedding_dense_backward
Value: 205
Counter: xla::empty
Value: 1155
Counter: xla::empty_strided
Value: 241
Counter: xla::expand
Value: 210
Counter: xla::fill_
Value: 5
Counter: xla::gather
Value: 200
Counter: xla::index_select
Value: 205
Counter: xla::log
Value: 200
Counter: xla::masked_scatter_
Value: 5
Counter: xla::masked_select
Value: 5
Counter: xla::mean
Value: 5
Counter: xla::mm
Value: 6180
Counter: xla::mul
Value: 7410
Counter: xla::mul_
Value: 2720
Counter: xla::neg
Value: 1600
Counter: xla::scatter_add_
Value: 200
Counter: xla::select
Value: 600
Counter: xla::sigmoid_
Value: 2400
Counter: xla::sigmoid_backward
Value: 2400
Counter: xla::slice
Value: 690
Counter: xla::split
Value: 2400
Counter: xla::sqrt
Value: 160
Counter: xla::squeeze
Value: 1205
Counter: xla::stack
Value: 1235
Counter: xla::sub
Value: 1200
Counter: xla::sum
Value: 3400
Counter: xla::t
Value: 12380
Counter: xla::tanh
Value: 200
Counter: xla::tanh_
Value: 1200
Counter: xla::tanh_backward
Value: 1400
Counter: xla::transpose
Value: 800
Counter: xla::unbind
Value: 1235
Counter: xla::unsqueeze
Value: 1600
Counter: xla::view
Value: 4810
Counter: xla::zero_
Value: 632
Metric: XrtAllocateFromTensor
TotalSamples: 901
Accumulator: 03s909ms989.031us
Mean: 003ms228.623us
StdDev: 009ms357.802us
Rate: 8.05805 / second
Percentiles: 25%=399.367us; 50%=001ms420.577us; 80%=004ms729.022us; 90%=005ms951.230us; 95%=006ms745.987us; 99%=072ms471.285us
Metric: XrtCompile
TotalSamples: 32
Accumulator: 09m16s544ms530.353us
Mean: 17s361ms735.324us
StdDev: 17s911ms264.581us
Rate: 0.324586 / second
Percentiles: 25%=053ms130.715us; 50%=31s218ms238.650us; 80%=36s723ms015.319us; 90%=36s792ms500.821us; 95%=36s934ms145.867us; 99%=36s063ms320.434us
Metric: XrtExecute
TotalSamples: 80
Accumulator: 18s071ms622.717us
Mean: 226ms882.784us
StdDev: 219ms630.605us
Rate: 0.72729 / second
Percentiles: 25%=061ms063.161us; 50%=070ms738.138us; 80%=471ms303.711us; 90%=517ms931.714us; 95%=598ms271.658us; 99%=781ms136.971us
Metric: XrtReadLiteral
TotalSamples: 40
Accumulator: 030ms686.025us
Mean: 742.151us
StdDev: 510.653us
Rate: 0.415284 / second
Percentiles: 25%=541.531us; 50%=612.526us; 80%=832.256us; 90%=959.902us; 95%=001ms448.644us; 99%=004ms681.102us
Metric: XrtReleaseAllocation
TotalSamples: 175
Accumulator: 012ms473.894us
Mean: 071.279us
StdDev: 096.979us
Rate: 1.59095 / second
Percentiles: 25%=016.327us; 50%=026.028us; 80%=129.571us; 90%=243.275us; 95%=282.826us; 99%=360.606us
^ @dlibenzi Let me know if this is what you were looking for.
Sorry, I forgot what we are debugging here 😄 From the last report things look OK. It has been run for too few steps to tell whether it stabilized, compile wise.
For the deeper pad/pack sequence, I have not had time to look into it (sorry about that). We had a chat with @mruberry and @ailzhang and we agreed the only sane way to use that is pack on CPU and pad on XLA device.
Hey @dlibenzi, 😄, yes I was busy with other things for a while.
Do you want me to run for more steps, @dlibenzi ?
We had a chat with @mruberry and @ailzhang and we agreed the only sane way to use that is pack on CPU and pad on XLA device.
What does this mean for this issue or next steps?
Further @mruberry asked me "can you pack as part of your network's input pipeline that runs on CPU". Are you still interested in this?
Yes, more step. About 10 should tell a better story.
Yep also to the fact of pack into input pipeline, and pad from inside the training step.
@dlibenzi, as promised a metric debug report after 10 batch iterations, all parameters are as described in my last post with comments. Let me know what you think.
I will now try to make packing a part of my input pipeline on CPU as indicated.
EEEEEEEE>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>!>><>>>>>>>>>!>><>>>>>>>>>>>>!>>><>>>>>>>>>>>>>!><>>>>>>>>!>>><!>><!<>!<!!!<encoder_optimizer<encoder_optimizer<decoder_optimizer<decoder_optimizer!
!
<encoder_optimizerEpoch 0/49 - ETA: 0:10:32 Batch 10/156 Average batch training time 4301ms
Batch:
<decoder_optimizertraining: loss 11.605.
!!!
Moving average:
!
training: loss 10.455.
!INFO: MetricDebugCallback::on_batch_training_completed : XLA METRIC DEBUG REPORT AFTER ITER 10 :
Metric: CompileTime
TotalSamples: 4
Accumulator: 01m26s428ms237.702us
ValueRate: 858ms724.600us / second
Rate: 0.0396965 / second
Percentiles: 1%=114ms046.633us; 5%=114ms046.633us; 10%=114ms046.633us; 20%=114ms046.633us; 50%=38s443ms588.270us; 80%=39s524ms243.144us; 90%=39s524ms243.144us; 95%=39s524ms243.144us; 99%=39s524ms243.144us
Metric: DeviceLockWait
TotalSamples: 22
Accumulator: 085.498us
ValueRate: 000.641us / second
Rate: 0.165049 / second
Percentiles: 1%=002.886us; 5%=003.086us; 10%=003.113us; 20%=003.131us; 50%=003.464us; 80%=003.844us; 90%=004.150us; 95%=004.502us; 99%=012.110us
Metric: ExecuteTime
TotalSamples: 22
Accumulator: 06s615ms612.831us
ValueRate: 042ms169.237us / second
Rate: 0.165234 / second
Percentiles: 1%=007ms066.400us; 5%=062ms097.164us; 10%=062ms487.875us; 20%=063ms592.919us; 50%=076ms699.757us; 80%=506ms819.089us; 90%=510ms086.590us; 95%=518ms480.405us; 99%=674ms742.384us
Metric: InboundData
TotalSamples: 11
Accumulator: 44.00B
ValueRate: 0.37B / second
Rate: 0.0915768 / second
Percentiles: 1%=4.00B; 5%=4.00B; 10%=4.00B; 20%=4.00B; 50%=4.00B; 80%=4.00B; 90%=4.00B; 95%=4.00B; 99%=4.00B
Metric: InputOutputAliasCount
TotalSamples: 3
Accumulator: 283.00
ValueRate: 4.55 / second
Rate: 0.0481894 / second
Percentiles: 1%=31.00; 5%=31.00; 10%=31.00; 20%=31.00; 50%=126.00; 80%=126.00; 90%=126.00; 95%=126.00; 99%=126.00
Metric: IrValueTensorToXlaData
TotalSamples: 53
Accumulator: 05s935ms358.610us
ValueRate: 037ms332.731us / second
Rate: 0.40091 / second
Percentiles: 1%=001ms314.478us; 5%=002ms507.220us; 10%=002ms659.477us; 20%=002ms884.705us; 50%=019ms117.200us; 80%=211ms645.112us; 90%=340ms971.160us; 95%=403ms097.573us; 99%=605ms230.950us
Metric: OutboundData
TotalSamples: 99
Accumulator: 359.57MB
ValueRate: 2.66MB / second
Rate: 0.732311 / second
Percentiles: 1%=4.00B; 5%=4.00B; 10%=4.00B; 20%=4.00B; 50%=800.00B; 80%=267.19KB; 90%=11.44MB; 95%=11.44MB; 99%=93.22MB
Metric: ReleaseDataHandlesTime
TotalSamples: 45
Accumulator: 03s224ms442.461us
ValueRate: 024ms216.670us / second
Rate: 0.337965 / second
Percentiles: 1%=620.475us; 5%=799.744us; 10%=831.020us; 20%=967.969us; 50%=002ms541.256us; 80%=222ms389.161us; 90%=324ms819.376us; 95%=376ms452.137us; 99%=466ms084.961us
Metric: TensorsGraphSize
TotalSamples: 22
Accumulator: 483458.00
ValueRate: 3632.55 / second
Rate: 0.165301 / second
Percentiles: 1%=157.00; 5%=10923.00; 10%=10923.00; 20%=10923.00; 50%=10923.00; 80%=36327.00; 90%=36327.00; 95%=36327.00; 99%=36327.00
Metric: TransferFromServerTime
TotalSamples: 11
Accumulator: 025ms812.503us
ValueRate: 206.568us / second
Rate: 0.0915768 / second
Percentiles: 1%=001ms420.842us; 5%=001ms420.842us; 10%=001ms479.343us; 20%=001ms489.190us; 50%=002ms928.930us; 80%=002ms111.464us; 90%=003ms460.742us; 95%=006ms619.715us; 99%=006ms619.715us
Metric: TransferToServerTime
TotalSamples: 99
Accumulator: 06s705ms991.980us
ValueRate: 042ms359.305us / second
Rate: 0.73507 / second
Percentiles: 1%=001ms093.955us; 5%=001ms307.741us; 10%=001ms423.728us; 20%=002ms618.483us; 50%=004ms364.349us; 80%=044ms255.991us; 90%=286ms285.890us; 95%=376ms886.684us; 99%=605ms135.443us
Metric: TransferToServerTransformTime
TotalSamples: 99
Accumulator: 398ms003.896us
ValueRate: 003ms944.066us / second
Rate: 0.732311 / second
Percentiles: 1%=041.533us; 5%=046.862us; 10%=053.281us; 20%=063.335us; 50%=195.711us; 80%=004ms646.402us; 90%=007ms957.361us; 95%=022ms001.853us; 99%=096ms154.207us
Counter: CachedCompile
Value: 18
Counter: CreateCompileHandles
Value: 4
Counter: CreateDataHandles
Value: 1461
Counter: CreateXlaTensor
Value: 199796
Counter: DestroyDataHandles
Value: 1299
Counter: DestroyXlaTensor
Value: 199646
Counter: MarkStep
Value: 11
Counter: ReleaseDataHandles
Value: 1299
Counter: UncachedCompile
Value: 4
Counter: XRTAllocateFromTensor_Empty
Value: 47
Counter: XrtCompile_Empty
Value: 32
Counter: XrtExecuteChained_Empty
Value: 32
Counter: XrtExecute_Empty
Value: 32
Counter: XrtRead_Empty
Value: 32
Counter: XrtReleaseAllocationHandle_Empty
Value: 32
Counter: XrtReleaseCompileHandle_Empty
Value: 32
Counter: XrtSessionCount
Value: 4
Counter: XrtSubTuple_Empty
Value: 32
Counter: aten::_local_scalar_dense
Value: 11
Counter: xla::_softmax
Value: 880
Counter: xla::_softmax_backward_data
Value: 880
Counter: xla::_unsafe_view
Value: 440
Counter: xla::add
Value: 26015
Counter: xla::add_
Value: 9374
Counter: xla::addcdiv_
Value: 352
Counter: xla::addcmul_
Value: 352
Counter: xla::addmm
Value: 6160
Counter: xla::as_strided
Value: 933
Counter: xla::bernoulli_
Value: 891
Counter: xla::bmm
Value: 1320
Counter: xla::cat
Value: 5742
Counter: xla::clone
Value: 5720
Counter: xla::copy_
Value: 1439
Counter: xla::div
Value: 803
Counter: xla::div_
Value: 891
Counter: xla::embedding
Value: 451
Counter: xla::embedding_dense_backward
Value: 451
Counter: xla::empty
Value: 2427
Counter: xla::empty_strided
Value: 493
Counter: xla::expand
Value: 462
Counter: xla::fill_
Value: 11
Counter: xla::gather
Value: 440
Counter: xla::index_select
Value: 451
Counter: xla::log
Value: 440
Counter: xla::masked_scatter_
Value: 11
Counter: xla::masked_select
Value: 11
Counter: xla::mean
Value: 11
Counter: xla::mm
Value: 13596
Counter: xla::mul
Value: 16302
Counter: xla::mul_
Value: 5984
Counter: xla::neg
Value: 3520
Counter: xla::scatter_add_
Value: 440
Counter: xla::select
Value: 1320
Counter: xla::sigmoid_
Value: 5280
Counter: xla::sigmoid_backward
Value: 5280
Counter: xla::slice
Value: 1518
Counter: xla::split
Value: 5280
Counter: xla::sqrt
Value: 352
Counter: xla::squeeze
Value: 2651
Counter: xla::stack
Value: 2717
Counter: xla::sub
Value: 2640
Counter: xla::sum
Value: 7480
Counter: xla::t
Value: 27236
Counter: xla::tanh
Value: 440
Counter: xla::tanh_
Value: 2640
Counter: xla::tanh_backward
Value: 3080
Counter: xla::transpose
Value: 1760
Counter: xla::unbind
Value: 2717
Counter: xla::unsqueeze
Value: 3520
Counter: xla::view
Value: 10582
Counter: xla::zero_
Value: 1352
Metric: XrtAllocateFromTensor
TotalSamples: 1257
Accumulator: 03s467ms925.537us
Mean: 002ms845.530us
StdDev: 004ms549.401us
Rate: 8.45732 / second
Percentiles: 25%=415.303us; 50%=946.654us; 80%=003ms927.578us; 90%=004ms128.344us; 95%=005ms874.917us; 99%=006ms857.999us
Metric: XrtCompile
TotalSamples: 32
Accumulator: 10m40s801ms255.940us
Mean: 18s119ms789.248us
StdDev: 18s657ms359.529us
Rate: 0.317522 / second
Percentiles: 25%=048ms247.089us; 50%=33s017ms957.903us; 80%=37s235ms200.347us; 90%=37s289ms079.357us; 95%=37s319ms358.383us; 99%=38s982ms255.098us
Metric: XrtExecute
TotalSamples: 171
Accumulator: 45s075ms975.925us
Mean: 264ms596.350us
StdDev: 235ms526.196us
Rate: 1.28365 / second
Percentiles: 25%=061ms097.722us; 50%=070ms609.567us; 80%=498ms713.629us; 90%=550ms680.843us; 95%=672ms554.584us; 99%=803ms268.751us
Metric: XrtReadLiteral
TotalSamples: 83
Accumulator: 054ms685.265us
Mean: 646.810us
StdDev: 156.971us
Rate: 0.69059 / second
Percentiles: 25%=552.789us; 50%=616.012us; 80%=757.129us; 90%=817.673us; 95%=981.137us; 99%=001ms201.167us
Metric: XrtReleaseAllocation
TotalSamples: 352
Accumulator: 030ms675.938us
Mean: 084.307us
StdDev: 112.424us
Rate: 2.64226 / second
Percentiles: 25%=017.130us; 50%=027.386us; 80%=200.698us; 90%=275.434us; 95%=314.798us; 99%=421.636us
@mruberry, @dlibenzi, @ailzhang
The following is implemented in this Colab Notebook.
Here, @mruberry asked : "can you pack as part of your network's input pipeline that runs on CPU?"
In order to do so, I wrapped my original collate function and added a pack_padded_sequence
on the input sequences:
def create_pack_on_cpu_collate_fn(PAD_token, fixed_sequence_length):
base_collate_fn = create_sentence_pairs_collate_fn(PAD_token,
fixed_sequence_length)
def colate_fn(indexed_sentence_pairs):
padded_input_batch, \
input_lengths, \
padded_output_batch, \
output_mask, \
max_output_len = base_collate_fn(indexed_sentence_pairs)
packed_input_batch = nn.utils.rnn.pack_padded_sequence(padded_input_batch,
input_lengths.squeeze(),
enforce_sorted=False)
total_length = padded_input_batch.size(0)
return packed_input_batch, \
total_length, \
padded_output_batch, \
output_mask, \
max_output_len
return colate_fn
The problem is, when a batch is read, packed_input_batch
has been transformed in to a list
of Tensors. packed_input_batch
is retrieved here in ChatbotTrainer::_evaluate_loss
:
packed_input_batch, \
total_length, \
output_batch, \
output_mask, \
max_output_len = batch_data
Because packed_input_batch
has become a list, instead of aPackedSequence
, the code further down the line breaks.
I suspect that the tensors in the list are a representation of the packed_input_batch
instance attributes. Do you agree?
If this is the case, how can I re-esemble the original PackedSequence instance from the list of tensors?
Finally, can I use a PackedSequence as an input to the first layer in the encoder (see the EncoderRNN
class), which is an embedding? If not, I would need to also evaluate the embedding layer as part of the input pipeline, which I dint think is our intention.
Let me know if I properly interpreted the request to "pack as part of your network's input pipeline that runs on CPU", and how to continue from here.
I have no clue what pad/pack do, and ATM I have no time to study it. Sorry 😑
From the metrics you posted, they seem good. It seems like compilations stabilized.
As far as packed_input_batch
turned to a list. Where does batch_data
comes from?
Hey @dlibenzi,
It was @mruberry who asked to "pack as part of your network's input pipeline that runs on CPU", so I hope he can respond.
Packing/padding (unpacking), is used to deal with batches that contain sequences (e.g. chats) with different lengths. To fit them in a tensor with constant length (dim. 0), they are padded with a pad values after the end of the sequences. By packing the tensors, the RNN knows what the duration is of each individual sequence, in order to not evaluate the sequence beyond its duration. This improves the ability to train RNNs considerably, because back prop. doesn't need to deal with the 'non-information' of the trailing padding values.
Effectively, batch_data
comes directly from the torch.utils.data.DataLoader
, as instantiated in the setup_datasets
function. The batch_data
tuple is collated by the collation function created using create_pack_on_cpu_collate_fn
.
Looking a bit deeper in to the PyTorch PackedSequence
class, returned by pack_padded_sequence
, I notice that it's a subclass of namedtuple
. To know for sure I should debug this with an actual debugger (not in Colab), but I think that after my collate function returns the batch_data
, containing the PackedSequence
(packed_input_batch
), the DataLoader
"serializes" the namedtuple
field values.
In any case, the more I think about it, it doesn't make sense to pack this early on, on CPU, in my input pipeline, because the sequences first need to go through an embedding layer, which does not take a PackedSequence
as input; it is not a Tensor type. If packing this early on should be the solution, the PackedSequence
should mimic a Tensor.
Is your Colab up to date with your recent changes?
I mentioned an updated version in my comment of 23 hours ago: https://colab.research.google.com/drive/10wxgY6sp8DbOR_qyGFHN5ZFlFAeYT7-P
Cheers,
FS
EDIT: PS: this code is broken due to the issues discussed in my last comments.
I have found a couple of issues.
Our code is probably not translating correctly the PackedSequence data structure when we send it to device. But even if we did, we would translate all tensors to XLA device ones, which makes the pad code unhappy (one tensor needs to be on CPU).
I am writing a PR to address that.
@dlibenzi thanks for taking time to work on this. I see what you have done in the PR. I could wrap my PackedSequence
in a DataWrapper
, that would allow me to see my PackedSequence
back 'on the other side'. However, this does not yet address the more fundamental issue that I can't use the pre-packed data, as input, in my EncoderRNN
class, because it first needs to go through an embedding layer there; the embedding layer does not accept a PackedSequence
object.
As mentioned, a solution is to apply the embedding layer in my collate function, before packing the sequences (of embedding vectors) there. Although this might technically work it has two disadvantages:
a) from an architectural point of it does not make sense, because I would place a part of my model in the batch collate function b) it would mean that part of my gradients (for the embedding vectors) are on CPU.
So my question is, is packing this early really the best solution? Because @mruberry did not explain why he wanted the packing done earlier in the input pipeline on CPU, I can't really reason any further about it.
I'm looking back at the original issues described in the beginning of the issue report:
pack_padded_sequence
doesn't work on TPU, while it works on GPUWe are working on a solution for the pack_padded_sequence
. ✔︎👍🏼
I reran the original Colab notebook to train the chatbot on TPU, to see if with the nightly Pytorch-XLA now does reduce the loss when disabling packing and unpacking. What I noticed was the following:
I get the following warning many times per batch iteration:
/pytorch/torch/csrc/autograd/variable.cpp:401: UserWarning: This view requires gradients but its base or another view of the same base has been modified inplace. Running a backward pass through an inplace update on view tensors is a WIP for the XLA backend and may result in incorrect gradient computation in certain cases. Note this warning is being triggered on the inplace update (not the corresponding backward pass), and this update is safe if a backward pass is not run. To work around this limitation and to silence this warning, please replace the inplace operation by the corresponding out-of-place operation.
Unfortunately it doesn't say what inplace operations are done here. But could this be the cause of the loss increase we have been seeing?
Because of the above warning, I could not really check the training speed on TPU, but what I understand from the debug metrics this might be resolved now?
I have changed out pack/pad API implementation in that above PR. I tested with this example I have found:
https://gist.github.com/dlibenzi/8cac2a955d50358687c313f5420a88a2
And got this IR graph:
https://gist.github.com/dlibenzi/7bf1995ee443071839fa1211a28b2bae
Seems OK, but my fear is compile stabilization.
The warning you are seeing we added it recently. @ailzhang We are looking into fixed the underline potential issue.
BTW ... yeah, pack/pad need to run on device, with auxiliary CPU tensor companion.
Hey @dlibenzi, thanks for this. I've looked at your gist, together with your changes in the PR, I see that the only thing I likely need to do is use seq_lengths.cpu().numpy()
. Is there a way for me to use your branch on Colab so I can try it out? i.e. is there a torch_xla
build for your specific branch available?
Actually, from my example, I converted seq_lenghts
to device, but they do not need to AFAICT.
You will have to wait for tomorrow, that my PR will show up in nightly.
Yes, of course, seq_lenghts
can just stay in CPU from the start, understand that; In practice though, the combination of using DataLoader
with DistributedSampler
, transfers it to the device automatically AFAIK.
So, I understand your PR will be merged to master? Cool. I can wait for that until tomorrow.
Cheers!
You should be able to marshal the CPU tensors as such, with the DataWrapper. Then it remains to be seen whether the compilations stabilize under real training load.
Yes, it should be in tomorrow's nightly builds.
Hi @dlibenzi,
To test your fixes (PR #1839 ) I created a new version of the Colab script, with the only difference that I enabled the use of pack_padded_sequence
\pad_packed_sequence
, and I transferred the sequence lengths back to cpu as you do in the gist you shared:
packed = nn.utils.rnn.pack_padded_sequence(embedded,
input_lengths.cpu().numpy(),
enforce_sorted=self.enforce_sorted)
Note that I already take care of ordering the samples by sequence length in my collate method.
The original masked_loss
method was used that makes use of masked_select
. Further, I used num_cores = 1
.
The training does not crash now, so that is good. But I do notice the following:
Epoch 0/49 - ETA: 23:53:30 Batch 9/1255 Average batch training time 68973ms
Batch:
training: loss 10.163.
Moving average:
training: loss 10.116.
E>>>>>>>>>>>>>>>>>>>>>>>>>>>/pytorch/torch/csrc/autograd/variable.cpp:401: UserWarning: This view requires gradients but its base or another view of the same base has been modified inplace. Running a backward pass through an inplace update on view tensors is a WIP for the XLA backend and may result in incorrect gradient computation in certain cases. Note this warning is being triggered on the inplace update (not the corresponding backward pass), and this update is safe if a backward pass is not run. To work around this limitation and to silence this warning, please replace the inplace operation by the corresponding out-of-place operation.
...
/pytorch/torch/csrc/autograd/variable.cpp:401: UserWarning: This view requires gradients but its base or another view of the same base has been modified inplace. Running a backward pass through an inplace update on view tensors is a WIP for the XLA backend and may result in incorrect gradient computation in certain cases. Note this warning is being triggered on the inplace update (not the corresponding backward pass), and this update is safe if a backward pass is not run. To work around this limitation and to silence this warning, please replace the inplace operation by the corresponding out-of-place operation.
!<encoder_optimizer<decoder_optimizer!
INFO: MetricDebugCallback::on_batch_training_completed : XLA METRIC DEBUG REPORT AFTER ITER 10 :
Metric: CompileTime
TotalSamples: 27
Accumulator: 12m19s600ms861.779us
ValueRate: 969ms975.758us / second
Rate: 0.0354215 / second
Percentiles: 1%=014ms324.312us; 5%=021ms432.868us; 10%=103ms716.515us; 20%=11s062ms056.263us; 50%=13s127ms507.002us; 80%=54s132ms774.048us; 90%=57s793ms217.181us; 95%=58s546ms008.095us; 99%=58s853ms997.859us
Metric: DeviceLockWait
TotalSamples: 66
Accumulator: 07s793ms848.868us
ValueRate: 009ms903.409us / second
Rate: 0.0865064 / second
Percentiles: 1%=002.864us; 5%=003.177us; 10%=003.564us; 20%=004.075us; 50%=005.395us; 80%=011ms518.557us; 90%=499ms879.408us; 95%=610ms750.793us; 99%=979ms362.579us
Metric: ExecuteTime
TotalSamples: 55
Accumulator: 09s521ms452.744us
ValueRate: 011ms171.546us / second
Rate: 0.0721045 / second
Percentiles: 1%=002ms966.902us; 5%=002ms216.400us; 10%=002ms465.960us; 20%=003ms824.185us; 50%=012ms761.798us; 80%=481ms917.180us; 90%=540ms220.626us; 95%=816ms038.452us; 99%=998ms862.235us
Metric: InboundData
TotalSamples: 44
Accumulator: 12.93KB
ValueRate: 18.72B / second
Rate: 0.0621929 / second
Percentiles: 1%=0.00B; 5%=0.00B; 10%=0.00B; 20%=0.00B; 50%=400.00B; 80%=800.00B; 90%=800.00B; 95%=800.00B; 99%=800.00B
Metric: InputOutputAliasCount
TotalSamples: 14
Accumulator: 1023.00
ValueRate: 1.44 / second
Rate: 0.0197695 / second
Percentiles: 1%=31.00; 5%=31.00; 10%=45.00; 20%=45.00; 50%=82.00; 80%=82.00; 90%=82.00; 95%=82.00; 99%=82.00
Metric: IrValueTensorToXlaData
TotalSamples: 64
Accumulator: 02s947ms835.242us
ValueRate: 003ms794.865us / second
Rate: 0.091878 / second
Percentiles: 1%=001ms214.198us; 5%=001ms258.800us; 10%=001ms313.010us; 20%=001ms433.756us; 50%=002ms116.456us; 80%=041ms583.913us; 90%=048ms507.184us; 95%=084ms408.210us; 99%=657ms167.075us
Metric: OutboundData
TotalSamples: 110
Accumulator: 359.58MB
ValueRate: 481.72KB / second
Rate: 0.143911 / second
Percentiles: 1%=4.00B; 5%=4.00B; 10%=4.00B; 20%=4.00B; 50%=800.00B; 80%=15.62KB; 90%=11.44MB; 95%=11.44MB; 99%=93.22MB
Metric: ReleaseDataHandlesTime
TotalSamples: 275
Accumulator: 600ms450.598us
ValueRate: 787.189us / second
Rate: 0.360524 / second
Percentiles: 1%=460.476us; 5%=512.792us; 10%=554.073us; 20%=600.081us; 50%=765.369us; 80%=001ms123.711us; 90%=002ms516.860us; 95%=004ms004.091us; 99%=046ms863.435us
Metric: TensorsGraphSize
TotalSamples: 56
Accumulator: 537692.00
ValueRate: 704.90 / second
Rate: 0.0734146 / second
Percentiles: 1%=2.00; 5%=2.00; 10%=2.00; 20%=9.00; 50%=405.00; 80%=11511.00; 90%=37530.00; 95%=38371.00; 99%=38661.00
Metric: TransferFromServerTime
TotalSamples: 44
Accumulator: 077ms703.341us
ValueRate: 108.418us / second
Rate: 0.0621929 / second
Percentiles: 1%=001.463us; 5%=001.601us; 10%=001.773us; 20%=002.583us; 50%=002ms607.147us; 80%=002ms154.747us; 90%=004ms482.412us; 95%=005ms687.710us; 99%=005ms361.099us
Metric: TransferToServerTime
TotalSamples: 110
Accumulator: 02s150ms808.545us
ValueRate: 003ms814.546us / second
Rate: 0.144013 / second
Percentiles: 1%=998.674us; 5%=001ms182.306us; 10%=001ms238.723us; 20%=001ms387.516us; 50%=002ms027.784us; 80%=013ms982.131us; 90%=043ms005.464us; 95%=048ms145.956us; 99%=449ms448.454us
Metric: TransferToServerTransformTime
TotalSamples: 110
Accumulator: 283ms946.375us
ValueRate: 370.174us / second
Rate: 0.143911 / second
Percentiles: 1%=038.736us; 5%=060.968us; 10%=076.311us; 20%=087.728us; 50%=160.793us; 80%=003ms009.087us; 90%=005ms911.395us; 95%=008ms243.738us; 99%=024ms953.028us
Counter: CachedCompile
Value: 29
Counter: CreateCompileHandles
Value: 27
Counter: CreateDataHandles
Value: 1586
Counter: CreateXlaTensor
Value: 210539
Counter: DestroyDataHandles
Value: 1394
Counter: DestroyXlaTensor
Value: 210389
Counter: MarkStep
Value: 33
Counter: ReleaseDataHandles
Value: 1394
Counter: UncachedCompile
Value: 27
Counter: XRTAllocateFromTensor_Empty
Value: 33
Counter: XrtCompile_Empty
Value: 128
Counter: XrtExecuteChained_Empty
Value: 128
Counter: XrtExecute_Empty
Value: 128
Counter: XrtRead_Empty
Value: 128
Counter: XrtReleaseAllocationHandle_Empty
Value: 128
Counter: XrtReleaseCompileHandle_Empty
Value: 128
Counter: XrtSessionCount
Value: 10
Counter: XrtSubTuple_Empty
Value: 128
Counter: aten::_local_scalar_dense
Value: 11
Counter: xla::_pack_padded_sequence
Value: 11
Counter: xla::_softmax
Value: 880
Counter: xla::_softmax_backward_data
Value: 880
Counter: xla::_unsafe_view
Value: 440
Counter: xla::add
Value: 26810
Counter: xla::add_
Value: 9062
Counter: xla::addcdiv_
Value: 352
Counter: xla::addcmul_
Value: 352
Counter: xla::addmm
Value: 5952
Counter: xla::arange_out
Value: 11
Counter: xla::as_strided
Value: 1548
Counter: xla::bernoulli_
Value: 891
Counter: xla::bmm
Value: 1320
Counter: xla::cat
Value: 6193
Counter: xla::clone
Value: 5814
Counter: xla::copy_
Value: 5903
Counter: xla::div
Value: 803
Counter: xla::div_
Value: 891
Counter: xla::embedding
Value: 451
Counter: xla::embedding_dense_backward
Value: 451
Counter: xla::empty
Value: 6230
Counter: xla::empty_strided
Value: 806
Counter: xla::expand
Value: 462
Counter: xla::fill_
Value: 22
Counter: xla::gather
Value: 440
Counter: xla::index_add_
Value: 33
Counter: xla::index_select
Value: 484
Counter: xla::log
Value: 440
Counter: xla::masked_scatter_
Value: 11
Counter: xla::masked_select
Value: 11
Counter: xla::mean
Value: 11
Counter: xla::mm
Value: 13180
Counter: xla::mul
Value: 15782
Counter: xla::mul_
Value: 5776
Counter: xla::neg
Value: 3416
Counter: xla::scatter_
Value: 11
Counter: xla::scatter_add_
Value: 440
Counter: xla::select
Value: 1734
Counter: xla::sigmoid_
Value: 5072
Counter: xla::sigmoid_backward
Value: 5072
Counter: xla::slice
Value: 13826
Counter: xla::split
Value: 5072
Counter: xla::sqrt
Value: 352
Counter: xla::squeeze
Value: 2651
Counter: xla::stack
Value: 2651
Counter: xla::sub
Value: 2536
Counter: xla::sum
Value: 7272
Counter: xla::t
Value: 26404
Counter: xla::tanh
Value: 440
Counter: xla::tanh_
Value: 2536
Counter: xla::tanh_backward
Value: 2976
Counter: xla::transpose
Value: 1760
Counter: xla::unbind
Value: 2651
Counter: xla::unsqueeze
Value: 3520
Counter: xla::view
Value: 11280
Counter: xla::zero_
Value: 4809
Metric: XrtAllocateFromTensor
TotalSamples: 170
Accumulator: 493ms383.877us
Mean: 003ms902.258us
StdDev: 009ms059.180us
Rate: 0.222562 / second
Percentiles: 25%=365.484us; 50%=001ms114.062us; 80%=004ms073.779us; 90%=005ms030.601us; 95%=006ms078.927us; 99%=084ms574.457us
Metric: XrtCompile
TotalSamples: 27
Accumulator: 12m14s352ms693.563us
Mean: 27s198ms210.873us
StdDev: 23s047ms470.482us
Rate: 0.0354214 / second
Percentiles: 25%=12s520ms961.722us; 50%=13s057ms546.359us; 80%=54s816ms571.979us; 90%=56s481ms706.198us; 95%=57s264ms213.373us; 99%=58s590ms349.306us
Metric: XrtExecute
TotalSamples: 56
Accumulator: 08s331ms967.339us
Mean: 149ms767.274us
StdDev: 249ms615.802us
Rate: 0.0734139 / second
Percentiles: 25%=001ms329.121us; 50%=010ms803.221us; 80%=137ms986.433us; 90%=538ms173.909us; 95%=814ms911.052us; 99%=996ms784.191us
Metric: XrtExecutorEvict
TotalSamples: 11
Accumulator: 1.22GB
Mean: 113.84MB
StdDev: 96.06MB
Rate: 0.0507931 / second
Percentiles: 25%=59.42MB; 50%=60.22MB; 80%=235.58MB; 90%=240.94MB; 95%=241.98MB; 99%=241.98MB
Metric: XrtReadLiteral
TotalSamples: 33
Accumulator: 019ms235.947us
Mean: 582.907us
StdDev: 318.369us
Rate: 0.0466446 / second
Percentiles: 25%=430.794us; 50%=499.384us; 80%=685.396us; 90%=794.497us; 95%=935.492us; 99%=002ms183.959us
Metric: XrtReleaseAllocation
TotalSamples: 276
Accumulator: 009ms287.557us
Mean: 033.651us
StdDev: 044.863us
Rate: 0.361827 / second
Percentiles: 25%=012.513us; 50%=020.354us; 80%=037.916us; 90%=068.796us; 95%=130.986us; 99%=217.753us
Epoch 0/49 - ETA: 1 day, 0:02:46 Batch 10/1255 Average batch training time 69474ms
Batch:
training: loss 10.204.
Moving average:
training: loss 10.124.
So the issue is like suspected, recompilations. Even by batching the sequence lengths into same padded-size, the CPU tensor representing the real lengths is effectively a different recipe on the PyTorch ops issued to XLA, which triggers recompilations. I did a simple experiment here:
https://gist.github.com/dlibenzi/2dec80fbf3594ee3a3ae00f2b79f36d9
The good news is, there are multiple folks training language models using TPU, so the pad/pack PyTorch code must not be mandatory. @taylanbil can you point to the code of the language models you have dealt with, WRT sequence lengths, padding, binning, etc...
IIRC pad/pack really only applies to RNNs (grus, lstms, etc). I actually don't know of any good examples of those that run on TPUs.
For transformer based models, we don't do packing (doesn't make sense for attention). AFAIU pad/pack produces a different shape tensor every time, one dimension of the tensor is number of non-pad tokens in the batch, which is likely highly variable from batch to batch. The reasoning behind it is, it saves flops that way. However, this must be causing a ton of compiles on TPUs due to shapes varying all over the place.
One idea would be, to further "pad" the packed sequence so it always has the same shape, but that probably defeats the purpose of packing padded sequences in the first place. I'd suggest instead of packing, if the RNN allows non-packed inputs, forget about the flop savings and try with the non-packed inputs.
@taylanbil @dlibenzi Thanks for your answers. The reason pad/pack is used for RNN is not to save compute (maybe that is a nice side effect), but it really influences how effectively the RNNs can be trained, i.e. how low we can get the loss, how fast.
When you don't use packing, the padding values are used as valid sequence values, this implies that the RNN has to learn that these trailing padding values do not add any information. In my experience this implies that the loss goes down more slowly and never reaches a loss value as low as when the sequences are packed.
So the use of pack/pad is very important when you are working with RNNs; if it was only to save flops I wouldn't have cared that much.
Another thing I have to note is that I set up my input pipeline such that for every batch the shape of the padded input sequences tensor is constant and the shape of the tensor after unpacking (pad_packed_sequence
) is constant for every batch. So, AFAIU this should not trigger recompilations. Something deeper, must be triggering a recompilation.
I found this nice gist that shows precisely how the sequences are packed.
Maybe this is what you wanted to explain as a solution @taylanbil : If we could make packed_input.data
and packed_input.batch_sizes
(in de the above gist) constant for every batch, it would likely resolve this if I understand correctly.
So that would mean to extend the data
attribute of the PackedSequence
, with the padding values (which are already available from the input sequences), and extend the batch_sizes
attributes with zeros.
The RNN can stop processing input sequences when a batch size of zero is reached, in this way no extra computation is done too!
Would this be a potential solution?
This also explains why you get the same issues when you bucket same length sequences per batch, @dlibenzi . Although the input sequences tensor shape is constant from batch to batch, the data
and batch_sizes
attributes of PackedSequence
are varying in length from batch to batch.
Let me know what you think.
Thanks for the explanation @visionscaper.
If we could make packed_input.data and packed_input.batch_sizes (in de the above gist) constant for every batch, it would likely resolve this if I understand correctly.
My gut feeling is, this is true. To easily test this, you could make your parallelloader spit out the same input over and over again and see if it compiles as often as it does normally.
The RNN can stop processing input sequences when a batch size of zero is reached, in this way no extra computation is done too!
That sounds like it may produce different graphs and thus may cause recompilations too. Definitely worth trying though.
Yes, or course, if you create a limited set of max-length bins, and within eacch bin the effective lengths are repeated many times, the compilations will be amortized. I hacked my example to run 10 times on the same input, and it only compiles once:
https://gist.github.com/dlibenzi/8681d01937efb044fc4c2ddbde9c465e
That's because the bin-size+effective-lengths combo remains constant, and hence hits the compilation cache.
❓ Questions and Help
This might relate to one or more bugs, but since I'm not sure I'm posing this issue as a pytorch/XLA usage question. Thanks for your time in advance!
EDIT 1 : you can find the Google Colab notebook for training on TPU here: chatbot-training-test-pytorch-xla-10012020.ipynb. To run it, make a copy of it first.
EDIT 2 : The CUDA (GPU) version of the Google Colab notebook can be found here : chatbot-training-test-pytorch-gpu-10012020.ipynb
My questions basically are (see the details below):
pack_padded_sequence
contains a bug?pack_padded_sequence
in conjunction withpytorch/xla
? (Without needing to transfer padded sequences to CPU and transferring the resulting packed tensor back to the TPU)NaN
or increases when training on TPU, while that is not the case on GPU?pytorch/xla
?To test training with PyTorch on TPUs, I created a Google Colab notebook (configured to use Python 3/TPU) to train a simple sequence to sequence model (chatbot). The code defining the model is actually derived from the PyTorch chatbot tutorial. Further, I used contrib/colab/resnet18-training-xrt-1-15.ipynb as an example for how to adapt my code to apply
pytorch/xla
.[CASE A] When packing padded sequences, with
pack_padded_sequence
, in the encoder, like so:... results in the following error:
Note that both
embedded
andinput_lengths
are transferred to the TPU.[CASE B] I managed to code around this error by transferring these variables to CPU and transfer the resulting
packed
tensor back to the TPU:Now training doesn't crash anymore, but loss is
NaN
, even after the first batch iteration.[CASE C] The last thing I tried was completely removing packing and unpacking of the padded sequences. In this case the loss is not
NaN
anymore, but my loss is increasing instead of decreasing!Some remarks:
I tested the same code (removing the XLA specific code) using Google Colab with a GPU, the training works perfectly.
CASE A (Original packing) worked (no errors, loss decreases), CASE B (first transfer to CPU) works, but with the following caveat:
Gives error:
Using
packed.cuda()
instead worked (no errors and loss decreases):CASE C also worked just fine.
Something that might be out of the ordinary for this training experiment is that two optimisers are defined, one for encoding and one for decoding:
Effectively I'm doing a step as follows (please see method
_update_model_parameters
):Is this the correct way of using multiple optimisers with
pytorch\XLA
?UPDATE : I started using
torch_xla.distributed.parallel_loader.ParallelLoader
and subsequently updated the optimizer step effectively as follows:num_cores
is the number of TPU cores we are using.