logix-project / logix

AI Logging for Interpretability and Explainability🔬
Apache License 2.0
86 stars 6 forks source link

Analytically compute LoRA Hessian #53

Closed hwijeen closed 10 months ago

hwijeen commented 10 months ago

Previously, Analog consumed one epoch to collect 'LoRA covariances' (covariances with the compressed activations). This PR computes the same quantity analytically, using the projection matrix.

This includes hacky patches, we might need to reconsider how we do finalize, clear, ...

I confirmed that the analytical way produces the same covariances and the same influence score.

Loaded scores from files/analog/test/if_analog.pt
Loaded scores from files/analog/test_analytical/if_analog.pt
scores: torch.Size([2, 2])
scores2: torch.Size([2, 2])
Average correlation: 1.0
[ins] In [2]: before = torch.load("test/state/hessian_state.pt")
af
[ins] In [3]: after = torch.load("test_analytical/state/hessian_state.pt")

[ins] In [4]: before["model.bert.pooler.dense.analog_lora_B"]["forward"]
Out[4]:
tensor([[ 0.0570,  0.0020, -0.0055,  ..., -0.0119,  0.0037,  0.0246],
        [ 0.0020,  0.0077, -0.0225,  ..., -0.0108,  0.0262, -0.0005],
        [-0.0055, -0.0225,  0.0653,  ...,  0.0312, -0.0761,  0.0016],
        ...,
        [-0.0119, -0.0108,  0.0312,  ...,  0.0165, -0.0360, -0.0033],
        [ 0.0037,  0.0262, -0.0761,  ..., -0.0360,  0.0890, -0.0030],
        [ 0.0246, -0.0005,  0.0016,  ..., -0.0033, -0.0030,  0.0108]])

[ins] In [6]: after["model.bert.pooler.dense.analog_lora_B"]["forward"]
Out[6]:
tensor([[ 0.0570,  0.0020, -0.0055,  ..., -0.0119,  0.0037,  0.0246],
        [ 0.0020,  0.0077, -0.0225,  ..., -0.0108,  0.0262, -0.0005],
        [-0.0055, -0.0225,  0.0653,  ...,  0.0312, -0.0761,  0.0016],
        ...,
        [-0.0119, -0.0108,  0.0312,  ...,  0.0165, -0.0360, -0.0033],
        [ 0.0037,  0.0262, -0.0761,  ..., -0.0360,  0.0890, -0.0030],
        [ 0.0246, -0.0005,  0.0016,  ..., -0.0033, -0.0030,  0.0108]])
hage1005 commented 10 months ago

Just a random thought, in a similar way, could LoRA's ekfac lambda also be computed analytically if we have the original ekfac lambda?

seems not, it's not useful anyways

hwijeen commented 10 months ago

@sangkeun00 Iirc, you mentioned that this PR needs to be readressed once you finish refactoring / LoRAEmbedding. Can you confirm that it's still the case? If so, I will close this PR and make another PR about small things (the issue with inplace opeartion, ...)

hwijeen commented 10 months ago

closing as outdated