facebookresearch / fairseq

Facebook AI Research Sequence-to-Sequence Toolkit written in Python.
MIT License
30.44k stars 6.4k forks source link

CUDA error when evaluating fairseq with `torch.compile`. #4975

Open 0x6b64 opened 1 year ago

0x6b64 commented 1 year ago

🐛 Bug

Hi, I'm training roberta_large with DDP with torch.compile API wrapping the model definition in trainer. This API is introduced PyTorch 2.0; This error doesn't happen without the torch.compile wrapper (so most likely this is a bug with the triton codegen; but given that it only happens with Fairseq && not the other models like huggingface GPT2, Bert_large suggests its worth auditing if fairseq is doing something extraordinary).

This is the one line change I've made to Fairseq in trainer model property. https://github.com/facebookresearch/fairseq/blob/main/fairseq/trainer.py#L253

I've also created a ticket in PyTorch issues: https://github.com/pytorch/pytorch/issues/93378

    @property
    def model(self):
        if self._wrapped_model is None:
            if self.use_distributed_wrapper:
                self._wrapped_model = models.DistributedFairseqModel(
                    self.cfg.distributed_training,
                    self._model,
                    process_group=self.data_parallel_process_group,
                device=self.device,
                )
                self._wrapped_model = torch.compile(self._wrapped_model) <- added line
            else:
    self._wrapped_model = self._modele("cuda")
        return self._wrapped_model

To Reproduce

Steps to reproduce the behavior (always include the command you ran):

  1. Run cmd

I'm working with the wikitext dataset: https://huggingface.co/datasets/wikitext/tree/main

mpirun -np 8 \
fairseq-train wikitext-103 \
--adam-eps 1e-06 \
--arch roberta_large \
--attention-dropout 0.1 \
--clip-norm 0.0 \
--criterion masked_lm \
--distributed-backend nccl \
--distributed-no-spawn \
--dropout 0.1 \
--encoder-embed-dim 2048 \
--encoder-ffn-embed-dim 8192 \
--encoder-layers 24 \
--log-format simple \
--log-interval 10 \
--lr 0.0001 \
--lr-scheduler polynomial_decay \
--max-sentences 8 \
--max-update 500 \
--optimizer adam \
--sample-break-mode complete \
--skip-invalid-size-inputs-valid-test \
--task masked_lm \
--tokens-per-sample 512 \
--total-num-update 100 \
--update-freq 1 \
--weight-decay 0.01 \
--no-save \
--memory-efficient-fp16 \
--skip-invalid-size-inputs-valid-test \
--no-last-checkpoints 
  1. See error

Here is the stacktrace:

  File "/opt/conda/bin/fairseq-train", line 8, in <module>
    sys.exit(cli_main())
  File "/fsx/roberta/fairseq_master/fairseq/fairseq_cli/train.py", line 574, in cli_main
    distributed_utils.call_main(cfg, main)
  File "/fsx/roberta/fairseq_master/fairseq/fairseq/distributed/utils.py", line 389, in call_main
    distributed_main(cfg.distributed_training.device_id, main, cfg, kwargs)
  File "/fsx/roberta/fairseq_master/fairseq/fairseq/distributed/utils.py", line 362, in distributed_main
    main(cfg, **kwargs)
  File "/fsx/roberta/fairseq_master/fairseq/fairseq_cli/train.py", line 205, in main
    valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
  File "/opt/conda/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/fsx/roberta/fairseq_master/fairseq/fairseq_cli/train.py", line 331, in train
    log_output = trainer.train_step(samples)
  File "/opt/conda/lib/python3.9/contextlib.py", line 79, in inner
    return func(*args, **kwds)
  File "/fsx/roberta/fairseq_master/fairseq/fairseq/trainer.py", line 869, in train_step
    raise e
  File "/fsx/roberta/fairseq_master/fairseq/fairseq/trainer.py", line 844, in train_step
    loss, sample_size_i, logging_output = self.task.train_step(
  File "/fsx/roberta/fairseq_master/fairseq/fairseq/tasks/fairseq_task.py", line 531, in train_step
    loss, sample_size, logging_output = criterion(model, sample)
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/roberta/fairseq_master/fairseq/fairseq/criterions/masked_lm.py", line 58, in forward
    logits = model(**sample["net_input"], masked_tokens=masked_tokens)[0]
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 211, in _fn
    return fn(*args, **kwargs)
  File "/fsx/roberta/fairseq_master/fairseq/fairseq/distributed/module_proxy_wrapper.py", line 56, in forward
    return self.module(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1157, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/parallel/distributed.py", line 1111, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])  # type: ignore[index]
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/roberta/fairseq_master/fairseq/fairseq/models/roberta/model.py", line 255, in forward
    x, extra = self.encoder(src_tokens, features_only, return_all_hiddens, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/roberta/fairseq_master/fairseq/fairseq/models/roberta/model.py", line 601, in forward
    x, extra = self.extract_features(
  File "/fsx/roberta/fairseq_master/fairseq/fairseq/models/roberta/model.py", line 609, in extract_features
    encoder_out = self.sentence_encoder(
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/roberta/fairseq_master/fairseq/fairseq/models/transformer/transformer_encoder.py", line 165, in forward
    return self.forward_scriptable(
  File "/fsx/roberta/fairseq_master/fairseq/fairseq/models/transformer/transformer_encoder.py", line 173, in forward_scriptable
    def forward_scriptable(
  File "/fsx/roberta/fairseq_master/fairseq/fairseq/models/transformer/transformer_encoder.py", line 212, in <graph break in forward_scriptable>
    x, encoder_embedding = self.forward_embedding(src_tokens, token_embeddings)
  File "/fsx/roberta/fairseq_master/fairseq/fairseq/models/transformer/transformer_encoder.py", line 230, in <graph break in forward_scriptable>
    lr = layer(
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/roberta/fairseq_master/fairseq/fairseq/modules/transformer_layer.py", line 197, in forward
    x, _ = self.self_attn(
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/fsx/roberta/fairseq_master/fairseq/fairseq/modules/multihead_attention.py", line 469, in forward
    def forward(
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 82, in forward
    return self.dynamo_ctx(self._orig_mod.forward)(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 211, in _fn
    return fn(*args, **kwargs)
  File "<eval_with_key>.4171", line 16, in forward
  File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1488, in _call_impl
    return forward_call(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/optimizations/distributed.py", line 239, in forward
    x = self.submod(*args)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/eval_frame.py", line 211, in _fn
    return fn(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2489, in forward
    return compiled_fn(full_args)
  File "/opt/conda/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 996, in g
    return f(*args)
  File "/opt/conda/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 2058, in debug_compiled_function
    return compiled_function(*args)
  File "/opt/conda/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1922, in compiled_function
    all_outs = CompiledFunction.apply(*args_with_synthetic_bases)
  File "/opt/conda/lib/python3.9/site-packages/torch/autograd/function.py", line 508, in apply
    return super().apply(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1706, in forward
    fw_outs = call_func_with_args(
  File "/opt/conda/lib/python3.9/site-packages/torch/_functorch/aot_autograd.py", line 1021, in call_func_with_args
    out = normalize_as_list(f(args))
  File "/opt/conda/lib/python3.9/site-packages/torch/_inductor/compile_fx.py", line 220, in run
    return model(new_inputs)
  File "/tmp/torchinductor_ec2-user/qc/cqc4z4g2oi6k5dtsfp7vzop4z7ucxgcvqy64pfz464n5yjjemmhf.py", line 309, in call
    triton__2.run(buf3, primals_5, buf7, 16384, 499, grid=grid(16384, 499), stream=stream4)
  File "/opt/conda/lib/python3.9/site-packages/torch/_inductor/triton_ops/autotune.py", line 180, in run
    self.autotune_to_one_config(*args, grid=grid)
  File "/opt/conda/lib/python3.9/site-packages/torch/_dynamo/utils.py", line 160, in time_wrapper
    r = func(*args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/_inductor/triton_ops/autotune.py", line 167, in autotune_to_one_config
    timings = {
  File "/opt/conda/lib/python3.9/site-packages/torch/_inductor/triton_ops/autotune.py", line 168, in <dictcomp>
    launcher: self.bench(launcher, *cloned_args, **kwargs)
  File "/opt/conda/lib/python3.9/site-packages/torch/_inductor/triton_ops/autotune.py", line 149, in bench
    return do_bench(kernel_call, rep=40, fast_flush=True)
  File "/opt/conda/lib/python3.9/site-packages/triton/testing.py", line 141, in do_bench
    torch.cuda.synchronize()
  File "/opt/conda/lib/python3.9/site-packages/torch/cuda/__init__.py", line 597, in synchronize
    return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Code sample

One line code change to fairseq attached above.

Expected behavior

Expect training to start.

Environment

Additional context

VarunGumma commented 1 year ago

When you train models do you get very verbose statements about the optimizations being performed like the following? If so, how are tackling it?

[2023-03-17 10:18:15,712] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 17
[2023-03-17 10:18:15,712] torch._dynamo.output_graph: [INFO] Step 2: done compiler function debug_wrapper
[2023-03-17 10:18:15,887] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward
[2023-03-17 10:18:15,953] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing forward (RETURN_VALUE)
[2023-03-17 10:18:15,954] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function debug_wrapper
[2023-03-17 10:18:16,279] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 18
[2023-03-17 10:18:16,292] torch._inductor.graph: [INFO] Using FallbackKernel: torch.ops.aten._scaled_dot_product_flash_attention.default
[2023-03-17 10:18:16,293] torch._inductor.utils: [INFO] using triton random, expect difference from eager
[2023-03-17 10:18:16,398] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 18
[2023-03-17 10:18:16,399] torch._dynamo.output_graph: [INFO] Step 2: done compiler function debug_wrapper
[2023-03-17 10:18:16,580] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in forward>
[2023-03-17 10:18:16,620] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing forward
[2023-03-17 10:18:16,657] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing extract_features
[2023-03-17 10:18:16,694] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing extract_features_scriptable
[2023-03-17 10:18:16,722] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function debug_wrapper
[2023-03-17 10:18:16,817] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 19
[2023-03-17 10:18:16,820] torch._inductor.graph: [INFO] Using FallbackKernel: aten.cumsum
[2023-03-17 10:18:16,831] torch._inductor.utils: [INFO] using triton random, expect difference from eager
[2023-03-17 10:18:16,912] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 19
[2023-03-17 10:18:16,912] torch._dynamo.output_graph: [INFO] Step 2: done compiler function debug_wrapper
[2023-03-17 10:18:17,079] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in extract_features_scriptable>
[2023-03-17 10:18:17,946] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing <graph break in extract_features_scriptable> (RETURN_VALUE)
[2023-03-17 10:18:17,960] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function debug_wrapper
[2023-03-17 10:18:22,078] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 20
[2023-03-17 10:18:22,138] torch._inductor.utils: [INFO] using triton random, expect difference from eager
[2023-03-17 10:18:24,039] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 20
[2023-03-17 10:18:24,040] torch._dynamo.output_graph: [INFO] Step 2: done compiler function debug_wrapper
[2023-03-17 10:18:24,350] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo start tracing <graph break in forward>
[2023-03-17 10:18:24,353] torch._dynamo.symbolic_convert: [INFO] Step 1: torchdynamo done tracing <graph break in forward> (RETURN_VALUE)
[2023-03-17 10:18:24,354] torch._dynamo.output_graph: [INFO] Step 2: calling compiler function debug_wrapper
[2023-03-17 10:18:24,375] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling FORWARDS graph 21
[2023-03-17 10:18:24,416] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling FORWARDS graph 21
[2023-03-17 10:18:24,416] torch._dynamo.output_graph: [INFO] Step 2: done compiler function debug_wrapper
[2023-03-17 10:18:24,633] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling BACKWARDS graph 21
[2023-03-17 10:18:24,849] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling BACKWARDS graph 21
[2023-03-17 10:18:24,873] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling BACKWARDS graph 20
[2023-03-17 10:18:28,375] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling BACKWARDS graph 20
[2023-03-17 10:18:28,378] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling BACKWARDS graph 19
[2023-03-17 10:18:28,735] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling BACKWARDS graph 19
[2023-03-17 10:18:28,739] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling BACKWARDS graph 18
[2023-03-17 10:18:28,774] torch._inductor.graph: [INFO] Using FallbackKernel: torch.ops.aten._scaled_dot_product_flash_attention_backward.default
[2023-03-17 10:18:29,456] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling BACKWARDS graph 18
[2023-03-17 10:18:29,459] torch._inductor.compile_fx: [INFO] Step 3: torchinductor compiling BACKWARDS graph 17
[2023-03-17 10:18:29,494] torch._inductor.graph: [INFO] Using FallbackKernel: torch.ops.aten._scaled_dot_product_flash_attention_backward.default
[2023-03-17 10:18:29,993] torch._inductor.compile_fx: [INFO] Step 3: torchinductor done compiling BACKWARDS graph 17