intel / intel-extension-for-pytorch

A Python package for extending the official PyTorch that can easily obtain performance on Intel platform
Apache License 2.0
1.58k stars 243 forks source link

aten::bmm op is much more slower in float16 for llm rest token generation #433

Open rnwang04 opened 1 year ago

rnwang04 commented 1 year ago

Describe the bug

Hi all, when I do some profile with open-llama-3b on Arc A770, I found in float16, aten::bmm becomes extramely slower compared to float32 (111.4ms vs 22.5ms). I wonder is this behavior normal or this is maybe an issue needed to be fixed?

reproduce code:

import torch
import os
import time
from transformers import LlamaTokenizer
import intel_extension_for_pytorch as ipex
from transformers import AutoModelForCausalLM
import time
import numpy as np

if __name__ == '__main__':
    model_path = "openlm-research/open_llama_3b"
    llama_model = AutoModelForCausalLM.from_pretrained(model_path)
    llama_model = llama_model.half().to('xpu')
    tokenizer = LlamaTokenizer.from_pretrained(model_path)

    input_str = "Once upon a time, there existed a little girl who liked to have adventures. She wanted to go to places and meet new people, and have fun"

    with torch.inference_mode():
        # warmup
        torch.xpu.synchronize()
        input_ids = tokenizer.encode(input_str, return_tensors="pt").to('xpu')
        print("input length is: ", len((input_ids[0])))
        output = llama_model.generate(input_ids, do_sample=False, max_new_tokens=32)
        output_str = tokenizer.decode(output[0], skip_special_tokens=True)
        torch.xpu.synchronize()
        e2e_time = []

        for i in range(5):
            st = time.time()
            torch.xpu.synchronize()
            input_ids = tokenizer.encode(input_str, return_tensors="pt").to('xpu')
            output = llama_model.generate(input_ids, do_sample=False, max_new_tokens=32)
            output_str = tokenizer.decode(output[0], skip_special_tokens=True)
            torch.xpu.synchronize()
            end = time.time()
            print(f"cost {end - st:.4f}s")
            e2e_time.append(end-st)
        print(output)
        print(output_str)
        print("mean e2e time is : ", np.mean(e2e_time))

        with torch.autograd.profiler_legacy.profile(enabled=True, use_xpu=True) as prof:
            output = llama_model.generate(input_ids[:,:1], do_sample=False, max_new_tokens=32)
        print(prof.key_averages().table(sort_by="self_xpu_time_total", row_limit=-1))
        with open("./llama3b_fp16.log", "w") as fw:
            fw.write(prof.key_averages(group_by_input_shape=True).table(sort_by="self_xpu_time_total"))

fp32 profile result:

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg      Self XPU    Self XPU %     XPU total  XPU time avg    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     aten::mm        11.18%     200.866ms        11.68%     209.803ms      35.827us        1.644s        91.90%        1.644s     280.685us          5856  
                    aten::cat         4.60%      82.583ms         5.33%      95.710ms      28.656us      26.582ms         1.49%      26.582ms       7.959us          3340  
                    aten::bmm         5.67%     101.857ms         5.83%     104.694ms      62.917us      21.452ms         1.20%      21.452ms      12.892us          1664  
                    aten::mul         5.00%      89.926ms         5.42%      97.367ms      12.731us      21.439ms         1.20%      21.439ms       2.803us          7648  
                   aten::mean         1.11%      19.950ms         1.20%      21.521ms      12.690us      17.145ms         0.96%      17.145ms      10.109us          1696  
                    aten::add         3.33%      59.863ms         3.61%      64.794ms      11.004us      16.959ms         0.95%      16.959ms       2.880us          5888  
                  aten::index         1.53%      27.451ms         1.88%      33.701ms      20.253us      13.388ms         0.75%      13.388ms       8.045us          1664  
               aten::_softmax         0.92%      16.558ms         1.82%      32.627ms      19.608us       5.401ms         0.30%      10.802ms       6.492us          1664  
                    aten::neg         0.92%      16.471ms         1.77%      31.734ms       9.536us       4.581ms         0.26%       9.162ms       2.753us          3328  
                    aten::pow         1.98%      35.544ms         3.23%      58.024ms      17.106us       4.451ms         0.25%       8.902ms       2.624us          3392  
                  aten::rsqrt         0.91%      16.343ms         1.73%      31.048ms       9.153us       4.095ms         0.23%       8.190ms       2.415us          3392  
                   aten::silu         0.60%      10.732ms         1.15%      20.579ms      12.367us       3.227ms         0.18%       6.453ms       3.878us          1664  
                    aten::div         0.71%      12.715ms         0.76%      13.683ms      16.446us       2.235ms         0.12%       2.235ms       2.687us           832  
                    aten::max         0.06%       1.075ms         0.07%       1.182ms      18.461us       1.127ms         0.06%       1.127ms      17.613us            64  
                  aten::copy_         0.10%       1.844ms         0.10%       1.844ms      11.384us     905.996us         0.05%     905.996us       5.593us           162  
           aten::index_select         0.03%     487.612us         0.03%     549.922us      17.185us     356.408us         0.02%     356.408us      11.138us            32  
                    aten::sub         0.21%       3.814ms         0.41%       7.409ms      38.586us     335.088us         0.02%     670.176us       3.491us           192  
                 aten::cumsum         0.05%     922.489us         0.10%       1.871ms      29.239us     315.952us         0.02%     631.904us       9.873us            64  
                   aten::prod         0.03%     523.535us         0.06%       1.038ms      32.438us     289.640us         0.02%     519.792us      16.244us            32  
           aten::masked_fill_         0.04%     735.955us         0.04%     735.955us      11.499us     220.584us         0.01%     220.584us       3.447us            64  
                     aten::eq         0.04%     763.114us         0.05%     863.105us      13.077us     177.528us         0.01%     177.528us       2.690us            66  
                     aten::ne         0.02%     369.455us         0.02%     414.197us      12.944us     126.880us         0.01%     126.880us       3.965us            32  
                  aten::fill_         0.03%     555.214us         0.03%     555.214us      16.330us      82.888us         0.00%      82.888us       2.438us            34  
    aten::_local_scalar_dense        53.24%     956.655ms        53.24%     956.655ms      28.137ms      46.280us         0.00%      46.280us       1.361us            34  
                    aten::any         0.00%      29.339us         0.00%      58.803us      29.401us      10.400us         0.00%      20.800us      10.400us             2  
                    aten::sum         0.00%      16.952us         0.00%      32.484us      32.484us       9.880us         0.00%      17.576us      17.576us             1  
                     aten::gt         0.00%      15.273us         0.00%      16.279us      16.279us       2.704us         0.00%       2.704us       2.704us             1  
                  aten::slice         0.49%       8.839ms         0.82%      14.744ms       1.239us       0.000us         0.00%       0.000us       0.000us         11904  
             aten::as_strided         1.02%      18.348ms         1.02%      18.348ms       0.566us       0.000us         0.00%       0.000us       0.000us         32417  
                  aten::empty         1.42%      25.593ms         1.42%      25.593ms       0.750us       0.000us         0.00%       0.000us       0.000us         34144  
                aten::resize_         0.71%      12.720ms         0.71%      12.720ms       1.151us       0.000us         0.00%       0.000us       0.000us         11056  
                   aten::item         0.00%      38.850us        53.24%     956.693ms      28.138ms       0.000us         0.00%      46.280us       1.361us            34  
                   aten::ones         0.00%       4.494us         0.01%     197.404us      98.702us       0.000us         0.00%       3.952us       1.976us             2  
                 aten::select         0.01%     117.995us         0.01%     149.055us       2.329us       0.000us         0.00%       0.000us       0.000us            64  
                     aten::to         0.12%       2.221ms         0.21%       3.761ms       0.363us       0.000us         0.00%     732.940us       0.071us         10371  
               aten::_to_copy         0.01%     153.829us         0.09%       1.539ms      15.708us       0.000us         0.00%     732.940us       7.479us            98  
          aten::empty_strided         0.26%       4.601ms         0.26%       4.601ms       1.343us       0.000us         0.00%       0.000us       0.000us          3426  
             aten::is_nonzero         0.00%      40.858us        53.24%     956.716ms      28.991ms       0.000us         0.00%      45.136us       1.368us            33  
             aten::lift_fresh         0.00%       0.345us         0.00%       0.345us       0.345us       0.000us         0.00%       0.000us       0.000us             1  
                aten::detach_         0.00%       0.847us         0.00%       0.847us       0.847us       0.000us         0.00%       0.000us       0.000us             1  
             aten::resize_as_         0.00%      27.048us         0.00%      27.048us       0.845us       0.000us         0.00%       0.000us       0.000us            32  
                   aten::view         0.26%       4.634ms         0.26%       4.634ms       0.551us       0.000us         0.00%       0.000us       0.000us          8416  
              aten::embedding         0.01%      91.433us         0.04%     740.564us      23.143us       0.000us         0.00%     356.408us      11.138us            32  
                aten::reshape         0.41%       7.409ms         0.57%      10.257ms       1.752us       0.000us         0.00%       0.000us       0.000us          5856  
         aten::_reshape_alias         0.16%       2.848ms         0.16%       2.848ms       0.486us       0.000us         0.00%       0.000us       0.000us          5856  
              aten::unsqueeze         0.06%       1.105ms         0.11%       1.969ms       1.080us       0.000us         0.00%       0.000us       0.000us          1823  
                 aten::expand         0.13%       2.281ms         0.22%       4.034ms       1.178us       0.000us         0.00%       0.000us       0.000us          3424  
                   aten::rsub         0.01%     132.348us         0.08%       1.446ms      22.587us       0.000us         0.00%     218.088us       3.408us            64  
            aten::masked_fill         0.00%      71.511us         0.04%     746.848us      23.339us       0.000us         0.00%     214.656us       6.708us            32  
                  aten::clone         0.00%      60.138us         0.02%     392.320us      12.260us       0.000us         0.00%      87.984us       2.749us            32  
             aten::empty_like         0.00%      32.722us         0.00%      60.759us       1.899us       0.000us         0.00%       0.000us       0.000us            32  
            aten::result_type         0.04%     761.450us         0.04%     761.450us       0.224us       0.000us         0.00%       0.000us       0.000us          3392  
               aten::can_cast         0.02%     349.212us         0.02%     349.212us       0.206us       0.000us         0.00%       0.000us       0.000us          1696  
                 aten::linear         0.33%       5.972ms        13.73%     246.654ms      42.120us       0.000us         0.00%        1.644s     280.685us          5856  
                      aten::t         0.30%       5.476ms         0.90%      16.178ms       2.763us       0.000us         0.00%       0.000us       0.000us          5856  
              aten::transpose         0.47%       8.439ms         0.87%      15.657ms       1.563us       0.000us         0.00%       0.000us       0.000us         10016  
                 aten::matmul         0.86%      15.501ms        19.28%     346.433ms      46.068us       0.000us         0.00%        1.665s     221.429us          7520  
           aten::_unsafe_view         0.19%       3.441ms         0.19%       3.441ms       0.458us       0.000us         0.00%       0.000us       0.000us          7520  
                aten::squeeze         0.16%       2.858ms         0.23%       4.198ms       1.261us       0.000us         0.00%       0.000us       0.000us          3328  
                 aten::narrow         0.13%       2.312ms         0.33%       5.858ms       1.760us       0.000us         0.00%       0.000us       0.000us          3328  
                aten::softmax         0.07%       1.256ms         1.05%      18.833ms      22.636us       0.000us         0.00%       5.401ms       6.492us           832  
                 aten::argmax         0.01%     155.471us         0.10%       1.803ms      28.166us       0.000us         0.00%       1.564ms      24.443us            64  
               aten::new_ones         0.00%      68.748us         0.03%     486.377us      15.199us       0.000us         0.00%      78.936us       2.467us            32  
              aten::new_empty         0.00%      18.553us         0.00%      53.763us       1.680us       0.000us         0.00%       0.000us       0.000us            32  
                   aten::tile         0.00%      34.271us         0.04%     756.959us      23.655us       0.000us         0.00%      85.072us       2.658us            32  
                 aten::repeat         0.01%     146.751us         0.04%     722.688us      22.584us       0.000us         0.00%      85.072us       2.658us            32  
                  aten::alias         0.00%      21.186us         0.00%      21.186us       0.662us       0.000us         0.00%       0.000us       0.000us            32  
                 aten::unfold         0.00%      75.919us         0.01%     105.108us       1.642us       0.000us         0.00%       0.000us       0.000us            64  
              aten::expand_as         0.00%      19.713us         0.00%      49.219us       1.538us       0.000us         0.00%       0.000us       0.000us            32  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.797s
Self XPU time total: 1.789s

fp16 profile result:

-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                         Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg      Self XPU    Self XPU %     XPU total  XPU time avg    # of Calls  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                     aten::mm        16.02%     201.651ms        16.66%     209.739ms      35.816us     837.075ms        69.84%     837.075ms     142.943us          5856  
                    aten::bmm         7.16%      90.154ms         7.35%      92.512ms      55.596us     111.417ms         9.30%     111.417ms      66.957us          1664  
                  aten::copy_         3.45%      43.406ms         3.45%      43.406ms       8.268us      92.354ms         7.71%      92.354ms      17.591us          5250  
                    aten::mul         5.07%      63.890ms         5.62%      70.708ms       9.245us      31.044ms         2.59%      31.044ms       4.059us          7648  
                    aten::cat         5.98%      75.285ms         6.89%      86.695ms      25.957us      30.640ms         2.56%      30.640ms       9.174us          3340  
                    aten::add         4.62%      58.172ms         4.97%      62.604ms      10.632us      22.312ms         1.86%      22.312ms       3.789us          5888  
                   aten::mean         1.42%      17.847ms         1.53%      19.253ms      11.352us      21.141ms         1.76%      21.141ms      12.465us          1696  
                  aten::index         1.98%      24.887ms         2.42%      30.472ms      18.312us      16.788ms         1.40%      16.788ms      10.089us          1664  
               aten::_softmax         0.80%      10.068ms         1.57%      19.738ms      11.862us       7.391ms         0.62%      14.783ms       8.884us          1664  
                    aten::neg         1.26%      15.882ms         2.44%      30.701ms       9.225us       5.797ms         0.48%      11.594ms       3.484us          3328  
                   aten::silu         0.85%      10.669ms         1.64%      20.639ms      12.403us       5.553ms         0.46%      11.106ms       6.674us          1664  
                    aten::pow         1.64%      20.588ms         3.23%      40.691ms      11.996us       4.970ms         0.41%       9.939ms       2.930us          3392  
                  aten::rsqrt         1.21%      15.203ms         2.30%      28.971ms       8.541us       4.722ms         0.39%       9.443ms       2.784us          3392  
                    aten::div         0.91%      11.448ms         0.98%      12.367ms      14.864us       3.866ms         0.32%       3.866ms       4.647us           832  
                    aten::max         0.08%     951.099us         0.08%       1.056ms      16.501us       1.289ms         0.11%       1.289ms      20.145us            64  
                    aten::sub         0.20%       2.544ms         0.39%       4.933ms      25.692us     427.024us         0.04%     854.048us       4.448us           192  
           aten::index_select         0.03%     420.413us         0.04%     466.950us      14.592us     383.344us         0.03%     383.344us      11.980us            32  
                 aten::cumsum         0.07%     873.633us         0.14%       1.802ms      28.155us     315.640us         0.03%     631.280us       9.864us            64  
                   aten::prod         0.03%     399.204us         0.06%     788.925us      24.654us     311.480us         0.03%     932.152us      29.130us            32  
           aten::masked_fill_         0.05%     572.749us         0.05%     572.749us       8.949us     222.664us         0.02%     222.664us       3.479us            64  
                     aten::eq         0.06%     713.505us         0.06%     803.088us      12.168us     180.024us         0.02%     180.024us       2.728us            66  
                     aten::ne         0.03%     316.734us         0.03%     357.768us      11.180us     156.624us         0.01%     156.624us       4.894us            32  
                  aten::fill_         0.03%     321.506us         0.03%     321.506us       9.456us      89.856us         0.01%      89.856us       2.643us            34  
    aten::_local_scalar_dense        34.71%     437.046ms        34.71%     437.046ms      12.854ms      47.112us         0.00%      47.112us       1.386us            34  
                    aten::any         0.00%      21.589us         0.00%      43.819us      21.910us      10.608us         0.00%      21.216us      10.608us             2  
                    aten::sum         0.00%      13.213us         0.00%      26.270us      26.270us       9.360us         0.00%      16.536us      16.536us             1  
                     aten::gt         0.00%      11.710us         0.00%      12.759us      12.759us       2.704us         0.00%       2.704us       2.704us             1  
                  aten::slice         0.61%       7.722ms         1.04%      13.144ms       1.104us       0.000us         0.00%       0.000us       0.000us         11904  
             aten::as_strided         1.31%      16.535ms         1.31%      16.535ms       0.510us       0.000us         0.00%       0.000us       0.000us         32417  
                  aten::empty         1.87%      23.520ms         1.87%      23.520ms       0.689us       0.000us         0.00%       0.000us       0.000us         34144  
                aten::resize_         0.88%      11.139ms         0.88%      11.139ms       1.008us       0.000us         0.00%       0.000us       0.000us         11056  
                   aten::item         0.00%      43.065us        34.71%     437.089ms      12.856ms       0.000us         0.00%      47.112us       1.386us            34  
                   aten::ones         0.00%       3.937us         0.00%      25.745us      12.873us       0.000us         0.00%       3.640us       1.820us             2  
                 aten::select         0.01%      97.937us         0.01%     123.410us       1.928us       0.000us         0.00%       0.000us       0.000us            64  
                     aten::to         0.43%       5.356ms         4.77%      60.000ms       5.785us       0.000us         0.00%      92.142ms       8.885us         10371  
               aten::_to_copy         0.45%       5.637ms         4.34%      54.644ms      10.537us       0.000us         0.00%      92.142ms      17.768us          5186  
          aten::empty_strided         0.79%       9.999ms         0.79%       9.999ms       1.174us       0.000us         0.00%       0.000us       0.000us          8514  
             aten::is_nonzero         0.00%      31.431us        34.71%     437.097ms      13.245ms       0.000us         0.00%      46.020us       1.395us            33  
             aten::lift_fresh         0.00%       0.238us         0.00%       0.238us       0.238us       0.000us         0.00%       0.000us       0.000us             1  
                aten::detach_         0.00%       0.265us         0.00%       0.265us       0.265us       0.000us         0.00%       0.000us       0.000us             1  
             aten::resize_as_         0.00%      33.467us         0.00%      33.467us       1.046us       0.000us         0.00%       0.000us       0.000us            32  
                   aten::view         0.35%       4.344ms         0.35%       4.344ms       0.516us       0.000us         0.00%       0.000us       0.000us          8416  
              aten::embedding         0.01%      90.001us         0.05%     651.865us      20.371us       0.000us         0.00%     383.344us      11.980us            32  
                aten::reshape         0.48%       6.056ms         0.69%       8.739ms       1.492us       0.000us         0.00%       0.000us       0.000us          5856  
         aten::_reshape_alias         0.21%       2.683ms         0.21%       2.683ms       0.458us       0.000us         0.00%       0.000us       0.000us          5856  
              aten::unsqueeze         0.08%     982.663us         0.14%       1.758ms       0.964us       0.000us         0.00%       0.000us       0.000us          1823  
                 aten::expand         0.16%       1.959ms         0.28%       3.527ms       1.030us       0.000us         0.00%       0.000us       0.000us          3424  
                   aten::rsub         0.01%     124.916us         0.08%       1.052ms      16.443us       0.000us         0.00%     312.832us       4.888us            64  
            aten::masked_fill         0.00%      59.325us         0.06%     694.597us      21.706us       0.000us         0.00%     229.736us       7.179us            32  
                  aten::clone         0.00%      52.104us         0.03%     360.548us      11.267us       0.000us         0.00%     103.272us       3.227us            32  
             aten::empty_like         0.00%      25.949us         0.00%      50.608us       1.582us       0.000us         0.00%       0.000us       0.000us            32  
            aten::result_type         0.05%     641.702us         0.05%     641.702us       0.189us       0.000us         0.00%       0.000us       0.000us          3392  
               aten::can_cast         0.02%     252.416us         0.02%     252.416us       0.149us       0.000us         0.00%       0.000us       0.000us          1696  
                 aten::linear         0.42%       5.262ms        20.60%     259.368ms      44.291us       0.000us         0.00%     837.075ms     142.943us          5856  
                      aten::t         0.39%       4.951ms         2.44%      30.702ms       5.243us       0.000us         0.00%       0.000us       0.000us          5856  
              aten::transpose         1.89%      23.754ms         2.39%      30.133ms       3.009us       0.000us         0.00%       0.000us       0.000us         10016  
                 aten::matmul         1.17%      14.720ms        26.29%     331.038ms      44.021us       0.000us         0.00%     948.491ms     126.129us          7520  
           aten::_unsafe_view         0.24%       2.959ms         0.24%       2.959ms       0.393us       0.000us         0.00%       0.000us       0.000us          7520  
                aten::squeeze         0.20%       2.533ms         0.30%       3.757ms       1.129us       0.000us         0.00%       0.000us       0.000us          3328  
                 aten::narrow         0.17%       2.109ms         0.41%       5.203ms       1.563us       0.000us         0.00%       0.000us       0.000us          3328  
                aten::softmax         0.13%       1.649ms         1.83%      23.029ms      27.679us       0.000us         0.00%      19.332ms      23.236us           832  
                 aten::argmax         0.01%     138.929us         0.13%       1.622ms      25.338us       0.000us         0.00%       1.737ms      27.141us            64  
               aten::new_ones         0.00%      53.200us         0.03%     402.552us      12.580us       0.000us         0.00%      86.216us       2.694us            32  
              aten::new_empty         0.00%      17.398us         0.00%      48.004us       1.500us       0.000us         0.00%       0.000us       0.000us            32  
                   aten::tile         0.00%      31.082us         0.05%     655.268us      20.477us       0.000us         0.00%     108.472us       3.390us            32  
                 aten::repeat         0.01%     117.771us         0.05%     624.186us      19.506us       0.000us         0.00%     108.472us       3.390us            32  
                  aten::alias         0.00%      12.692us         0.00%      12.692us       0.397us       0.000us         0.00%       0.000us       0.000us            32  
                 aten::unfold         0.01%      64.772us         0.01%      94.910us       1.483us       0.000us         0.00%       0.000us       0.000us            64  
              aten::expand_as         0.00%      16.572us         0.00%      43.442us       1.358us       0.000us         0.00%       0.000us       0.000us            32  
-----------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
Self CPU time total: 1.259s
Self XPU time total: 1.199s

Other op (mul / add / ..) becomes slower in fp16 too, but aten::bmm is the most obvious.

Versions

Collecting environment information...
PyTorch version: 2.0.1a0+cxx11.abi
PyTorch CXX11 ABI: Yes
IPEX version: 2.0.110+xpu
IPEX commit: ba7f6c127
Build type: Release

OS: Ubuntu 22.04.3 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: N/A
IGC version: 2023.2.0 (2023.2.0.20230721)
CMake version: version 3.22.1
Libc version: glibc-2.35

Python version: 3.9.17 (main, Jul  5 2023, 20:41:20)  [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.19.0-41-generic-x86_64-with-glibc2.35
Is XPU available: True
DPCPP runtime version: 2023.2.1
MKL version: 2023.2.0
GPU models and configuration: 
[0] _DeviceProperties(name='Intel(R) Arc(TM) A770 Graphics', platform_name='Intel(R) Level-Zero', dev_type='gpu, support_fp64=0, total_memory=15473MB, max_compute_units=512, gpu_eu_count=512)
[1] _DeviceProperties(name='Intel(R) UHD Graphics 770', platform_name='Intel(R) Level-Zero', dev_type='gpu, support_fp64=0, total_memory=51240MB, max_compute_units=32, gpu_eu_count=32)
Intel OpenCL ICD version: 23.17.26241.33-647~22.04
Level Zero version: 1.3.26241.33-647~22.04

CPU:
Architecture:                    x86_64
CPU op-mode(s):                  32-bit, 64-bit
Address sizes:                   46 bits physical, 48 bits virtual
Byte Order:                      Little Endian
CPU(s):                          24
On-line CPU(s) list:             0-23
Vendor ID:                       GenuineIntel
Model name:                      12th Gen Intel(R) Core(TM) i9-12900K
CPU family:                      6
Model:                           151
Thread(s) per core:              2
Core(s) per socket:              16
Socket(s):                       1
Stepping:                        2
CPU max MHz:                     5200.0000
CPU min MHz:                     800.0000
BogoMIPS:                        6374.40
Flags:                           fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 monitor ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb invpcid_single ssbd ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid rdseed adx smap clflushopt clwb intel_pt sha_ni xsaveopt xsavec xgetbv1 xsaves split_lock_detect avx_vnni dtherm ida arat pln pts hwp hwp_notify hwp_act_window hwp_epp hwp_pkg_req hfi umip pku ospke waitpkg gfni vaes vpclmulqdq tme rdpid movdiri movdir64b fsrm md_clear serialize pconfig arch_lbr ibt flush_l1d arch_capabilities
Virtualization:                  VT-x
L1d cache:                       640 KiB (16 instances)
L1i cache:                       768 KiB (16 instances)
L2 cache:                        14 MiB (10 instances)
L3 cache:                        30 MiB (1 instance)
NUMA node(s):                    1
NUMA node0 CPU(s):               0-23
Vulnerability Itlb multihit:     Not affected
Vulnerability L1tf:              Not affected
Vulnerability Mds:               Not affected
Vulnerability Meltdown:          Not affected
Vulnerability Mmio stale data:   Not affected
Vulnerability Retbleed:          Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl
Vulnerability Spectre v1:        Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2:        Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds:             Not affected
Vulnerability Tsx async abort:   Not affected

Versions of relevant libraries:
[pip3] intel-extension-for-pytorch==2.0.110+xpu
[pip3] numpy==1.24.4
[pip3] torch==2.0.1a0+cxx11.abi
[pip3] torchvision==0.15.2a0+cxx11.abi
[conda] intel-extension-for-pytorch 2.0.110+xpu              pypi_0    pypi
[conda] numpy                     1.24.4                   pypi_0    pypi
[conda] torch                     2.0.1a0+cxx11.abi          pypi_0    pypi
[conda] torchvision               0.15.2a0+cxx11.abi          pypi_0    pypi
jgong5 commented 1 year ago

are you able to get the similar result with a standalone bmm op?

rnwang04 commented 1 year ago

Yes, I can get the similar result with a standalone bmm op. I found that if just loop bmm with same input, then fp16 is much faster than fp32. However, if in the loop, each input has different shape, then fp16 becomes much slower.

import torch
import os
import time
import intel_extension_for_pytorch as ipex

st1 = time.time()
for i in range(32):
    a1 = torch.rand(32, 1, i).to('xpu')
    b1 = torch.rand(32, i, 100).to('xpu')
    o = torch.bmm(a1, b1)
    a2 = torch.rand(32, 1, 100).to('xpu')
    b2 = torch.rand(32, 100, i).to('xpu')
    o = torch.bmm(a1, b1)
st2 = time.time()
print(f"bmm fp32 cost : {(st2-st1) * 1000}ms")

st1 = time.time()
for i in range(1000):
    a1 = torch.rand(32, 1, i).half().to('xpu')
    b1 = torch.rand(32, i, 100).half().to('xpu')
    o = torch.bmm(a1, b1)
    a2 = torch.rand(32, 1, 100).half().to('xpu')
    b2 = torch.rand(32, 100, i).half().to('xpu')
    o = torch.bmm(a1, b1)
st2 = time.time()
print(f"bmm fp16 cost : {(st2-st1) * 1000}ms")

a132_list = [torch.rand(32, 1, i).to('xpu') for i in range(32)]
a232_list = [torch.rand(32, 1, 100).to('xpu') for i in range(32)]
b132_list = [torch.rand(32, i, 100).to('xpu') for i in range(32)]
b232_list = [torch.rand(32, 100, i).to('xpu') for i in range(32)]

st1 = time.time()
for i in range(32):
    o = torch.bmm(a132_list[i], b132_list[i])
    o = torch.bmm(a232_list[i], b232_list[i])
st2 = time.time()
print(f"bmm fp32 cost : {(st2-st1) * 1000}ms")

a116_list = [torch.rand(32, 1, i).half().to('xpu') for i in range(32)]
a216_list = [torch.rand(32, 1, 100).half().to('xpu') for _ in range(32)]
b116_list = [torch.rand(32, i, 100).half().to('xpu') for _ in range(32)]
b216_list = [torch.rand(32, 100, i).half().to('xpu') for i in range(32)]

st1 = time.time()
for i in range(32):
    o = torch.bmm(a116_list[i], b116_list[i])
    o = torch.bmm(a216_list[i], b216_list[i])
st2 = time.time()
print(f"bmm fp16 cost : {(st2-st1) * 1000}ms")

output is:

bmm fp32 cost : 146.5318202972412ms
bmm fp16 cost : 12086.34614944458ms
bmm fp32 cost : 15.504598617553711ms
bmm fp16 cost : 88.47284317016602ms
jgong5 commented 1 year ago

@rnwang04 Thanks for the benchmarking. I have two comments:

  1. It seems you are counting in not only the bmm time but also the torch.rand time. Can you only count in the bmm time in your benchmark?
  2. Can you do a warm up with the same loop/input before the counting the time? Just wanted to make sure there is no additional overhead like JIT codegen.
rnwang04 commented 1 year ago

@jgong5 Thanks for the reply! I have updated test script based on your comment, now I remove the rand time and added warmup. Now the time of aten::bmm seems almost the same for fp16 & fp32. This is my new script:

import torch
import os
import time
import intel_extension_for_pytorch as ipex

a132_list = [torch.rand(32, 1, i).to('xpu') for i in range(32)]
a232_list = [torch.rand(32, 1, 100).to('xpu') for _ in range(32)]
b132_list = [torch.rand(32, i, 100).to('xpu') for i in range(32)]
b232_list = [torch.rand(32, 100, i).to('xpu') for i in range(32)]

#warmup for fp32
for i in range(32):
    o = torch.bmm(a132_list[i], b132_list[i])
    o = torch.bmm(a232_list[i], b232_list[i])
torch.xpu.synchronize()

st1 = time.time()
for i in range(32):
    o = torch.bmm(a132_list[i], b132_list[i])
    o = torch.bmm(a232_list[i], b232_list[i])
torch.xpu.synchronize()
st2 = time.time()
print(f"bmm fp32 cost : {(st2-st1) * 1000}ms")

a116_list = [torch.rand(32, 1, i).half().to('xpu') for i in range(32)]
a216_list = [torch.rand(32, 1, 100).half().to('xpu') for _ in range(32)]
b116_list = [torch.rand(32, i, 100).half().to('xpu') for i in range(32)]
b216_list = [torch.rand(32, 100, i).half().to('xpu') for i in range(32)]

# warmup for fp16
for i in range(32):
    o = torch.bmm(a116_list[i], b116_list[i])
    o = torch.bmm(a216_list[i], b216_list[i])
torch.xpu.synchronize()

st1 = time.time()
for i in range(32):
    o = torch.bmm(a116_list[i], b116_list[i])
    o = torch.bmm(a216_list[i], b216_list[i])
torch.xpu.synchronize()
st2 = time.time()
print(f"bmm fp16 cost : {(st2-st1) * 1000}ms")

and this is my output:

bmm fp32 cost : 1.3713836669921875ms
bmm fp16 cost : 1.3885498046875ms

Now my question is why this op has different behaviour in unittest and in llm? I think in above llama-3b profile, I have alreadly done the warmup, but as you can see, aten::bmm is much slower in fp16.

jgong5 commented 1 year ago

Now my question is why this op has different behaviour in unittest and in llm? I think in above llama-3b profile, I have alreadly done the warmup, but as you can see, aten::bmm is much slower in fp16.

I guess that is due to the dynamic shape? A quick test could be to feed the model with all the sequence length during warmup. I don't mean it is the best practice but just wanted to confirm if that is the problem.