pytorch / TensorRT

PyTorch/TorchScript/FX compiler for NVIDIA GPUs using TensorRT
https://pytorch.org/TensorRT
BSD 3-Clause "New" or "Revised" License
2.57k stars 350 forks source link

🐛 [Bug] `FxGraphCachePickler.get_hash(new_gm)` takes up a large portion of the total compile time, which makes reusing cached engine slow #3249

Open zewenli98 opened 1 week ago

zewenli98 commented 1 week ago

Bug Description

When using engine cache feature on Llama2-7b, I found that reusing cached engine is pretty slow, even slower than training a non-refittable engine from scratch. I figured out the reason is that, in engine caching, we use FxGraphCachePickler.get_hash(new_gm) to calculate hash value which takes up a large portion of the total compile time. For reference, I list the running time for some parts:

non-refittable: [353483.9375, 267709.59375, 271658.46875] ms
REFIT w/o engine caching: [393867.96875, 309644.25, 317363.09375] ms
REFIT_IDENTICAL w/o engine caching: [1108521.0, 1243913.75, 1105686.25] ms
REFIT w/ engine caching: [727660.4375, 410300.1875, 413173.3125] ms
REFIT_IDENTICAL w/ engine caching: [1545847.625, 416759.03125, 412087.03125] ms

There are 5 configs and I ran each three times. Before running each config, timing cache was removed.

For REFIT w/ engine caching and REFIT_IDENTICAL w/ engine caching, their first run is slower than the rest two runs because it spent time in building engine from scratch and save engine to cache. The rest two runs just directly reuse the cached engine and refit with new weights.

At first, we assumed that the rest two runs would be very fast. However, we can see, take the second run in REFIT w/ engine caching for example, it took 410300.1875ms, while recompiling a non-refittable engine just took 267709.59375ms. To break down the time cost in each part, let's take REFIT w/ engine caching as an example,

1st run:
get_hash(): 323009 ms
  - FxGraphCachePickler.get_hash(new_gm): 321425 ms
build engine + save in cache: 359696 ms
---------- total compile time: 727660 ms
2nd run:
get_hash(): 323275.46875 ms
  - FxGraphCachePickler.get_hash(new_gm): 321687 ms
reuse cached engine + refit: 41489 ms
---------- total compile time: 410300 ms
3rd run:
get_hash(): 322430 ms
  - FxGraphCachePickler.get_hash(new_gm): 320844 ms
reuse cached engine + refit: 44538 ms
---------- total compile time: 413173 ms

The function FxGraphCachePickler.get_hash(new_gm) in get_hash() took up a very large portion.

narendasan commented 1 week ago

@zewenli98 can you try profiling the hash function? Would be interesting to know which parts are taking the longest

zewenli98 commented 1 week ago

@narendasan FxGraphCachePickler.get_hash(new_gm) has only two lines:

serialized_data = cls.dumps(obj)
return sha256_hash(serialized_data)

serialized_data = cls.dumps(obj) takes the most time.

zewenli98 commented 5 days ago

pytorch team is working on the issue