pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.47k stars 471 forks source link

"AssertionError: Invalid partition, found dependency cycles" error when using "openxla_eval" backend for `torch.compile()` with SPMD on v3-8 TPU (Kaggle) on HF LLAMA. Compiling with "openxla" backend works, but forward pass is significantly slower when increasing batch_size. #5834

Open defdet opened 11 months ago

defdet commented 11 months ago

🐛 Bug

After compiling model with "openxla_eval" forward pass fails with "AssertionError: Invalid partition, found dependency cycles ". Forward pass without compiling the model works as usual, compiling with "openxla" backed also works, but with an issue: when increasing batch_size (for example, from 1 to 2), forward pass for each batch is more than 2 times slower (if batch_size from 1 to 2), which results in ~1.2 slower generation overall. Whole generation should've been ~2 times faster instead.

To Reproduce

Unfortunately, I can't reproduce the issue in Collab because there isn't enough CPU RAM. Please take a look at this Kaggle notebook which requiers no further setup/login information. The notebook, linked model as well as external python script for SPMD are both set to public.

Steps to reproduce the behavior:

  1. Install latest torch xla, torch-cpu, transformers
  2. Prepare SMPD (sharding the model, inputs)
  3. Compile the model with either torch.compile(backend='openxla_eval') or torch.compile(backend='openxla')
  4. Perform a forward pass under torch.no_grad()

The whole error traceback will be provided as the next message for ease of reading. Cannot provide metrics summary for 'openxla_eval' since there isn't any sucessfull forward pass.

Expected behavior

If 'openxla_eval' backend is supported, the sucessfull forward pass after compilation with this backend is expected. If only 'openxla' is supported, overall increase in speed is expected with batch_size increase (since there is definetely enough memory and 8 TPU cores available)

Environment

Additional context

When we either look at TQDM progress or measure the time for forward pass, it seems like it starts fast and then slows down. I'm assuming it's due to recompilations, but there's nothing there except for the forward pass.

defdet commented 11 months ago

Whole error traceback: AssertionError Traceback (most recent call last) Cell In[11], line 22 20 xs.mark_sharding(attention_mask, mesh, (0, 1)) 21 with torch.no_grad(): ---> 22 outputs = dynamo_model(input_ids, attention_mask)

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:328, in _TorchDynamoContext.call.._fn(*args, *kwargs) 326 dynamic_ctx.enter() 327 try: --> 328 return fn(args, **kwargs) 329 finally: 330 set_eval_frame(prior)

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /usr/local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:1034, in LlamaForCausalLM.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict) 1031 return_dict = return_dict if return_dict is not None else self.config.use_return_dict 1033 # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) -> 1034 outputs = self.model( 1035 input_ids=input_ids, 1036 attention_mask=attention_mask, 1037 position_ids=position_ids, 1038 past_key_values=past_key_values, 1039 inputs_embeds=inputs_embeds, 1040 use_cache=use_cache, 1041 output_attentions=output_attentions, 1042 output_hidden_states=output_hidden_states, 1043 return_dict=return_dict, 1044 ) 1046 hidden_states = outputs[0] 1047 if self.config.pretraining_tp > 1:

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1518, in Module._wrapped_call_impl(self, *args, kwargs) 1516 return self._compiled_call_impl(*args, *kwargs) # type: ignore[misc] 1517 else: -> 1518 return self._call_impl(args, kwargs)

File /usr/local/lib/python3.10/site-packages/torch/nn/modules/module.py:1527, in Module._call_impl(self, *args, *kwargs) 1522 # If we don't have any hooks, we want to skip the rest of the logic in 1523 # this function, and just call forward. 1524 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks 1525 or _global_backward_pre_hooks or _global_backward_hooks 1526 or _global_forward_hooks or _global_forward_pre_hooks): -> 1527 return forward_call(args, **kwargs) 1529 try: 1530 result = None

File /usr/local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:886, in LlamaModel.forward(self, input_ids, attention_mask, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict) 883 attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None 884 else: 885 # 4d mask is passed through the layers --> 886 attention_mask = _prepare_4d_causal_attention_mask( 887 attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length 888 ) 890 # embed positions 891 hidden_states = inputs_embeds

File /usr/local/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py:886, in (___stack0, self, position_ids, past_key_values, inputs_embeds, use_cache, output_attentions, output_hidden_states, return_dict) 883 attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None 884 else: 885 # 4d mask is passed through the layers --> 886 attention_mask = _prepare_4d_causal_attention_mask( 887 attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length 888 ) 890 # embed positions 891 hidden_states = inputs_embeds

File /usr/local/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:328, in _TorchDynamoContext.call.._fn(*args, *kwargs) 326 dynamic_ctx.enter() 327 try: --> 328 return fn(args, **kwargs) 329 finally: 330 set_eval_frame(prior)

File /usr/local/lib/python3.10/site-packages/torch/_dynamo/external_utils.py:17, in wrap_inline..inner(*args, kwargs) 15 @functools.wraps(fn) 16 def inner(*args, *kwargs): ---> 17 return fn(args, kwargs)

File /usr/local/lib/python3.10/site-packages/torch/_dynamo/backends/torchxla.py:49, in xla_backend_helper..fwd(args) 47 nonlocal compiled_graph 48 if compiled_graph is None: ---> 49 compiled_graph = bridge.extract_compiled_graph(model, args) 50 del model 51 return compiled_graph(args)

File /usr/local/lib/python3.10/site-packages/torch_xla/core/dynamo_bridge.py:514, in extract_compiled_graph(xla_model, xla_args) 511 partition.nodes = topo_sort(partition.nodes) 513 # fuse partitions and exectue to collect inputs --> 514 partitioned_graph = partitioner.fuse_partitions(partitions) 515 InputCollector(partitioned_graph).run(*xla_args) 517 # compile each submodule and replace it with a call

File /usr/local/lib/python3.10/site-packages/torch/fx/passes/infra/partitioner.py:217, in CapabilityBasedPartitioner.fuse_partitions(self, partitions) 215 logger.debug("Fusing partitions...") 216 # fuse_by_partitions expects partitions in List[List[Node]]: [ [node0, node1], [node2, node3] ] --> 217 return fuse_by_partitions(self.graph_module, [list(partition.nodes) for partition in partitions])

File /usr/local/lib/python3.10/site-packages/torch/fx/passes/utils/fuser_utils.py:224, in fuse_by_partitions(gm, partitions) 221 sorted_nodes = topo_sort(nodes) 223 submodulename = "fused" + str(partition_id) --> 224 sub_gm, orig_inputs, orig_outputs = fuse_as_graphmodule(gm, sorted_nodes, submodule_name) 226 insert_subgm(gm, sub_gm, orig_inputs, orig_outputs) 228 erase_nodes(gm, sorted_nodes)

File /usr/local/lib/python3.10/site-packages/torch/fx/passes/utils/fuser_utils.py:123, in fuse_as_graphmodule(gm, nodes, module_name) 120 assert node in gm.graph.nodes, f"{node} is not found in graph module {gm._get_name()}" 122 # validates partition doesn't introduce dependency circles in the graph --> 123 assert validate_partition(nodes), "Invalid partition, found dependency cycles" 125 subgraph = Graph() 127 node_to_placeholder: Dict[Node, Node] = {} # mapping of nodes from old graph to placeholder in new graph

AssertionError: Invalid partition, found dependency cycles

yeounoh commented 10 months ago

Hi @defdet thanks for reporting, will take a look.

yeounoh commented 10 months ago

@defdet could you provide the full partiton_module implementation to see where you are applying the annotations? Some ops & sharding was not supported in Dynamo in 2.1 (the version you are using) -- the Dynamo + SPMD integration is still experimental, we will track and debug this in the upcoming releases. cc @JackCaoG @wonjoolee95

yeounoh commented 10 months ago

I just ran the full inference loop with openxla_eval and with just the input shardings (without partition_module):

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████| 3/3 [00:13<00:00,  4.34s/it]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [01:47<00:00,  9.32it/s]
defdet commented 10 months ago

@yeounoh Thanks for an explanation. Of course, here's the code for the partition_module that was taken from this repo: import torch import torch.nn as nn import re import torch_xla.experimental.xla_sharding as xs import torch_xla.core.xla_model as xm from transformers import ( GPTNeoXConfig, T5Config, LlamaConfig )

GPTNEOX_RULES = (

embeddings

("gpt_neox\\.embed_in", ("mp", "fsdp")),
# atention
("attention\\.query_key_value$", ("fsdp", "mp")),
("attention\\.dense$", ("mp", "fsdp")),
# mlp
("mlp\\.dense_h_to_4h$", ("fsdp", "mp")),
("mlp\\.dense_4h_to_h$", ("mp", "fsdp")),
# output
("embed_out", ("fsdp", "mp")),

)

T5_RULES = (

embeddings

("shared$", ("mp", "fsdp")),
("embed_tokens$", ("mp", "fsdp")),

# attention
("q$", ("fsdp", "mp")),
("k$", ("fsdp", "mp")),
("v$", ("fsdp", "mp")),
("o$", ("mp", "fsdp")),

# mlp
("w$", ("fsdp", "mp")),
("wi_0$", ("fsdp", "mp")),
("wi_1$", ("fsdp", "mp")),
("wo$", ("mp", "fsdp")),

# seq2seq lm head
("lm_head", ("fsdp", "mp")),

)

LLAMA_RULES = ( ("model\.embed_tokens", ("mp", "fsdp")), ("self_attn\.(q_proj|k_proj|v_proj)", ("fsdp", "mp")), ("self_attn\.o_proj", ("mp", "fsdp")), ("mlp\.gate_proj", ("fsdp", "mp")), ("mlp\.down_proj", ("mp", "fsdp")), ("mlp\.up_proj", ("fsdp", "mp")), ("lm_head", ("fsdp", "mp")), )

ALL_RULES = [ (GPTNeoXConfig, GPTNEOX_RULES), (T5Config, T5_RULES), (LlamaConfig, LLAMA_RULES) ]

def find_rule(model): for config, rule in ALL_RULES: if model.config.class == config: return rule raise Exception("unsupported model to partitioning")

strkey2id = { "dp": 0, "fsdp": 1, "mp": 2 }

def partition_module(model, mesh, device=xm.xla_device(), verbose=False): partition_specs = find_rule(model) rule = [(k, tuple([strkey2id[x] for x in v])) for k, v in partition_specs]

# print(rule)

for name, module in model.named_modules():
    module.to(device)
    # print(name, module.__class__.__name__)
    if isinstance(module, (nn.Embedding, nn.Linear)):
        for rule_pattern, spec in rule:
            if re.findall(rule_pattern, name):
                if verbose:
                    print("match", rule_pattern, name)

                xs.mark_sharding(module.weight, mesh, spec)
                break

def partition_module_dp(model, mesh, device=xm.xla_device(), verbose=False): spec = (1, 2)

for name, module in model.named_modules():
    module.to(device)
    if isinstance(module, (nn.Embedding, nn.Linear)):
        xs.mark_sharding(module.weight, mesh, spec)

It looks a bit weird because of markdown. You said that some sharding and ops isn't supported in dynamo 2.1. Does it mean there's a chance it's supported in the nightly version?

defdet commented 10 months ago

If the model works with just the input shardings, I guess that must mean the issue's with sharding the model. It definetely works without dynamo though.

yeounoh commented 10 months ago

If the model works with just the input shardings, I guess that must mean the issue's with sharding the model. It definetely works without dynamo though.

Not sure, there may be some ops support still missing with the latest.

yeounoh commented 10 months ago

Ok, the issue is the CPU fallback for einsum. If you replace it with matmul, e.g., in modeling_llama.py in the transformers:

# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.matmul(t.unsqueeze(1), self.inv_freq.unsqueeze(0))

then your partition_module woruld work with both openxla and openxla_eval backends. Also, I dont' see much slowdown with larger batch sizes. However, the larger batch size (>>2) will result in longer latency, but overall your generation is faster since you are processing more data at a time.

TLDR, we dont' have a full support for einsum in Dynamo, yet.

I used the nightly. Again, the SPMD + Dynamo integration is experimental and not Alpha/Beta.

@defdet let me know if you could verify, then we can close.

defdet commented 10 months ago

I'm trying to verify but still running into some issues. Using 2.1 (not nightly version) results in the same error. However, when using the nightly version of both torch_xla and torch, kernel dies when trying to execute num_devices = xr.global_runtime_device_count(). I'm assuming I didn't install some neccesary package. Can you please provide commands on how to install the nigthly version correctly? Commands I used are: !pip uninstall torch torchvision torchaudio torchtext transformers torchdata -y !pip3 install --pre torch torchvision torchaudio --index-url https://download.pytorch.org/whl/nightly/cpu !pip install transformers git+https://github.com/defdet/transformers-fixed-llama # Replaced every instance of freqs = torch.einsum("i,j->ij", t, self.inv_freq) with freqs = torch.matmul(t.unsqueeze(1), self.inv_freq.unsqueeze(0)) !pip install torch_xla https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl -q !pip3 install datasets sentencepiece einops -q

yeounoh commented 10 months ago

Hi @defdet ,

You could try this to install nightly, the release wheels are public:

!pip install --user \
    https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp310-cp310-linux_x86_64.whl \
   'torch_xla[tpuvm] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl'

I also made some minor changes to your code, in addition to the HF modelling_llama.py:

import os
import pandas as pd
import numpy as np
import datasets
import torch.optim as optim
import torch_xla.debug.profiler as xp
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp # We also import mp modules if we wanna use that for some reason
import torch_xla.distributed.parallel_loader as pl
import torch_xla.test.test_utils as test_utils
import torch
import torch.nn as nn
import re
import torch_xla.experimental.xla_sharding as xs
import torch_xla.core.xla_model as xm
from transformers import (
    GPTNeoXConfig, T5Config, LlamaConfig, AutoTokenizer, AutoModelForCausalLM, DataCollatorWithPadding, AutoConfig
)

from transformers import logging as hf_logging
import torch.nn.functional as F
import torch_xla.runtime as xr

xr.use_spmd()

import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharded_tensor import XLAShardedTensor
from torch_xla.experimental.xla_sharding import Mesh

from spmd_util import partition_module
from tqdm import tqdm

os.environ['USE_TORCH']="1"
os.environ['XLA_HLO_DEBUG'] = '1'
os.environ["PJRT_DEVICE"] = "TPU"
#os.environ.pop('TPU_PROCESS_ADDRESSES')
#os.environ.pop('CLOUD_TPU_TASK_ID')

device = xm.xla_device()

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch, random
MODEL = "garage-bAInd/Stable-Platypus2-13B"
model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True)
model = model.to(device)
tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)

TOPICS = ["Car-free cities",
"Does the electoral college work?",
"Exploring Venus",
"The Face on Mars",
"Facial action coding system",
"Seeking multiple opinions",
"Phones and driving"]

config = AutoConfig.from_pretrained(MODEL)
num_devices = xr.global_runtime_device_count()
mesh_shape = (1, num_devices, 1)
device_ids = np.array(range(num_devices))
mesh = Mesh(device_ids, mesh_shape, ('dp', 'fsdp', 'mp'))

# partition the module
partition_module(model, mesh)

model.eval()

torch._dynamo.reset()
dynamo_model = torch.compile(model, backend='openxla_eval')  # 'openxla' works, but with issue
tokenizer.pad_token = tokenizer.eos_token
batch_size = 2 # Increase in batch_size leads to slower forward pass with openxla backend compilation
max_length = 512
for _ in tqdm(range(1000)): #TQDM shows slowdown wuth openxla backend. No slowdown noticed without compilation
    topic = TOPICS[random.randint(0, 6)]
    prompt = f'''
### Instruction:
Write an essay based on the topic provided as if you were a student. Your essay needs to be unique and convincing and not very long. Output nothing but the essay.

{topic}

### Response:

    '''
    inputs = tokenizer([prompt for _ in range(batch_size)], return_tensors="pt", padding='max_length', max_length=max_length)
    input_ids, attention_mask = inputs.input_ids.to(device), inputs.attention_mask.to(device)
    xs.mark_sharding(input_ids, mesh, (0, 1))
    xs.mark_sharding(attention_mask, mesh, (0, 1))
    with torch.no_grad():
        outputs = dynamo_model(input_ids, attention_mask) # Fails here. We can use uncompiled `model`, but it's slower.
defdet commented 10 months ago

Hello @yeounoh,

Thanks for code on how to install correct packages. There are still some issues though: it seems that without correct versions of torchtext, torchaudio, torchdata (or at least some of them) import torch, torch_xla fails. It looks like I need correct versions of them for torch 2.2.0, but I can't find how to install them. I can, however, use nightly CPU version of torch, and forward pass after open_xla compilation is sucessfull, but batch size issue still remains (you can see that in the updated version of notebook I provided earlier). You mentioned that with correct torch versions, this issue also goes away. Could you please provide the command on how to install torchdata, torchvision? Since the underlying error's been solved, I'm ready to close the issue.

yeounoh commented 10 months ago

Hi @defdet, maybe try pip3 install --pre torchvision torchdata and see if that works? Otherwise, I would build from source to install the nightly (make sure that you are on the commit from the same date as your torch nightly).

By the way, import torch, torch_xla shouldn't depend on torchtext, torchaudio, torchdata. Maybe try to uninstall all those if not needed.

EDIT: also I don't think you have ported some changes I've made in the above script or maybe I am not seeing the latest notebook, like this model = model.to(device) ?

defdet commented 10 months ago

@yeounoh, I indeed forgot to update the code, done already. I still haven't been able to solve the import torch though. pip3 install --pre torchvision torchdata installs torch==2.1.0 which doesn't work. I tried uninstalling torchtext, torchaudio, torchdata, but I'm getting this error at from torch._C import *: ImportError: libopenblas.so.0: cannot open shared object file: No such file or directory (full traceback at the next message and as an updated notebook version at Kaggle). I feel like the error lies within the dependencies that are installed with torchtext, torchaudio, torchdata but I'm really confused (also, if you don't uninstall them, pip says they conflict with 2.2.0 version of torch). I also tried running the nightly CPU version after fully applying your modifications but to no avail.

defdet commented 10 months ago

Full traceback:

ImportError Traceback (most recent call last) Cell In[4], line 1 ----> 1 import torch

File ~/.local/lib/python3.10/site-packages/torch/init.py:237 235 if USE_GLOBAL_DEPS: 236 _load_global_deps() --> 237 from torch._C import * # noqa: F403 239 # Appease the type checker; ordinarily this binding is inserted by the 240 # torch._C module initialization code in C 241 if TYPE_CHECKING:

ImportError: libopenblas.so.0: cannot open shared object file: No such file or directory

yeounoh commented 10 months ago

I see, @defdet this looks like another dependency issue on the system... if you had a TPUVM, then the easy way to pip install nightly would be (from my earlier comment):

!pip install --user \
    https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch-nightly-cp310-cp310-linux_x86_64.whl \
   'torch_xla[tpuvm] @ https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-nightly-cp310-cp310-linux_x86_64.whl'

otherwise, since this is not working on your Kaggle environment, it seems like you need some additional support from the Kaggle team to resolve the dependency issue. For instance, not sure if you have a permission to override the system dendencies like openblas?

Hm, as an alternative, we are releasing 2.2 soon, so if Kaggle is planning to support the official release, then maybe you could wait for it?

defdet commented 10 months ago

@yeounoh, hardware provided by Kaggle is actually TPU VM v3-8 - maybe the specific release you provided is for v4? I'm also gonna search for help on Kaggle's side as you advised. But waiting for official relase seems reasonable - I don't see why Kaggle wouldn't support it since it supports 2.1. Can you give a rough estimate of when you're releasing 2.2?

yeounoh commented 10 months ago

Hi @defdet, we don't release separate versions for different TPU or HW types. Both r2.1 and nightly should work on TPUv3 and TPUv4. The next release is scheduled for sometime in January, we will release with PyTorch 2.2. Hope this helps.

defdet commented 8 months ago

@yeounoh, sorry for such a late response, but I've managed to install torch 2.3 in the Kaggle env and can confirm, very little, if any, slowdown with larger batch_size! openxla_eval doesn't work again since the code for modeling_llama's been updated once again, but openxla works just fine (I suspect another cpu fallback - not an issue with torch xla). Can finally close the issue. Thanks a lot for your help! edit: notebook that works

defdet commented 8 months ago

So sorry, I was actually wrong. It's pretty interesting though. I've replaced the code for modeling_llama to the same as 4.35.2 release with implemented changes and now some measurements: 1) Uncompiled model takes ~60 ms for forward pass for any batch size. I think it's faster than before 2) Compiled model with any backend takes ~30 ms for forward pass for batch size 1. 3) For batch size 2 (and larger batch size if we scale) compiled model with any backend takes ~30 ms for about 30 first forward passes, then for some reason slows down to 70 ms (about 120 ms for batch_size=3) and stays at that speed. Notebook with replaced code has been updated to show the repro. From pure practical standpoint, for any batch size larger than 1 uncompiled model runs much faster. I'm really confused at the results. Can you please explain what could be the reason for this?