pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.45k stars 455 forks source link

IR Compilation takes 25 minutes each time #1412

Closed hrbigelow closed 4 years ago

hrbigelow commented 4 years ago

❓ Questions and Help

I have been experimenting with Google Colab with TPU for the last two weeks using iterations of the same repo and program. However, about four days ago, everything slowed down by a factor of about 50x. The graph compilation phases used to take ~3-4 seconds, and are now taking 25 minutes. (Below is a link to the IR output)

I saw in the README it said that "performance may at times be severely impacted when running in Colab compared to creating your own VM and TPU pair"

Why is this the case? Are the TPUs being used by multiple users? If so, is there a command in Colab/Jupyter that can show the current usage? Or is there another reason?

Thanks very much!

xla_ir_output.txt

ailzhang commented 4 years ago

@hrbigelow Would you mind pasting the new metrics report here? Thanks!

hrbigelow commented 4 years ago

@hrbigelow Would you mind pasting the new metrics report here? Thanks!

Sure! Thank you Ailing.

xla.report.9241694.txt xla_metrics.9241694.txt

Couldn't paste the save_tensors file - it exceeded the 10MB limit. Do you need it as well?

ailzhang commented 4 years ago

Oh I'm seeing aten::pow going back to CPU - likely due to this https://github.com/pytorch/xla/blob/master/torch_xla/csrc/aten_xla_type.cpp#L2135. cc: @dlibenzi is it possible we make XLA support other than floating point pow somehow? I also see a few local_scalar_dense in the report, which means we're going back to CPU a lot through .item(). @hrbigelow did any code change happen between ~3-4 seconds and 25mins?

dlibenzi commented 4 years ago

Let me check. What type would that be? Integer? Complex?

hrbigelow commented 4 years ago

Hi @ailzhang Yes, there were lots of code changes between "normal" speeds and the dramatic slowdown. However, I've gone back and tried older commits as well, and they are also extremely slow. Also, the .item() calls only happen when I print progress, and I still see the extreme slowness if I change the progress interval to every 10 steps.

dlibenzi commented 4 years ago

I see 51 steps and more than 5000 item() calls. Something seems off.

dlibenzi commented 4 years ago

Compilations seem to stabilize, execution time is OK. It's just A LOT of item() calls (about 100 per step). I will look into the pow(), but would be nice to know which data type it is called upon.

hrbigelow commented 4 years ago

I see 51 steps and more than 5000 item() calls. Something seems off.

@dlibenzi which report are you looking at that you see the 5000 item() calls? I grepped 'item' or 'Item' from both xla.report.9241694.txt and xla_metrics.9241694.txt but there is no mention of it.

The pow() may be this:

        self.register_buffer('norm_gamma', torch.tensor(norm_gamma))  
        self.register_buffer('two', torch.tensor(2, dtype=torch.int32))
        self.register_buffer('one', torch.tensor(1.0))

    def forward(self, quant_pred, target_wav):

        log_pred = self.logsoftmax(quant_pred)
        target_wav_gather = target_wav.long().unsqueeze(1)
        log_pred_target = torch.gather(log_pred, 1, target_wav_gather)

        rec_loss = - log_pred_target.mean()
=>        ze_norm = (self.bottleneck.ze ** self.two).sum(dim=1).sqrt()

so, self.bottleneck.ze will be a torch.float32, and self.two is a torch.int32

hrbigelow commented 4 years ago

By the way, I'm trying to run this on an actual GCP VM + TPU v2-8, rather than Colab. At the moment I'm getting:

hrbigelow@hrbigelow:~$ sudo pip3 install torch_xla-nightly-cp36-cp36m-linux_x86_64.whl
ERROR: torch_xla-nightly-cp36-cp36m-linux_x86_64.whl is not a supported wheel on this platform.
hrbigelow@hrbigelow:~$ sudo pip install torch_xla-nightly-cp36-cp36m-linux_x86_64.whl
DEPRECATION: Python 2.7 will reach the end of its life on January 1st, 2020. Please upgrade your Python as Python 2.7 won't be maintained after that date. A fu
ture version of pip will drop support for Python 2.7. More details about Python 2 support in pip, can be found at https://pip.pypa.io/en/latest/development/rel
ease-process/#python-2-support
ERROR: torch_xla-nightly-cp36-cp36m-linux_x86_64.whl is not a supported wheel on this platform.
hrbigelow@hrbigelow:~$
jysohn23 commented 4 years ago

For installing the wheels, don't install with sudo and only install on python3.6 as that's what we currently build, test, and support.

Also, we already provide you with GCE Images with all the tools you need pre-installed. For an example gcloud command on how to create a GCE instance with our images check out this tutorial.

hrbigelow commented 4 years ago

Thank you Jin Young! I'll do that.

dlibenzi commented 4 years ago

I see 51 steps and more than 5000 item() calls. Something seems off.

@dlibenzi which report are you looking at that you see the 5000 item() calls? I grepped 'item' or 'Item' from both xla.report.9241694.txt and xla_metrics.9241694.txt but there is no mention of it.

The pow() may be this:

        self.register_buffer('norm_gamma', torch.tensor(norm_gamma))  
        self.register_buffer('two', torch.tensor(2, dtype=torch.int32))
        self.register_buffer('one', torch.tensor(1.0))

    def forward(self, quant_pred, target_wav):

        log_pred = self.logsoftmax(quant_pred)
        target_wav_gather = target_wav.long().unsqueeze(1)
        log_pred_target = torch.gather(log_pred, 1, target_wav_gather)

        rec_loss = - log_pred_target.mean()
=>        ze_norm = (self.bottleneck.ze ** self.two).sum(dim=1).sqrt()

so, self.bottleneck.ze will be a torch.float32, and self.two is a torch.int32

This is the symptom of item(): aten::_local_scalar_dense

We are fixing pow() to support integer exponents.

hrbigelow commented 4 years ago

For installing the wheels, don't install with sudo and only install on python3.6 as that's what we currently build, test, and support.

Also, we already provide you with GCE Images with all the tools you need pre-installed. For an example gcloud command on how to create a GCE instance with our images check out this tutorial.

Hi @jysohn23 I am following the tutorial you linked. But, is it difficult to set up the VM and TPU instances to use torch-xla-nightly or something more recent than 0.5?

EDIT: Oops, sorry about that, I didn't read the README carefully enough. I see the nightly option in conda.

jysohn23 commented 4 years ago

Ah yes 😄 you can just swap in pytorch-nightly when creating the TPU and activate the torch-xla-nightly conda environment instead on the GCE VM.

hrbigelow commented 4 years ago

So, I'm running the model using torch-xla-nightly on GCP VM + TPU v3-8. But, it is still taking 25 minutes to compile one of the IR graphs. I can't imagine this is normal - as the entire model is quite small - just WaveNet and a similar-sized encoder. Is something wrong?

EDIT: It looks like there are five separate IR graphs being compiled. The first four all compile in 2-3 seconds, but the 5th one takes 25 minutes, according to the output messages when I set export TF_CPP_VMODULE=tensor=5. Is there a way to match up the graph hash in the output to the content in the other reports?

I note that I used us-central1-a for the VM, us-central1-b for the TPU.

Also, one question: why is it necessary to specify a pytorch version when creating a TPU instance? doesn't the VM send the library over to the TPU upon launching an application? And if not, what happens if the versions don't match?

 gcloud compute tpus create transformer-pytorch-tutorial \
--zone=us-central1-a \
--network=default \
--range=10.2.3.0 \
--version=pytorch-0.5 \
--accelerator-type=v3-8
dlibenzi commented 4 years ago

Can you export the HLO graphs? We might be hitting a corner case of bad scalability in XLA.

hrbigelow commented 4 years ago

Will do...

Current settings are to use a single TPU core, and:

%env TF_CPP_VMODULE=tensor=5
%env XLA_IR_DEBUG=1
%env XLA_HLO_DEBUG=1

Do I need to also add:

%env XLA_SAVE_TENSORS_FMT=hlo

?

dlibenzi commented 4 years ago
export XLA_SAVE_TENSORS_FILE=/PATH/...

No need for TF_CPP_VMODULE (unless you want to see it).

hrbigelow commented 4 years ago

@dlibenzi Here they are, thanks!

The command-line was:

%run ae-wavenet/train.py new -af ae-wavenet/par/arch.ae.json -tf ae-wavenet/par/train.basic.json -hw TPU-single -pi 1 -lrr 1e-6 -lrs 0 -nw 100 -nb 8 -si 5000 ./basic.full.%.ckpt ae-wavenet/dat/librispeech.some.dat

using commit 03b7286

xla.report.03b7286.txt xla_metrics.03b7286.txt xla_save_tensors.03b7286.txt.gz

dlibenzi commented 4 years ago

OK, yeah, this is bad:

Metric: CompileTime
  TotalSamples: 11
  Accumulator: 03h20m41s719ms403.570us
  ValueRate: 980ms521.460us / second
  Rate: 0.00089934 / second
  Percentiles: 1%=001ms218.934us; 5%=001ms218.934us; 10%=002ms907.720us; 20%=003ms642.511us; 50%=014ms566.838us; 80%=56m57s729ms582.701us; 90%=01h12m30s105ms129.406us; 95%=01h12m14s652ms713.986us; 99%=01h12m14s652ms713.986us

Especially given this:

Metric: SyncTensorsGraphSize
  TotalSamples: 433
  Accumulator: 373200.00
  ValueRate: 24.71 / second
  Rate: 0.0286655 / second
  Percentiles: 1%=3.00; 5%=3.00; 10%=3.00; 20%=3.00; 50%=3.00; 80%=1458.00; 90%=1474.00; 95%=3695.00; 99%=12490.00

I noticed pow() is now cured 😄

You seem to have posted text graphs, not hlo though.

hrbigelow commented 4 years ago

Hi @dlibenzi I am re-running it with XLA_SAVE_TENSORS_FMT=hlo, is that the right thing to do?

I did a test run and indeed see a bit different output. Am running the full run now and will post when finished.

Here is a snippet of the test run:

  launch_instance (/usr/local/lib/python3.6/dist-packages/traitlets/config/application.py:664)
  <module> (/usr/local/lib/python3.6/dist-packages/ipykernel_launcher.py:16)
  _run_code (/usr/lib/python3.6/runpy.py:85)
  _run_module_as_main (/usr/lib/python3.6/runpy.py:193)

HloModule IrToHlo.4

ENTRY %IrToHlo.4 (param_0.1: f32[64,768,1], param_1.2: f32[64]) -> (f32[64,768,1], f32[64]) {
  %param_0.1 = f32[64,768,1]{1,0,2} parameter(0), metadata={op_type="xla::device_data" source_file="_apply@module.py" source_line=226}
  %param_1.2 = f32[64]{0} parameter(1), metadata={op_type="xla::device_data" source_file="_apply@module.py" source_line=226}
  ROOT %tuple.3 = (f32[64,768,1]{1,0,2}, f32[64]{0}) tuple(f32[64,768,1]{1,0,2} %param_0.1, f32[64]{0} %param_1.2)
}

[ScheduleSyncTensorsGraph]
TensorsGraphInfo:
  run (/content/ae-wavenet/model.py:208)
  run_batch (/content/ae-wavenet/model.py:351)

Let me know if this is correct. The full environment is:

%env TF_CPP_VMODULE=
%env XLA_IR_DEBUG=1
%env XLA_HLO_DEBUG=1
%env XLA_SAVE_TENSORS_FMT=hlo
VER=!git -C ae-wavenet rev-parse --short HEAD
XLA_SAVE_TENSORS_FILE='xla_save_tensors.{}.txt'.format(VER[0])
XLA_METRICS_FILE='xla_metrics.{}.txt'.format(VER[0])
!echo > {XLA_SAVE_TENSORS_FILE}
!echo > {XLA_METRICS_FILE}
%env XLA_METRICS_FILE={XLA_METRICS_FILE}
%env XLA_SAVE_TENSORS_FILE={XLA_SAVE_TENSORS_FILE}
# %run -b ae-wavenet/model.py:359 -d ae-wavenet/train.py new -af ae-wavenet/par/arch.ae.json -tf ae-wavenet/par/train.basic.json -hw TPU-single -pi 10 -lrr 1e-6 -lrs 0 -nw 100 -nb 8 -si 5000 ./basic.full.%.ckpt ae-wavenet/dat/librispeech.some.dat
%run ae-wavenet/train.py new -af ae-wavenet/par/arch.ae.json -tf ae-wavenet/par/train.basic.json -hw TPU-single -pi 1 -lrr 1e-6 -lrs 0 -nw 100 -nb 8 -si 5000 ./basic.full.%.ckpt ae-wavenet/dat/librispeech.some.dat
hrbigelow commented 4 years ago

Hi @dlibenzi

Here are the metrics and save_tensors files from the above run.
xla_metrics.03b7286.txt xla_save_tensors.03b7286.txt.gz

It ran for about 3 hours before I interrupted. The output is:

Already on 'master'
Your branch is up to date with 'origin/master'.
Already up to date.
env: TF_CPP_VMODULE=
env: XLA_IR_DEBUG=1
env: XLA_HLO_DEBUG=1
env: XLA_SAVE_TENSORS_FMT=hlo
Command line:  ae-wavenet/train.py new -af ae-wavenet/par/arch.ae.json -tf ae-wavenet/par/train.basic.json -hw TPU-single -pi 1 -lrr 1e-6 -lrs 0 -nw 100 -nb 8 -si 5000 ./basic.full.%.ckpt ae-wavenet/dat/librispeech.some.dat
env: XLA_METRICS_FILE=xla_metrics.03b7286.txt
env: XLA_SAVE_TENSORS_FILE=xla_save_tensors.03b7286.txt
Using TPU-single
Training parameters used:
Namespace(arch_file='ae-wavenet/par/arch.ae.json', bn_n_out=64, bn_type='ae', bn_vq_gamma=0.25, bn_vq_n_embed=4096, ckpt_template='./basic.full.%.ckpt', dat_file='ae-wavenet/dat/librispeech.some.dat', dec_filter_sz=2, dec_jitter_prob=0.12, dec_lc_upsample_filt_sizes=[25, 16, 16, 16], dec_lc_upsample_strides=[5, 4, 4, 4], dec_n_block_layers=10, dec_n_blocks=2, dec_n_dil=256, dec_n_global_embed=10, dec_n_lc_out=128, dec_n_post=256, dec_n_quant=256, dec_n_res=368, dec_n_skp=256, enc_n_out=768, hwtype='TPU-single', learning_rate_rates=[1e-06], learning_rate_steps=[0], max_steps=1e+20, n_batch=8, n_win_batch=100, pre_hop_sz=160, pre_n_mels=80, pre_n_mfcc=13, pre_sample_rate=16000, pre_win_sz=400, progress_interval=1, random_seed=2507, save_interval=5000, train_file='ae-wavenet/par/train.basic.json')
Initializing model and data source...Done.
step    loss    tprb_m  rec norm    bn_grad_sd
From worker 0
0   10.030  0.0043  6.517   3.513   2.96e-06
1   10.018  0.00367 6.503   3.515   6.89e-06
2   10.156  0.00413 6.567   3.588   3.09e-06
3   10.180  0.00399 6.582   3.598   4.93e-06
4   9.976   0.00387 6.528   3.448   3.05e-06
5   9.994   0.0038  6.494   3.501   3.16e-06
6   10.002  0.00359 6.686   3.316   2.97e-06
7   10.004  0.00342 6.659   3.345   3.97e-06
8   9.867   0.00434 6.436   3.430   4.29e-06
9   9.803   0.00388 6.532   3.270   5.5e-06
10  9.906   0.00372 6.558   3.348   3.51e-06
11  9.643   0.00362 6.534   3.109   3.95e-06
12  9.795   0.00417 6.381   3.414   3.85e-06
13  9.788   0.00331 6.578   3.210   3.48e-06

I also tried:

cat xla_save_tensors.03b7286.txt | ~/ai/xla/scripts/grab_graphs.py --graphdir=graphs03b7286 > xla.report.03b7286.txt

but it generated an empty report.

dlibenzi commented 4 years ago

Thank you! I will take a look to see if anything stands out. A 1h compile time is no good 😄

dlibenzi commented 4 years ago

Yeah, it is taking forever. We are debugging it internally.

hrbigelow commented 4 years ago

I will experiment with slashing parts of my model to see what is causing it.

Hi @dlibenzi so I am not finished quite yet with the ablation experiments but I wondered if dilated convolutions might pose a problem? WaveNet (the decoder portion of my model) uses a stack of 20 convolutions, each with kernel size 2, but dilation (1,2,4,8,...,512,1,2,4,8,...,512).

one dramatic speedup came from setting these dilations all to 1. although i'm not sure if thay's just because of the reduction in the size of activation tensors - I will be testing that as well soon...

The branch I'm testing is called ablation in the ae-wavenet repo.

hrbigelow commented 4 years ago

UPDATE 2:

I just realized that this change is actually not very minor, as it severs all the gradients that go back to the encoder. So, all this experimemt does is narrow down the problem to the encoder...

UPDATE:

Okay, so, I found out that the slowness is completely cured by this small change: here

wavenet.py:93

    def forward(self, x, cond):
        """
        B, T: batchsize, win_size (determined from input)
        C, R, D, S: n_cond, n_res, n_dil, n_skp
        x: (B, R, T) (necessary shape for Conv1d)
        cond: (B, C, T) (necessary shape for Conv1d)
        returns: sig: (B, R, T), skp: (B, S, T) 
        """
        #cond_lead = self.cond_lead()
        #skip_lead = self.skip_lead()

        # filt = self.conv_signal(x) + self.proj_signal(cond[:,:,self.cond_lead:])
        filt = self.conv_signal(x)
        # gate = self.conv_gate(x) + self.proj_gate(cond[:,:,self.cond_lead:])
        gate = self.conv_gate(x)
        z = torch.tanh(filt) * torch.sigmoid(gate)
        sig = self.dil_res(z)
        skp = self.dil_skp(z[:,:,self.skip_lead:])
        sig += x[:,:,self.left_wing_size:]
        return sig, skp 

The code above is the forward() in a GatedResidualCondConv -- there are 20 of them in an nn.Sequential in WaveNet. In the code above, the commented lines which include the added proj_signal and proj_gate are the original extremely slow version. by removing them, the code runs 50x faster.

Note that self.cond_lead is an int tensor initialized at setup and constant throughout training, although it has a different value for each of the 20 GatedResidualCondConv modules. However, the cond argument is the same tensor being injected into each of them.

Here is the HLO graph for the fast run

xla_save_tensors.4128fcf.txt.gz xla_metrics.4128fcf.txt

dlibenzi commented 4 years ago

Thanks! We suspected was something related to scatter/gather compilations. We are looking into it, modulo vacation days here in the US. Of course, you should not be changing your model to remove things, otherwise it won't be mathematically the same.

But yes, compilation time is MUCH faster in your last example:

Metric: CompileTime
  TotalSamples: 10
  Accumulator: 30s128ms368.026us
  ValueRate: 959ms51.005us / second
  Rate: 0.318322 / second
  Percentiles: 1%=021ms711.773us; 5%=021ms711.773us; 10%=107ms44.059us; 20%=120ms480.693us; 50%=03s807ms47.982us; 80%=10s034ms767.878us; 90%=11s725ms43.083us; 95%=11s725ms43.083us; 99%=11s725ms43.083us

But the graph is less than half the size:

Metric: SyncTensorsGraphSize
  TotalSamples: 1542
  Accumulator: 691398.00
  ValueRate: 18915.81 / second
  Rate: 41.6799 / second
  Percentiles: 1%=3.00; 5%=3.00; 10%=3.00; 20%=3.00; 50%=3.00; 80%=164.00; 90%=802.00; 95%=943.00; 99%=7429.00
hrbigelow commented 4 years ago

Sounds good, thanks @dlibenzi . I will continue the ablation experiments until I can excise the smallest possible piece which causes significant slowdown. I hope to find a graph with nearly the same size but fast compilation.

hrbigelow commented 4 years ago

Okay, so I believe I found the culprit now - it is torch.take

-        lc_dense_trim = torch.take(lc_dense,
-                lcond_slice.unsqueeze(1).expand(-1, D2, -1))
-        # lc_dense_trim = lc_dense[:,:,:2146]
+        # lc_dense_trim = torch.take(lc_dense,
+        #         lcond_slice.unsqueeze(1).expand(-1, D2, -1))
+        lc_dense_trim = lc_dense[:,:,:2146]

The 464ac56 is the slow one that includes torch.take, while a1c97e1 is the slow one that uses explicit slicing.

xla_save_tensors.464ac56.txt.gz xla_save_tensors.a1c97e1.txt.gz xla_metrics.464ac56.txt xla_metrics.a1c97e1.txt

dlibenzi commented 4 years ago

Does it happen if you do only the forward (just print the loss)? How many indices you hand over to take() WRT the total size of the source tensor? You can try playing with XLA_DENSE_GATHER_FACTOR (set it to 1 and 100000) to see if you read any difference. We know that slow compilations are linked to scatter/gather lowering, but this info can be useful.

hrbigelow commented 4 years ago

Hi @dlibenzi

Does it happen if you do only the forward (just print the loss)?

With torch.take and NO call to backward(), rate is 1 step/second With NO torch.take but with backward(), rate is 5 steps/second

How many indices you hand over to take() WRT the total size of the source tensor?

About 80% of the source tensor is taken. Moreover, the indices are consecutive runs - I am using torch.take to implement a slicing operation

You can try playing with XLA_DENSE_GATHER_FACTOR (set it to 1 and 100000) to see if you read any difference.

I tried both settings and there was no difference - in both cases, only two steps had completed in 5 minutes. However, without the environment variable set at all, it may have been slower still.

We know that slow compilations are linked to scatter/gather lowering, but this info can be useful.

dlibenzi commented 4 years ago

Yes, the slicing op is much faster than a gather. Especially when the gather lists many indices. Unfortunately in your case the slice window is not static (IIRC). We go back to the topic about XLA having a DynamicSlice op but no PyTorch op (and backward) that can leverage it. But in your case, even though take() is slower execution wise, it should not be taking forever to compile when present in graph (this is a XLA scatter lowering issue). I can try to sense whether having a new pytorch op leveraging it, could be an option.

hrbigelow commented 4 years ago

Hi @dlibenzi that would be great. I looked at the docs for DynamicSlice but actually that doesn't do what I need in this case, because I need a different start position for each value of the batch index, as described here.

For the moment I'm stuck, but I will ponder if there is a way to avoid this heterogeneous logic. I understand why it poses a problem.

I thought about two other possibilities:

1) is there any "tensor map" operation? that is, something that can apply the same op along slices of input tensors?

2) if I did use a for-loop like:

# b, batch_indices_tensor, and slice are all tensors on TPU
# also, the shape of batch_indices_tensor, slice, lcond, and lcond_trim are all
# constant throughout the program
for b in batch_indices_tensor:
    lcond_trim[b,...] = lcond[b, slice[b,0]:slice[b,1]]

would this work? I realize that the for-loop must be executed on CPU, but would pytorch-xla at least be able to detect that the operations themselves were of a static shape?

dlibenzi commented 4 years ago

DynamicSlice would help in your case because, you have variable start_indices and constant indices_sizes.

tensors = []
for b in batch_indices_tensor:
  tensors.append(torch.dynamic_slice(lcond, dim, b, const_window_size))
batch = torch.stack(tensors, dim)

Or:

tensors = []
for b in batch_indices_tensor:
  tensors.append(torch.dynamic_slice(lcond, b, const_window_size_tensor))
batch = torch.stack(tensors, dim)

If b is a multi-dimensional index.

The issue in your example is that b needs to be a tensor, in order to not exit from the Tensor World. And all the START+SIZE slicing ops I am aware of in pytorch, they take Python scalars.

But yes, your code above will work if you call b.item(), but that is going to pretty bad since it will trigger graph executions for every item() call. You could try to use this:

torch_xla._XLAC._xla_sync_multi([batch_indices_tensors...], [])

So that a single graph execution will materialize device data for all the tensors you are going to call item() for.

But this really needs a DynamicSlice to be decent.

hrbigelow commented 4 years ago

@dlibenzi Okay, it sounds like I'll have to rework my code to avoid the slicing.

One thing to point out: my call to torch.take() that was taking so long is actually:

# lc_dense.shape() == [B, 128, 2600]
# lc_dense_trim.shape() == [B, 128, 2400]
lc_dense_trim = torch.take(lc_dense, lcond_slice.unsqueeze(1).expand(-1, 128, -1))

The second dimension is sliced the same within a given value of the first (batch) dimension.

Would it be possible to have a version of torch.take() that does broadcasting and takes advantage of the low entropy of the index tensor with all of its repetition?

dlibenzi commented 4 years ago

We cannot just make up ops 😉 This needs to be coordinated with PyTorch mainstream, but in this case, take() is the wrong API.

Would you mind trying this with the slow case?

export XLA_USE_32BIT_LONG=1
dlibenzi commented 4 years ago

But that take (unless I misinterpreted your data formats), alone, does not seem to be the issue. The issue is the scatter that is coming off the backward.

import torch
import torch_xla
import torch_xla.core.xla_model as xm

device = xm.xla_device()
B = 16
lc_dense = torch.randn(B, 128, 2600, device=device)
lcond_slice = torch.arange(100, 2500, device=device).unsqueeze(0).unsqueeze(0).expand(B, 128, -1)
lc_dense_trim = torch.take(lc_dense, lcond_slice)
print(torch_xla._XLAC._get_xla_tensors_hlo([lc_dense_trim]))
HloModule IrToHlo.22

ENTRY %IrToHlo.22 (p0.8: f32[16,128,2600]) -> (f32[16,128,2400]) {
  %p0.8 = f32[16,128,2600]{2,1,0} parameter(0), metadata={op_type="xla::device_data" source_file="<module>@zfufu.py" source_line=12}
  %reshape.9 = f32[5324800]{0} reshape(f32[16,128,2600]{2,1,0} %p0.8), metadata={op_type="aten::take" source_file="<module>@zfufu.py" source_line=12}
  %constant.1 = s64[2400]{0} constant({...}), metadata={op_type="prim::Constant" source_file="<module>@zfufu.py" source_line=11}
  %reshape.2 = s64[1,2400]{1,0} reshape(s64[2400]{0} %constant.1), metadata={op_type="aten::view" source_file="<module>@zfufu.py" source_line=11}
  %reshape.3 = s64[1,1,2400]{2,1,0} reshape(s64[1,2400]{1,0} %reshape.2), metadata={op_type="aten::view" source_file="<module>@zfufu.py" source_line=11}
  %reshape.4 = s64[1,1,2400]{2,1,0} reshape(s64[1,1,2400]{2,1,0} %reshape.3), metadata={op_type="aten::expand" source_file="<module>@zfufu.py" source_line=11}
  %broadcast.5 = s64[1,1,2400]{2,1,0} broadcast(s64[1,1,2400]{2,1,0} %reshape.4), dimensions={0,1,2}, metadata={op_type="aten::expand" source_file="<module>@zfufu.py" source_line=11}
  %reshape.6 = s64[2400]{0} reshape(s64[1,1,2400]{2,1,0} %broadcast.5), metadata={op_type="aten::expand" source_file="<module>@zfufu.py" source_line=11}
  %broadcast.7 = s64[16,128,2400]{2,1,0} broadcast(s64[2400]{0} %reshape.6), dimensions={2}, metadata={op_type="aten::expand" source_file="<module>@zfufu.py" source_line=11}
  %reshape.10 = s64[4915200]{0} reshape(s64[16,128,2400]{2,1,0} %broadcast.7), metadata={op_type="aten::take" source_file="<module>@zfufu.py" source_line=12}
  %constant.12 = s64[] constant(0), metadata={op_type="aten::take" source_file="<module>@zfufu.py" source_line=12}
  %broadcast.13 = s64[4915200]{0} broadcast(s64[] %constant.12), dimensions={}, metadata={op_type="aten::take" source_file="<module>@zfufu.py" source_line=12}
  %compare.14 = pred[4915200]{0} compare(s64[4915200]{0} %reshape.10, s64[4915200]{0} %broadcast.13), direction=GE, metadata={op_type="aten::take" source_file="<module>@zfufu.py" source_line=12}
  %constant.11 = s64[] constant(5324800), metadata={op_type="aten::take" source_file="<module>@zfufu.py" source_line=12}
  %broadcast.15 = s64[4915200]{0} broadcast(s64[] %constant.11), dimensions={}, metadata={op_type="aten::take" source_file="<module>@zfufu.py" source_line=12}
  %add.16 = s64[4915200]{0} add(s64[4915200]{0} %reshape.10, s64[4915200]{0} %broadcast.15), metadata={op_type="aten::take" source_file="<module>@zfufu.py" source_line=12}
  %select.17 = s64[4915200]{0} select(pred[4915200]{0} %compare.14, s64[4915200]{0} %reshape.10, s64[4915200]{0} %add.16), metadata={op_type="aten::take" source_file="<module>@zfufu.py" source_line=12}
  %convert.18 = u32[4915200]{0} convert(s64[4915200]{0} %select.17), metadata={op_type="aten::take" source_file="<module>@zfufu.py" source_line=12}
  %gather.19 = f32[4915200]{0} gather(f32[5324800]{0} %reshape.9, u32[4915200]{0} %convert.18), offset_dims={}, collapsed_slice_dims={0}, start_index_map={0}, index_vector_dim=1, slice_sizes={1}, metadata={op_type="aten::take" source_file="<module>@zfufu.py" source_line=12}
  %reshape.20 = f32[16,128,2400]{2,1,0} reshape(f32[4915200]{0} %gather.19), metadata={op_type="aten::take" source_file="<module>@zfufu.py" source_line=12}
  ROOT %tuple.21 = (f32[16,128,2400]{2,1,0}) tuple(f32[16,128,2400]{2,1,0} %reshape.20)
}

Which compiles in 2s.

dlibenzi commented 4 years ago

An example of API which would work in your case is a torch.narrow() where the start argument is a tensor instead of a scalar:

https://pytorch.org/docs/stable/torch.html#torch.narrow

This is a relatively simple extension on the PyTorch side, which would allow us, via an API like the one below on the C++ side, to map that to XLA DynamicSlice and DynamicUpdateSlice (for its backward):

static at::Tensor narrow_copy(const at::Tensor & self, int64_t dim, const at::Tensor& start, int64_t length);
hrbigelow commented 4 years ago

But that take (unless I misinterpreted your data formats), alone, does not seem to be the issue. The issue is the scatter that is coming off the backward.

Actually no. Even in the experiments without the backward call (first two table rows), the presence of take() still slows the iteration to a crawl.

Here is a four-way comparison. third column is number of seconds from steps 10-110. The first several steps are slower and then the rate levels out. In the last two experiments I didn't see the first step at all for more than 10 minutes, and in previous experiments that ran overnight, I saw anywhere from 2-30 steps elapsed. If you would like, I can run these experiments and get the actual times.

backward() torch.take sec (s10-s110) commit
no no 77 1c19037
no yes -- 6a2d625
yes no 137 28e9afc
yes yes -- 57dd0cb
dlibenzi commented 4 years ago

It is hard to tell what ends up in the graph with your experiments. We do have an internal bug open on XLA scatter compilation bottleneck, but we have not seen any issues with XLA gather. Maybe using the new debug_run.py script, let run 10..20 steps, and post the debug archives?

./scripts/debug_run.py --tidy --outfile /tmp/debug_run.tar.gz -- python -u SCRIPT ARGS...
hrbigelow commented 4 years ago

Maybe using the new debug_run.py script, let run 10..20 steps, and post the debug archives?

Hi @dlibenzi Here is a run (with 200 steps actually). This run, checksum 57dd0cb, has both the backward() call and the take() call. The 'dlo32' means decoder local condition output channels = 32. Previous runs (that took all night to produce 30 steps) would have had dlo = 128, thus the take() would've been four times larger.

UPDATE: adding in other runs, all run for 20 steps, reporting interval 20, dlo=128

backward() torch.take commit debug_out
no no 1c19037 1c19037.dlo128.debug_run.tar.gz
no yes 6a2d625 6a2d625.dlo128.debug_run.tar.gz
yes no 28e9afc 28e9afc.dlo128.debug_run.tar.gz
yes yes 57dd0cb 57dd0cb.dlo128.debug_run.tar.gz
dlibenzi commented 4 years ago

I see you using --progress_interval=1. This ends up generating a lot of item() calls per step:

Counter: MarkStep
  Value: 232
Counter: aten::_local_scalar_dense
  Value: 23562

Can we try to set that to something like 20?

The compilations do not look too bad in your data. They stabilize at 10, you have a couple at around 5m, and the other in the seconds range.

But if you say the bigger (128) one is taking long time, we need a few steps of that data as well. Very likely that is going to generate an XLA gather with A LOT of indices, which could create issues.

hrbigelow commented 4 years ago

Sure, I will do a debug run with progress interval set to 10 and dlo = 128 and post result.

So, when you say the compilations don't look too bad, does that mean there could be something causing the extreme slowness? Because, even with this run with dlo=32, the per-step rate is about 10x slower than without torch.take(). And with dlo=128 it will be on the order of 1000x slower (ie 20 steps overnight, as opposed to 20000)

On Sat, Nov 30, 2019 at 5:15 PM Davide Libenzi notifications@github.com wrote:

I see you using --progress_interval=1. This ends up generating a lot of item() calls per step:

Counter: MarkStep Value: 232 Counter: aten::_local_scalar_dense Value: 23562

Can we try to set that to something like 20?

The compilations do not look too bad in your data. They stabilize at 10, you have a couple at around 5m, and the other in the seconds range.

But if you say the bigger (128) one is taking long time, we need a few steps of that data as well. Very likely that is going to generate an XLA gather with A LOT of indices, which could create issues.

— You are receiving this because you were mentioned. Reply to this email directly, view it on GitHub https://github.com/pytorch/xla/issues/1412?email_source=notifications&email_token=ABI3OFTLYFCXHRGTYIA7CEDQWMF3JA5CNFSM4JRCN7C2YY3PNVWWK3TUL52HS4DFVREXG43VMVBW63LNMVXHJKTDN5WW2ZLOORPWSZGOEFQXQ3Q#issuecomment-560035950, or unsubscribe https://github.com/notifications/unsubscribe-auth/ABI3OFRNRAHLU763RTPEPFDQWMF3JANCNFSM4JRCN7CQ .

dlibenzi commented 4 years ago

Doing 100 item() call per step does not help 😉 The reporting of things like loss, and other intermediate values should be throttled even N step (we use 20 in our example code). And using xm.add_step_closure() as further optimization:

https://github.com/pytorch/xla/blob/59ce59c77248c5557e4e5ecb91ac4d41391b1d03/test/test_train_mp_mnist.py#L129

hrbigelow commented 4 years ago

Hi @dlibenzi

Here is the report using dlo=128, reporting interval 20, for 20 steps. The revision of the code used is 57dd0cb, which includes both the torch.take() call, and the backward() call.

57dd0cb.dlo128.debug_run.tar.gz

dlibenzi commented 4 years ago

Thanks! Now the number of item() calls is much less and we go back to 1h+ compilation times. The compilations stabilizes (which is good), but 1h+ compile times are so bad it's not even funny. We have an internal bug open. The take(), alone (see my HLO graph above), doe not seem to trip the compilation bug. Must be its interactions with bigger graphs. So far I don't see much we can do to hack things quickly. We will look if we can add the narrow() extension, and I will push for the XLA compile time bug fix.

hrbigelow commented 4 years ago

Hi @dlibenzi see update to this comment added other runs' debug output.

dlibenzi commented 4 years ago

We have found the compilation issue. Should be in tomorrow's nightly builds of the TPU VM (pytorch-nightly). But the issue will be, even though compilation will be faster, execution time will be very slow, as torch.take() is not the right API for this case.

hrbigelow commented 4 years ago

Great work, @dlibenzi and all! Really appreciate it. In the mean time I'm rewriting my model to not rely on torch.take(). Am looking forward to deploying on TPU with the help of torch xla!

Thanks very much. Feel free to close the issue.

dlibenzi commented 4 years ago

Hopefully we can get the "dynamic narrow" in, in a relatively short time.