Open zewenli98 opened 1 week ago
@zewenli98 can you try profiling the hash function? Would be interesting to know which parts are taking the longest
@narendasan FxGraphCachePickler.get_hash(new_gm)
has only two lines:
serialized_data = cls.dumps(obj)
return sha256_hash(serialized_data)
serialized_data = cls.dumps(obj)
takes the most time.
pytorch team is working on the issue
Bug Description
When using engine cache feature on Llama2-7b, I found that reusing cached engine is pretty slow, even slower than training a non-refittable engine from scratch. I figured out the reason is that, in engine caching, we use
FxGraphCachePickler.get_hash(new_gm)
to calculate hash value which takes up a large portion of the total compile time. For reference, I list the running time for some parts:There are 5 configs and I ran each three times. Before running each config, timing cache was removed.
For
REFIT w/ engine caching
andREFIT_IDENTICAL w/ engine caching
, their first run is slower than the rest two runs because it spent time in building engine from scratch and save engine to cache. The rest two runs just directly reuse the cached engine and refit with new weights.At first, we assumed that the rest two runs would be very fast. However, we can see, take the second run in
REFIT w/ engine caching
for example, it took410300.1875ms
, while recompiling a non-refittable engine just took267709.59375ms
. To break down the time cost in each part, let's takeREFIT w/ engine caching
as an example,The function
FxGraphCachePickler.get_hash(new_gm)
inget_hash()
took up a very large portion.