Open ibeltagy opened 4 years ago
Sorry, super busy this week. I will try to run _sliding_chunks_matmul_qk
and check the hlo generated next week.
I looked at the debug output you posted, it seems like compilation stabilized after step 4.
[MetricsData; step=7]
Metric: CompileTime
TotalSamples: 4
Accumulator: 02m10s739ms497.948us
ValueRate: 891ms118.518us / second
Rate: 0.0274741 / second
Percentiles: 1%=387ms939.013us; 5%=387ms939.013us; 10%=387ms939.013us; 20%=387ms939.013us; 50%=53s200ms848.581us; 80%=53s279ms453.114us; 90%=53s279ms453.114us; 95%=53s279ms453.114us; 99%=53s279ms453.114us
Metric: DeviceLockWait
TotalSamples: 14
Accumulator: 004ms757.483us
ValueRate: 021.585us / second
Rate: 0.0804229 / second
Percentiles: 1%=001.957us; 5%=001.957us; 10%=002.068us; 20%=002.168us; 50%=003.370us; 80%=004.188us; 90%=004.416us; 95%=004ms717.704us; 99%=004ms717.704us
Metric: ExecuteTime
TotalSamples: 12
Accumulator: 06s548ms137.257us
ValueRate: 032ms024.149us / second
Rate: 0.0692647 / second
Percentiles: 1%=009ms817.567us; 5%=009ms817.567us; 10%=304ms702.798us; 20%=304ms040.119us; 50%=352ms252.134us; 80%=698ms730.039us; 90%=777ms375.913us; 95%=788ms461.432us; 99%=788ms461.432us
From the above metric, it seems like majority of the time was spent on Compile (2m vs 6s). Could you run it with a bit more steps and post the debug data again? That will help us better understanding the speed bottleneck. 😄
Here's the new debug info: debug.tar.gz. I ran it for more steps but still, the last CompileTime Accumulator shows 2minutes even though if you check the logs
file you will see that it is done compiling in the first few steps and it is not compiling anymore at the end. Could be an issue with the step-wise metric report?
Yup, from the metric it looks like it does not recompile since step3(TotalSamples stayed at 4 from step 3->step 15). Accumulator just shows accumulated time spent during this training. We will try to come up with a document that explain the metric better soon.
[MetricsData; step=15]
Metric: CompileTime
TotalSamples: 4
Accumulator: 02m16s671ms485.638us
ValueRate: 895ms494.158us / second
Rate: 0.0264018 / second
Percentiles: 1%=420ms911.135us; 5%=420ms911.135us; 10%=420ms911.135us; 20%=420ms911.135us; 50%=55s485ms865.246us; 80%=56s059ms455.713us; 90%=56s059ms455.713us; 95%=56s059ms455.713us; 99%=56s059ms455.713us
Metric: DeviceLockWait
TotalSamples: 30
Accumulator: 003ms086.302us
ValueRate: 012.542us / second
Rate: 0.121917 / second
Percentiles: 1%=001.792us; 5%=001.867us; 10%=001.955us; 20%=002.106us; 50%=003.197us; 80%=003.529us; 90%=003.974us; 95%=004.257us; 99%=003ms004.072us
Metric: ExecuteTime
TotalSamples: 28
Accumulator: 14s586ms120.483us
ValueRate: 055ms405.222us / second
Rate: 0.114186 / second
Percentiles: 1%=010ms718.758us; 5%=303ms316.149us; 10%=304ms597.979us; 20%=304ms652.587us; 50%=359ms343.030us; 80%=698ms693.051us; 90%=699ms788.650us; 95%=781ms773.674us; 99%=797ms489.440us
It looks like each step execute time is around 1s, 50%=359ms343.030us
also looks reasonable. I don't think there is a bug in our compiler lowering but there might be things we can improve. How many steps do you run for this model? Does this 1s
execute time per steps matches your observation when running the full model?
Yes, it takes slightly longer than 1s/step which is usable but slower than on GPUs and I was wondering if you have insights on how to optimize it. AFAIK, small changes in things like slicing and reshaping can lead to big gains on TPUs.
For reference, the hugginface RoBERTa model on TPU is 1.18x faster than on GPUs, while Longformer is like 1.7x slower making it 2x slower than expected. As mentioned earlier, this is probably just optimizing the functions _sliding_chunks_matmul_qk
and _sliding_chunks_matmul_pv
.
@ibeltagy I don't see any obvious optimization. @davidel any thoughts? I will try to run _sliding_chunks_matmul_qk
and _sliding_chunks_matmul_pv
separately, get the HLO and check with the XLA team.
Certainly that unfold loop is much more expensive WRT CPU/GPU which can represent the tensor slices with stride/size games over the existing buffer. To try to understand where the bottleneck might be, it'd require exporting XLA_IR_DEBUG and XLA_HLO_DEBUG (it will run slow) and getting an xprof trace. But it should almost certainly be the unfold loop thing.
Hi @ibeltagy , I worked with Blake from the XLA team and have this new convolution approach of unfold which use much less memory than my first approach. This should also be faster than the slice approach. If you could try this idea and let us know if this helps that would be great. If you encounter any long compile and slower execute time, we are happy to investigate.
import torch
import torch_xla
import torch_xla.core.xla_model as xm
# [12, 4096, 64], len=768, step=256
base = torch.randn(12, 4096, 64, device=xm.xla_device())
res_native = base.cpu().unfold(1,768,256) # [12, 14, 64, 768]
# conv approach on xla
reshape_base = base.reshape([12,16,256,64])
# transpose the input and filter to better utilize the TPU performance
transpose_base = reshape_base.permute(3, 2, 0 ,1) # [64, 256, 12, 16]
filter = torch.eye(768, device=xm.xla_device()).view([1,3,256,768]).permute(3,2,0,1) # [768,256,1,3]
res = torch.nn.functional.conv2d(transpose_base, filter) # [64, 768, 12, 16]
res = res.permute(2,3,0,1) # [12, 14, 64, 768]
print(res.equal(res_native))
I chatted with Blake about the general rule of thumb to optimize the performance on the TPU. I think the idea is to do things in as big of pieces as possible.The input was vectorized by [12,64], now after the reshape and transpose it is vectorized by [12,64,256], which maps better to TPU's 8, 128 registers.
If this approach works well I will change my lowering of unfold to this so you won't need this manual work in the future.
Thanks. Will give it a try and let you know.
Here's the IR graph with the iterative unfold. Anything stands out other than the unfold
? how do you know which operations are expensive? I will try the convolution trick and post the IR graph for that too.
IR {
%0 = bf16[768,768]{1,0} xla::device_data(), location=linear@functional.py:1676, device=TPU:0
%1 = bf16[768,768]{0,1} aten::permute(%0), location=linear@functional.py:1676, dims=(1, 0), ROOT=0
%2 = s64[1,4096]{1,0} xla::device_data(), location=forward@test_tpu.py:53, device=TPU:0
%3 = s64[1,4096]{1,0} xla::select(%2), location=forward@test_tpu.py:53, dim=0, start=0, end=1, stride=1
%4 = s64[1,4096]{1,0} xla::select(%3), location=forward@test_tpu.py:53, dim=1, start=0, end=4096, stride=1
%5 = s64[1,4096,1]{2,1,0} aten::view(%4), location=forward@test_tpu.py:53, output_size=(1, 4096, 1)
%6 = s64[1,4096,768]{2,1,0} aten::expand(%5), location=forward@test_tpu.py:53, size=(1, 4096, 768)
%7 = bf16[1,4096,768]{2,1,0} xla::cast(%6), location=forward@test_tpu.py:53, type=bf16, dtype=Float, stype=Long
%8 = bf16[4096,1,768]{2,0,1} aten::permute(%7), location=linear@functional.py:1676, dims=(1, 0, 2)
%9 = bf16[4096,768]{1,0} aten::view(%8), location=linear@functional.py:1676, output_size=(4096, 768), ROOT=1
%10 = bf16[768,768]{1,0} xla::device_data(), location=linear@functional.py:1676, device=TPU:0
%11 = bf16[768,768]{0,1} aten::permute(%10), location=linear@functional.py:1676, dims=(1, 0), ROOT=2
%12 = bf16[4096,1,768]{2,0,1} aten::permute(%7), location=linear@functional.py:1676, dims=(1, 0, 2)
%13 = bf16[4096,768]{1,0} aten::view(%12), location=linear@functional.py:1676, output_size=(4096, 768), ROOT=3
%14 = bf16[768,768]{1,0} xla::device_data(), location=linear@functional.py:1676, device=TPU:0
%15 = bf16[768,768]{0,1} aten::permute(%14), location=linear@functional.py:1676, dims=(1, 0), ROOT=4
%16 = bf16[4096,1,768]{2,0,1} aten::permute(%7), location=linear@functional.py:1676, dims=(1, 0, 2)
%17 = bf16[4096,768]{1,0} aten::view(%16), location=linear@functional.py:1676, output_size=(4096, 768), ROOT=5
%18 = bf16[] xla::device_data(), location=forward@longformer.py:152, device=TPU:0
%19 = bf16[] prim::Constant(), location=linear@functional.py:1678, value=1
%20 = bf16[768]{0} aten::expand(%19), location=linear@functional.py:1678, size=(768)
%21 = bf16[768]{0} xla::device_data(), location=linear@functional.py:1678, device=TPU:0
%22 = bf16[768]{0} aten::mul(%21, %20), location=linear@functional.py:1678
%23 = bf16[4096,768]{1,0} aten::mm(%9, %1), location=linear@functional.py:1676
%24 = bf16[4096,1,768]{2,1,0} aten::view(%23), location=linear@functional.py:1678, output_size=(4096, 1, 768)
%25 = bf16[4096,1,768]{2,1,0} aten::add(%24, %22), location=linear@functional.py:1678
%26 = bf16[4096,768]{1,0} aten::view(%25), location=forward@longformer.py:152, output_size=(4096, 768)
%27 = bf16[4096,1,768]{2,1,0} aten::view(%26), location=forward@longformer.py:152, output_size=(4096, 1, 768)
%28 = bf16[4096,1,768]{2,1,0} aten::div(%27, %18), location=forward@longformer.py:152
%29 = bf16[4096,768]{1,0} aten::view(%28), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 768)
%30 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%31 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%30), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%32 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%31), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%33 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%32), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%34 = bf16[12,4096,64]{2,1,0} aten::view(%33), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%35 = bf16[12,512,64]{2,1,0} xla::select(%34), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=0, end=512, stride=1, ROOT=6
%36 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%37 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%36), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%38 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%37), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%39 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%38), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%40 = bf16[12,4096,64]{2,1,0} aten::view(%39), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%41 = bf16[12,512,64]{2,1,0} xla::select(%40), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=256, end=768, stride=1, ROOT=7
%42 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%43 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%42), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%44 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%43), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%45 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%44), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%46 = bf16[12,4096,64]{2,1,0} aten::view(%45), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%47 = bf16[12,512,64]{2,1,0} xla::select(%46), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=512, end=1024, stride=1, ROOT=8
%48 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%49 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%48), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%50 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%49), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%51 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%50), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%52 = bf16[12,4096,64]{2,1,0} aten::view(%51), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%53 = bf16[12,512,64]{2,1,0} xla::select(%52), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=768, end=1280, stride=1, ROOT=9
%54 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%55 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%54), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%56 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%55), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%57 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%56), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%58 = bf16[12,4096,64]{2,1,0} aten::view(%57), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%59 = bf16[12,512,64]{2,1,0} xla::select(%58), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1024, end=1536, stride=1, ROOT=10
%60 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%61 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%60), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%62 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%61), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%63 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%62), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%64 = bf16[12,4096,64]{2,1,0} aten::view(%63), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%65 = bf16[12,512,64]{2,1,0} xla::select(%64), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1280, end=1792, stride=1, ROOT=11
%66 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%67 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%66), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%68 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%67), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%69 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%68), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%70 = bf16[12,4096,64]{2,1,0} aten::view(%69), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%71 = bf16[12,512,64]{2,1,0} xla::select(%70), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1536, end=2048, stride=1, ROOT=12
%72 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%73 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%72), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%74 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%73), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%75 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%74), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%76 = bf16[12,4096,64]{2,1,0} aten::view(%75), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%77 = bf16[12,512,64]{2,1,0} xla::select(%76), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1792, end=2304, stride=1, ROOT=13
%78 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%79 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%78), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%80 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%79), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%81 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%80), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%82 = bf16[12,4096,64]{2,1,0} aten::view(%81), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%83 = bf16[12,512,64]{2,1,0} xla::select(%82), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2048, end=2560, stride=1, ROOT=14
%84 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%85 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%84), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%86 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%85), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%87 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%86), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%88 = bf16[12,4096,64]{2,1,0} aten::view(%87), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%89 = bf16[12,512,64]{2,1,0} xla::select(%88), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2304, end=2816, stride=1, ROOT=15
%90 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%91 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%90), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%92 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%91), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%93 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%92), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%94 = bf16[12,4096,64]{2,1,0} aten::view(%93), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%95 = bf16[12,512,64]{2,1,0} xla::select(%94), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2560, end=3072, stride=1, ROOT=16
%96 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%97 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%96), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%98 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%97), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%99 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%98), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%100 = bf16[12,4096,64]{2,1,0} aten::view(%99), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%101 = bf16[12,512,64]{2,1,0} xla::select(%100), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2816, end=3328, stride=1, ROOT=17
%102 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%103 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%102), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%104 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%103), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%105 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%104), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%106 = bf16[12,4096,64]{2,1,0} aten::view(%105), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%107 = bf16[12,512,64]{2,1,0} xla::select(%106), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=3072, end=3584, stride=1, ROOT=18
%108 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%109 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%108), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%110 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%109), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%111 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%110), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%112 = bf16[12,4096,64]{2,1,0} aten::view(%111), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%113 = bf16[12,512,64]{2,1,0} xla::select(%112), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=3328, end=3840, stride=1, ROOT=19
%114 = bf16[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%115 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%114), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%116 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%115), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%117 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%116), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%118 = bf16[12,4096,64]{2,1,0} aten::view(%117), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%119 = bf16[12,512,64]{2,1,0} xla::select(%118), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=3584, end=4096, stride=1, ROOT=20
%120 = bf16[] prim::Constant(), location=linear@functional.py:1678, value=1
%121 = bf16[768]{0} aten::expand(%120), location=linear@functional.py:1678, size=(768)
%122 = bf16[768]{0} xla::device_data(), location=linear@functional.py:1678, device=TPU:0
%123 = bf16[768]{0} aten::mul(%122, %121), location=linear@functional.py:1678
%124 = bf16[4096,768]{1,0} aten::mm(%13, %11), location=linear@functional.py:1676
%125 = bf16[4096,1,768]{2,1,0} aten::view(%124), location=linear@functional.py:1678, output_size=(4096, 1, 768)
%126 = bf16[4096,1,768]{2,1,0} aten::add(%125, %123), location=linear@functional.py:1678
%127 = bf16[4096,768]{1,0} aten::view(%126), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 768)
%128 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%129 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%128), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%130 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%129), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%131 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%130), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%132 = bf16[12,4096,64]{2,1,0} aten::view(%131), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%133 = bf16[12,512,64]{2,1,0} xla::select(%132), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=0, end=512, stride=1, ROOT=21
%134 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%135 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%134), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%136 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%135), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%137 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%136), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%138 = bf16[12,4096,64]{2,1,0} aten::view(%137), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%139 = bf16[12,512,64]{2,1,0} xla::select(%138), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=256, end=768, stride=1, ROOT=22
%140 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%141 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%140), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%142 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%141), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%143 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%142), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%144 = bf16[12,4096,64]{2,1,0} aten::view(%143), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%145 = bf16[12,512,64]{2,1,0} xla::select(%144), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=512, end=1024, stride=1, ROOT=23
%146 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%147 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%146), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%148 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%147), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%149 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%148), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%150 = bf16[12,4096,64]{2,1,0} aten::view(%149), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%151 = bf16[12,512,64]{2,1,0} xla::select(%150), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=768, end=1280, stride=1, ROOT=24
%152 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%153 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%152), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%154 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%153), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%155 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%154), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%156 = bf16[12,4096,64]{2,1,0} aten::view(%155), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%157 = bf16[12,512,64]{2,1,0} xla::select(%156), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1024, end=1536, stride=1, ROOT=25
%158 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%159 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%158), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%160 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%159), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%161 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%160), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%162 = bf16[12,4096,64]{2,1,0} aten::view(%161), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%163 = bf16[12,512,64]{2,1,0} xla::select(%162), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1280, end=1792, stride=1, ROOT=26
%164 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%165 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%164), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%166 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%165), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%167 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%166), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%168 = bf16[12,4096,64]{2,1,0} aten::view(%167), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%169 = bf16[12,512,64]{2,1,0} xla::select(%168), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1536, end=2048, stride=1, ROOT=27
%170 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%171 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%170), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%172 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%171), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%173 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%172), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%174 = bf16[12,4096,64]{2,1,0} aten::view(%173), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%175 = bf16[12,512,64]{2,1,0} xla::select(%174), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1792, end=2304, stride=1, ROOT=28
%176 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%177 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%176), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%178 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%177), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%179 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%178), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%180 = bf16[12,4096,64]{2,1,0} aten::view(%179), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%181 = bf16[12,512,64]{2,1,0} xla::select(%180), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2048, end=2560, stride=1, ROOT=29
%182 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%183 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%182), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%184 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%183), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%185 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%184), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%186 = bf16[12,4096,64]{2,1,0} aten::view(%185), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%187 = bf16[12,512,64]{2,1,0} xla::select(%186), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2304, end=2816, stride=1, ROOT=30
%188 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%189 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%188), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%190 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%189), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%191 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%190), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%192 = bf16[12,4096,64]{2,1,0} aten::view(%191), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%193 = bf16[12,512,64]{2,1,0} xla::select(%192), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2560, end=3072, stride=1, ROOT=31
%194 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%195 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%194), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%196 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%195), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%197 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%196), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%198 = bf16[12,4096,64]{2,1,0} aten::view(%197), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%199 = bf16[12,512,64]{2,1,0} xla::select(%198), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2816, end=3328, stride=1, ROOT=32
%200 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%201 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%200), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%202 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%201), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%203 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%202), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%204 = bf16[12,4096,64]{2,1,0} aten::view(%203), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%205 = bf16[12,512,64]{2,1,0} xla::select(%204), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=3072, end=3584, stride=1, ROOT=33
%206 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%207 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%206), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%208 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%207), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%209 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%208), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%210 = bf16[12,4096,64]{2,1,0} aten::view(%209), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%211 = bf16[12,512,64]{2,1,0} xla::select(%210), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=3328, end=3840, stride=1, ROOT=34
%212 = bf16[4096,1,768]{2,1,0} aten::view(%127), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 768)
%213 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%212), location=_unfold_loop@sliding_chunks.py:17, output_size=(4096, 1, 12, 64)
%214 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%213), location=_unfold_loop@sliding_chunks.py:17, dims=(1, 0, 2, 3)
%215 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%214), location=_unfold_loop@sliding_chunks.py:17, dims=(0, 2, 1, 3)
%216 = bf16[12,4096,64]{2,1,0} aten::view(%215), location=_unfold_loop@sliding_chunks.py:17, output_size=(12, 4096, 64)
%217 = bf16[12,512,64]{2,1,0} xla::select(%216), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=3584, end=4096, stride=1, ROOT=35
%218 = bf16[12,15,512,64]{3,2,1,0} aten::stack(%35, %41, %47, %53, %59, %65, %71, %77, %83, %89, %95, %101, %107, %113, %119), location=_unfold_loop@sliding_chunks.py:17, dim=1
%219 = bf16[12,15,512,64]{3,2,1,0} aten::permute(%218), location=einsum@functional.py:327, dims=(0, 1, 2, 3)
%220 = bf16[12,15,512,1,64]{4,3,2,1,0} aten::view(%219), location=einsum@functional.py:327, output_size=(12, 15, 512, 1, 64)
%221 = bf16[12,15,512,64,1]{3,4,2,1,0} aten::permute(%220), location=einsum@functional.py:327, dims=(0, 1, 2, 4, 3)
%222 = bf16[180,512,64]{2,1,0} aten::view(%221), location=einsum@functional.py:327, output_size=(180, 512, 64), ROOT=36
%223 = bf16[12,15,512,64]{3,2,1,0} aten::stack(%133, %139, %145, %151, %157, %163, %169, %175, %181, %187, %193, %199, %205, %211, %217), location=_unfold_loop@sliding_chunks.py:17, dim=1
%224 = bf16[12,15,512,64]{3,2,1,0} aten::permute(%223), location=einsum@functional.py:327, dims=(0, 1, 2, 3)
%225 = bf16[12,15,1,512,64]{4,3,2,1,0} aten::view(%224), location=einsum@functional.py:327, output_size=(12, 15, 1, 512, 64)
%226 = bf16[12,15,64,512,1]{2,3,4,1,0} aten::permute(%225), location=einsum@functional.py:327, dims=(0, 1, 4, 3, 2)
%227 = bf16[180,64,512]{2,1,0} aten::view(%226), location=einsum@functional.py:327, output_size=(180, 64, 512), ROOT=37
%228 = pred[1,256,1,257]{3,1,2,0} xla::device_data(), location=_get_invalid_locations_mask@diagonaled_mm_tvm.py:312, device=TPU:0
%229 = pred[] prim::Constant(), location=_get_invalid_locations_mask@diagonaled_mm_tvm.py:312, value=0
%230 = pred[1,256,1,257]{3,2,1,0} aten::expand(%229), location=_get_invalid_locations_mask@diagonaled_mm_tvm.py:312, size=(1, 256, 1, 257)
%231 = pred[1,256,1,257]{3,2,1,0} xla::as_strided_view_update(%230, %228), location=mask_invalid_locations@diagonaled_mm_tvm.py:323, size=(1, 256, 1, 257), stride=(65792, 257, 257, 1), storage_offset=0
%232 = pred[1,256,1,257]{3,2,1,0} aten::as_strided(%231), location=mask_invalid_locations@diagonaled_mm_tvm.py:323, size=(1, 256, 1, 257), stride=(65792, 257, 257, 1), storage_offset=0, ROOT=39
%233 = pred[1,256,12,257]{3,2,1,0} aten::expand(%232), location=mask_invalid_locations@diagonaled_mm_tvm.py:323, size=(1, 256, 12, 257), ROOT=42
%234 = pred[1,256,1,257]{3,1,2,0} xla::device_data(), location=_get_invalid_locations_mask@diagonaled_mm_tvm.py:313, device=TPU:0
%235 = pred[] prim::Constant(), location=_get_invalid_locations_mask@diagonaled_mm_tvm.py:313, value=0
%236 = pred[1,256,1,257]{3,2,1,0} aten::expand(%235), location=_get_invalid_locations_mask@diagonaled_mm_tvm.py:313, size=(1, 256, 1, 257)
%237 = pred[1,256,1,257]{3,2,1,0} xla::as_strided_view_update(%236, %234), location=mask_invalid_locations@diagonaled_mm_tvm.py:319, size=(1, 256, 1, 257), stride=(65792, 257, 257, 1), storage_offset=0
%238 = pred[1,256,1,257]{3,2,1,0} aten::as_strided(%237), location=mask_invalid_locations@diagonaled_mm_tvm.py:319, size=(1, 256, 1, 257), stride=(65792, 257, 257, 1), storage_offset=0, ROOT=40
%239 = pred[1,256,12,257]{3,2,1,0} aten::expand(%238), location=mask_invalid_locations@diagonaled_mm_tvm.py:319, size=(1, 256, 12, 257), ROOT=41
%240 = bf16[180,512,512]{2,1,0} aten::matmul(%222, %227), location=einsum@functional.py:327
%241 = bf16[12,15,512,1,512]{4,3,2,1,0} aten::view(%240), location=_pad@functional.py:3547, output_size=(12, 15, 512, 1, 512)
%242 = bf16[12,15,512,512,1]{3,4,2,1,0} aten::permute(%241), location=_pad@functional.py:3547, dims=(0, 1, 2, 4, 3)
%243 = bf16[12,15,512,512]{3,2,1,0} aten::view(%242), location=_pad@functional.py:3547, output_size=(12, 15, 512, 512)
%244 = bf16[12,15,513,512]{3,2,1,0} aten::constant_pad_nd(%243), location=_pad@functional.py:3547, pad=(0, 0, 0, 1, 0, 0, 0, 0), value=0
%245 = bf16[12,15,512,513]{3,2,1,0} aten::view(%244), location=sliding_chunks_matmul_qk@sliding_chunks.py:97, output_size=(12, 15, 512, 513)
%246 = bf16[12,15,512,513]{3,2,1,0} xla::select(%245), location=sliding_chunks_matmul_qk@sliding_chunks.py:97, dim=0, start=0, end=12, stride=1
%247 = bf16[12,1,512,513]{3,2,1,0} xla::generic_slice(%246), location=sliding_chunks_matmul_qk@sliding_chunks.py:97, base_indices=(0, 0, 0, 0), sizes=(12, 1, 512, 513)
%248 = bf16[12,512,513]{2,1,0} aten::view(%247), location=sliding_chunks_matmul_qk@sliding_chunks.py:97, output_size=(12, 512, 513)
%249 = bf16[12,255,513]{2,1,0} xla::select(%248), location=sliding_chunks_matmul_qk@sliding_chunks.py:97, dim=1, start=0, end=255, stride=1
%250 = bf16[12,255,255]{2,1,0} xla::select(%249), location=sliding_chunks_matmul_qk@sliding_chunks.py:97, dim=2, start=258, end=513, stride=1
%251 = bf16[12,255,255]{2,1,0} aten::view(%250), location=sliding_chunks_matmul_qk@sliding_chunks.py:97, output_size=(12, 255, 255)
%252 = bf16[12,15,512,513]{3,2,1,0} aten::view(%244), location=sliding_chunks_matmul_qk@sliding_chunks.py:96, output_size=(12, 15, 512, 513)
%253 = bf16[12,15,512,513]{3,2,1,0} xla::select(%252), location=sliding_chunks_matmul_qk@sliding_chunks.py:96, dim=0, start=0, end=12, stride=1
%254 = bf16[12,15,512,513]{3,2,1,0} xla::select(%253), location=sliding_chunks_matmul_qk@sliding_chunks.py:96, dim=1, start=0, end=15, stride=1
%255 = bf16[12,15,256,513]{3,2,1,0} xla::select(%254), location=sliding_chunks_matmul_qk@sliding_chunks.py:96, dim=2, start=255, end=511, stride=1
%256 = bf16[12,15,256,256]{3,2,1,0} xla::select(%255), location=sliding_chunks_matmul_qk@sliding_chunks.py:96, dim=3, start=257, end=513, stride=1
%257 = bf16[12,15,256,256]{3,2,1,0} aten::view(%256), location=sliding_chunks_matmul_qk@sliding_chunks.py:96, output_size=(12, 15, 256, 256)
%258 = bf16[12,15,512,513]{3,2,1,0} aten::view(%244), location=sliding_chunks_matmul_qk@sliding_chunks.py:94, output_size=(12, 15, 512, 513)
%259 = bf16[12,15,512,513]{3,2,1,0} xla::select(%258), location=sliding_chunks_matmul_qk@sliding_chunks.py:94, dim=0, start=0, end=12, stride=1
%260 = bf16[12,1,512,513]{3,2,1,0} xla::generic_slice(%259), location=sliding_chunks_matmul_qk@sliding_chunks.py:94, base_indices=(0, 14, 0, 0), sizes=(12, 1, 512, 513)
%261 = bf16[12,512,513]{2,1,0} aten::view(%260), location=sliding_chunks_matmul_qk@sliding_chunks.py:94, output_size=(12, 512, 513)
%262 = bf16[12,256,513]{2,1,0} xla::select(%261), location=sliding_chunks_matmul_qk@sliding_chunks.py:94, dim=1, start=256, end=512, stride=1
%263 = bf16[12,256,257]{2,1,0} xla::select(%262), location=sliding_chunks_matmul_qk@sliding_chunks.py:94, dim=2, start=0, end=257, stride=1
%264 = bf16[12,256,257]{2,1,0} aten::view(%263), location=sliding_chunks_matmul_qk@sliding_chunks.py:94, output_size=(12, 256, 257)
%265 = bf16[12,15,512,513]{3,2,1,0} aten::view(%244), location=sliding_chunks_matmul_qk@sliding_chunks.py:93, output_size=(12, 15, 512, 513)
%266 = bf16[12,15,512,513]{3,2,1,0} xla::select(%265), location=sliding_chunks_matmul_qk@sliding_chunks.py:93, dim=0, start=0, end=12, stride=1
%267 = bf16[12,15,512,513]{3,2,1,0} xla::select(%266), location=sliding_chunks_matmul_qk@sliding_chunks.py:93, dim=1, start=0, end=15, stride=1
%268 = bf16[12,15,256,513]{3,2,1,0} xla::select(%267), location=sliding_chunks_matmul_qk@sliding_chunks.py:93, dim=2, start=0, end=256, stride=1
%269 = bf16[12,15,256,257]{3,2,1,0} xla::select(%268), location=sliding_chunks_matmul_qk@sliding_chunks.py:93, dim=3, start=0, end=257, stride=1
%270 = bf16[12,15,256,257]{3,2,1,0} aten::view(%269), location=sliding_chunks_matmul_qk@sliding_chunks.py:93, output_size=(12, 15, 256, 257)
%271 = bf16[] prim::Constant(), location=sliding_chunks_matmul_qk@sliding_chunks.py:89, value=0
%272 = bf16[12,16,256,513]{3,2,1,0} aten::expand(%271), location=sliding_chunks_matmul_qk@sliding_chunks.py:89, size=(12, 16, 256, 513)
%273 = bf16[12,16,256,513]{3,2,1,0} xla::select(%272), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=0, start=0, end=12, stride=1
%274 = bf16[12,15,256,513]{3,2,1,0} xla::select(%273), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=1, start=0, end=15, stride=1
%275 = bf16[12,15,256,513]{3,2,1,0} xla::select(%274), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=2, start=0, end=256, stride=1
%276 = bf16[12,15,256,513]{3,2,1,0} xla::unselect(%275, %270), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=3, start=256, end=513, stride=1
%277 = bf16[12,15,256,513]{3,2,1,0} xla::unselect(%274, %276), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=2, start=0, end=256, stride=1
%278 = bf16[12,16,256,513]{3,2,1,0} xla::unselect(%273, %277), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=1, start=0, end=15, stride=1
%279 = bf16[12,16,256,513]{3,2,1,0} xla::unselect(%272, %278), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=0, start=0, end=12, stride=1
%280 = bf16[12,16,256,513]{3,2,1,0} xla::select(%279), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=0, start=0, end=12, stride=1
%281 = bf16[12,1,256,513]{3,2,1,0} xla::generic_slice(%280), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, base_indices=(0, 15, 0, 0), sizes=(12, 1, 256, 513)
%282 = bf16[12,256,513]{2,1,0} aten::view(%281), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, output_size=(12, 256, 513)
%283 = bf16[12,256,513]{2,1,0} xla::select(%282), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=1, start=0, end=256, stride=1
%284 = bf16[12,256,513]{2,1,0} xla::unselect(%283, %264), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=2, start=256, end=513, stride=1
%285 = bf16[12,256,513]{2,1,0} xla::unselect(%282, %284), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=1, start=0, end=256, stride=1
%286 = bf16[12,1,256,513]{3,2,1,0} aten::view(%285), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, output_size=(12, 1, 256, 513)
%287 = bf16[12,16,256,513]{3,2,1,0} xla::update_slice(%280, %286), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, base_indices=(0, 15, 0, 0)
%288 = bf16[12,16,256,513]{3,2,1,0} xla::unselect(%279, %287), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=0, start=0, end=12, stride=1
%289 = bf16[12,16,256,513]{3,2,1,0} xla::select(%288), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=0, start=0, end=12, stride=1
%290 = bf16[12,15,256,513]{3,2,1,0} xla::select(%289), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=1, start=1, end=16, stride=1
%291 = bf16[12,15,256,513]{3,2,1,0} xla::select(%290), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=2, start=0, end=256, stride=1
%292 = bf16[12,15,256,513]{3,2,1,0} xla::unselect(%291, %257), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=3, start=0, end=256, stride=1
%293 = bf16[12,15,256,513]{3,2,1,0} xla::unselect(%290, %292), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=2, start=0, end=256, stride=1
%294 = bf16[12,16,256,513]{3,2,1,0} xla::unselect(%289, %293), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=1, start=1, end=16, stride=1
%295 = bf16[12,16,256,513]{3,2,1,0} xla::unselect(%288, %294), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=0, start=0, end=12, stride=1
%296 = bf16[12,16,256,513]{3,2,1,0} xla::select(%295), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=0, start=0, end=12, stride=1
%297 = bf16[12,1,256,513]{3,2,1,0} xla::generic_slice(%296), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, base_indices=(0, 0, 0, 0), sizes=(12, 1, 256, 513)
%298 = bf16[12,256,513]{2,1,0} aten::view(%297), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, output_size=(12, 256, 513)
%299 = bf16[12,255,513]{2,1,0} xla::select(%298), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=1, start=1, end=256, stride=1
%300 = bf16[12,255,513]{2,1,0} xla::unselect(%299, %251), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=2, start=1, end=256, stride=1
%301 = bf16[12,256,513]{2,1,0} xla::unselect(%298, %300), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=1, start=1, end=256, stride=1
%302 = bf16[12,1,256,513]{3,2,1,0} aten::view(%301), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, output_size=(12, 1, 256, 513)
%303 = bf16[12,16,256,513]{3,2,1,0} xla::update_slice(%296, %302), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, base_indices=(0, 0, 0, 0)
%304 = bf16[12,16,256,513]{3,2,1,0} xla::unselect(%295, %303), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=0, start=0, end=12, stride=1
%305 = bf16[1,12,4096,513]{3,2,1,0} aten::view(%304), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, output_size=(1, 12, 4096, 513)
%306 = bf16[1,4096,12,513]{3,1,2,0} aten::permute(%305), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dims=(0, 2, 1, 3)
%307 = bf16[1,4096,12,513]{3,1,2,0} xla::select(%306), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=0, start=0, end=1, stride=1
%308 = bf16[1,256,12,513]{3,1,2,0} xla::select(%307), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=1, start=0, end=256, stride=1
%309 = bf16[1,256,12,513]{3,1,2,0} xla::select(%308), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=2, start=0, end=12, stride=1
%310 = bf16[1,256,12,257]{3,1,2,0} xla::select(%309), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, dim=3, start=0, end=257, stride=1
%311 = bf16[1,256,12,257]{3,1,2,0} aten::masked_fill(%310, %239), scope=aten::masked_fill.1, location=mask_invalid_locations@diagonaled_mm_tvm.py:320, value=-inf
%312 = bf16[1,12,4096,513]{3,2,1,0} aten::view(%304), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, output_size=(1, 12, 4096, 513)
%313 = bf16[1,4096,12,513]{3,1,2,0} aten::permute(%312), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dims=(0, 2, 1, 3)
%314 = bf16[1,4096,12,513]{3,1,2,0} xla::select(%313), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=0, start=0, end=1, stride=1
%315 = bf16[1,256,12,513]{3,1,2,0} xla::select(%314), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=1, start=0, end=256, stride=1
%316 = bf16[1,256,12,513]{3,1,2,0} xla::select(%315), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=2, start=0, end=12, stride=1
%317 = bf16[1,256,12,513]{3,1,2,0} xla::unselect(%316, %311), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=3, start=0, end=257, stride=1
%318 = bf16[1,256,12,513]{3,1,2,0} xla::unselect(%315, %317), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=2, start=0, end=12, stride=1
%319 = bf16[1,4096,12,513]{3,1,2,0} xla::unselect(%314, %318), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=1, start=0, end=256, stride=1
%320 = bf16[1,4096,12,513]{3,1,2,0} xla::unselect(%313, %319), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=0, start=0, end=1, stride=1
%321 = bf16[1,12,4096,513]{3,2,1,0} aten::permute(%320), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dims=(0, 2, 1, 3)
%322 = bf16[12,16,256,513]{3,2,1,0} aten::view(%321), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, output_size=(12, 16, 256, 513)
%323 = bf16[1,12,4096,513]{3,2,1,0} aten::view(%322), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, output_size=(1, 12, 4096, 513)
%324 = bf16[1,4096,12,513]{3,1,2,0} aten::permute(%323), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dims=(0, 2, 1, 3)
%325 = bf16[1,4096,12,513]{3,1,2,0} xla::select(%324), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=0, start=0, end=1, stride=1
%326 = bf16[1,256,12,513]{3,1,2,0} xla::select(%325), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=1, start=3840, end=4096, stride=1
%327 = bf16[1,256,12,513]{3,1,2,0} xla::select(%326), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=2, start=0, end=12, stride=1
%328 = bf16[1,256,12,257]{3,1,2,0} xla::select(%327), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, dim=3, start=256, end=513, stride=1
%329 = bf16[1,256,12,257]{3,1,2,0} aten::masked_fill(%328, %233), scope=aten::masked_fill.2, location=mask_invalid_locations@diagonaled_mm_tvm.py:324, value=-inf
%330 = bf16[1,12,4096,513]{3,2,1,0} aten::view(%322), location=softmax@functional.py:1500, output_size=(1, 12, 4096, 513)
%331 = bf16[1,4096,12,513]{3,1,2,0} aten::permute(%330), location=softmax@functional.py:1500, dims=(0, 2, 1, 3)
%332 = bf16[1,4096,12,513]{3,1,2,0} xla::select(%331), location=softmax@functional.py:1500, dim=0, start=0, end=1, stride=1
%333 = bf16[1,256,12,513]{3,1,2,0} xla::select(%332), location=softmax@functional.py:1500, dim=1, start=3840, end=4096, stride=1
%334 = bf16[1,256,12,513]{3,1,2,0} xla::select(%333), location=softmax@functional.py:1500, dim=2, start=0, end=12, stride=1
%335 = bf16[1,256,12,513]{3,1,2,0} xla::unselect(%334, %329), location=softmax@functional.py:1500, dim=3, start=256, end=513, stride=1
%336 = bf16[1,256,12,513]{3,1,2,0} xla::unselect(%333, %335), location=softmax@functional.py:1500, dim=2, start=0, end=12, stride=1
%337 = bf16[1,4096,12,513]{3,1,2,0} xla::unselect(%332, %336), location=softmax@functional.py:1500, dim=1, start=3840, end=4096, stride=1
%338 = bf16[1,4096,12,513]{3,1,2,0} xla::unselect(%331, %337), location=softmax@functional.py:1500, dim=0, start=0, end=1, stride=1
%339 = bf16[1,12,4096,513]{3,2,1,0} aten::permute(%338), location=softmax@functional.py:1500, dims=(0, 2, 1, 3)
%340 = bf16[12,16,256,513]{3,2,1,0} aten::view(%339), location=softmax@functional.py:1500, output_size=(12, 16, 256, 513)
%341 = bf16[1,12,4096,513]{3,2,1,0} aten::view(%340), location=softmax@functional.py:1500, output_size=(1, 12, 4096, 513)
%342 = bf16[1,4096,12,513]{3,1,2,0} aten::permute(%341), location=softmax@functional.py:1500, dims=(0, 2, 1, 3), ROOT=38
%343 = bf16[1,4096,12,513]{3,1,2,0} aten::softmax(%342), location=softmax@functional.py:1500, dim=3, dtype=6, ROOT=43
%344 = bf16[] xla::device_data(), location=dropout@functional.py:973, device=TPU:0
%345 = bf16[1,4096,12,513]{3,2,1,0} aten::expand(%344), location=dropout@functional.py:973, size=(1, 4096, 12, 513)
%346 = s64[] xla::device_data(), location=dropout@functional.py:973, device=TPU:0
%347 = s64[] prim::Constant(), location=dropout@functional.py:973, value=214013
%348 = s64[] aten::mul(%347, %346), location=dropout@functional.py:973
%349 = s64[] prim::Constant(), location=dropout@functional.py:973, value=2.53101e+06
%350 = s64[] aten::add(%349, %348), location=dropout@functional.py:973
%351 = bf16[] xla::device_data(), location=dropout@functional.py:973, device=TPU:0
%352 = bf16[1,4096,12,513]{3,2,1,0} aten::expand(%351), location=dropout@functional.py:973, size=(1, 4096, 12, 513)
%353 = bf16[1,4096,12,513]{3,2,1,0} aten::bernoulli(%352, %350), location=dropout@functional.py:973
%354 = bf16[1,4096,12,513]{3,2,1,0} aten::div(%353, %345), location=dropout@functional.py:973, ROOT=44
%355 = bf16[] prim::Constant(), location=linear@functional.py:1678, value=1
%356 = bf16[768]{0} aten::expand(%355), location=linear@functional.py:1678, size=(768)
%357 = bf16[768]{0} xla::device_data(), location=linear@functional.py:1678, device=TPU:0
%358 = bf16[768]{0} aten::mul(%357, %356), location=linear@functional.py:1678
%359 = bf16[4096,768]{1,0} aten::mm(%17, %15), location=linear@functional.py:1676
%360 = bf16[4096,1,768]{2,1,0} aten::view(%359), location=linear@functional.py:1678, output_size=(4096, 1, 768)
%361 = bf16[4096,1,768]{2,1,0} aten::add(%360, %358), location=linear@functional.py:1678
%362 = bf16[4096,768]{1,0} aten::view(%361), location=_pad@functional.py:3547, output_size=(4096, 768)
%363 = bf16[4096,1,768]{2,1,0} aten::view(%362), location=_pad@functional.py:3547, output_size=(4096, 1, 768)
%364 = bf16[4096,1,12,64]{3,2,1,0} aten::view(%363), location=_pad@functional.py:3547, output_size=(4096, 1, 12, 64)
%365 = bf16[1,4096,12,64]{3,2,0,1} aten::permute(%364), location=_pad@functional.py:3547, dims=(1, 0, 2, 3)
%366 = bf16[1,12,4096,64]{3,1,0,2} aten::permute(%365), location=_pad@functional.py:3547, dims=(0, 2, 1, 3)
%367 = bf16[12,4096,64]{2,1,0} aten::view(%366), location=_pad@functional.py:3547, output_size=(12, 4096, 64)
%368 = bf16[12,4608,64]{2,1,0} aten::constant_pad_nd(%367), location=_pad@functional.py:3547, pad=(0, 0, 256, 256, 0, 0), value=-1
%369 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=0, end=768, stride=1, ROOT=45
%370 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=256, end=1024, stride=1, ROOT=46
%371 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=512, end=1280, stride=1, ROOT=47
%372 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=768, end=1536, stride=1, ROOT=48
%373 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1024, end=1792, stride=1, ROOT=49
%374 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1280, end=2048, stride=1, ROOT=50
%375 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1536, end=2304, stride=1, ROOT=51
%376 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=1792, end=2560, stride=1, ROOT=52
%377 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2048, end=2816, stride=1, ROOT=53
%378 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2304, end=3072, stride=1, ROOT=54
%379 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2560, end=3328, stride=1, ROOT=55
%380 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=2816, end=3584, stride=1, ROOT=56
%381 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=3072, end=3840, stride=1, ROOT=57
%382 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=3328, end=4096, stride=1, ROOT=58
%383 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=3584, end=4352, stride=1, ROOT=59
%384 = bf16[12,768,64]{2,1,0} xla::select(%368), location=_unfold_loop@sliding_chunks.py:17, dim=1, start=3840, end=4608, stride=1, ROOT=60
%385 = bf16[1,4096,12,513]{3,2,1,0} aten::mul(%343, %354), location=dropout@functional.py:973
%386 = bf16[1,12,4096,513]{3,1,2,0} aten::permute(%385), location=_pad@functional.py:3547, dims=(0, 2, 1, 3)
%387 = bf16[12,16,256,513]{3,2,1,0} aten::view(%386), location=_pad@functional.py:3547, output_size=(12, 16, 256, 513)
%388 = bf16[12,16,256,770]{3,2,1,0} aten::constant_pad_nd(%387), location=_pad@functional.py:3547, pad=(0, 257, 0, 0, 0, 0, 0, 0), value=0
%389 = bf16[12,16,197120]{2,1,0} aten::view(%388), location=einsum@functional.py:327, output_size=(12, 16, 197120)
%390 = bf16[12,16,197120]{2,1,0} xla::select(%389), location=einsum@functional.py:327, dim=0, start=0, end=12, stride=1
%391 = bf16[12,16,197120]{2,1,0} xla::select(%390), location=einsum@functional.py:327, dim=1, start=0, end=16, stride=1
%392 = bf16[12,16,196864]{2,1,0} xla::select(%391), location=einsum@functional.py:327, dim=2, start=0, end=196864, stride=1
%393 = bf16[12,16,256,769]{3,2,1,0} aten::view(%392), location=einsum@functional.py:327, output_size=(12, 16, 256, 769)
%394 = bf16[12,16,256,769]{3,2,1,0} xla::select(%393), location=einsum@functional.py:327, dim=0, start=0, end=12, stride=1
%395 = bf16[12,16,256,769]{3,2,1,0} xla::select(%394), location=einsum@functional.py:327, dim=1, start=0, end=16, stride=1
%396 = bf16[12,16,256,769]{3,2,1,0} xla::select(%395), location=einsum@functional.py:327, dim=2, start=0, end=256, stride=1
%397 = bf16[12,16,256,768]{3,2,1,0} xla::select(%396), location=einsum@functional.py:327, dim=3, start=0, end=768, stride=1
%398 = bf16[12,16,256,768]{3,2,1,0} aten::permute(%397), location=einsum@functional.py:327, dims=(0, 1, 2, 3)
%399 = bf16[12,16,256,1,768]{4,3,2,1,0} aten::view(%398), location=einsum@functional.py:327, output_size=(12, 16, 256, 1, 768)
%400 = bf16[12,16,256,768,1]{3,4,2,1,0} aten::permute(%399), location=einsum@functional.py:327, dims=(0, 1, 2, 4, 3)
%401 = bf16[192,256,768]{2,1,0} aten::view(%400), location=einsum@functional.py:327, output_size=(192, 256, 768), ROOT=61
%402 = bf16[12,16,768,64]{3,2,1,0} aten::stack(%369, %370, %371, %372, %373, %374, %375, %376, %377, %378, %379, %380, %381, %382, %383, %384), location=_unfold_loop@sliding_chunks.py:17, dim=1
%403 = bf16[12,16,64,768]{2,3,1,0} aten::permute(%402), location=einsum@functional.py:327, dims=(0, 1, 3, 2)
%404 = bf16[12,16,1,64,768]{4,3,2,1,0} aten::view(%403), location=einsum@functional.py:327, output_size=(12, 16, 1, 64, 768)
%405 = bf16[12,16,768,64,1]{2,3,4,1,0} aten::permute(%404), location=einsum@functional.py:327, dims=(0, 1, 4, 3, 2)
%406 = bf16[192,768,64]{2,1,0} aten::view(%405), location=einsum@functional.py:327, output_size=(192, 768, 64), ROOT=62
%407 = bf16[] prim::Constant(), location=forward@longformer.py:227, value=1
%408 = bf16[] prim::Constant(), location=forward@longformer.py:227, value=0
%409 = bf16[] aten::mul(%408, %407), location=forward@longformer.py:227
%410 = bf16[192,256,64]{2,1,0} aten::matmul(%401, %406), location=einsum@functional.py:327
%411 = bf16[12,16,256,1,64]{4,3,2,1,0} aten::view(%410), location=forward@longformer.py:227, output_size=(12, 16, 256, 1, 64)
%412 = bf16[12,16,256,64,1]{3,4,2,1,0} aten::permute(%411), location=forward@longformer.py:227, dims=(0, 1, 2, 4, 3)
%413 = bf16[12,16,256,64]{3,2,1,0} aten::view(%412), location=forward@longformer.py:227, output_size=(12, 16, 256, 64)
%414 = bf16[1,12,4096,64]{3,2,1,0} aten::view(%413), location=forward@longformer.py:227, output_size=(1, 12, 4096, 64)
%415 = bf16[1,4096,12,64]{3,1,2,0} aten::permute(%414), location=forward@longformer.py:227, dims=(0, 2, 1, 3)
%416 = bf16[1,4096,12,64]{3,2,1,0} aten::add(%415, %409), location=forward@longformer.py:227, ROOT=63
%417 = bf16[4096,1,12,64]{3,2,0,1} aten::permute(%416), location=training_step@test_tpu.py:61, dims=(1, 0, 2, 3)
%418 = bf16[4096,1,768]{2,1,0} aten::view(%417), location=training_step@test_tpu.py:61, output_size=(4096, 1, 768)
%419 = bf16[1,4096,768]{2,0,1} aten::permute(%418), location=training_step@test_tpu.py:61, dims=(1, 0, 2), ROOT=64
%420 = bf16[] aten::sum(%419), location=training_step@test_tpu.py:61, dimensions=(0, 1, 2), keep_reduced_dimensions=0, dtype=6, ROOT=65
}
@ibeltagy I looked at the graph with Blake and he mentioned that view
op can be expensive. View on XLA usually got lower as reshape and transpose. Transpose is typically free on TPU but reshape is expensive. When you tried the conv approach and get back a new HLO, we can try to run the xprof tool to better understand the performance.
Tried the convolution trick; it makes the model 20% faster and it takes the same amount of memory, which is great. The new graph is below and I am looking for opportunities to save compute as well as memory as well.
The part that I suspect is the least efficient is the lines from 70-100 in the graph below which correspond to these lines of code but I could be totally wrong.
IR {
%0 = f32[768,768]{1,0} xla::device_data(), location=linear@functional.py:1676, device=TPU:0
%1 = f32[768,768]{0,1} aten::permute(%0), location=linear@functional.py:1676, dims=(1, 0), ROOT=0
%2 = s64[1,4096]{1,0} xla::device_data(), location=forward@test_tpu.py:53, device=TPU:0
%3 = s64[1,4096]{1,0} xla::select(%2), location=forward@test_tpu.py:53, dim=0, start=0, end=1, stride=1
%4 = s64[1,4096]{1,0} xla::select(%3), location=forward@test_tpu.py:53, dim=1, start=0, end=4096, stride=1
%5 = s64[1,4096,1]{2,1,0} aten::view(%4), location=forward@test_tpu.py:53, output_size=(1, 4096, 1)
%6 = s64[1,4096,768]{2,1,0} aten::expand(%5), location=forward@test_tpu.py:53, size=(1, 4096, 768)
%7 = f32[1,4096,768]{2,1,0} xla::cast(%6), location=forward@test_tpu.py:53, type=f32, dtype=Float, stype=Long
%8 = f32[4096,1,768]{2,0,1} aten::permute(%7), location=linear@functional.py:1676, dims=(1, 0, 2)
%9 = f32[4096,768]{1,0} aten::view(%8), location=linear@functional.py:1676, output_size=(4096, 768), ROOT=1
%10 = f32[768,768]{1,0} xla::device_data(), location=linear@functional.py:1676, device=TPU:0
%11 = f32[768,768]{0,1} aten::permute(%10), location=linear@functional.py:1676, dims=(1, 0), ROOT=2
%12 = f32[4096,1,768]{2,0,1} aten::permute(%7), location=linear@functional.py:1676, dims=(1, 0, 2)
%13 = f32[4096,768]{1,0} aten::view(%12), location=linear@functional.py:1676, output_size=(4096, 768), ROOT=3
%14 = f32[768,768]{1,0} xla::device_data(), location=linear@functional.py:1676, device=TPU:0
%15 = f32[768,768]{0,1} aten::permute(%14), location=linear@functional.py:1676, dims=(1, 0), ROOT=4
%16 = f32[4096,1,768]{2,0,1} aten::permute(%7), location=linear@functional.py:1676, dims=(1, 0, 2)
%17 = f32[4096,768]{1,0} aten::view(%16), location=linear@functional.py:1676, output_size=(4096, 768), ROOT=5
%18 = f32[] xla::device_data(), location=forward@longformer.py:155, device=TPU:0
%19 = f32[] prim::Constant(), location=linear@functional.py:1678, value=1
%20 = f32[768]{0} aten::expand(%19), location=linear@functional.py:1678, size=(768)
%21 = f32[768]{0} xla::device_data(), location=linear@functional.py:1678, device=TPU:0
%22 = f32[768]{0} aten::mul(%21, %20), location=linear@functional.py:1678
%23 = f32[4096,768]{1,0} aten::mm(%9, %1), location=linear@functional.py:1676
%24 = f32[4096,1,768]{2,1,0} aten::view(%23), location=linear@functional.py:1678, output_size=(4096, 1, 768)
%25 = f32[4096,1,768]{2,1,0} aten::add(%24, %22), location=linear@functional.py:1678
%26 = f32[4096,768]{1,0} aten::view(%25), location=forward@longformer.py:155, output_size=(4096, 768)
%27 = f32[4096,1,768]{2,1,0} aten::view(%26), location=forward@longformer.py:155, output_size=(4096, 1, 768)
%28 = f32[4096,1,768]{2,1,0} aten::div(%27, %18), location=forward@longformer.py:155
%29 = f32[4096,768]{1,0} aten::view(%28), location=_unfold_conv@sliding_chunks.py:27, output_size=(4096, 768)
%30 = f32[4096,1,768]{2,1,0} aten::view(%29), location=_unfold_conv@sliding_chunks.py:27, output_size=(4096, 1, 768)
%31 = f32[4096,1,12,64]{3,2,1,0} aten::view(%30), location=_unfold_conv@sliding_chunks.py:27, output_size=(4096, 1, 12, 64)
%32 = f32[1,4096,12,64]{3,2,0,1} aten::permute(%31), location=_unfold_conv@sliding_chunks.py:27, dims=(1, 0, 2, 3)
%33 = f32[1,12,4096,64]{3,1,0,2} aten::permute(%32), location=_unfold_conv@sliding_chunks.py:27, dims=(0, 2, 1, 3)
%34 = f32[12,4096,64]{2,1,0} aten::view(%33), location=_unfold_conv@sliding_chunks.py:27, output_size=(12, 4096, 64)
%35 = f32[12,16,256,64]{3,2,1,0} aten::view(%34), location=_unfold_conv@sliding_chunks.py:27, output_size=(12, 16, 256, 64)
%36 = f32[64,256,12,16]{0,1,3,2} aten::permute(%35), location=_unfold_conv@sliding_chunks.py:27, dims=(3, 2, 0, 1), ROOT=6
%37 = f32[512,512]{1,0} aten::eye(), location=_unfold_conv@sliding_chunks.py:26
%38 = f32[1,2,256,512]{3,2,1,0} aten::view(%37), location=_unfold_conv@sliding_chunks.py:27, output_size=(1, 2, 256, 512)
%39 = f32[512,256,1,2]{0,1,3,2} aten::permute(%38), location=_unfold_conv@sliding_chunks.py:27, dims=(3, 2, 0, 1), ROOT=7
%40 = f32[] prim::Constant(), location=linear@functional.py:1678, value=1
%41 = f32[768]{0} aten::expand(%40), location=linear@functional.py:1678, size=(768)
%42 = f32[768]{0} xla::device_data(), location=linear@functional.py:1678, device=TPU:0
%43 = f32[768]{0} aten::mul(%42, %41), location=linear@functional.py:1678
%44 = f32[4096,768]{1,0} aten::mm(%13, %11), location=linear@functional.py:1676
%45 = f32[4096,1,768]{2,1,0} aten::view(%44), location=linear@functional.py:1678, output_size=(4096, 1, 768)
%46 = f32[4096,1,768]{2,1,0} aten::add(%45, %43), location=linear@functional.py:1678
%47 = f32[4096,768]{1,0} aten::view(%46), location=_unfold_conv@sliding_chunks.py:27, output_size=(4096, 768)
%48 = f32[4096,1,768]{2,1,0} aten::view(%47), location=_unfold_conv@sliding_chunks.py:27, output_size=(4096, 1, 768)
%49 = f32[4096,1,12,64]{3,2,1,0} aten::view(%48), location=_unfold_conv@sliding_chunks.py:27, output_size=(4096, 1, 12, 64)
%50 = f32[1,4096,12,64]{3,2,0,1} aten::permute(%49), location=_unfold_conv@sliding_chunks.py:27, dims=(1, 0, 2, 3)
%51 = f32[1,12,4096,64]{3,1,0,2} aten::permute(%50), location=_unfold_conv@sliding_chunks.py:27, dims=(0, 2, 1, 3)
%52 = f32[12,4096,64]{2,1,0} aten::view(%51), location=_unfold_conv@sliding_chunks.py:27, output_size=(12, 4096, 64)
%53 = f32[12,16,256,64]{3,2,1,0} aten::view(%52), location=_unfold_conv@sliding_chunks.py:27, output_size=(12, 16, 256, 64)
%54 = f32[64,256,12,16]{0,1,3,2} aten::permute(%53), location=_unfold_conv@sliding_chunks.py:27, dims=(3, 2, 0, 1), ROOT=8
%55 = f32[512,512]{1,0} aten::eye(), location=_unfold_conv@sliding_chunks.py:26
%56 = f32[1,2,256,512]{3,2,1,0} aten::view(%55), location=_unfold_conv@sliding_chunks.py:27, output_size=(1, 2, 256, 512)
%57 = f32[512,256,1,2]{0,1,3,2} aten::permute(%56), location=_unfold_conv@sliding_chunks.py:27, dims=(3, 2, 0, 1), ROOT=9
%58 = f32[64,512,12,15]{3,2,1,0} aten::convolution_overrideable(%36, %39), location=_unfold_conv@sliding_chunks.py:27, stride=(1, 1), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%59 = f32[12,15,512,64]{1,0,2,3} aten::permute(%58), location=einsum@functional.py:327, dims=(2, 3, 1, 0)
%60 = f32[12,15,512,64]{1,0,2,3} aten::permute(%59), location=einsum@functional.py:327, dims=(0, 1, 2, 3)
%61 = f32[12,15,512,1,64]{4,3,2,1,0} aten::view(%60), location=einsum@functional.py:327, output_size=(12, 15, 512, 1, 64)
%62 = f32[12,15,512,64,1]{3,4,2,1,0} aten::permute(%61), location=einsum@functional.py:327, dims=(0, 1, 2, 4, 3)
%63 = f32[180,512,64]{2,1,0} aten::view(%62), location=einsum@functional.py:327, output_size=(180, 512, 64), ROOT=10
%64 = f32[64,512,12,15]{3,2,1,0} aten::convolution_overrideable(%54, %57), location=_unfold_conv@sliding_chunks.py:27, stride=(1, 1), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%65 = f32[12,15,512,64]{1,0,2,3} aten::permute(%64), location=einsum@functional.py:327, dims=(2, 3, 1, 0)
%66 = f32[12,15,512,64]{1,0,2,3} aten::permute(%65), location=einsum@functional.py:327, dims=(0, 1, 2, 3)
%67 = f32[12,15,1,512,64]{4,3,2,1,0} aten::view(%66), location=einsum@functional.py:327, output_size=(12, 15, 1, 512, 64)
%68 = f32[12,15,64,512,1]{2,3,4,1,0} aten::permute(%67), location=einsum@functional.py:327, dims=(0, 1, 4, 3, 2)
%69 = f32[180,64,512]{2,1,0} aten::view(%68), location=einsum@functional.py:327, output_size=(180, 64, 512), ROOT=11
%70 = f32[180,512,512]{2,1,0} aten::matmul(%63, %69), location=einsum@functional.py:327
%71 = f32[12,15,512,1,512]{4,3,2,1,0} aten::view(%70), location=_pad@functional.py:3547, output_size=(12, 15, 512, 1, 512)
%72 = f32[12,15,512,512,1]{3,4,2,1,0} aten::permute(%71), location=_pad@functional.py:3547, dims=(0, 1, 2, 4, 3)
%73 = f32[12,15,512,512]{3,2,1,0} aten::view(%72), location=_pad@functional.py:3547, output_size=(12, 15, 512, 512)
%74 = f32[12,15,513,512]{3,2,1,0} aten::constant_pad_nd(%73), location=_pad@functional.py:3547, pad=(0, 0, 0, 1, 0, 0, 0, 0), value=0
%75 = f32[12,15,512,513]{3,2,1,0} aten::view(%74), location=sliding_chunks_matmul_qk@sliding_chunks.py:108, output_size=(12, 15, 512, 513)
%76 = f32[12,15,512,513]{3,2,1,0} xla::select(%75), location=sliding_chunks_matmul_qk@sliding_chunks.py:108, dim=0, start=0, end=12, stride=1
%77 = f32[12,1,512,513]{3,2,1,0} xla::generic_slice(%76), location=sliding_chunks_matmul_qk@sliding_chunks.py:108, base_indices=(0, 0, 0, 0), sizes=(12, 1, 512, 513)
%78 = f32[12,512,513]{2,1,0} aten::view(%77), location=sliding_chunks_matmul_qk@sliding_chunks.py:108, output_size=(12, 512, 513)
%79 = f32[12,255,513]{2,1,0} xla::select(%78), location=sliding_chunks_matmul_qk@sliding_chunks.py:108, dim=1, start=0, end=255, stride=1
%80 = f32[12,255,255]{2,1,0} xla::select(%79), location=sliding_chunks_matmul_qk@sliding_chunks.py:108, dim=2, start=258, end=513, stride=1
%81 = f32[12,255,255]{2,1,0} aten::view(%80), location=sliding_chunks_matmul_qk@sliding_chunks.py:108, output_size=(12, 255, 255)
%82 = f32[12,15,512,513]{3,2,1,0} aten::view(%74), location=sliding_chunks_matmul_qk@sliding_chunks.py:107, output_size=(12, 15, 512, 513)
%83 = f32[12,15,512,513]{3,2,1,0} xla::select(%82), location=sliding_chunks_matmul_qk@sliding_chunks.py:107, dim=0, start=0, end=12, stride=1
%84 = f32[12,15,512,513]{3,2,1,0} xla::select(%83), location=sliding_chunks_matmul_qk@sliding_chunks.py:107, dim=1, start=0, end=15, stride=1
%85 = f32[12,15,256,513]{3,2,1,0} xla::select(%84), location=sliding_chunks_matmul_qk@sliding_chunks.py:107, dim=2, start=255, end=511, stride=1
%86 = f32[12,15,256,256]{3,2,1,0} xla::select(%85), location=sliding_chunks_matmul_qk@sliding_chunks.py:107, dim=3, start=257, end=513, stride=1
%87 = f32[12,15,256,256]{3,2,1,0} aten::view(%86), location=sliding_chunks_matmul_qk@sliding_chunks.py:107, output_size=(12, 15, 256, 256)
%88 = f32[12,15,512,513]{3,2,1,0} aten::view(%74), location=sliding_chunks_matmul_qk@sliding_chunks.py:105, output_size=(12, 15, 512, 513)
%89 = f32[12,15,512,513]{3,2,1,0} xla::select(%88), location=sliding_chunks_matmul_qk@sliding_chunks.py:105, dim=0, start=0, end=12, stride=1
%90 = f32[12,1,512,513]{3,2,1,0} xla::generic_slice(%89), location=sliding_chunks_matmul_qk@sliding_chunks.py:105, base_indices=(0, 14, 0, 0), sizes=(12, 1, 512, 513)
%91 = f32[12,512,513]{2,1,0} aten::view(%90), location=sliding_chunks_matmul_qk@sliding_chunks.py:105, output_size=(12, 512, 513)
%92 = f32[12,256,513]{2,1,0} xla::select(%91), location=sliding_chunks_matmul_qk@sliding_chunks.py:105, dim=1, start=256, end=512, stride=1
%93 = f32[12,256,257]{2,1,0} xla::select(%92), location=sliding_chunks_matmul_qk@sliding_chunks.py:105, dim=2, start=0, end=257, stride=1
%94 = f32[12,256,257]{2,1,0} aten::view(%93), location=sliding_chunks_matmul_qk@sliding_chunks.py:105, output_size=(12, 256, 257)
%95 = f32[12,15,512,513]{3,2,1,0} aten::view(%74), location=sliding_chunks_matmul_qk@sliding_chunks.py:104, output_size=(12, 15, 512, 513)
%96 = f32[12,15,512,513]{3,2,1,0} xla::select(%95), location=sliding_chunks_matmul_qk@sliding_chunks.py:104, dim=0, start=0, end=12, stride=1
%97 = f32[12,15,512,513]{3,2,1,0} xla::select(%96), location=sliding_chunks_matmul_qk@sliding_chunks.py:104, dim=1, start=0, end=15, stride=1
%98 = f32[12,15,256,513]{3,2,1,0} xla::select(%97), location=sliding_chunks_matmul_qk@sliding_chunks.py:104, dim=2, start=0, end=256, stride=1
%99 = f32[12,15,256,257]{3,2,1,0} xla::select(%98), location=sliding_chunks_matmul_qk@sliding_chunks.py:104, dim=3, start=0, end=257, stride=1
%100 = f32[12,15,256,257]{3,2,1,0} aten::view(%99), location=sliding_chunks_matmul_qk@sliding_chunks.py:104, output_size=(12, 15, 256, 257)
%101 = f32[] prim::Constant(), location=sliding_chunks_matmul_qk@sliding_chunks.py:100, value=0
%102 = f32[12,16,256,513]{3,2,1,0} aten::expand(%101), location=sliding_chunks_matmul_qk@sliding_chunks.py:100, size=(12, 16, 256, 513)
%103 = f32[12,16,256,513]{3,2,1,0} xla::select(%102), location=softmax@functional.py:1500, dim=0, start=0, end=12, stride=1
%104 = f32[12,15,256,513]{3,2,1,0} xla::select(%103), location=softmax@functional.py:1500, dim=1, start=0, end=15, stride=1
%105 = f32[12,15,256,513]{3,2,1,0} xla::select(%104), location=softmax@functional.py:1500, dim=2, start=0, end=256, stride=1
%106 = f32[12,15,256,513]{3,2,1,0} xla::unselect(%105, %100), location=softmax@functional.py:1500, dim=3, start=256, end=513, stride=1
%107 = f32[12,15,256,513]{3,2,1,0} xla::unselect(%104, %106), location=softmax@functional.py:1500, dim=2, start=0, end=256, stride=1
%108 = f32[12,16,256,513]{3,2,1,0} xla::unselect(%103, %107), location=softmax@functional.py:1500, dim=1, start=0, end=15, stride=1
%109 = f32[12,16,256,513]{3,2,1,0} xla::unselect(%102, %108), location=softmax@functional.py:1500, dim=0, start=0, end=12, stride=1
%110 = f32[12,16,256,513]{3,2,1,0} xla::select(%109), location=softmax@functional.py:1500, dim=0, start=0, end=12, stride=1
%111 = f32[12,1,256,513]{3,2,1,0} xla::generic_slice(%110), location=softmax@functional.py:1500, base_indices=(0, 15, 0, 0), sizes=(12, 1, 256, 513)
%112 = f32[12,256,513]{2,1,0} aten::view(%111), location=softmax@functional.py:1500, output_size=(12, 256, 513)
%113 = f32[12,256,513]{2,1,0} xla::select(%112), location=softmax@functional.py:1500, dim=1, start=0, end=256, stride=1
%114 = f32[12,256,513]{2,1,0} xla::unselect(%113, %94), location=softmax@functional.py:1500, dim=2, start=256, end=513, stride=1
%115 = f32[12,256,513]{2,1,0} xla::unselect(%112, %114), location=softmax@functional.py:1500, dim=1, start=0, end=256, stride=1
%116 = f32[12,1,256,513]{3,2,1,0} aten::view(%115), location=softmax@functional.py:1500, output_size=(12, 1, 256, 513)
%117 = f32[12,16,256,513]{3,2,1,0} xla::update_slice(%110, %116), location=softmax@functional.py:1500, base_indices=(0, 15, 0, 0)
%118 = f32[12,16,256,513]{3,2,1,0} xla::unselect(%109, %117), location=softmax@functional.py:1500, dim=0, start=0, end=12, stride=1
%119 = f32[12,16,256,513]{3,2,1,0} xla::select(%118), location=softmax@functional.py:1500, dim=0, start=0, end=12, stride=1
%120 = f32[12,15,256,513]{3,2,1,0} xla::select(%119), location=softmax@functional.py:1500, dim=1, start=1, end=16, stride=1
%121 = f32[12,15,256,513]{3,2,1,0} xla::select(%120), location=softmax@functional.py:1500, dim=2, start=0, end=256, stride=1
%122 = f32[12,15,256,513]{3,2,1,0} xla::unselect(%121, %87), location=softmax@functional.py:1500, dim=3, start=0, end=256, stride=1
%123 = f32[12,15,256,513]{3,2,1,0} xla::unselect(%120, %122), location=softmax@functional.py:1500, dim=2, start=0, end=256, stride=1
%124 = f32[12,16,256,513]{3,2,1,0} xla::unselect(%119, %123), location=softmax@functional.py:1500, dim=1, start=1, end=16, stride=1
%125 = f32[12,16,256,513]{3,2,1,0} xla::unselect(%118, %124), location=softmax@functional.py:1500, dim=0, start=0, end=12, stride=1
%126 = f32[12,16,256,513]{3,2,1,0} xla::select(%125), location=softmax@functional.py:1500, dim=0, start=0, end=12, stride=1
%127 = f32[12,1,256,513]{3,2,1,0} xla::generic_slice(%126), location=softmax@functional.py:1500, base_indices=(0, 0, 0, 0), sizes=(12, 1, 256, 513)
%128 = f32[12,256,513]{2,1,0} aten::view(%127), location=softmax@functional.py:1500, output_size=(12, 256, 513)
%129 = f32[12,255,513]{2,1,0} xla::select(%128), location=softmax@functional.py:1500, dim=1, start=1, end=256, stride=1
%130 = f32[12,255,513]{2,1,0} xla::unselect(%129, %81), location=softmax@functional.py:1500, dim=2, start=1, end=256, stride=1
%131 = f32[12,256,513]{2,1,0} xla::unselect(%128, %130), location=softmax@functional.py:1500, dim=1, start=1, end=256, stride=1
%132 = f32[12,1,256,513]{3,2,1,0} aten::view(%131), location=softmax@functional.py:1500, output_size=(12, 1, 256, 513)
%133 = f32[12,16,256,513]{3,2,1,0} xla::update_slice(%126, %132), location=softmax@functional.py:1500, base_indices=(0, 0, 0, 0)
%134 = f32[12,16,256,513]{3,2,1,0} xla::unselect(%125, %133), location=softmax@functional.py:1500, dim=0, start=0, end=12, stride=1
%135 = f32[1,12,4096,513]{3,2,1,0} aten::view(%134), location=softmax@functional.py:1500, output_size=(1, 12, 4096, 513)
%136 = f32[1,4096,12,513]{3,1,2,0} aten::permute(%135), location=softmax@functional.py:1500, dims=(0, 2, 1, 3), ROOT=12
%137 = f32[1,4096,12,513]{3,1,2,0} aten::softmax(%136), location=softmax@functional.py:1500, dim=3, dtype=6, ROOT=13
%138 = f32[] xla::device_data(), location=dropout@functional.py:973, device=TPU:0
%139 = f32[1,4096,12,513]{3,2,1,0} aten::expand(%138), location=dropout@functional.py:973, size=(1, 4096, 12, 513)
%140 = s64[] xla::device_data(), location=dropout@functional.py:973, device=TPU:0
%141 = s64[] prim::Constant(), location=dropout@functional.py:973, value=214013
%142 = s64[] aten::mul(%141, %140), location=dropout@functional.py:973
%143 = s64[] prim::Constant(), location=dropout@functional.py:973, value=2.53101e+06
%144 = s64[] aten::add(%143, %142), location=dropout@functional.py:973
%145 = f32[] xla::device_data(), location=dropout@functional.py:973, device=TPU:0
%146 = f32[1,4096,12,513]{3,2,1,0} aten::expand(%145), location=dropout@functional.py:973, size=(1, 4096, 12, 513)
%147 = f32[1,4096,12,513]{3,2,1,0} aten::bernoulli(%146, %144), location=dropout@functional.py:973
%148 = f32[1,4096,12,513]{3,2,1,0} aten::div(%147, %139), location=dropout@functional.py:973, ROOT=14
%149 = f32[] prim::Constant(), location=linear@functional.py:1678, value=1
%150 = f32[768]{0} aten::expand(%149), location=linear@functional.py:1678, size=(768)
%151 = f32[768]{0} xla::device_data(), location=linear@functional.py:1678, device=TPU:0
%152 = f32[768]{0} aten::mul(%151, %150), location=linear@functional.py:1678
%153 = f32[4096,768]{1,0} aten::mm(%17, %15), location=linear@functional.py:1676
%154 = f32[4096,1,768]{2,1,0} aten::view(%153), location=linear@functional.py:1678, output_size=(4096, 1, 768)
%155 = f32[4096,1,768]{2,1,0} aten::add(%154, %152), location=linear@functional.py:1678
%156 = f32[4096,768]{1,0} aten::view(%155), location=_pad@functional.py:3547, output_size=(4096, 768)
%157 = f32[4096,1,768]{2,1,0} aten::view(%156), location=_pad@functional.py:3547, output_size=(4096, 1, 768)
%158 = f32[4096,1,12,64]{3,2,1,0} aten::view(%157), location=_pad@functional.py:3547, output_size=(4096, 1, 12, 64)
%159 = f32[1,4096,12,64]{3,2,0,1} aten::permute(%158), location=_pad@functional.py:3547, dims=(1, 0, 2, 3)
%160 = f32[1,12,4096,64]{3,1,0,2} aten::permute(%159), location=_pad@functional.py:3547, dims=(0, 2, 1, 3)
%161 = f32[12,4096,64]{2,1,0} aten::view(%160), location=_pad@functional.py:3547, output_size=(12, 4096, 64)
%162 = f32[12,4608,64]{2,1,0} aten::constant_pad_nd(%161), location=_pad@functional.py:3547, pad=(0, 0, 256, 256, 0, 0), value=-1
%163 = f32[12,18,256,64]{3,2,1,0} aten::view(%162), location=_unfold_conv@sliding_chunks.py:27, output_size=(12, 18, 256, 64)
%164 = f32[64,256,12,18]{0,1,3,2} aten::permute(%163), location=_unfold_conv@sliding_chunks.py:27, dims=(3, 2, 0, 1), ROOT=15
%165 = f32[768,768]{1,0} aten::eye(), location=_unfold_conv@sliding_chunks.py:26
%166 = f32[1,3,256,768]{3,2,1,0} aten::view(%165), location=_unfold_conv@sliding_chunks.py:27, output_size=(1, 3, 256, 768)
%167 = f32[768,256,1,3]{0,1,3,2} aten::permute(%166), location=_unfold_conv@sliding_chunks.py:27, dims=(3, 2, 0, 1), ROOT=16
%168 = f32[1,4096,12,513]{3,2,1,0} aten::mul(%137, %148), location=dropout@functional.py:973
%169 = f32[1,12,4096,513]{3,1,2,0} aten::permute(%168), location=_pad@functional.py:3547, dims=(0, 2, 1, 3)
%170 = f32[12,16,256,513]{3,2,1,0} aten::view(%169), location=_pad@functional.py:3547, output_size=(12, 16, 256, 513)
%171 = f32[12,16,256,770]{3,2,1,0} aten::constant_pad_nd(%170), location=_pad@functional.py:3547, pad=(0, 257, 0, 0, 0, 0, 0, 0), value=0
%172 = f32[12,16,197120]{2,1,0} aten::view(%171), location=einsum@functional.py:327, output_size=(12, 16, 197120)
%173 = f32[12,16,197120]{2,1,0} xla::select(%172), location=einsum@functional.py:327, dim=0, start=0, end=12, stride=1
%174 = f32[12,16,197120]{2,1,0} xla::select(%173), location=einsum@functional.py:327, dim=1, start=0, end=16, stride=1
%175 = f32[12,16,196864]{2,1,0} xla::select(%174), location=einsum@functional.py:327, dim=2, start=0, end=196864, stride=1
%176 = f32[12,16,256,769]{3,2,1,0} aten::view(%175), location=einsum@functional.py:327, output_size=(12, 16, 256, 769)
%177 = f32[12,16,256,769]{3,2,1,0} xla::select(%176), location=einsum@functional.py:327, dim=0, start=0, end=12, stride=1
%178 = f32[12,16,256,769]{3,2,1,0} xla::select(%177), location=einsum@functional.py:327, dim=1, start=0, end=16, stride=1
%179 = f32[12,16,256,769]{3,2,1,0} xla::select(%178), location=einsum@functional.py:327, dim=2, start=0, end=256, stride=1
%180 = f32[12,16,256,768]{3,2,1,0} xla::select(%179), location=einsum@functional.py:327, dim=3, start=0, end=768, stride=1
%181 = f32[12,16,256,768]{3,2,1,0} aten::permute(%180), location=einsum@functional.py:327, dims=(0, 1, 2, 3)
%182 = f32[12,16,256,1,768]{4,3,2,1,0} aten::view(%181), location=einsum@functional.py:327, output_size=(12, 16, 256, 1, 768)
%183 = f32[12,16,256,768,1]{3,4,2,1,0} aten::permute(%182), location=einsum@functional.py:327, dims=(0, 1, 2, 4, 3)
%184 = f32[192,256,768]{2,1,0} aten::view(%183), location=einsum@functional.py:327, output_size=(192, 256, 768), ROOT=17
%185 = f32[64,768,12,16]{3,2,1,0} aten::convolution_overrideable(%164, %167), location=_unfold_conv@sliding_chunks.py:27, stride=(1, 1), padding=(0, 0), dilation=(1, 1), transpose=0, output_padding=(0, 0), groups=1
%186 = f32[12,16,768,64]{1,0,2,3} aten::permute(%185), location=einsum@functional.py:327, dims=(2, 3, 1, 0)
%187 = f32[12,16,64,768]{1,0,3,2} aten::permute(%186), location=einsum@functional.py:327, dims=(0, 1, 3, 2)
%188 = f32[12,16,1,64,768]{4,3,2,1,0} aten::view(%187), location=einsum@functional.py:327, output_size=(12, 16, 1, 64, 768)
%189 = f32[12,16,768,64,1]{2,3,4,1,0} aten::permute(%188), location=einsum@functional.py:327, dims=(0, 1, 4, 3, 2)
%190 = f32[192,768,64]{2,1,0} aten::view(%189), location=einsum@functional.py:327, output_size=(192, 768, 64), ROOT=18
%191 = f32[] prim::Constant(), location=forward@longformer.py:230, value=1
%192 = f32[] prim::Constant(), location=forward@longformer.py:230, value=0
%193 = f32[] aten::mul(%192, %191), location=forward@longformer.py:230
%194 = f32[192,256,64]{2,1,0} aten::matmul(%184, %190), location=einsum@functional.py:327
%195 = f32[12,16,256,1,64]{4,3,2,1,0} aten::view(%194), location=forward@longformer.py:230, output_size=(12, 16, 256, 1, 64)
%196 = f32[12,16,256,64,1]{3,4,2,1,0} aten::permute(%195), location=forward@longformer.py:230, dims=(0, 1, 2, 4, 3)
%197 = f32[12,16,256,64]{3,2,1,0} aten::view(%196), location=forward@longformer.py:230, output_size=(12, 16, 256, 64)
%198 = f32[1,12,4096,64]{3,2,1,0} aten::view(%197), location=forward@longformer.py:230, output_size=(1, 12, 4096, 64)
%199 = f32[1,4096,12,64]{3,1,2,0} aten::permute(%198), location=forward@longformer.py:230, dims=(0, 2, 1, 3)
%200 = f32[1,4096,12,64]{3,2,1,0} aten::add(%199, %193), location=forward@longformer.py:230, ROOT=19
%201 = f32[4096,1,12,64]{3,2,0,1} aten::permute(%200), location=training_step@test_tpu.py:61, dims=(1, 0, 2, 3)
%202 = f32[4096,1,768]{2,1,0} aten::view(%201), location=training_step@test_tpu.py:61, output_size=(4096, 1, 768)
%203 = f32[1,4096,768]{2,0,1} aten::permute(%202), location=training_step@test_tpu.py:61, dims=(1, 0, 2), ROOT=20
%204 = f32[] aten::sum(%203), location=training_step@test_tpu.py:61, dimensions=(0, 1, 2), keep_reduced_dimensions=0, dtype=6, ROOT=21
}```
what is xprof
? is this something I can run myself or is it internal?
Yes it's this capture_tpu_profile
tool: https://cloud.google.com/tpu/docs/cloud-tpu-tools#install_cloud_tpu_profiler
Unfortunately, currently there are some TF APIs bundled into the profiler client so you'll need to run the profiler from a VM that has TF installed, and also be in the same network as the TPU (so on GCP same project, same network). We're working on breaking out the TF bits fyi.
@ibeltagy If you get the HLO graph, maybe @JackCaoG can run it with run_hlo_module
internally and get the xprof data.
From there to actually being able to view it externally ... would require TensorBoard instances to be brought up, unless things changed.
@davidel, with the HLO graph, I am assuming you mean the list of operations like
IR {
%0 = f32[768,768]{1,0} xla::device_data(), location=linear@functional.py:1676, device=TPU:0
....
}
which I posted in the previous comment, right? The graph I posted earlier is for a tiny part of the code, or is it more useful to get the whole thing?
That is the IR graph.
The HLO graph is the graph which is fed to XLA, by lowering the IR into HLO.
There are a few ways to get that, but to be sure it's best to run debug_run.py
with the --hlo flag, for 5..6 steps and post the tarball.
Debugging information with the hlo grpah debug-hlo.tar.gz
Just curious, why am I getting 20 graphs even though I didn't get any aten::
and, AFAICT, all the forward/backward passes are the same?
We dump every XLA execution graph in the mode debug_run.py
runs with, independently on whether it is a new or old one.
@JackCaoG try to get the xprof of the attached graph and have a look with XLA team.
@ibeltagy I have collected the xprof data and open an issue with XLA team.
@ibeltagy Could you share the your GPU setup and memory/runtime, and how does it compared to your TPU setup after the convolution trick? FYI we usually do chip to chip comparison, each TPU chips has 2 TPU cores. 4*V100 to a v3-8 is fair comparison.
Looking at the profile the most expensive parts are the atten::unselect operations followed by the reshapes and transposes internal to atten::einsum.
Thanks, @blakehechtman. einsum
is called multiple times in the code, do you know which one? any suggestions on how to fix this? Is this expensive in terms of memory or compute time? if compute, any suggestions on how to improve memory as well?
The xla::unselect
should come from in-place updating views.
Actually when I looked deeper, I think the einsums were fine, it was only the in-place updates of views that seemed to be very expensive. This is around 80% of the step time.
80%!! ok, I will try to reorganize the code to reduce number of .view and .reshape operations. Would that be enough?
Any thoughts about memory optimizations as well?
Views are "OK" if you read from them. They generate xla::unselect
once you write into them.
@ibeltagy FYI, if you search for unselct
in the hlo, the metadata should tell you where are they coming from.
%pad.350 = pred[12,256,513]{2,1,0} pad(pred[12,256,257]{2,1,0} %broadcast.346, pred[] %constant.347), padding=0_0x0_0x256_0, metadata={op_type="xla::unselect" op_name="aten::masked_fill.1" source_file="mask_invalid_locations@diagonaled_mm_tvm.py" source_line=320}
This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your contributions.
Hi @ibeltagy, there is a xla side optimization merge in recently and that should improve the speed of unselect
and the Longformer
. The change is in today's nightly build. Please let me know if that helps when you got some time to try it out.
Great, thanks @JackCaoG for the update. Will give it a try and let you know.
❓ Questions and Help
I have managed to run a version of Longformer on pytorch-xla. It is memory efficient and reasonably fast. That said, it is still 2x slower than what I would expect, so if you guys have any insights on how to optimize the model implementation, that would be great. The model code is here. It is the same as RoBERTa with the only difference being the selfattention operation. In particular, the two matrix multiplications here and here are replaced with the two functions
_sliding_chunks_matmul_qk
and_sliding_chunks_matmul_pv
. I am also attaching the debug output which has a dump of the HLO graph debug.tar.gz.The only difference between the GPU code and the TPU code is that the GPU code uses
as_strided
while the TPU code uses an iterative version ofunfold
(the if statement here. Check issue https://github.com/pytorch/xla/issues/2239 for more details.)Thank you.