turboderp / exllamav2

A fast inference library for running LLMs locally on modern consumer-class GPUs
MIT License
3.19k stars 234 forks source link

optimization: put rfn_sum on cuda and do .item() call out of for loop #455

Closed kallewoof closed 1 month ago

kallewoof commented 1 month ago

Edit: looking closer at the profiler output, the .cuda() call is making a rather suspicious jump in time there. The total time is lower so I think this is still an improvement but will do some more checks and then un-draft.

The .item() call inside the for loops is quite expensive. It is better to place rfn_sum on cuda and do .item() once at the end outside the loop.

>>> p.sort_stats(SortKey.TIME).print_stats(10)

Before the patch (one single layer iteration):

(quantize)
Wed May 15 18:29:10 2024    profiler.out

         12347828 function calls (12238743 primitive calls) in 385.159 seconds

   Ordered by: internal time
   List reduced from 8601 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     3800  179.361    0.047  179.361    0.047 {method 'item' of 'torch._C.TensorBase' objects}
     7658  109.105    0.014  109.105    0.014 {method 'cuda' of 'torch._C.TensorBase' objects}
    50352   37.682    0.001   37.682    0.001 {method 'to' of 'torch._C.TensorBase' objects}
        2   12.386    6.193   12.386    6.193 {built-in method safetensors._safetensors_rust.serialize_file}
      201    6.898    0.034    6.898    0.034 {method 'tobytes' of 'numpy.ndarray' objects}
   103/83    6.272    0.061    7.027    0.085 {built-in method _imp.create_dynamic}
    40973    3.364    0.000    3.364    0.000 {built-in method torch.tensor}
     1926    2.967    0.002    2.967    0.002 {method 'read' of '_io.BufferedReader' objects}
    32853    2.239    0.000    2.239    0.000 {built-in method torch.cat}
      207    1.934    0.009    2.319    0.011 {method 'get_tensor' of 'builtins.safe_open' objects}

(calibrate)
      648    0.370    0.001    0.370    0.001 {method 'item' of 'torch._C.TensorBase' objects}

With this patch:

(quantize)
Wed May 15 18:54:54 2024    profiler.out

         12343629 function calls (12234544 primitive calls) in 369.516 seconds

   Ordered by: internal time
   List reduced from 9213 to 10 due to restriction <10>

   ncalls  tottime  percall  cumtime  percall filename:lineno(function)
     7677  287.720    0.037  287.720    0.037 {method 'cuda' of 'torch._C.TensorBase' objects}
    50352   38.081    0.001   38.081    0.001 {method 'to' of 'torch._C.TensorBase' objects}
        2   11.782    5.891   11.782    5.891 {built-in method safetensors._safetensors_rust.serialize_file}
      201    6.923    0.034    6.923    0.034 {method 'tobytes' of 'numpy.ndarray' objects}
    40992    3.642    0.000    3.642    0.000 {built-in method torch.tensor}
    32857    2.996    0.000    2.996    0.000 {method 'encode' of 'tokenizers.Tokenizer' objects}
    32853    2.956    0.000    2.956    0.000 {built-in method torch.cat}
        1    1.804    1.804    1.804    1.804 {built-in method torch.embedding}
        1    1.056    1.056  369.517  369.517 convert.py:1(<module>)
       19    0.899    0.047    0.899    0.047 {method 'item' of 'torch._C.TensorBase' objects}

(it is so low that it doesn't show up profiler)
turboderp commented 1 month ago

I very much appreciate this.

As I recall there's a whole bunch of synchronization points there, as the converter is swapping stuff to system memory like crazy. I think this change likely just shifts those points around a little bit, which would show up in a CPU profiler but may not affect wall time at the end of the day.

The change looks fine, though, so if it does help I'm happy to merge it.

kallewoof commented 1 month ago

As I recall there's a whole bunch of synchronization points there, as the converter is swapping stuff to system memory like crazy. I think this change likely just shifts those points around a little bit, which would show up in a CPU profiler but may not affect wall time at the end of the day.

Ultimately it should cause a speed increase as it simply does less stuff (calling .item() once instead of repeatedly).

The change looks fine, though, so if it does help I'm happy to merge it.

I am looking closer at the .cuda() bump and I think I have a few more optimizations for you.

~For example, the test_error function calls .cuda() twice for x and xref, but I think we can drop the xref call completely and instead move the results in xtest back to CPU. That way we circumvent moving the target_states to cuda completely. Testing that change now. Should have results in a bit.~ Edit: That was a downgrade. Misunderstood the sizes of these things. :)

kallewoof commented 1 month ago

Closing in favor of #456 but will reopen if that PR turns out to not work out.