pytorch / torchtitan

A native PyTorch Library for large model training
BSD 3-Clause "New" or "Revised" License
1.53k stars 139 forks source link

2D whole model compile fails at embedding layer #534

Open tianyu-l opened 3 weeks ago

tianyu-l commented 3 weeks ago

Specifically it failed at dealing with DTensor MaskPartial placement of sharded embedding.

This only happens when we do whole model compile. TransformerBlock-level compilation (default) + separately compiling the embedding layer doesn't have this issue.

error log ./run_llama_train.sh + NGPU=8 + LOG_RANK=0 + CONFIG_FILE=./train_configs/llama3_8b.toml + overrides= + '[' 0 -ne 0 ']' + torchrun --nproc_per_node=8 --rdzv_backend c10d --rdzv_endpoint=localhost:0 --local-ranks-filter 0 --role rank --tee 3 train.py --job.config_file ./train_configs/llama3_8b.toml W0819 17:09:53.189000 1633288 torch/distributed/run.py:793] W0819 17:09:53.189000 1633288 torch/distributed/run.py:793] ***************************************** W0819 17:09:53.189000 1633288 torch/distributed/run.py:793] Setting OMP_NUM_THREADS environment variable for each process to be 1 in default, to avoid your system being overloaded, please further tune the variable for optimal performance in your application as needed. W0819 17:09:53.189000 1633288 torch/distributed/run.py:793] ***************************************** [rank0]:2024-08-19 17:09:55,187 - root - INFO - Starting job: Llama 3 8B training [rank0]:2024-08-19 17:09:58,867 - root - WARNING - ENV[TORCH_NCCL_ASYNC_ERROR_HANDLING] = 1 will be overridden to 3 based on job config [rank0]:2024-08-19 17:09:58,879 - root - INFO - GPU capacity: NVIDIA H100 (0) with 95.04GiB memory [rank0]:2024-08-19 17:09:58,879 - root - INFO - Building 2-D device mesh with ['dp', 'tp'], [4, 2] [rank0]:2024-08-19 17:09:58,906 - root - INFO - Building tiktoken tokenizer locally from ./torchtitan/datasets/tokenizer/original/tokenizer.model [rank0]:2024-08-19 17:09:59,059 - root - INFO - TikTokenizer built: #words 128256, BOS ID 128000, EOS ID 128001 [rank0]:2024-08-19 17:09:59,059 - root - INFO - Preparing c4 dataset from allenai/c4 [rank0]:2024-08-19 17:10:06,989 - root - INFO - Building llama3 8B with ModelArgs(dim=4096, n_layers=32, n_heads=32, n_kv_heads=8, vocab_size=128256, multiple_of=1024, ffn_dim_multiplier=1.3, norm_eps=1e-05, rope_theta=500000, max_batch_size=32, max_seq_len=8192, depth_init=True, norm_type='rmsnorm') [rank0]:2024-08-19 17:10:07,101 - root - INFO - Model llama3 8B size: 8,030,261,248 total parameters [rank0]:2024-08-19 17:10:07,155 - root - INFO - Applied Tensor Parallelism to the model [rank0]:2024-08-19 17:10:07,425 - root - WARNING - detected that the pytorch is built from source. Please make sure the PR (https://github.com/pytorch/pytorch/pull/130760) is included in pytorch for correct 2D/3D DCP usage. [rank0]:2024-08-19 17:10:07,475 - root - INFO - Applied FSDP to the model [rank0]:2024-08-19 17:10:07,835 - root - INFO - GPU memory usage for model: 3.78GiB(3.98%) [rank0]:2024-08-19 17:10:07,836 - root - INFO - Training starts at step 1, with local batch size 1, global batch size 4, sequence length 8192, total steps 10 (warmup 2) [rank0]:NCCL version 2.21.5+cuda12.0 [rank0]:[rank0]: Traceback (most recent call last): [rank0]:[rank0]: File "/data/users/lty/torchtitan/train.py", line 424, in [rank0]:[rank0]: main(config) [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/distributed/elastic/multiprocessing/errors/__init__.py", line 355, in wrapper [rank0]:[rank0]: return f(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/torchtitan/train.py", line 299, in main [rank0]:[rank0]: pred = model(input_ids) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl [rank0]:[rank0]: return self._call_impl(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/nn/modules/module.py", line 1788, in _call_impl [rank0]:[rank0]: result = forward_call(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/eval_frame.py", line 509, in _fn [rank0]:[rank0]: return fn(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl [rank0]:[rank0]: return self._call_impl(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/nn/modules/module.py", line 1747, in _call_impl [rank0]:[rank0]: return forward_call(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/torchtitan/torchtitan/models/llama/model.py", line 436, in forward [rank0]:[rank0]: h = self.tok_embeddings(tokens) if self.tok_embeddings else tokens [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/nn/modules/module.py", line 1736, in _wrapped_call_impl [rank0]:[rank0]: return self._call_impl(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/nn/modules/module.py", line 1801, in _call_impl [rank0]:[rank0]: hook_result = hook(self, args, result) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/convert_frame.py", line 1238, in __call__ [rank0]:[rank0]: return self._torchdynamo_orig_callable( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/convert_frame.py", line 1039, in __call__ [rank0]:[rank0]: result = self._inner_convert( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/convert_frame.py", line 514, in __call__ [rank0]:[rank0]: return _compile( [rank0]:[rank0]: ^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/convert_frame.py", line 902, in _compile [rank0]:[rank0]: guarded_code = compile_inner(code, one_graph, hooks, transform) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/convert_frame.py", line 653, in compile_inner [rank0]:[rank0]: return _compile_inner(code, one_graph, hooks, transform) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_utils_internal.py", line 87, in wrapper_function [rank0]:[rank0]: return function(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/convert_frame.py", line 686, in _compile_inner [rank0]:[rank0]: out_code = transform_code_object(code, transform) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/bytecode_transformation.py", line 1322, in transform_code_object [rank0]:[rank0]: transformations(instructions, code_options) [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/convert_frame.py", line 208, in _fn [rank0]:[rank0]: return fn(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/convert_frame.py", line 622, in transform [rank0]:[rank0]: tracer.run() [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2731, in run [rank0]:[rank0]: super().run() [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 958, in run [rank0]:[rank0]: while self.step(): [rank0]:[rank0]: ^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 870, in step [rank0]:[rank0]: self.dispatch_table[inst.opcode](self, inst) [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 558, in wrapper [rank0]:[rank0]: return inner_fn(self, inst) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2242, in CALL [rank0]:[rank0]: self._call(inst) [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2236, in _call [rank0]:[rank0]: self.call_function(fn, args, kwargs) [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 805, in call_function [rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/variables/lazy.py", line 156, in realize_and_forward [rank0]:[rank0]: return getattr(self.realize(), name)(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/variables/functions.py", line 906, in call_function [rank0]:[rank0]: return self.func.call_function(tx, merged_args, merged_kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/variables/functions.py", line 322, in call_function [rank0]:[rank0]: return super().call_function(tx, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/variables/functions.py", line 106, in call_function [rank0]:[rank0]: return tx.inline_user_function_return(self, [*self.self_args(), *args], kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 811, in inline_user_function_return [rank0]:[rank0]: return InliningInstructionTranslator.inline_call(self, fn, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2946, in inline_call [rank0]:[rank0]: return cls.inline_call_(parent, func, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 3062, in inline_call_ [rank0]:[rank0]: tracer.run() [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 958, in run [rank0]:[rank0]: while self.step(): [rank0]:[rank0]: ^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 870, in step [rank0]:[rank0]: self.dispatch_table[inst.opcode](self, inst) [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 558, in wrapper [rank0]:[rank0]: return inner_fn(self, inst) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2242, in CALL [rank0]:[rank0]: self._call(inst) [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 2236, in _call [rank0]:[rank0]: self.call_function(fn, args, kwargs) [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/symbolic_convert.py", line 805, in call_function [rank0]:[rank0]: self.push(fn.call_function(self, args, kwargs)) # type: ignore[arg-type] [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/variables/misc.py", line 970, in call_function [rank0]:[rank0]: return self.obj.call_method(tx, self.name, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/variables/tensor.py", line 527, in call_method [rank0]:[rank0]: result = handler_method(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/variables/tensor.py", line 905, in method_redistribute [rank0]:[rank0]: return wrap_fx_proxy( [rank0]:[rank0]: ^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/variables/builder.py", line 1916, in wrap_fx_proxy [rank0]:[rank0]: return wrap_fx_proxy_cls(target_cls=TensorVariable, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/variables/builder.py", line 2003, in wrap_fx_proxy_cls [rank0]:[rank0]: example_value = get_fake_value(proxy.node, tx, allow_non_graph_fake=True) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/utils.py", line 2051, in get_fake_value [rank0]:[rank0]: raise TorchRuntimeError(str(e)).with_traceback(e.__traceback__) from None [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/utils.py", line 1983, in get_fake_value [rank0]:[rank0]: ret_val = wrap_fake_exception( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/utils.py", line 1468, in wrap_fake_exception [rank0]:[rank0]: return fn() [rank0]:[rank0]: ^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/utils.py", line 1984, in [rank0]:[rank0]: lambda: run_node(tx.output, node, args, kwargs, nnmodule) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/utils.py", line 2119, in run_node [rank0]:[rank0]: raise RuntimeError(make_error_message(e)).with_traceback( [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/utils.py", line 2101, in run_node [rank0]:[rank0]: return node.target(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_dynamo/variables/tensor.py", line 898, in redistribute_fn_with_prim_types [rank0]:[rank0]: return x.redistribute(*args_as_value, **kwargs_as_value) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/distributed/_tensor/api.py", line 541, in redistribute [rank0]:[rank0]: return Redistribute.apply(self, device_mesh, placements, async_op) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/autograd/function.py", line 575, in apply [rank0]:[rank0]: return super().apply(*args, **kwargs) # type: ignore[misc] [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/distributed/_tensor/_redistribute.py", line 295, in forward [rank0]:[rank0]: output = redistribute_local_tensor( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/distributed/_tensor/_redistribute.py", line 214, in redistribute_local_tensor [rank0]:[rank0]: new_local_tensor = partial_spec._reduce_shard_value( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/distributed/_tensor/ops/_embedding_ops.py", line 143, in _reduce_shard_value [rank0]:[rank0]: self.mask_buffer.apply_mask(tensor) [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/distributed/_tensor/ops/_embedding_ops.py", line 67, in apply_mask [rank0]:[rank0]: tensor[self.data, :] = 0.0 [rank0]:[rank0]: ~~~~~~^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/utils/_stats.py", line 21, in wrapper [rank0]:[rank0]: return fn(*args, **kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_subclasses/fake_tensor.py", line 1251, in __torch_dispatch__ [rank0]:[rank0]: return self.dispatch(func, types, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_subclasses/fake_tensor.py", line 1705, in dispatch [rank0]:[rank0]: return self._cached_dispatch_impl(func, types, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_subclasses/fake_tensor.py", line 1361, in _cached_dispatch_impl [rank0]:[rank0]: output = self._dispatch_impl(func, types, args, kwargs) [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_subclasses/fake_tensor.py", line 1800, in _dispatch_impl [rank0]:[rank0]: (flat_args, flat_arg_fake_tensors) = self.validate_and_convert_non_fake_tensors( [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_subclasses/fake_tensor.py", line 2104, in validate_and_convert_non_fake_tensors [rank0]:[rank0]: validated_args = [validate(a) for a in flat_args] [rank0]:[rank0]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_subclasses/fake_tensor.py", line 2104, in [rank0]:[rank0]: validated_args = [validate(a) for a in flat_args] [rank0]:[rank0]: ^^^^^^^^^^^ [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/_subclasses/fake_tensor.py", line 2092, in validate [rank0]:[rank0]: raise AssertionError( [rank0]:[rank0]: torch._dynamo.exc.TorchRuntimeError: Failed running call_function .redistribute_fn_with_prim_types at 0x7fd6b53b8a40>(*(DTensor(local_tensor=FakeTensor(..., device='cuda:0', size=(1, 8192, 4096), dtype=torch.bfloat16), device_mesh=DeviceMesh('cuda', [0, 1], mesh_dim_names=('tp',)), placements=(_MaskPartial(offset_shape=(128256, 4096), offset_dim=0),)),), **{}): [rank0]:[rank0]: Please convert all Tensors to FakeTensors first or instantiate FakeTensorMode with 'allow_non_fake_inputs'. Found in aten.index_put_.default(FakeTensor(..., device='cuda:0', size=(1, 8192, 4096), dtype=torch.bfloat16), [tensor([...], device='cuda:0', size=(1, 8192))], FakeTensor(..., size=(), dtype=torch.bfloat16)) [rank0]: [rank0]:[rank0]: from user code: [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/distributed/_tensor/api.py", line 895, in [rank0]:[rank0]: lambda mod, inputs, outputs: output_fn(mod, outputs, device_mesh) [rank0]:[rank0]: File "/data/users/lty/pytorch/torch/distributed/tensor/parallel/style.py", line 251, in _prepare_output_fn [rank0]:[rank0]: outputs = outputs.redistribute(placements=output_layouts, async_op=True) [rank0]: [rank0]:[rank0]: Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information
bdhirsh commented 3 weeks ago

hmmm I wonder if this is the same as what wanchao and I saw with this: https://github.com/pytorch/pytorch/issues/130028#issuecomment-2234077127

tianyu-l commented 3 weeks ago

It looks Wanchao and Brian have been aware of this. Given how hard it is to tackle, let's stick with TransformerBlock-level compilation for now.

Also as of 08/19:

  • it seems whole model compile doesn't work well with SAC, as the performance dropped quite a bit (5700 -> 5000 tok/s on 8B model) compared with block-level compile.
  • whole model compile provides ~1.6% throughput gain, but has recompilation warnings.