pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
82.8k stars 22.32k forks source link

Observing negative number in PyTorch profiling #101861

Closed jianan-gu closed 3 months ago

jianan-gu commented 1 year ago

🐛 Describe the bug

We found there are negative numbers in PyTorch profilings, which is inconvenient for users to get solid profiling for operators:

from transformers import AutoModelForCausalLM, AutoTokenizer, LlamaForCausalLM, LlamaTokenizer
import torch
model = LlamaForCausalLM.from_pretrained("decapoda-research/llama-7b-hf", low_cpu_mem_usage=True, torch_dtype=torch.bfloat16)
tokenizer = LlamaTokenizer.from_pretrained("decapoda-research/llama-7b-hf")
model = model.eval()
prompt = "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun"
input_size = tokenizer(prompt, return_tensors="pt").input_ids.size(dim=1)

num_iter = 3
num_warmup = 2
def trace_handler(prof):
        sort_by="self_cpu_time_total", row_limit=-1))
with torch.profiler.profile(
            ) as prof:
  with torch.cpu.amp.autocast(enabled=True, dtype=torch.bfloat16):
    for i in range(num_iter):
        input_ids = tokenizer(prompt, return_tensors="pt").input_ids
        output  = model.generate(input_ids, max_new_tokens=32, do_sample=False, temperature=0.9, num_beams=4)
        gen_text = tokenizer.batch_decode(output, skip_special_tokens=True)
        print(gen_text, flush=True)
                     Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
                 aten::mm        72.77%        2.775s        72.77%        2.775s     385.375us          7200  
             aten::linear        14.62%     557.318ms        96.82%        3.692s     383.617us          9624  
       aten::index_select         2.71%     103.334ms         2.73%     103.926ms      49.916us          2082  
                aten::cat         2.36%      90.020ms         2.56%      97.640ms      20.144us          4847  
                aten::bmm         2.29%      87.493ms         2.29%      87.493ms      42.721us          2048  
                aten::mul         1.76%      67.101ms         1.96%      74.875ms       8.068us          9280  
             aten::matmul         1.46%      55.748ms        77.53%        2.956s     292.904us         10093  
                aten::add         1.18%      44.884ms         1.32%      50.220ms       6.822us          7362  
              aten::copy_         1.07%      40.672ms         1.07%      40.672ms       2.349us         17317  
               aten::silu         0.86%      32.886ms         0.86%      32.886ms      32.115us          1024  
           aten::_to_copy         0.73%      27.821ms         1.63%      62.339ms       3.739us         16674  
                 aten::to         0.38%      14.327ms         1.82%      69.255ms       3.472us         19945  
               aten::topk         0.37%      14.210ms         0.37%      14.210ms     444.062us            32  
                aten::pow         0.35%      13.345ms         0.35%      13.360ms       6.423us          2080  
              aten::index         0.31%      12.008ms         0.39%      14.734ms       7.084us          2080  
          aten::transpose         0.30%      11.467ms         0.31%      11.720ms       0.951us         12320  
           aten::_softmax         0.27%      10.438ms         0.27%      10.438ms      10.193us          1024  
              aten::slice         0.25%       9.689ms         0.26%       9.746ms       0.521us         18722  
                aten::sum         0.25%       9.355ms         0.25%       9.587ms       4.605us          2082  
            aten::reshape         0.23%       8.724ms         0.34%      12.862ms       0.891us         14434  
                aten::neg         0.22%       8.420ms         0.22%       8.420ms       4.111us          2048  
                  aten::t         0.21%       8.072ms         0.36%      13.617ms       1.891us          7200  
               aten::mean         0.18%       6.981ms         0.69%      26.352ms      12.669us          2080  
               aten::div_         0.18%       6.962ms         0.28%      10.676ms       4.886us          2185  
                aten::div         0.17%       6.423ms         0.26%       9.769ms       9.251us          1056  
             aten::narrow         0.16%       6.030ms         0.17%       6.536ms       0.798us          8192  
            aten::maximum         0.14%       5.213ms         0.18%       6.820ms       6.660us          1024  
             aten::expand         0.08%       3.104ms         0.08%       3.121ms       0.750us          4163  
            aten::squeeze         0.07%       2.580ms         0.07%       2.590ms       0.632us          4096  
            aten::softmax         0.06%       2.424ms         0.45%      17.091ms      16.690us          1024  
              aten::rsqrt         0.06%       2.204ms         0.06%       2.204ms       1.060us          2080  
                aten::max         0.05%       1.842ms         0.21%       8.023ms       7.583us          1058  
       aten::_unsafe_view         0.05%       1.812ms         0.05%       1.812ms       0.193us          9376  
       aten::_log_softmax         0.04%       1.636ms         0.04%       1.657ms      51.781us            32  
          aten::unsqueeze         0.04%       1.384ms         0.04%       1.384ms       0.627us          2209  
            aten::detach_         0.04%       1.344ms         0.04%       1.377ms       1.339us          1028  
     aten::_reshape_alias         0.03%       1.046ms         0.03%       1.046ms       0.348us          3007  
             aten::select         0.02%     918.000us         0.02%     922.000us       0.154us          6003  
             aten::unbind         0.02%     766.000us         0.02%     820.000us       8.454us            97  
               aten::view         0.02%     714.000us         0.02%     714.000us       0.049us         14597  
              aten::clone         0.02%     700.000us         0.08%       2.887ms      15.036us           192  
         aten::as_strided         0.01%     521.000us         0.01%     521.000us       0.011us         49595  
        aten::log_softmax         0.01%     458.000us         0.04%       1.705ms      53.281us            32  
                aten::sub         0.01%     306.000us         0.01%     367.000us       5.734us            64  
      aten::empty_strided         0.01%     288.000us         0.01%     288.000us       0.017us         16674  
              aten::fill_         0.01%     236.000us         0.01%     236.000us       0.110us          2154  
                aten::all         0.01%     235.000us         0.01%     335.000us      10.469us            32  
          aten::remainder         0.01%     227.000us         0.01%     227.000us       7.094us            32  
             aten::cumsum         0.01%     218.000us         0.01%     218.000us       6.412us            34  
          aten::embedding         0.01%     212.000us         0.03%     965.000us      30.156us            32  
              aten::zeros         0.01%     207.000us         0.01%     208.000us       2.122us            98  
       aten::masked_fill_         0.01%     198.000us         0.01%     198.000us       3.046us            65  
                 aten::eq         0.00%     166.000us         0.00%     166.000us       4.882us            34  
         aten::empty_like         0.00%     137.000us         0.00%     164.000us       0.854us           192  
           aten::new_ones         0.00%     119.000us         0.01%     192.000us       6.000us            32  
        aten::masked_fill         0.00%     115.000us         0.01%     223.000us       6.969us            32  
               aten::item         0.00%     102.000us         0.00%     169.000us       0.626us           270  
               aten::rsub         0.00%     101.000us         0.01%     291.000us       9.094us            32  
              aten::empty         0.00%      93.000us         0.00%      93.000us       0.027us          3476  
         aten::is_nonzero         0.00%      90.000us         0.01%     256.000us       2.586us            99  
aten::_local_scalar_dense         0.00%      76.000us         0.00%      76.000us       0.281us           270  
         aten::contiguous         0.00%      70.000us         0.01%     236.000us       7.375us            32  
            aten::view_as         0.00%      63.000us         0.00%      70.000us       2.188us            32  
                  detach_         0.00%      51.000us         0.00%      51.000us       0.050us          1028  
          aten::expand_as         0.00%      46.000us         0.00%     137.000us       4.281us            32  
          aten::new_empty         0.00%      33.000us         0.00%      40.000us       1.250us            32  
  aten::repeat_interleave         0.00%      28.000us         0.00%      66.000us      16.500us             4  
                aten::any         0.00%      11.000us         0.00%      17.000us      17.000us             1  
                 aten::lt         0.00%      11.000us         0.00%      11.000us       5.500us             2  
        aten::result_type         0.00%      10.000us         0.00%      10.000us       0.005us          2080  
             aten::arange         0.00%       9.000us         0.00%      15.000us       7.500us             2  
       aten::resolve_conj         0.00%       8.000us         0.00%       8.000us       0.000us         18497  
                aten::min         0.00%       7.000us         0.00%       7.000us       7.000us             1  
                   detach         0.00%       5.000us         0.00%       5.000us       5.000us             1  
                 aten::gt         0.00%       4.000us         0.00%       4.000us       4.000us             1  
         aten::lift_fresh         0.00%       3.000us         0.00%       3.000us       0.003us          1063  
               aten::ones         0.00%       3.000us         0.00%       4.000us       4.000us             1  
               aten::full         0.00%       3.000us         0.00%       3.000us       3.000us             1  
             aten::detach         0.00%       2.000us         0.00%       7.000us       7.000us             1  
              aten::zero_         0.00%       0.000us         0.00%       0.000us       0.000us            98  
            aten::resize_         0.00%       0.000us         0.00%       0.000us       0.000us             1  
        aten::resolve_neg         0.00%       0.000us         0.00%       0.000us       0.000us             1  
            ProfilerStep        -6.49%  -247365.000us       100.00%        3.813s        3.813s             1

Self CPU time total: 3.813s


Collecting environment information... PyTorch version: 2.1.0.dev20230518+cpu Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: CentOS Stream 8 (x86_64) GCC version: (GCC) 11.2.1 20210728 (Red Hat 11.2.1-1) Clang version: 14.0.0 (Red Hat 14.0.0-1.module_el8.7.0+1142+5343df54) CMake version: version 3.22.1 Libc version: glibc-2.28

Python version: 3.8.16 (default, Mar 2 2023, 03:21:46) [GCC 11.2.0] (64-bit runtime) Python platform: Linux-4.18.0-365.el8.x86_64-x86_64-with-glibc2.17 Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

Versions of relevant libraries: [pip3] numpy==1.24.3 [pip3] torch==2.1.0.dev20230518+cpu [conda] blas 1.0 mkl [conda] mkl 2023.1.0 h6d00ec8_46342 [conda] mkl-include 2023.1.0 pypi_0 pypi [conda] mkl-service 2.4.0 py38h5eee18b_1 [conda] mkl-static 2023.1.0 pypi_0 pypi [conda] mkl_fft 1.3.6 py38h417a72b_1 [conda] mkl_random 1.2.2 py38h417a72b_1 [conda] numpy 1.24.3 py38hf6e8229_1 [conda] numpy-base 1.24.3 py38h060ed82_1 [conda] torch 2.1.0.dev20230518+cpu pypi_0 pypi

cc @robieta @chaekit @aaronenyeshi @ngimel @nbcsm @guotuofeng @guyang3532 @gaoteng-git @tiffzhaofb @dzhulgakov @davidberard98

jingxu10 commented 1 year ago

I'll look into it.

gujinghui commented 3 months ago
                 Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
        ProfilerStep        -6.49%  -247365.000us       100.00%        3.813s        3.813s             1

This issue should happen on very performant CPU only, that means, it's hard to reproduce on ordinary CPU. But, due to CPU-related, it will happen on any accelerator + CPU systems, in PyTorch version earlier than or equal to 2.3.

The root-cause should be in profiler post-process in PyTorch, which truncates the timestamps during converting from ns to us, instead of rounding.

Let's think about this case, we have 3 operators, grandparent_op, parent_op, and child_op. Obviously, the timeline between these 3 ops should be, grandparent_op_start --> parent_op_start --> child_op_start --> child_op_end --> parent_op_end --> grandparent_op_end. And, grandparent_op_duration = grandparent_op_end - grandparent_op_start - parent_duration.

In some case, PyTorch post-process will disorder the relationship between these 3 ops. It will take the child_op as the child of grandparent_op, instead of parent_op, because the child_op_end is later than parent_op_end. Then, the timeline will be, grandparent_op_start --> parent_op_start --> child_op_start --> parent_op_end --> child_op_end --> grandparent_op_end. In result, grandparent_op_duration = grandparent_op_end - grandparent_op_start - parent_op_duration - child_op_duration, will become a negative value.

Why the parent_op_end is earler than child_op_end? Because the end timestamp of operator is the result of twice computations, instead of from raw timestamp, that is, op_end = op_start + op_duration, while op_duration = raw_op_end - raw_op_start with truncation.

For example, we have parent_start_ns = 315877, parent_end_ns = 486764, child_start_ns = 319059, and child_end_ns = 486499. Then, parent_duration_us = (486764 - 315877) / 1000 = 170, child_duration_us = (486499 - 319059) / 1000 = 167. And, parent_start_us = 315877 / 1000 = 315, parent_end_us = parent_start_us + parent_duration_us = 485. But, child_start_us = 319059 / 1000 = 319, child_end_us = child_start_us + child_duration_us = 486. Obviously, the child_end_us is later than parent_end_us.

Fortunately, in latest PyTorch master, the profiler post-process promoted the timestamp precision from us to ns, which bypass-ed this issue. But it's still a potential issue if improving the precision and adopting the truncation again. Therefore, we will submit a PR to use the raw_op_end to check the parent relationship, instead of computation result.

zejun-chen commented 3 months ago

Associated fix PR: