bmaltais / dehydrate

52 stars 12 forks source link

Compression using sparse tensors and thresholding #1

Open AmericanPresidentJimmyCarter opened 1 year ago

AmericanPresidentJimmyCarter commented 1 year ago

Not sure what the original script did since it had a default alpha value of 0.

parser.add_argument("--str", type=float, help="Strength of the rehydration (-0.05..0.05)", default=0, required=False)`
alpha = args.str
...
for key in tqdm(theta_0.keys(), desc="Stage 1/2?: merge common keys"):
    if "model" in key and key in theta_1:
        theta_0[key] = theta_0[key] * (1 + alpha) + theta_1[key] * (1 - alpha)

Which appears just replaced the final weights with weights that were the same, since theta_0[key] = theta_0[key] * 1 + theta_1[key] * 0 = theta_0[key] = theta_0[key].

Constructing a patch with M1 - M2 = patch and then M1 + patch = M2 does not work due to lossy FP and seems to give me corrupted models.

What I did instead was use thresholding based on the min/max delta of the individual tensors and then apply the most changed values as a patch.

Here's my code: https://gist.github.com/AmericanPresidentJimmyCarter/1947162f371e601ce183070443f41dc2

This results in an about 60 MB patch, but the quality of transfer is not great. You can increase the THRESHOLD_STRENGTH variable but the file sizes increase dramatically while the quality does not seem to. Here are results at THRESHOLD_STRENGTH == 2..

elon_disney_patch

Test code is:

_, extra_data = engine.sample(
    'elon musk in modern disney style',
    4,
    'heun',
    679566949,
    25,
    scale=7.,
)

Using dreambooth'd weights here: https://huggingface.co/nitrosocke/mo-di-diffusion

AmericanPresidentJimmyCarter commented 1 year ago

For completion sake I tried using SD 1.5 as the source weights, resulting in about a 140 MB patch at THRESHOLD_STRENGTH == 2. The images came closer to the DB weights but still not very good.

elon_compressed

bmaltais commented 1 year ago

Not sure what the original script did since it had a default alpha value of 0.

Essentially it was doing nothing unless you passed a value... but yes, that feature was highly "experimental". I am happy to see you have improved on it. I will lpay around with the new code and see how it is. I might merge it in.

bmaltais commented 1 year ago

I tried the script and I am not getting the results expected. Expectations:

  1. Running compress create a patch representing the dreamboot model that was applied on top of the base model
  2. Applying the patch back on the base model should give back the dreambooth model

What I have tried:

Create Patch: python .\compress.py -m D:\models\v1-5-pruned.ckpt -m2 D:\models\sks_man-1e-6-3000-sd15.ckpt -o patch. File compress

Restore dreambooth from patch: python .\compress.py -m D:\models\v1-5-pruned.ckpt -p patch.file inflate -o patch.ckpt

Result:

Same image as if patch was not applied on the base model.

SD15: 04508-0-sks man

patch.ckpt: 04509-0-sks man

Am I missing something?

bmaltais commented 1 year ago

OK... I found the hardcoded value. Your script appears to do the same as mine when applying a --loss such that you have only part of the model extracted... We just used different method to get there. I used the torch. Where() function.

AmericanPresidentJimmyCarter commented 1 year ago

You will be unable to get a 1:1 restoration of the original finetuned weights with a patch of just 62 MB -- the purpose of thresholding and making sparse is to achieve very high compressed patches to apply that approximate the end result.

AmericanPresidentJimmyCarter commented 1 year ago

It also looks like your loss function does not take into account the min-max deltas of each tensor too, which helps scale thresholding on a per tensor basis.

AmericanPresidentJimmyCarter commented 1 year ago

Thinking about it, min/2 and max/2 is probably not the way to go either -- it might be better implemented as something like median, etc. It's hard to know what cutoff is able to approximate the original model, or if another method like https://github.com/samuela/git-re-basin should instead be used.

The point was more or less -- try to figure out which portions of the finetuning contribute most to the weights and then create a sparse tensor patch based on that, which is much smaller than a normal tensor.

bmaltais commented 1 year ago

The point was more or less -- try to figure out which portions of the finetuning contribute most to the weights and then create a sparse tensor patch based on that, which is much smaller than a normal tensor.

I see. This is interesting... Let me know if you find a way to make it work reliably. For some reason I think each dreambooth will require a different threshold to make it create a proper patch. During my testing even a threshold of 2 or 3 it never created a patch that represented the model in a meaningful way. Going about 3 started to produce results but with much bigger patch. At the the patch was 10x bigger than the ckot itself.

bmaltais commented 1 year ago

Is it possible that sparse tensors are less storage efficient than regular tensors? Saving patches can result in larger files than the original ckpt... This should not be possible... I am questioning the use of sparse tensors as possibly less dense than the original tensors and result in bigger than needed files?

AmericanPresidentJimmyCarter commented 1 year ago

Is it possible that sparse tensors are less storage efficient than regular tensors?

Sparse tensors will be less efficient if you don't have a large number of zeroed values in the patch. Sparse tensors are compressed tensors that just store the coordinates of non-zero values.

AmericanPresidentJimmyCarter commented 1 year ago

Birch-san says that appropriate compression probably depends on the activation function for each model.

with regards to "what change in the weights is significant enough that it merits keeping": doesn't it depend on the activation function? Maybe you'd want to compute a mapping function that maps the old activation curve to the new one, and threshold based on whether the mapping changes it significantly