pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.47k stars 471 forks source link

Optimizing the implementation of Longformer #2368

Open ibeltagy opened 4 years ago

ibeltagy commented 4 years ago

❓ Questions and Help

I have managed to run a version of Longformer on pytorch-xla. It is memory efficient and reasonably fast. That said, it is still 2x slower than what I would expect, so if you guys have any insights on how to optimize the model implementation, that would be great. The model code is here. It is the same as RoBERTa with the only difference being the selfattention operation. In particular, the two matrix multiplications here and here are replaced with the two functions _sliding_chunks_matmul_qk and _sliding_chunks_matmul_pv. I am also attaching the debug output which has a dump of the HLO graph debug.tar.gz.

The only difference between the GPU code and the TPU code is that the GPU code uses as_strided while the TPU code uses an iterative version of unfold (the if statement here. Check issue https://github.com/pytorch/xla/issues/2239 for more details.)

Thank you.

JackCaoG commented 4 years ago

Sorry, super busy this week. I will try to run _sliding_chunks_matmul_qk and check the hlo generated next week.

JackCaoG commented 4 years ago

I looked at the debug output you posted, it seems like compilation stabilized after step 4.

[MetricsData; step=7]
Metric: CompileTime
  TotalSamples: 4
  Accumulator: 02m10s739ms497.948us
  ValueRate: 891ms118.518us / second
  Rate: 0.0274741 / second
  Percentiles: 1%=387ms939.013us; 5%=387ms939.013us; 10%=387ms939.013us; 20%=387ms939.013us; 50%=53s200ms848.581us; 80%=53s279ms453.114us; 90%=53s279ms453.114us; 95%=53s279ms453.114us; 99%=53s279ms453.114us
Metric: DeviceLockWait
  TotalSamples: 14
  Accumulator: 004ms757.483us
  ValueRate: 021.585us / second
  Rate: 0.0804229 / second
  Percentiles: 1%=001.957us; 5%=001.957us; 10%=002.068us; 20%=002.168us; 50%=003.370us; 80%=004.188us; 90%=004.416us; 95%=004ms717.704us; 99%=004ms717.704us
Metric: ExecuteTime
  TotalSamples: 12
  Accumulator: 06s548ms137.257us
  ValueRate: 032ms024.149us / second
  Rate: 0.0692647 / second
  Percentiles: 1%=009ms817.567us; 5%=009ms817.567us; 10%=304ms702.798us; 20%=304ms040.119us; 50%=352ms252.134us; 80%=698ms730.039us; 90%=777ms375.913us; 95%=788ms461.432us; 99%=788ms461.432us

From the above metric, it seems like majority of the time was spent on Compile (2m vs 6s). Could you run it with a bit more steps and post the debug data again? That will help us better understanding the speed bottleneck. 😄

ibeltagy commented 4 years ago

Here's the new debug info: debug.tar.gz. I ran it for more steps but still, the last CompileTime Accumulator shows 2minutes even though if you check the logs file you will see that it is done compiling in the first few steps and it is not compiling anymore at the end. Could be an issue with the step-wise metric report?

JackCaoG commented 4 years ago

Yup, from the metric it looks like it does not recompile since step3(TotalSamples stayed at 4 from step 3->step 15). Accumulator just shows accumulated time spent during this training. We will try to come up with a document that explain the metric better soon.

[MetricsData; step=15]
Metric: CompileTime
  TotalSamples: 4
  Accumulator: 02m16s671ms485.638us
  ValueRate: 895ms494.158us / second
  Rate: 0.0264018 / second
  Percentiles: 1%=420ms911.135us; 5%=420ms911.135us; 10%=420ms911.135us; 20%=420ms911.135us; 50%=55s485ms865.246us; 80%=56s059ms455.713us; 90%=56s059ms455.713us; 95%=56s059ms455.713us; 99%=56s059ms455.713us
Metric: DeviceLockWait
  TotalSamples: 30
  Accumulator: 003ms086.302us
  ValueRate: 012.542us / second
  Rate: 0.121917 / second
  Percentiles: 1%=001.792us; 5%=001.867us; 10%=001.955us; 20%=002.106us; 50%=003.197us; 80%=003.529us; 90%=003.974us; 95%=004.257us; 99%=003ms004.072us
Metric: ExecuteTime
  TotalSamples: 28
  Accumulator: 14s586ms120.483us
  ValueRate: 055ms405.222us / second
  Rate: 0.114186 / second
  Percentiles: 1%=010ms718.758us; 5%=303ms316.149us; 10%=304ms597.979us; 20%=304ms652.587us; 50%=359ms343.030us; 80%=698ms693.051us; 90%=699ms788.650us; 95%=781ms773.674us; 99%=797ms489.440us

It looks like each step execute time is around 1s, 50%=359ms343.030us also looks reasonable. I don't think there is a bug in our compiler lowering but there might be things we can improve. How many steps do you run for this model? Does this 1s execute time per steps matches your observation when running the full model?

ibeltagy commented 4 years ago

Yes, it takes slightly longer than 1s/step which is usable but slower than on GPUs and I was wondering if you have insights on how to optimize it. AFAIK, small changes in things like slicing and reshaping can lead to big gains on TPUs.

For reference, the hugginface RoBERTa model on TPU is 1.18x faster than on GPUs, while Longformer is like 1.7x slower making it 2x slower than expected. As mentioned earlier, this is probably just optimizing the functions _sliding_chunks_matmul_qk and _sliding_chunks_matmul_pv.

JackCaoG commented 4 years ago

@ibeltagy I don't see any obvious optimization. @davidel any thoughts? I will try to run _sliding_chunks_matmul_qk and _sliding_chunks_matmul_pv separately, get the HLO and check with the XLA team.

davidel commented 4 years ago

Certainly that unfold loop is much more expensive WRT CPU/GPU which can represent the tensor slices with stride/size games over the existing buffer. To try to understand where the bottleneck might be, it'd require exporting XLA_IR_DEBUG and XLA_HLO_DEBUG (it will run slow) and getting an xprof trace. But it should almost certainly be the unfold loop thing.

JackCaoG commented 4 years ago

Hi @ibeltagy , I worked with Blake from the XLA team and have this new convolution approach of unfold which use much less memory than my first approach. This should also be faster than the slice approach. If you could try this idea and let us know if this helps that would be great. If you encounter any long compile and slower execute time, we are happy to investigate.

import torch                                                                                                                                                                                                                                       
import torch_xla                                                                                                                                                                                                                                   
import torch_xla.core.xla_model as xm                                                                                                                                                                                                              

# [12, 4096, 64], len=768, step=256                                                                                                                                                                                                                
base = torch.randn(12, 4096, 64, device=xm.xla_device())                                                                                                                                                                                           
res_native = base.cpu().unfold(1,768,256) # [12, 14, 64, 768]                                                                                                                                                                                      

# conv approach on xla                                                                                                                                                                                                                             
reshape_base = base.reshape([12,16,256,64])                                                                                                                                                                                                        
# transpose the input and filter to better utilize the TPU performance                                                                                                                                                                             
transpose_base = reshape_base.permute(3, 2, 0 ,1) # [64, 256, 12, 16]                                                                                                                                                                              
filter = torch.eye(768, device=xm.xla_device()).view([1,3,256,768]).permute(3,2,0,1) # [768,256,1,3]                                                                                                                                               
res = torch.nn.functional.conv2d(transpose_base, filter) # [64, 768, 12, 16]                                                                                                                                                                       
res = res.permute(2,3,0,1) # [12, 14, 64, 768]                                                                                                                                                                                                     

print(res.equal(res_native)) 

I chatted with Blake about the general rule of thumb to optimize the performance on the TPU. I think the idea is to do things in as big of pieces as possible.The input was vectorized by [12,64], now after the reshape and transpose it is vectorized by [12,64,256], which maps better to TPU's 8, 128 registers.

JackCaoG commented 4 years ago

If this approach works well I will change my lowering of unfold to this so you won't need this manual work in the future.

ibeltagy commented 4 years ago

Thanks. Will give it a try and let you know.

ibeltagy commented 4 years ago

Here's the IR graph with the iterative unfold. Anything stands out other than the unfold? how do you know which operations are expensive? I will try the convolution trick and post the IR graph for that too.

IR {
  %0 = bf16[768,768]{1,0} xla::device_data(), location=linear@functional.py:1676, device=TPU:0
  %1 = bf16[768,768]{0,1} aten::permute(%0), location=linear@functional.py:1676, dims=(1, 0), ROOT=0
  %2 = s64[1,4096]{1,0} xla::device_data(), location=forward@test_tpu.py:53, device=TPU:0
  %3 = s64[1,4096]{1,0} xla::select(%2), location=forward@test_tpu.py:53, dim=0, start=0, end=1, stride=1
  %4 = s64[1,4096]{1,0} xla::select(%3), location=forward@test_tpu.py:53, dim=1, start=0, end=4096, stride=1
  %5 = s64[1,4096,1]{2,1,0} aten::view(%4), location=forward@test_tpu.py:53, output_size=(1, 4096, 1)
  %6 = s64[1,4096,768]{2,1,0} aten::expand(%5), location=forward@test_tpu.py:53, size=(1, 4096, 768)
  %7 = bf16[1,4096,768]{2,1,0} xla::cast(%6), location=forward@test_tpu.py:53, type=bf16, dtype=Float, stype=Long
  %8 = bf16[4096,1,768]{2,0,1} aten::permute(%7), location=linear@functional.py:1676, dims=(1, 0, 2)
  %9 = bf16[4096,768]{1,0} aten::view(%8), location=linear@functional.py:1676, output_size=(4096, 768), ROOT=1
  %10 = bf16[768,768]{1,0} xla::device_data(), location=linear@functional.py:1676, device=TPU:0
  %11 = bf16[768,768]{0,1} aten::permute(%10), location=linear@functional.py:1676, dims=(1, 0), ROOT=2
  %12 = bf16[4096,1,768]{2,0,1} aten::permute(%7), location=linear@functional.py:1676, dims=(1, 0, 2)
  %13 = bf16[4096,768]{1,0} aten::view(%12), location=linear@functional.py:1676, output_size=(4096, 768), ROOT=3
  %14 = bf16[768,768]{1,0} xla::device_data(), location=linear@functional.py:1676, device=TPU:0
  %15 = bf16[768,768]{0,1} aten::permute(%14), location=linear@functional.py:1676, dims=(1, 0), ROOT=4
  %16 = bf16[4096,1,768]{2,0,1} aten::permute(%7), location=linear@functional.py:1676, dims=(1, 0, 2)
  %17 = bf16[4096,768]{1,0} aten::view(%16), location=linear@functional.py:1676, output_size=(4096, 768), ROOT=5
  %18 = bf16[] xla::device_data(), location=forward@longformer.py:152, device=TPU:0
  %19 = bf16[] prim::Constant(), location=linear@functional.py:1678, value=1
  %20 = bf16[768]{0} aten::expand(%19), location=linear@functional.py:1678, size=(768)
  %21 = bf16[768]{0} xla::device_data(), location=linear@functional.py:1678, device=TPU:0
  %22 = bf16[768]{0} aten::mul(%21, %20), location=linear@functional.py:1678
  %23 = bf16[4096,768]{1,0} aten::mm(%9, %1), location=linear@functional.py:1676
  %24 = bf16[4096,1,768]{2,1,0} aten::view(%23), location=linear@functional.py:1678, output_size=(4096, 1, 768)
  %25 = bf16[4096,1,768]{2,1,0} aten::add(%24, %22), location=linear@functional.py:1678
  %26 = bf16[4096,768]{1,0} aten::view(%25), location=forward@longformer.py:152, output_size=(4096, 768)
  %27 = bf16[4096,1,768]{2,1,0} aten::view(%26), location=forward@longformer.py:152, output_size=(4096, 1, 768)
  %28 = bf16[4096,1,768]{2,1,0} aten::div(%27, %18), location=forward@longformer.py:152
  %29 = bf16[4096,768]{1,0} aten::view(%28), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 768)
  %30 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %31 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%30), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %32 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%31), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %33 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%32), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %34 = bf16[12,4096,64]{2,1,0} aten::view(%33), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %35 = bf16[12,512,64]{2,1,0} xla::select(%34), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=0, end=512, stride=1, ROOT=6
  %36 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %37 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%36), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %38 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%37), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %39 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%38), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %40 = bf16[12,4096,64]{2,1,0} aten::view(%39), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %41 = bf16[12,512,64]{2,1,0} xla::select(%40), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=256, end=768, stride=1, ROOT=7
  %42 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %43 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%42), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %44 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%43), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %45 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%44), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %46 = bf16[12,4096,64]{2,1,0} aten::view(%45), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %47 = bf16[12,512,64]{2,1,0} xla::select(%46), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=512, end=1024, stride=1, ROOT=8
  %48 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %49 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%48), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %50 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%49), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %51 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%50), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %52 = bf16[12,4096,64]{2,1,0} aten::view(%51), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %53 = bf16[12,512,64]{2,1,0} xla::select(%52), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=768, end=1280, stride=1, ROOT=9
  %54 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %55 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%54), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %56 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%55), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %57 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%56), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %58 = bf16[12,4096,64]{2,1,0} aten::view(%57), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %59 = bf16[12,512,64]{2,1,0} xla::select(%58), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1024, end=1536, stride=1, ROOT=10
  %60 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %61 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%60), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %62 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%61), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %63 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%62), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %64 = bf16[12,4096,64]{2,1,0} aten::view(%63), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %65 = bf16[12,512,64]{2,1,0} xla::select(%64), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1280, end=1792, stride=1, ROOT=11
  %66 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %67 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%66), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %68 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%67), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %69 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%68), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %70 = bf16[12,4096,64]{2,1,0} aten::view(%69), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %71 = bf16[12,512,64]{2,1,0} xla::select(%70), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1536, end=2048, stride=1, ROOT=12
  %72 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %73 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%72), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %74 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%73), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %75 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%74), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %76 = bf16[12,4096,64]{2,1,0} aten::view(%75), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %77 = bf16[12,512,64]{2,1,0} xla::select(%76), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1792, end=2304, stride=1, ROOT=13
  %78 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %79 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%78), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %80 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%79), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %81 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%80), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %82 = bf16[12,4096,64]{2,1,0} aten::view(%81), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %83 = bf16[12,512,64]{2,1,0} xla::select(%82), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2048, end=2560, stride=1, ROOT=14
  %84 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %85 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%84), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %86 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%85), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %87 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%86), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %88 = bf16[12,4096,64]{2,1,0} aten::view(%87), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %89 = bf16[12,512,64]{2,1,0} xla::select(%88), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2304, end=2816, stride=1, ROOT=15
  %90 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %91 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%90), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %92 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%91), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %93 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%92), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %94 = bf16[12,4096,64]{2,1,0} aten::view(%93), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %95 = bf16[12,512,64]{2,1,0} xla::select(%94), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2560, end=3072, stride=1, ROOT=16
  %96 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %97 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%96), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %98 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%97), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %99 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%98), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %100 = bf16[12,4096,64]{2,1,0} aten::view(%99), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %101 = bf16[12,512,64]{2,1,0} xla::select(%100), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2816, end=3328, stride=1, ROOT=17
  %102 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %103 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%102), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %104 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%103), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %105 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%104), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %106 = bf16[12,4096,64]{2,1,0} aten::view(%105), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %107 = bf16[12,512,64]{2,1,0} xla::select(%106), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=3072, end=3584, stride=1, ROOT=18
  %108 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %109 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%108), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %110 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%109), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %111 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%110), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %112 = bf16[12,4096,64]{2,1,0} aten::view(%111), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %113 = bf16[12,512,64]{2,1,0} xla::select(%112), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=3328, end=3840, stride=1, ROOT=19
  %114 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %115 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%114), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %116 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%115), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %117 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%116), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %118 = bf16[12,4096,64]{2,1,0} aten::view(%117), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %119 = bf16[12,512,64]{2,1,0} xla::select(%118), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=3584, end=4096, stride=1, ROOT=20
  %120 = bf16[] prim::Constant(), location=linear@functional.py:1678, value=1
  %121 = bf16[768]{0} aten::expand(%120), location=linear@functional.py:1678, size=(768)
  %122 = bf16[768]{0} xla::device_data(), location=linear@functional.py:1678, device=TPU:0
  %123 = bf16[768]{0} aten::mul(%122, %121), location=linear@functional.py:1678
  %124 = bf16[4096,768]{1,0} aten::mm(%13, %11), location=linear@functional.py:1676
  %125 = bf16[4096,1,768]{2,1,0} aten::view(%124), location=linear@functional.py:1678, output_size=(4096, 1, 768)
  %126 = bf16[4096,1,768]{2,1,0} aten::add(%125, %123), location=linear@functional.py:1678
  %127 = bf16[4096,768]{1,0} aten::view(%126), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 768)
  %128 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %129 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%128), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %130 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%129), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %131 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%130), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %132 = bf16[12,4096,64]{2,1,0} aten::view(%131), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %133 = bf16[12,512,64]{2,1,0} xla::select(%132), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=0, end=512, stride=1, ROOT=21
  %134 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %135 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%134), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %136 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%135), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %137 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%136), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %138 = bf16[12,4096,64]{2,1,0} aten::view(%137), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %139 = bf16[12,512,64]{2,1,0} xla::select(%138), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=256, end=768, stride=1, ROOT=22
  %140 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %141 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%140), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %142 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%141), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %143 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%142), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %144 = bf16[12,4096,64]{2,1,0} aten::view(%143), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %145 = bf16[12,512,64]{2,1,0} xla::select(%144), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=512, end=1024, stride=1, ROOT=23
  %146 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %147 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%146), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %148 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%147), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %149 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%148), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %150 = bf16[12,4096,64]{2,1,0} aten::view(%149), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %151 = bf16[12,512,64]{2,1,0} xla::select(%150), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=768, end=1280, stride=1, ROOT=24
  %152 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %153 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%152), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %154 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%153), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %155 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%154), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %156 = bf16[12,4096,64]{2,1,0} aten::view(%155), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %157 = bf16[12,512,64]{2,1,0} xla::select(%156), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1024, end=1536, stride=1, ROOT=25
  %158 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %159 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%158), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %160 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%159), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %161 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%160), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %162 = bf16[12,4096,64]{2,1,0} aten::view(%161), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %163 = bf16[12,512,64]{2,1,0} xla::select(%162), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1280, end=1792, stride=1, ROOT=26
  %164 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %165 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%164), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %166 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%165), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %167 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%166), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %168 = bf16[12,4096,64]{2,1,0} aten::view(%167), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %169 = bf16[12,512,64]{2,1,0} xla::select(%168), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1536, end=2048, stride=1, ROOT=27
  %170 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %171 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%170), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %172 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%171), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %173 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%172), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %174 = bf16[12,4096,64]{2,1,0} aten::view(%173), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %175 = bf16[12,512,64]{2,1,0} xla::select(%174), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1792, end=2304, stride=1, ROOT=28
  %176 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %177 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%176), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %178 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%177), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %179 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%178), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %180 = bf16[12,4096,64]{2,1,0} aten::view(%179), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %181 = bf16[12,512,64]{2,1,0} xla::select(%180), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2048, end=2560, stride=1, ROOT=29
  %182 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %183 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%182), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %184 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%183), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %185 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%184), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %186 = bf16[12,4096,64]{2,1,0} aten::view(%185), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %187 = bf16[12,512,64]{2,1,0} xla::select(%186), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2304, end=2816, stride=1, ROOT=30
  %188 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %189 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%188), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %190 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%189), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %191 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%190), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %192 = bf16[12,4096,64]{2,1,0} aten::view(%191), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %193 = bf16[12,512,64]{2,1,0} xla::select(%192), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2560, end=3072, stride=1, ROOT=31
  %194 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %195 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%194), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %196 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%195), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %197 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%196), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %198 = bf16[12,4096,64]{2,1,0} aten::view(%197), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %199 = bf16[12,512,64]{2,1,0} xla::select(%198), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2816, end=3328, stride=1, ROOT=32
  %200 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %201 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%200), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %202 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%201), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %203 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%202), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %204 = bf16[12,4096,64]{2,1,0} aten::view(%203), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %205 = bf16[12,512,64]{2,1,0} xla::select(%204), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=3072, end=3584, stride=1, ROOT=33
  %206 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %207 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%206), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %208 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%207), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %209 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%208), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %210 = bf16[12,4096,64]{2,1,0} aten::view(%209), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %211 = bf16[12,512,64]{2,1,0} xla::select(%210), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=3328, end=3840, stride=1, ROOT=34
  %212 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
  %213 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%212), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
  %214 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%213), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
  %215 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%214), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
  %216 = bf16[12,4096,64]{2,1,0} aten::view(%215), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
  %217 = bf16[12,512,64]{2,1,0} xla::select(%216), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=3584, end=4096, stride=1, ROOT=35
  %218 = bf16[12,15,512,64]{3,2,1,0} aten::stack(%35, %41, %47, %53, %59, %65, %71, %77, %83, %89, %95, %101, %107, %113, %119), location=_unfold_loop@sliding_chunks.py:17, dim=1
  %219 = bf16[12,15,512,64]{3,2,1,0} aten::permute(%218), location=einsum@functional.py:327, dims=(0, 1, 2, 3)
  %220 = bf16[12,15,512,1,64]{4,3,2,1,0} aten::view(%219), location=einsum@functional.py:327, output_size=(12, 15, 512, 1, 64)
  %221 = bf16[12,15,512,64,1]{3,4,2,1,0} aten::permute(%220), location=einsum@functional.py:327, dims=(0, 1, 2, 4, 3)
  %222 = bf16[180,512,64]{2,1,0} aten::view(%221), location=einsum@functional.py:327, output_size=(180, 512, 64), ROOT=36
  %223 = bf16[12,15,512,64]{3,2,1,0} aten::stack(%133, %139, %145, %151, %157, %163, %169, %175, %181, %187, %193, %199, %205, %211, %217), location=_unfold_loop@sliding_chunks.py:17, dim=1
  %224 = bf16[12,15,512,64]{3,2,1,0} aten::permute(%223), location=einsum@functional.py:327, dims=(0, 1, 2, 3)
  %225 = bf16[12,15,1,512,64]{4,3,2,1,0} aten::view(%224), location=einsum@functional.py:327, output_size=(12, 15, 1, 512, 64)
  %226 = bf16[12,15,64,512,1]{2,3,4,1,0} aten::permute(%225), location=einsum@functional.py:327, dims=(0, 1, 4, 3, 2)
  %227 = bf16[180,64,512]{2,1,0} aten::view(%226), location=einsum@functional.py:327, output_size=(180, 64, 512), ROOT=37
  %228 = pred[1,256,1,257]{3,1,2,0} xla::device_data(), location=_get_invalid_locations_mask@diagonaled_mm_tvm.py:312, device=TPU:0
  %229 = pred[] prim::Constant(), location=_get_invalid_locations_mask@diagonaled_mm_tvm.py:312, value=0
  %230 = pred[1,256,1,257]{3,2,1,0} aten::expand(%229), location=_get_invalid_locations_mask@diagonaled_mm_tvm.py:312, size=(1, 256, 1, 257)
  %231 = pred[1,256,1,257]{3,2,1,0} xla::as_strided_view_update(%230, %228), location=mask_invalid_locations@diagonaled_mm_tvm.py:323, size=(1, 256, 1, 257), stride=(65792, 257, 257, 1), storage_offset=0
  %232 = pred[1,256,1,257]{3,2,1,0} aten::as_strided(%231), location=mask_invalid_locations@diagonaled_mm_tvm.py:323, size=(1, 256, 1, 257), stride=(65792, 257, 257, 1), storage_offset=0, ROOT=39
  %233 = pred[1,256,12,257]{3,2,1,0} aten::expand(%232), location=mask_invalid_locations@diagonaled_mm_tvm.py:323, size=(1, 256, 12, 257), ROOT=42
  %234 = pred[1,256,1,257]{3,1,2,0} xla::device_data(), location=_get_invalid_locations_mask@diagonaled_mm_tvm.py:313, device=TPU:0
  %235 = pred[] prim::Constant(), location=_get_invalid_locations_mask@diagonaled_mm_tvm.py:313, value=0
  %236 = pred[1,256,1,257]{3,2,1,0} aten::expand(%235), location=_get_invalid_locations_mask@diagonaled_mm_tvm.py:313, size=(1, 256, 1, 257)
  %237 = pred[1,256,1,257]{3,2,1,0} xla::as_strided_view_update(%236, %234), location=mask_invalid_locations@diagonaled_mm_tvm.py:319, size=(1, 256, 1, 257), stride=(65792, 257, 257, 1), storage_offset=0
  %238 = pred[1,256,1,257]{3,2,1,0} aten::as_strided(%237), location=mask_invalid_locations@diagonaled_mm_tvm.py:319, size=(1, 256, 1, 257), stride=(65792, 257, 257, 1), storage_offset=0, ROOT=40
  %239 = pred[1,256,12,257]{3,2,1,0} aten::expand(%238), location=mask_invalid_locations@diagonaled_mm_tvm.py:319, size=(1, 256, 12, 257), ROOT=41
  %240 = bf16[180,512,512]{2,1,0} aten::matmul(%222, %227), location=einsum@functional.py:327
  %241 = bf16[12,15,512,1,512]{4,3,2,1,0} aten::view(%240), location=_pad@functional.py:3547, output_size=(12, 15, 512, 1, 512)
  %242 = bf16[12,15,512,512,1]{3,4,2,1,0} aten::permute(%241), location=_pad@functional.py:3547, dims=(0, 1, 2, 4, 3)
  %243 = bf16[12,15,512,512]{3,2,1,0} aten::view(%242), location=_pad@functional.py:3547, output_size=(12, 15, 512, 512)
  %244 = bf16[12,15,513,512]{3,2,1,0} aten::constant_pad_nd(%243), location=_pad@functional.py:3547, pad=(0, 0, 0, 1, 0, 0, 0, 0), value=0
  %245 = bf16[12,15,512,513]{3,2,1,0} aten::view(%244), location=sliding_chunks_matmul_qk@sliding_chunks.py:97, output_size=(12, 15, 512, 513)
  %246 = bf16[12,15,512,513]{3,2,1,0} xla::select(%245), location=sliding_chunks_matmul_qk@sliding_chunks.py:97, dim=0, start=0, end=12, stride=1
  %247 = bf16[12,1,512,513]{3,2,1,0} xla::generic_slice(%246), location=sliding_chunks_matmul_qk@sliding_chunks.py:97, base_indices=(0, 0, 0, 0), sizes=(12, 1, 512, 513)
  %248 = bf16[12,512,513]{2,1,0} aten::view(%247), location=sliding_chunks_matmul_qk@sliding_chunks.py:97, output_size=(12, 512, 513)
  %249 = bf16[12,255,513]{2,1,0} xla::select(%248), location=sliding_chunks_matmul_qk@sliding_chunks.py:97, dim=1, start=0, end=255, stride=1
  %250 = bf16[12,255,255]{2,1,0} xla::select(%249), location=sliding_chunks_matmul_qk@sliding_chunks.py:97, dim=2, start=258, end=513, stride=1
  %251 = bf16[12,255,255]{2,1,0} aten::view(%250), location=sliding_chunks_matmul_qk@sliding_chunks.py:97, output_size=(12, 255, 255)
  %252 = bf16[12,15,512,513]{3,2,1,0} aten::view(%244), location=sliding_chunks_matmul_qk@sliding_chunks.py:96, output_size=(12, 15, 512, 513)
  %253 = bf16[12,15,512,513]{3,2,1,0} xla::select(%252), location=sliding_chunks_matmul_qk@sliding_chunks.py:96, dim=0, start=0, end=12, stride=1
  %254 = bf16[12,15,512,513]{3,2,1,0} xla::select(%253), location=sliding_chunks_matmul_qk@sliding_chunks.py:96, dim=1, start=0, end=15, stride=1
  %255 = bf16[12,15,256,513]{3,2,1,0} xla::select(%254), location=sliding_chunks_matmul_qk@sliding_chunks.py:96, dim=2, start=255, end=511, stride=1
  %256 = bf16[12,15,256,256]{3,2,1,0} xla::select(%255), location=sliding_chunks_matmul_qk@sliding_chunks.py:96, dim=3, start=257, end=513, stride=1
  %257 = bf16[12,15,256,256]{3,2,1,0} aten::view(%256), location=sliding_chunks_matmul_qk@sliding_chunks.py:96, output_size=(12, 15, 256, 256)
  %258 = bf16[12,15,512,513]{3,2,1,0} aten::view(%244), location=sliding_chunks_matmul_qk@sliding_chunks.py:94, output_size=(12, 15, 512, 513)
  %259 = bf16[12,15,512,513]{3,2,1,0} xla::select(%258), location=sliding_chunks_matmul_qk@sliding_chunks.py:94, dim=0, start=0, end=12, stride=1
  %260 = bf16[12,1,512,513]{3,2,1,0} xla::generic_slice(%259), location=sliding_chunks_matmul_qk@sliding_chunks.py:94, base_indices=(0, 14, 0, 0), sizes=(12, 1, 512, 513)
  %261 = bf16[12,512,513]{2,1,0} aten::view(%260), location=sliding_chunks_matmul_qk@sliding_chunks.py:94, output_size=(12, 512, 513)
  %262 = bf16[12,256,513]{2,1,0} xla::select(%261), location=sliding_chunks_matmul_qk@sliding_chunks.py:94, dim=1, start=256, end=512, stride=1
  %263 = bf16[12,256,257]{2,1,0} xla::select(%262), location=sliding_chunks_matmul_qk@sliding_chunks.py:94, dim=2, start=0, end=257, stride=1
  %264 = bf16[12,256,257]{2,1,0} aten::view(%263), location=sliding_chunks_matmul_qk@sliding_chunks.py:94, output_size=(12, 256, 257)
  %265 = bf16[12,15,512,513]{3,2,1,0} aten::view(%244), location=sliding_chunks_matmul_qk@sliding_chunks.py:93, output_size=(12, 15, 512, 513)
  %266 = bf16[12,15,512,513]{3,2,1,0} xla::select(%265), location=sliding_chunks_matmul_qk@sliding_chunks.py:93, dim=0, start=0, end=12, stride=1
  %267 = bf16[12,15,512,513]{3,2,1,0} xla::select(%266), location=sliding_chunks_matmul_qk@sliding_chunks.py:93, dim=1, start=0, end=15, stride=1
  %268 = bf16[12,15,256,513]{3,2,1,0} xla::select(%267), location=sliding_chunks_matmul_qk@sliding_chunks.py:93, dim=2, start=0, end=256, stride=1
  %269 = bf16[12,15,256,257]{3,2,1,0} xla::select(%268), location=sliding_chunks_matmul_qk@sliding_chunks.py:93, dim=3, start=0, end=257, stride=1
  %270 = bf16[12,15,256,257]{3,2,1,0} aten::view(%269), location=sliding_chunks_matmul_qk@sliding_chunks.py:93, output_size=(12, 15, 256, 257)
  %271 = bf16[] prim::Constant(), location=sliding_chunks_matmul_qk@sliding_chunks.py:89, value=0
  %272 = bf16[12,16,256,513]{3,2,1,0} aten::expand(%271), location=sliding_chunks_matmul_qk@sliding_chunks.py:89, size=(12, 16, 256, 513)
  %273 = bf16[12,16,256,513]{3,2,1,0} xla::select(%272), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=0, start=0, end=12, stride=1
  %274 = bf16[12,15,256,513]{3,2,1,0} xla::select(%273), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=1, start=0, end=15, stride=1
  %275 = bf16[12,15,256,513]{3,2,1,0} xla::select(%274), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=2, start=0, end=256, stride=1
  %276 = bf16[12,15,256,513]{3,2,1,0} xla::unselect(%275, %270), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=3, start=256, end=513, stride=1
  %277 = bf16[12,15,256,513]{3,2,1,0} xla::unselect(%274, %276), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=2, start=0, end=256, stride=1
  %278 = bf16[12,16,256,513]{3,2,1,0} xla::unselect(%273, %277), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=1, start=0, end=15, stride=1
  %279 = bf16[12,16,256,513]{3,2,1,0} xla::unselect(%272, %278), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=0, start=0, end=12, stride=1
  %280 = bf16[12,16,256,513]{3,2,1,0} xla::select(%279), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=0, start=0, end=12, stride=1
  %281 = bf16[12,1,256,513]{3,2,1,0} xla::generic_slice(%280), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, base_indices=(0, 15, 0, 0), sizes=(12, 1, 256, 513)
  %282 = bf16[12,256,513]{2,1,0} aten::view(%281), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, output_size=(12, 256, 513)
  %283 = bf16[12,256,513]{2,1,0} xla::select(%282), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=1, start=0, end=256, stride=1
  %284 = bf16[12,256,513]{2,1,0} xla::unselect(%283, %264), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=2, start=256, end=513, stride=1
  %285 = bf16[12,256,513]{2,1,0} xla::unselect(%282, %284), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=1, start=0, end=256, stride=1
  %286 = bf16[12,1,256,513]{3,2,1,0} aten::view(%285), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, output_size=(12, 1, 256, 513)
  %287 = bf16[12,16,256,513]{3,2,1,0} xla::update_slice(%280, %286), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, base_indices=(0, 15, 0, 0)
  %288 = bf16[12,16,256,513]{3,2,1,0} xla::unselect(%279, %287), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=0, start=0, end=12, stride=1
  %289 = bf16[12,16,256,513]{3,2,1,0} xla::select(%288), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=0, start=0, end=12, stride=1
  %290 = bf16[12,15,256,513]{3,2,1,0} xla::select(%289), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=1, start=1, end=16, stride=1
  %291 = bf16[12,15,256,513]{3,2,1,0} xla::select(%290), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=2, start=0, end=256, stride=1
  %292 = bf16[12,15,256,513]{3,2,1,0} xla::unselect(%291, %257), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=3, start=0, end=256, stride=1
  %293 = bf16[12,15,256,513]{3,2,1,0} xla::unselect(%290, %292), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=2, start=0, end=256, stride=1
  %294 = bf16[12,16,256,513]{3,2,1,0} xla::unselect(%289, %293), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=1, start=1, end=16, stride=1
  %295 = bf16[12,16,256,513]{3,2,1,0} xla::unselect(%288, %294), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=0, start=0, end=12, stride=1
  %296 = bf16[12,16,256,513]{3,2,1,0} xla::select(%295), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=0, start=0, end=12, stride=1
  %297 = bf16[12,1,256,513]{3,2,1,0} xla::generic_slice(%296), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, base_indices=(0, 0, 0, 0), sizes=(12, 1, 256, 513)
  %298 = bf16[12,256,513]{2,1,0} aten::view(%297), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, output_size=(12, 256, 513)
  %299 = bf16[12,255,513]{2,1,0} xla::select(%298), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=1, start=1, end=256, stride=1
  %300 = bf16[12,255,513]{2,1,0} xla::unselect(%299, %251), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=2, start=1, end=256, stride=1
  %301 = bf16[12,256,513]{2,1,0} xla::unselect(%298, %300), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=1, start=1, end=256, stride=1
  %302 = bf16[12,1,256,513]{3,2,1,0} aten::view(%301), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, output_size=(12, 1, 256, 513)
  %303 = bf16[12,16,256,513]{3,2,1,0} xla::update_slice(%296, %302), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, base_indices=(0, 0, 0, 0)
  %304 = bf16[12,16,256,513]{3,2,1,0} xla::unselect(%295, %303), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=0, start=0, end=12, stride=1
  %305 = bf16[1,12,4096,513]{3,2,1,0} aten::view(%304), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, output_size=(1, 12, 4096, 513)
  %306 = bf16[1,4096,12,513]{3,1,2,0} aten::permute(%305), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dims=(0, 2, 1, 3)
  %307 = bf16[1,4096,12,513]{3,1,2,0} xla::select(%306), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=0, start=0, end=1, stride=1
  %308 = bf16[1,256,12,513]{3,1,2,0} xla::select(%307), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=1, start=0, end=256, stride=1
  %309 = bf16[1,256,12,513]{3,1,2,0} xla::select(%308), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=2, start=0, end=12, stride=1
  %310 = bf16[1,256,12,257]{3,1,2,0} xla::select(%309), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=3, start=0, end=257, stride=1
  %311 = bf16[1,256,12,257]{3,1,2,0} aten::masked_fill(%310, %239), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, value=-inf
  %312 = bf16[1,12,4096,513]{3,2,1,0} aten::view(%304), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, output_size=(1, 12, 4096, 513)
  %313 = bf16[1,4096,12,513]{3,1,2,0} aten::permute(%312), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dims=(0, 2, 1, 3)
  %314 = bf16[1,4096,12,513]{3,1,2,0} xla::select(%313), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=0, start=0, end=1, stride=1
  %315 = bf16[1,256,12,513]{3,1,2,0} xla::select(%314), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=1, start=0, end=256, stride=1
  %316 = bf16[1,256,12,513]{3,1,2,0} xla::select(%315), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=2, start=0, end=12, stride=1
  %317 = bf16[1,256,12,513]{3,1,2,0} xla::unselect(%316, %311), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=3, start=0, end=257, stride=1
  %318 = bf16[1,256,12,513]{3,1,2,0} xla::unselect(%315, %317), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=2, start=0, end=12, stride=1
  %319 = bf16[1,4096,12,513]{3,1,2,0} xla::unselect(%314, %318), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=1, start=0, end=256, stride=1
  %320 = bf16[1,4096,12,513]{3,1,2,0} xla::unselect(%313, %319), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=0, start=0, end=1, stride=1
  %321 = bf16[1,12,4096,513]{3,2,1,0} aten::permute(%320), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dims=(0, 2, 1, 3)
  %322 = bf16[12,16,256,513]{3,2,1,0} aten::view(%321), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, output_size=(12, 16, 256, 513)
  %323 = bf16[1,12,4096,513]{3,2,1,0} aten::view(%322), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, output_size=(1, 12, 4096, 513)
  %324 = bf16[1,4096,12,513]{3,1,2,0} aten::permute(%323), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dims=(0, 2, 1, 3)
  %325 = bf16[1,4096,12,513]{3,1,2,0} xla::select(%324), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=0, start=0, end=1, stride=1
  %326 = bf16[1,256,12,513]{3,1,2,0} xla::select(%325), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=1, start=3840, end=4096, stride=1
  %327 = bf16[1,256,12,513]{3,1,2,0} xla::select(%326), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=2, start=0, end=12, stride=1
  %328 = bf16[1,256,12,257]{3,1,2,0} xla::select(%327), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=3, start=256, end=513, stride=1
  %329 = bf16[1,256,12,257]{3,1,2,0} aten::masked_fill(%328, %233), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, value=-inf
  %330 = bf16[1,12,4096,513]{3,2,1,0} aten::view(%322), location=softmax@functional.py:1500, output_size=(1, 12, 4096, 513)
  %331 = bf16[1,4096,12,513]{3,1,2,0} aten::permute(%330), location=softmax@functional.py:1500, dims=(0, 2, 1, 3)
  %332 = bf16[1,4096,12,513]{3,1,2,0} xla::select(%331), location=softmax@functional.py:1500, dim=0, start=0, end=1, stride=1
  %333 = bf16[1,256,12,513]{3,1,2,0} xla::select(%332), location=softmax@functional.py:1500, dim=1, start=3840, end=4096, stride=1
  %334 = bf16[1,256,12,513]{3,1,2,0} xla::select(%333), location=softmax@functional.py:1500, dim=2, start=0, end=12, stride=1
  %335 = bf16[1,256,12,513]{3,1,2,0} xla::unselect(%334, %329), location=softmax@functional.py:1500, dim=3, start=256, end=513, stride=1
  %336 = bf16[1,256,12,513]{3,1,2,0} xla::unselect(%333, %335), location=softmax@functional.py:1500, dim=2, start=0, end=12, stride=1
  %337 = bf16[1,4096,12,513]{3,1,2,0} xla::unselect(%332, %336), location=softmax@functional.py:1500, dim=1, start=3840, end=4096, stride=1
  %338 = bf16[1,4096,12,513]{3,1,2,0} xla::unselect(%331, %337), location=softmax@functional.py:1500, dim=0, start=0, end=1, stride=1
  %339 = bf16[1,12,4096,513]{3,2,1,0} aten::permute(%338), location=softmax@functional.py:1500, dims=(0, 2, 1, 3)
  %340 = bf16[12,16,256,513]{3,2,1,0} aten::view(%339), location=softmax@functional.py:1500, output_size=(12, 16, 256, 513)
  %341 = bf16[1,12,4096,513]{3,2,1,0} aten::view(%340), location=softmax@functional.py:1500, output_size=(1, 12, 4096, 513)
  %342 = bf16[1,4096,12,513]{3,1,2,0} aten::permute(%341), location=softmax@functional.py:1500, dims=(0, 2, 1, 3), ROOT=38
  %343 = bf16[1,4096,12,513]{3,1,2,0} aten::softmax(%342), location=softmax@functional.py:1500, dim=3, dtype=6, ROOT=43
  %344 = bf16[] xla::device_data(), location=dropout@functional.py:973, device=TPU:0
  %345 = bf16[1,4096,12,513]{3,2,1,0} aten::expand(%344), location=dropout@functional.py:973, size=(1, 4096, 12, 513)
  %346 = s64[] xla::device_data(), location=dropout@functional.py:973, device=TPU:0
  %347 = s64[] prim::Constant(), location=dropout@functional.py:973, value=214013
  %348 = s64[] aten::mul(%347, %346), location=dropout@functional.py:973
  %349 = s64[] prim::Constant(), location=dropout@functional.py:973, value=2.53101e+06
  %350 = s64[] aten::add(%349, %348), location=dropout@functional.py:973
  %351 = bf16[] xla::device_data(), location=dropout@functional.py:973, device=TPU:0
  %352 = bf16[1,4096,12,513]{3,2,1,0} aten::expand(%351), location=dropout@functional.py:973, size=(1, 4096, 12, 513)
  %353 = bf16[1,4096,12,513]{3,2,1,0} aten::bernoulli(%352, %350), location=dropout@functional.py:973
  %354 = bf16[1,4096,12,513]{3,2,1,0} aten::div(%353, %345), location=dropout@functional.py:973, ROOT=44
  %355 = bf16[] prim::Constant(), location=linear@functional.py:1678, value=1
  %356 = bf16[768]{0} aten::expand(%355), location=linear@functional.py:1678, size=(768)
  %357 = bf16[768]{0} xla::device_data(), location=linear@functional.py:1678, device=TPU:0
  %358 = bf16[768]{0} aten::mul(%357, %356), location=linear@functional.py:1678
  %359 = bf16[4096,768]{1,0} aten::mm(%17, %15), location=linear@functional.py:1676
  %360 = bf16[4096,1,768]{2,1,0} aten::view(%359), location=linear@functional.py:1678, output_size=(4096, 1, 768)
  %361 = bf16[4096,1,768]{2,1,0} aten::add(%360, %358), location=linear@functional.py:1678
  %362 = bf16[4096,768]{1,0} aten::view(%361), location=_pad@functional.py:3547, output_size=(4096, 768)
  %363 = bf16[4096,1,768]{2,1,0} aten::view(%362), location=_pad@functional.py:3547, output_size=(4096, 1, 768)
  %364 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%363), location=_pad@functional.py:3547, output_size=(4096, 1, 12, 64)
  %365 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%364), location=_pad@functional.py:3547, dims=(1, 0, 2, 3)
  %366 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%365), location=_pad@functional.py:3547, dims=(0, 2, 1, 3)
  %367 = bf16[12,4096,64]{2,1,0} aten::view(%366), location=_pad@functional.py:3547, output_size=(12, 4096, 64)
  %368 = bf16[12,4608,64]{2,1,0} aten::constant_pad_nd(%367), location=_pad@functional.py:3547, pad=(0, 0, 256, 256, 0, 0), value=-1
  %369 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=0, end=768, stride=1, ROOT=45
  %370 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=256, end=1024, stride=1, ROOT=46
  %371 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=512, end=1280, stride=1, ROOT=47
  %372 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=768, end=1536, stride=1, ROOT=48
  %373 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1024, end=1792, stride=1, ROOT=49
  %374 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1280, end=2048, stride=1, ROOT=50
  %375 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1536, end=2304, stride=1, ROOT=51
  %376 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1792, end=2560, stride=1, ROOT=52
  %377 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2048, end=2816, stride=1, ROOT=53
  %378 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2304, end=3072, stride=1, ROOT=54
  %379 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2560, end=3328, stride=1, ROOT=55
  %380 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2816, end=3584, stride=1, ROOT=56
  %381 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=3072, end=3840, stride=1, ROOT=57
  %382 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=3328, end=4096, stride=1, ROOT=58
  %383 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=3584, end=4352, stride=1, ROOT=59
  %384 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=3840, end=4608, stride=1, ROOT=60
  %385 = bf16[1,4096,12,513]{3,2,1,0} aten::mul(%343, %354), location=dropout@functional.py:973
  %386 = bf16[1,12,4096,513]{3,1,2,0} aten::permute(%385), location=_pad@functional.py:3547, dims=(0, 2, 1, 3)
  %387 = bf16[12,16,256,513]{3,2,1,0} aten::view(%386), location=_pad@functional.py:3547, output_size=(12, 16, 256, 513)
  %388 = bf16[12,16,256,770]{3,2,1,0} aten::constant_pad_nd(%387), location=_pad@functional.py:3547, pad=(0, 257, 0, 0, 0, 0, 0, 0), value=0
  %389 = bf16[12,16,197120]{2,1,0} aten::view(%388), location=einsum@functional.py:327, output_size=(12, 16, 197120)
  %390 = bf16[12,16,197120]{2,1,0} xla::select(%389), location=einsum@functional.py:327, dim=0, start=0, end=12, stride=1
  %391 = bf16[12,16,197120]{2,1,0} xla::select(%390), location=einsum@functional.py:327, dim=1, start=0, end=16, stride=1
  %392 = bf16[12,16,196864]{2,1,0} xla::select(%391), location=einsum@functional.py:327, dim=2, start=0, end=196864, stride=1
  %393 = bf16[12,16,256,769]{3,2,1,0} aten::view(%392), location=einsum@functional.py:327, output_size=(12, 16, 256, 769)
  %394 = bf16[12,16,256,769]{3,2,1,0} xla::select(%393), location=einsum@functional.py:327, dim=0, start=0, end=12, stride=1
  %395 = bf16[12,16,256,769]{3,2,1,0} xla::select(%394), location=einsum@functional.py:327, dim=1, start=0, end=16, stride=1
  %396 = bf16[12,16,256,769]{3,2,1,0} xla::select(%395), location=einsum@functional.py:327, dim=2, start=0, end=256, stride=1
  %397 = bf16[12,16,256,768]{3,2,1,0} xla::select(%396), location=einsum@functional.py:327, dim=3, start=0, end=768, stride=1
  %398 = bf16[12,16,256,768]{3,2,1,0} aten::permute(%397), location=einsum@functional.py:327, dims=(0, 1, 2, 3)
  %399 = bf16[12,16,256,1,768]{4,3,2,1,0} aten::view(%398), location=einsum@functional.py:327, output_size=(12, 16, 256, 1, 768)
  %400 = bf16[12,16,256,768,1]{3,4,2,1,0} aten::permute(%399), location=einsum@functional.py:327, dims=(0, 1, 2, 4, 3)
  %401 = bf16[192,256,768]{2,1,0} aten::view(%400), location=einsum@functional.py:327, output_size=(192, 256, 768), ROOT=61
  %402 = bf16[12,16,768,64]{3,2,1,0} aten::stack(%369, %370, %371, %372, %373, %374, %375, %376, %377, %378, %379, %380, %381, %382, %383, %384), location=_unfold_loop@sliding_chunks.py:17, dim=1
  %403 = bf16[12,16,64,768]{2,3,1,0} aten::permute(%402), location=einsum@functional.py:327, dims=(0, 1, 3, 2)
  %404 = bf16[12,16,1,64,768]{4,3,2,1,0} aten::view(%403), location=einsum@functional.py:327, output_size=(12, 16, 1, 64, 768)
  %405 = bf16[12,16,768,64,1]{2,3,4,1,0} aten::permute(%404), location=einsum@functional.py:327, dims=(0, 1, 4, 3, 2)
  %406 = bf16[192,768,64]{2,1,0} aten::view(%405), location=einsum@functional.py:327, output_size=(192, 768, 64), ROOT=62
  %407 = bf16[] prim::Constant(), location=forward@longformer.py:227, value=1
  %408 = bf16[] prim::Constant(), location=forward@longformer.py:227, value=0
  %409 = bf16[] aten::mul(%408, %407), location=forward@longformer.py:227
  %410 = bf16[192,256,64]{2,1,0} aten::matmul(%401, %406), location=einsum@functional.py:327
  %411 = bf16[12,16,256,1,64]{4,3,2,1,0} aten::view(%410), location=forward@longformer.py:227, output_size=(12, 16, 256, 1, 64)
  %412 = bf16[12,16,256,64,1]{3,4,2,1,0} aten::permute(%411), location=forward@longformer.py:227, dims=(0, 1, 2, 4, 3)
  %413 = bf16[12,16,256,64]{3,2,1,0} aten::view(%412), location=forward@longformer.py:227, output_size=(12, 16, 256, 64)
  %414 = bf16[1,12,4096,64]{3,2,1,0} aten::view(%413), location=forward@longformer.py:227, output_size=(1, 12, 4096, 64)
  %415 = bf16[1,4096,12,64]{3,1,2,0} aten::permute(%414), location=forward@longformer.py:227, dims=(0, 2, 1, 3)
  %416 = bf16[1,4096,12,64]{3,2,1,0} aten::add(%415, %409), location=forward@longformer.py:227, ROOT=63
  %417 = bf16[4096,1,12,64]{3,2,0,1} aten::permute(%416), location=training_step@test_tpu.py:61, dims=(1, 0, 2, 3)
  %418 = bf16[4096,1,768]{2,1,0} aten::view(%417), location=training_step@test_tpu.py:61, output_size=(4096, 1, 768)
  %419 = bf16[1,4096,768]{2,0,1} aten::permute(%418), location=training_step@test_tpu.py:61, dims=(1, 0, 2), ROOT=64
  %420 = bf16[] aten::sum(%419), location=training_step@test_tpu.py:61, dimensions=(0, 1, 2), keep_reduced_dimensions=0, dtype=6, ROOT=65
}
JackCaoG commented 4 years ago

@ibeltagy I looked at the graph with Blake and he mentioned that view op can be expensive. View on XLA usually got lower as reshape and transpose. Transpose is typically free on TPU but reshape is expensive. When you tried the conv approach and get back a new HLO, we can try to run the xprof tool to better understand the performance.

ibeltagy commented 4 years ago

Tried the convolution trick; it makes the model 20% faster and it takes the same amount of memory, which is great. The new graph is below and I am looking for opportunities to save compute as well as memory as well.

The part that I suspect is the least efficient is the lines from 70-100 in the graph below which correspond to these lines of code but I could be totally wrong.


IR {
  %0 = f32[768,768]{1,0} xla::device_data(), location=linear@functional.py:1676, device=TPU:0
  %1 = f32[768,768]{0,1} aten::permute(%0), location=linear@functional.py:1676, dims=(1, 0), ROOT=0
  %2 = s64[1,4096]{1,0} xla::device_data(), location=forward@test_tpu.py:53, device=TPU:0
  %3 = s64[1,4096]{1,0} xla::select(%2), location=forward@test_tpu.py:53, dim=0, start=0, end=1, stride=1
  %4 = s64[1,4096]{1,0} xla::select(%3), location=forward@test_tpu.py:53, dim=1, start=0, end=4096, stride=1
  %5 = s64[1,4096,1]{2,1,0} aten::view(%4), location=forward@test_tpu.py:53, output_size=(1, 4096, 1)
  %6 = s64[1,4096,768]{2,1,0} aten::expand(%5), location=forward@test_tpu.py:53, size=(1, 4096, 768)
  %7 = f32[1,4096,768]{2,1,0} xla::cast(%6), location=forward@test_tpu.py:53, type=f32, dtype=Float, stype=Long
  %8 = f32[4096,1,768]{2,0,1} aten::permute(%7), location=linear@functional.py:1676, dims=(1, 0, 2)
  %9 = f32[4096,768]{1,0} aten::view(%8), location=linear@functional.py:1676, output_size=(4096, 768), ROOT=1
  %10 = f32[768,768]{1,0} xla::device_data(), location=linear@functional.py:1676, device=TPU:0
  %11 = f32[768,768]{0,1} aten::permute(%10), location=linear@functional.py:1676, dims=(1, 0), ROOT=2
  %12 = f32[4096,1,768]{2,0,1} aten::permute(%7), location=linear@functional.py:1676, dims=(1, 0, 2)
  %13 = f32[4096,768]{1,0} aten::view(%12), location=linear@functional.py:1676, output_size=(4096, 768), ROOT=3
  %14 = f32[768,768]{1,0} xla::device_data(), location=linear@functional.py:1676, device=TPU:0
  %15 = f32[768,768]{0,1} aten::permute(%14), location=linear@functional.py:1676, dims=(1, 0), ROOT=4
  %16 = f32[4096,1,768]{2,0,1} aten::permute(%7), location=linear@functional.py:1676, dims=(1, 0, 2)
  %17 = f32[4096,768]{1,0} aten::view(%16), location=linear@functional.py:1676, output_size=(4096, 768), ROOT=5
  %18 = f32[] xla::device_data(), location=forward@longformer.py:155, device=TPU:0
  %19 = f32[] prim::Constant(), location=linear@functional.py:1678, value=1
  %20 = f32[768]{0} aten::expand(%19), location=linear@functional.py:1678, size=(768)
  %21 = f32[768]{0} xla::device_data(), location=linear@functional.py:1678, device=TPU:0
  %22 = f32[768]{0} aten::mul(%21, %20), location=linear@functional.py:1678
  %23 = f32[4096,768]{1,0} aten::mm(%9, %1), location=linear@functional.py:1676
  %24 = f32[4096,1,768]{2,1,0} aten::view(%23), location=linear@functional.py:1678, output_size=(4096, 1, 768)
  %25 = f32[4096,1,768]{2,1,0} aten::add(%24, %22), location=linear@functional.py:1678
  %26 = f32[4096,768]{1,0} aten::view(%25), location=forward@longformer.py:155, output_size=(4096, 768)
  %27 = f32[4096,1,768]{2,1,0} aten::view(%26), location=forward@longformer.py:155, output_size=(4096, 1, 768)
  %28 = f32[4096,1,768]{2,1,0} aten::div(%27, %18), location=forward@longformer.py:155
  %29 = f32[4096,768]{1,0} aten::view(%28), location=_unfold_conv@sliding_chunks.py:27, output_size=(4096, 768)
  %30 = f32[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_conv@sliding_chunks.py:27, output_size=(4096, 1, 768)
  %31 = f32[4096,1,12,64]{3,2,1,0} aten::view(%30), location=_unfold_conv@sliding_chunks.py:27, output_size=(4096, 1, 12, 64)
  %32 = f32[1,4096,12,64]{3,2,0,1} aten::permute(%31), location=_unfold_conv@sliding_chunks.py:27, dims=(1, 0, 2, 3)
  %33 = f32[1,12,4096,64]{3,1,0,2} aten::permute(%32), location=_unfold_conv@sliding_chunks.py:27, dims=(0, 2, 1, 3)
  %34 = f32[12,4096,64]{2,1,0} aten::view(%33), location=_unfold_conv@sliding_chunks.py:27, output_size=(12, 4096, 64)
  %35 = f32[12,16,256,64]{3,2,1,0} aten::view(%34), location=_unfold_conv@sliding_chunks.py:27, output_size=(12, 16, 256, 64)
  %36 = f32[64,256,12,16]{0,1,3,2} aten::permute(%35), location=_unfold_conv@sliding_chunks.py:27, dims=(3, 2, 0, 1), ROOT=6
  %37 = f32[512,512]{1,0} aten::eye(), location=_unfold_conv@sliding_chunks.py:26
  %38 = f32[1,2,256,512]{3,2,1,0} aten::view(%37), location=_unfold_conv@sliding_chunks.py:27, output_size=(1, 2, 256, 512)
  %39 = f32[512,256,1,2]{0,1,3,2} aten::permute(%38), location=_unfold_conv@sliding_chunks.py:27, dims=(3, 2, 0, 1), ROOT=7
  %40 = f32[] prim::Constant(), location=linear@functional.py:1678, value=1
  %41 = f32[768]{0} aten::expand(%40), location=linear@functional.py:1678, size=(768)
  %42 = f32[768]{0} xla::device_data(), location=linear@functional.py:1678, device=TPU:0
  %43 = f32[768]{0} aten::mul(%42, %41), location=linear@functional.py:1678
  %44 = f32[4096,768]{1,0} aten::mm(%13, %11), location=linear@functional.py:1676
  %45 = f32[4096,1,768]{2,1,0} aten::view(%44), location=linear@functional.py:1678, output_size=(4096, 1, 768)
  %46 = f32[4096,1,768]{2,1,0} aten::add(%45, %43), location=linear@functional.py:1678
  %47 = f32[4096,768]{1,0} aten::view(%46), location=_unfold_conv@sliding_chunks.py:27, output_size=(4096, 768)
  %48 = f32[4096,1,768]{2,1,0} aten::view(%47), location=_unfold_conv@sliding_chunks.py:27, output_size=(4096, 1, 768)
  %49 = f32[4096,1,12,64]{3,2,1,0} aten::view(%48), location=_unfold_conv@sliding_chunks.py:27, output_size=(4096, 1, 12, 64)
  %50 = f32[1,4096,12,64]{3,2,0,1} aten::permute(%49), location=_unfold_conv@sliding_chunks.py:27, dims=(1, 0, 2, 3)
  %51 = f32[1,12,4096,64]{3,1,0,2} aten::permute(%50), location=_unfold_conv@sliding_chunks.py:27, dims=(0, 2, 1, 3)
  %52 = f32[12,4096,64]{2,1,0} aten::view(%51), location=_unfold_conv@sliding_chunks.py:27, output_size=(12, 4096, 64)
  %53 = f32[12,16,256,64]{3,2,1,0} aten::view(%52), location=_unfold_conv@sliding_chunks.py:27, output_size=(12, 16, 256, 64)
  %54 = f32[64,256,12,16]{0,1,3,2} aten::permute(%53), location=_unfold_conv@sliding_chunks.py:27, dims=(3, 2, 0, 1), ROOT=8
  %55 = f32[512,512]{1,0} aten::eye(), location=_unfold_conv@sliding_chunks.py:26
  %56 = f32[1,2,256,512]{3,2,1,0} aten::view(%55), location=_unfold_conv@sliding_chunks.py:27, output_size=(1, 2, 256, 512)
  %57 = f32[512,256,1,2]{0,1,3,2} aten::permute(%56), location=_unfold_conv@sliding_chunks.py:27, dims=(3, 2, 0, 1), ROOT=9
  %58 = f32[64,512,12,15]{3,2,1,0} aten::convolution_overrideable(%36, %39), location=_unfold_conv@sliding_chunks.py:27, stride=(1, 1), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %59 = f32[12,15,512,64]{1,0,2,3} aten::permute(%58), location=einsum@functional.py:327, dims=(2, 3, 1, 0)
  %60 = f32[12,15,512,64]{1,0,2,3} aten::permute(%59), location=einsum@functional.py:327, dims=(0, 1, 2, 3)
  %61 = f32[12,15,512,1,64]{4,3,2,1,0} aten::view(%60), location=einsum@functional.py:327, output_size=(12, 15, 512, 1, 64)
  %62 = f32[12,15,512,64,1]{3,4,2,1,0} aten::permute(%61), location=einsum@functional.py:327, dims=(0, 1, 2, 4, 3)
  %63 = f32[180,512,64]{2,1,0} aten::view(%62), location=einsum@functional.py:327, output_size=(180, 512, 64), ROOT=10
  %64 = f32[64,512,12,15]{3,2,1,0} aten::convolution_overrideable(%54, %57), location=_unfold_conv@sliding_chunks.py:27, stride=(1, 1), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %65 = f32[12,15,512,64]{1,0,2,3} aten::permute(%64), location=einsum@functional.py:327, dims=(2, 3, 1, 0)
  %66 = f32[12,15,512,64]{1,0,2,3} aten::permute(%65), location=einsum@functional.py:327, dims=(0, 1, 2, 3)
  %67 = f32[12,15,1,512,64]{4,3,2,1,0} aten::view(%66), location=einsum@functional.py:327, output_size=(12, 15, 1, 512, 64)
  %68 = f32[12,15,64,512,1]{2,3,4,1,0} aten::permute(%67), location=einsum@functional.py:327, dims=(0, 1, 4, 3, 2)
  %69 = f32[180,64,512]{2,1,0} aten::view(%68), location=einsum@functional.py:327, output_size=(180, 64, 512), ROOT=11
  %70 = f32[180,512,512]{2,1,0} aten::matmul(%63, %69), location=einsum@functional.py:327
  %71 = f32[12,15,512,1,512]{4,3,2,1,0} aten::view(%70), location=_pad@functional.py:3547, output_size=(12, 15, 512, 1, 512)
  %72 = f32[12,15,512,512,1]{3,4,2,1,0} aten::permute(%71), location=_pad@functional.py:3547, dims=(0, 1, 2, 4, 3)
  %73 = f32[12,15,512,512]{3,2,1,0} aten::view(%72), location=_pad@functional.py:3547, output_size=(12, 15, 512, 512)
  %74 = f32[12,15,513,512]{3,2,1,0} aten::constant_pad_nd(%73), location=_pad@functional.py:3547, pad=(0, 0, 0, 1, 0, 0, 0, 0), value=0
  %75 = f32[12,15,512,513]{3,2,1,0} aten::view(%74), location=sliding_chunks_matmul_qk@sliding_chunks.py:108, output_size=(12, 15, 512, 513)
  %76 = f32[12,15,512,513]{3,2,1,0} xla::select(%75), location=sliding_chunks_matmul_qk@sliding_chunks.py:108, dim=0, start=0, end=12, stride=1
  %77 = f32[12,1,512,513]{3,2,1,0} xla::generic_slice(%76), location=sliding_chunks_matmul_qk@sliding_chunks.py:108, base_indices=(0, 0, 0, 0), sizes=(12, 1, 512, 513)
  %78 = f32[12,512,513]{2,1,0} aten::view(%77), location=sliding_chunks_matmul_qk@sliding_chunks.py:108, output_size=(12, 512, 513)
  %79 = f32[12,255,513]{2,1,0} xla::select(%78), location=sliding_chunks_matmul_qk@sliding_chunks.py:108, dim=1, start=0, end=255, stride=1
  %80 = f32[12,255,255]{2,1,0} xla::select(%79), location=sliding_chunks_matmul_qk@sliding_chunks.py:108, dim=2, start=258, end=513, stride=1
  %81 = f32[12,255,255]{2,1,0} aten::view(%80), location=sliding_chunks_matmul_qk@sliding_chunks.py:108, output_size=(12, 255, 255)
  %82 = f32[12,15,512,513]{3,2,1,0} aten::view(%74), location=sliding_chunks_matmul_qk@sliding_chunks.py:107, output_size=(12, 15, 512, 513)
  %83 = f32[12,15,512,513]{3,2,1,0} xla::select(%82), location=sliding_chunks_matmul_qk@sliding_chunks.py:107, dim=0, start=0, end=12, stride=1
  %84 = f32[12,15,512,513]{3,2,1,0} xla::select(%83), location=sliding_chunks_matmul_qk@sliding_chunks.py:107, dim=1, start=0, end=15, stride=1
  %85 = f32[12,15,256,513]{3,2,1,0} xla::select(%84), location=sliding_chunks_matmul_qk@sliding_chunks.py:107, dim=2, start=255, end=511, stride=1
  %86 = f32[12,15,256,256]{3,2,1,0} xla::select(%85), location=sliding_chunks_matmul_qk@sliding_chunks.py:107, dim=3, start=257, end=513, stride=1
  %87 = f32[12,15,256,256]{3,2,1,0} aten::view(%86), location=sliding_chunks_matmul_qk@sliding_chunks.py:107, output_size=(12, 15, 256, 256)
  %88 = f32[12,15,512,513]{3,2,1,0} aten::view(%74), location=sliding_chunks_matmul_qk@sliding_chunks.py:105, output_size=(12, 15, 512, 513)
  %89 = f32[12,15,512,513]{3,2,1,0} xla::select(%88), location=sliding_chunks_matmul_qk@sliding_chunks.py:105, dim=0, start=0, end=12, stride=1
  %90 = f32[12,1,512,513]{3,2,1,0} xla::generic_slice(%89), location=sliding_chunks_matmul_qk@sliding_chunks.py:105, base_indices=(0, 14, 0, 0), sizes=(12, 1, 512, 513)
  %91 = f32[12,512,513]{2,1,0} aten::view(%90), location=sliding_chunks_matmul_qk@sliding_chunks.py:105, output_size=(12, 512, 513)
  %92 = f32[12,256,513]{2,1,0} xla::select(%91), location=sliding_chunks_matmul_qk@sliding_chunks.py:105, dim=1, start=256, end=512, stride=1
  %93 = f32[12,256,257]{2,1,0} xla::select(%92), location=sliding_chunks_matmul_qk@sliding_chunks.py:105, dim=2, start=0, end=257, stride=1
  %94 = f32[12,256,257]{2,1,0} aten::view(%93), location=sliding_chunks_matmul_qk@sliding_chunks.py:105, output_size=(12, 256, 257)
  %95 = f32[12,15,512,513]{3,2,1,0} aten::view(%74), location=sliding_chunks_matmul_qk@sliding_chunks.py:104, output_size=(12, 15, 512, 513)
  %96 = f32[12,15,512,513]{3,2,1,0} xla::select(%95), location=sliding_chunks_matmul_qk@sliding_chunks.py:104, dim=0, start=0, end=12, stride=1
  %97 = f32[12,15,512,513]{3,2,1,0} xla::select(%96), location=sliding_chunks_matmul_qk@sliding_chunks.py:104, dim=1, start=0, end=15, stride=1
  %98 = f32[12,15,256,513]{3,2,1,0} xla::select(%97), location=sliding_chunks_matmul_qk@sliding_chunks.py:104, dim=2, start=0, end=256, stride=1
  %99 = f32[12,15,256,257]{3,2,1,0} xla::select(%98), location=sliding_chunks_matmul_qk@sliding_chunks.py:104, dim=3, start=0, end=257, stride=1
  %100 = f32[12,15,256,257]{3,2,1,0} aten::view(%99), location=sliding_chunks_matmul_qk@sliding_chunks.py:104, output_size=(12, 15, 256, 257)
  %101 = f32[] prim::Constant(), location=sliding_chunks_matmul_qk@sliding_chunks.py:100, value=0
  %102 = f32[12,16,256,513]{3,2,1,0} aten::expand(%101), location=sliding_chunks_matmul_qk@sliding_chunks.py:100, size=(12, 16, 256, 513)
  %103 = f32[12,16,256,513]{3,2,1,0} xla::select(%102), location=softmax@functional.py:1500, dim=0, start=0, end=12, stride=1
  %104 = f32[12,15,256,513]{3,2,1,0} xla::select(%103), location=softmax@functional.py:1500, dim=1, start=0, end=15, stride=1
  %105 = f32[12,15,256,513]{3,2,1,0} xla::select(%104), location=softmax@functional.py:1500, dim=2, start=0, end=256, stride=1
  %106 = f32[12,15,256,513]{3,2,1,0} xla::unselect(%105, %100), location=softmax@functional.py:1500, dim=3, start=256, end=513, stride=1
  %107 = f32[12,15,256,513]{3,2,1,0} xla::unselect(%104, %106), location=softmax@functional.py:1500, dim=2, start=0, end=256, stride=1
  %108 = f32[12,16,256,513]{3,2,1,0} xla::unselect(%103, %107), location=softmax@functional.py:1500, dim=1, start=0, end=15, stride=1
  %109 = f32[12,16,256,513]{3,2,1,0} xla::unselect(%102, %108), location=softmax@functional.py:1500, dim=0, start=0, end=12, stride=1
  %110 = f32[12,16,256,513]{3,2,1,0} xla::select(%109), location=softmax@functional.py:1500, dim=0, start=0, end=12, stride=1
  %111 = f32[12,1,256,513]{3,2,1,0} xla::generic_slice(%110), location=softmax@functional.py:1500, base_indices=(0, 15, 0, 0), sizes=(12, 1, 256, 513)
  %112 = f32[12,256,513]{2,1,0} aten::view(%111), location=softmax@functional.py:1500, output_size=(12, 256, 513)
  %113 = f32[12,256,513]{2,1,0} xla::select(%112), location=softmax@functional.py:1500, dim=1, start=0, end=256, stride=1
  %114 = f32[12,256,513]{2,1,0} xla::unselect(%113, %94), location=softmax@functional.py:1500, dim=2, start=256, end=513, stride=1
  %115 = f32[12,256,513]{2,1,0} xla::unselect(%112, %114), location=softmax@functional.py:1500, dim=1, start=0, end=256, stride=1
  %116 = f32[12,1,256,513]{3,2,1,0} aten::view(%115), location=softmax@functional.py:1500, output_size=(12, 1, 256, 513)
  %117 = f32[12,16,256,513]{3,2,1,0} xla::update_slice(%110, %116), location=softmax@functional.py:1500, base_indices=(0, 15, 0, 0)
  %118 = f32[12,16,256,513]{3,2,1,0} xla::unselect(%109, %117), location=softmax@functional.py:1500, dim=0, start=0, end=12, stride=1
  %119 = f32[12,16,256,513]{3,2,1,0} xla::select(%118), location=softmax@functional.py:1500, dim=0, start=0, end=12, stride=1
  %120 = f32[12,15,256,513]{3,2,1,0} xla::select(%119), location=softmax@functional.py:1500, dim=1, start=1, end=16, stride=1
  %121 = f32[12,15,256,513]{3,2,1,0} xla::select(%120), location=softmax@functional.py:1500, dim=2, start=0, end=256, stride=1
  %122 = f32[12,15,256,513]{3,2,1,0} xla::unselect(%121, %87), location=softmax@functional.py:1500, dim=3, start=0, end=256, stride=1
  %123 = f32[12,15,256,513]{3,2,1,0} xla::unselect(%120, %122), location=softmax@functional.py:1500, dim=2, start=0, end=256, stride=1
  %124 = f32[12,16,256,513]{3,2,1,0} xla::unselect(%119, %123), location=softmax@functional.py:1500, dim=1, start=1, end=16, stride=1
  %125 = f32[12,16,256,513]{3,2,1,0} xla::unselect(%118, %124), location=softmax@functional.py:1500, dim=0, start=0, end=12, stride=1
  %126 = f32[12,16,256,513]{3,2,1,0} xla::select(%125), location=softmax@functional.py:1500, dim=0, start=0, end=12, stride=1
  %127 = f32[12,1,256,513]{3,2,1,0} xla::generic_slice(%126), location=softmax@functional.py:1500, base_indices=(0, 0, 0, 0), sizes=(12, 1, 256, 513)
  %128 = f32[12,256,513]{2,1,0} aten::view(%127), location=softmax@functional.py:1500, output_size=(12, 256, 513)
  %129 = f32[12,255,513]{2,1,0} xla::select(%128), location=softmax@functional.py:1500, dim=1, start=1, end=256, stride=1
  %130 = f32[12,255,513]{2,1,0} xla::unselect(%129, %81), location=softmax@functional.py:1500, dim=2, start=1, end=256, stride=1
  %131 = f32[12,256,513]{2,1,0} xla::unselect(%128, %130), location=softmax@functional.py:1500, dim=1, start=1, end=256, stride=1
  %132 = f32[12,1,256,513]{3,2,1,0} aten::view(%131), location=softmax@functional.py:1500, output_size=(12, 1, 256, 513)
  %133 = f32[12,16,256,513]{3,2,1,0} xla::update_slice(%126, %132), location=softmax@functional.py:1500, base_indices=(0, 0, 0, 0)
  %134 = f32[12,16,256,513]{3,2,1,0} xla::unselect(%125, %133), location=softmax@functional.py:1500, dim=0, start=0, end=12, stride=1
  %135 = f32[1,12,4096,513]{3,2,1,0} aten::view(%134), location=softmax@functional.py:1500, output_size=(1, 12, 4096, 513)
  %136 = f32[1,4096,12,513]{3,1,2,0} aten::permute(%135), location=softmax@functional.py:1500, dims=(0, 2, 1, 3), ROOT=12
  %137 = f32[1,4096,12,513]{3,1,2,0} aten::softmax(%136), location=softmax@functional.py:1500, dim=3, dtype=6, ROOT=13
  %138 = f32[] xla::device_data(), location=dropout@functional.py:973, device=TPU:0
  %139 = f32[1,4096,12,513]{3,2,1,0} aten::expand(%138), location=dropout@functional.py:973, size=(1, 4096, 12, 513)
  %140 = s64[] xla::device_data(), location=dropout@functional.py:973, device=TPU:0
  %141 = s64[] prim::Constant(), location=dropout@functional.py:973, value=214013
  %142 = s64[] aten::mul(%141, %140), location=dropout@functional.py:973
  %143 = s64[] prim::Constant(), location=dropout@functional.py:973, value=2.53101e+06
  %144 = s64[] aten::add(%143, %142), location=dropout@functional.py:973
  %145 = f32[] xla::device_data(), location=dropout@functional.py:973, device=TPU:0
  %146 = f32[1,4096,12,513]{3,2,1,0} aten::expand(%145), location=dropout@functional.py:973, size=(1, 4096, 12, 513)
  %147 = f32[1,4096,12,513]{3,2,1,0} aten::bernoulli(%146, %144), location=dropout@functional.py:973
  %148 = f32[1,4096,12,513]{3,2,1,0} aten::div(%147, %139), location=dropout@functional.py:973, ROOT=14
  %149 = f32[] prim::Constant(), location=linear@functional.py:1678, value=1
  %150 = f32[768]{0} aten::expand(%149), location=linear@functional.py:1678, size=(768)
  %151 = f32[768]{0} xla::device_data(), location=linear@functional.py:1678, device=TPU:0
  %152 = f32[768]{0} aten::mul(%151, %150), location=linear@functional.py:1678
  %153 = f32[4096,768]{1,0} aten::mm(%17, %15), location=linear@functional.py:1676
  %154 = f32[4096,1,768]{2,1,0} aten::view(%153), location=linear@functional.py:1678, output_size=(4096, 1, 768)
  %155 = f32[4096,1,768]{2,1,0} aten::add(%154, %152), location=linear@functional.py:1678
  %156 = f32[4096,768]{1,0} aten::view(%155), location=_pad@functional.py:3547, output_size=(4096, 768)
  %157 = f32[4096,1,768]{2,1,0} aten::view(%156), location=_pad@functional.py:3547, output_size=(4096, 1, 768)
  %158 = f32[4096,1,12,64]{3,2,1,0} aten::view(%157), location=_pad@functional.py:3547, output_size=(4096, 1, 12, 64)
  %159 = f32[1,4096,12,64]{3,2,0,1} aten::permute(%158), location=_pad@functional.py:3547, dims=(1, 0, 2, 3)
  %160 = f32[1,12,4096,64]{3,1,0,2} aten::permute(%159), location=_pad@functional.py:3547, dims=(0, 2, 1, 3)
  %161 = f32[12,4096,64]{2,1,0} aten::view(%160), location=_pad@functional.py:3547, output_size=(12, 4096, 64)
  %162 = f32[12,4608,64]{2,1,0} aten::constant_pad_nd(%161), location=_pad@functional.py:3547, pad=(0, 0, 256, 256, 0, 0), value=-1
  %163 = f32[12,18,256,64]{3,2,1,0} aten::view(%162), location=_unfold_conv@sliding_chunks.py:27, output_size=(12, 18, 256, 64)
  %164 = f32[64,256,12,18]{0,1,3,2} aten::permute(%163), location=_unfold_conv@sliding_chunks.py:27, dims=(3, 2, 0, 1), ROOT=15
  %165 = f32[768,768]{1,0} aten::eye(), location=_unfold_conv@sliding_chunks.py:26
  %166 = f32[1,3,256,768]{3,2,1,0} aten::view(%165), location=_unfold_conv@sliding_chunks.py:27, output_size=(1, 3, 256, 768)
  %167 = f32[768,256,1,3]{0,1,3,2} aten::permute(%166), location=_unfold_conv@sliding_chunks.py:27, dims=(3, 2, 0, 1), ROOT=16
  %168 = f32[1,4096,12,513]{3,2,1,0} aten::mul(%137, %148), location=dropout@functional.py:973
  %169 = f32[1,12,4096,513]{3,1,2,0} aten::permute(%168), location=_pad@functional.py:3547, dims=(0, 2, 1, 3)
  %170 = f32[12,16,256,513]{3,2,1,0} aten::view(%169), location=_pad@functional.py:3547, output_size=(12, 16, 256, 513)
  %171 = f32[12,16,256,770]{3,2,1,0} aten::constant_pad_nd(%170), location=_pad@functional.py:3547, pad=(0, 257, 0, 0, 0, 0, 0, 0), value=0
  %172 = f32[12,16,197120]{2,1,0} aten::view(%171), location=einsum@functional.py:327, output_size=(12, 16, 197120)
  %173 = f32[12,16,197120]{2,1,0} xla::select(%172), location=einsum@functional.py:327, dim=0, start=0, end=12, stride=1
  %174 = f32[12,16,197120]{2,1,0} xla::select(%173), location=einsum@functional.py:327, dim=1, start=0, end=16, stride=1
  %175 = f32[12,16,196864]{2,1,0} xla::select(%174), location=einsum@functional.py:327, dim=2, start=0, end=196864, stride=1
  %176 = f32[12,16,256,769]{3,2,1,0} aten::view(%175), location=einsum@functional.py:327, output_size=(12, 16, 256, 769)
  %177 = f32[12,16,256,769]{3,2,1,0} xla::select(%176), location=einsum@functional.py:327, dim=0, start=0, end=12, stride=1
  %178 = f32[12,16,256,769]{3,2,1,0} xla::select(%177), location=einsum@functional.py:327, dim=1, start=0, end=16, stride=1
  %179 = f32[12,16,256,769]{3,2,1,0} xla::select(%178), location=einsum@functional.py:327, dim=2, start=0, end=256, stride=1
  %180 = f32[12,16,256,768]{3,2,1,0} xla::select(%179), location=einsum@functional.py:327, dim=3, start=0, end=768, stride=1
  %181 = f32[12,16,256,768]{3,2,1,0} aten::permute(%180), location=einsum@functional.py:327, dims=(0, 1, 2, 3)
  %182 = f32[12,16,256,1,768]{4,3,2,1,0} aten::view(%181), location=einsum@functional.py:327, output_size=(12, 16, 256, 1, 768)
  %183 = f32[12,16,256,768,1]{3,4,2,1,0} aten::permute(%182), location=einsum@functional.py:327, dims=(0, 1, 2, 4, 3)
  %184 = f32[192,256,768]{2,1,0} aten::view(%183), location=einsum@functional.py:327, output_size=(192, 256, 768), ROOT=17
  %185 = f32[64,768,12,16]{3,2,1,0} aten::convolution_overrideable(%164, %167), location=_unfold_conv@sliding_chunks.py:27, stride=(1, 1), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
  %186 = f32[12,16,768,64]{1,0,2,3} aten::permute(%185), location=einsum@functional.py:327, dims=(2, 3, 1, 0)
  %187 = f32[12,16,64,768]{1,0,3,2} aten::permute(%186), location=einsum@functional.py:327, dims=(0, 1, 3, 2)
  %188 = f32[12,16,1,64,768]{4,3,2,1,0} aten::view(%187), location=einsum@functional.py:327, output_size=(12, 16, 1, 64, 768)
  %189 = f32[12,16,768,64,1]{2,3,4,1,0} aten::permute(%188), location=einsum@functional.py:327, dims=(0, 1, 4, 3, 2)
  %190 = f32[192,768,64]{2,1,0} aten::view(%189), location=einsum@functional.py:327, output_size=(192, 768, 64), ROOT=18
  %191 = f32[] prim::Constant(), location=forward@longformer.py:230, value=1
  %192 = f32[] prim::Constant(), location=forward@longformer.py:230, value=0
  %193 = f32[] aten::mul(%192, %191), location=forward@longformer.py:230
  %194 = f32[192,256,64]{2,1,0} aten::matmul(%184, %190), location=einsum@functional.py:327
  %195 = f32[12,16,256,1,64]{4,3,2,1,0} aten::view(%194), location=forward@longformer.py:230, output_size=(12, 16, 256, 1, 64)
  %196 = f32[12,16,256,64,1]{3,4,2,1,0} aten::permute(%195), location=forward@longformer.py:230, dims=(0, 1, 2, 4, 3)
  %197 = f32[12,16,256,64]{3,2,1,0} aten::view(%196), location=forward@longformer.py:230, output_size=(12, 16, 256, 64)
  %198 = f32[1,12,4096,64]{3,2,1,0} aten::view(%197), location=forward@longformer.py:230, output_size=(1, 12, 4096, 64)
  %199 = f32[1,4096,12,64]{3,1,2,0} aten::permute(%198), location=forward@longformer.py:230, dims=(0, 2, 1, 3)
  %200 = f32[1,4096,12,64]{3,2,1,0} aten::add(%199, %193), location=forward@longformer.py:230, ROOT=19
  %201 = f32[4096,1,12,64]{3,2,0,1} aten::permute(%200), location=training_step@test_tpu.py:61, dims=(1, 0, 2, 3)
  %202 = f32[4096,1,768]{2,1,0} aten::view(%201), location=training_step@test_tpu.py:61, output_size=(4096, 1, 768)
  %203 = f32[1,4096,768]{2,0,1} aten::permute(%202), location=training_step@test_tpu.py:61, dims=(1, 0, 2), ROOT=20
  %204 = f32[] aten::sum(%203), location=training_step@test_tpu.py:61, dimensions=(0, 1, 2), keep_reduced_dimensions=0, dtype=6, ROOT=21
}```
ibeltagy commented 4 years ago

what is xprof? is this something I can run myself or is it internal?

jysohn23 commented 4 years ago

Yes it's this capture_tpu_profile tool: https://cloud.google.com/tpu/docs/cloud-tpu-tools#install_cloud_tpu_profiler

Unfortunately, currently there are some TF APIs bundled into the profiler client so you'll need to run the profiler from a VM that has TF installed, and also be in the same network as the TPU (so on GCP same project, same network). We're working on breaking out the TF bits fyi.

davidel commented 4 years ago

@ibeltagy If you get the HLO graph, maybe @JackCaoG can run it with run_hlo_module internally and get the xprof data. From there to actually being able to view it externally ... would require TensorBoard instances to be brought up, unless things changed.

ibeltagy commented 4 years ago

@davidel, with the HLO graph, I am assuming you mean the list of operations like

IR {
  %0 = f32[768,768]{1,0} xla::device_data(), location=linear@functional.py:1676, device=TPU:0
   ....
}

which I posted in the previous comment, right? The graph I posted earlier is for a tiny part of the code, or is it more useful to get the whole thing?

davidel commented 4 years ago

That is the IR graph. The HLO graph is the graph which is fed to XLA, by lowering the IR into HLO. There are a few ways to get that, but to be sure it's best to run debug_run.py with the --hlo flag, for 5..6 steps and post the tarball.

ibeltagy commented 4 years ago

Debugging information with the hlo grpah debug-hlo.tar.gz

Just curious, why am I getting 20 graphs even though I didn't get any aten:: and, AFAICT, all the forward/backward passes are the same?

davidel commented 4 years ago

We dump every XLA execution graph in the mode debug_run.py runs with, independently on whether it is a new or old one.

@JackCaoG try to get the xprof of the attached graph and have a look with XLA team.

graph_0021.hlo.txt

JackCaoG commented 4 years ago

@ibeltagy I have collected the xprof data and open an issue with XLA team.

JackCaoG commented 4 years ago

@ibeltagy Could you share the your GPU setup and memory/runtime, and how does it compared to your TPU setup after the convolution trick? FYI we usually do chip to chip comparison, each TPU chips has 2 TPU cores. 4*V100 to a v3-8 is fair comparison.

blakehechtman commented 4 years ago

Looking at the profile the most expensive parts are the atten::unselect operations followed by the reshapes and transposes internal to atten::einsum.

ibeltagy commented 4 years ago

Thanks, @blakehechtman. einsum is called multiple times in the code, do you know which one? any suggestions on how to fix this? Is this expensive in terms of memory or compute time? if compute, any suggestions on how to improve memory as well?

davidel commented 4 years ago

The xla::unselect should come from in-place updating views.

blakehechtman commented 4 years ago

Actually when I looked deeper, I think the einsums were fine, it was only the in-place updates of views that seemed to be very expensive. This is around 80% of the step time.

ibeltagy commented 4 years ago

80%!! ok, I will try to reorganize the code to reduce number of .view and .reshape operations. Would that be enough?

Any thoughts about memory optimizations as well?

davidel commented 4 years ago

Views are "OK" if you read from them. They generate xla::unselect once you write into them.

JackCaoG commented 4 years ago

@ibeltagy FYI, if you search for unselct in the hlo, the metadata should tell you where are they coming from.

  %pad.350 = pred[12,256,513]{2,1,0} pad(pred[12,256,257]{2,1,0} %broadcast.346, pred[] %constant.347), padding=0_0x0_0x256_0, metadata={op_type="xla::unselect" op_name="aten::masked_fill.1" source_file="mask_invalid_locations@diagonaled_mm_tvm.py" source_line=320}
stale[bot] commented 4 years ago

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.

JackCaoG commented 4 years ago

Hi @ibeltagy, there is a xla side optimization merge in recently and that should improve the speed of unselect and the Longformer. The change is in today's nightly build. Please let me know if that helps when you got some time to try it out.

ibeltagy commented 4 years ago

Great, thanks @JackCaoG for the update. Will give it a try and let you know.