szxiangjn / world-model-for-language-model

102 stars 4 forks source link

Information regarding fisher matrix #4

Closed thakral-kartik closed 8 months ago

thakral-kartik commented 8 months ago

Hi, I really liked your work!!! Could you please confirm for fisher matrix calculation, you trained LORA adapters for pretrain data and then used the weights for these adaptors for final objective(basically I have a confusion whether fisher matrix was calculated for lora adaptors or original model) ? Thanks.

szxiangjn commented 8 months ago

Hi,

Thanks for your interest in our work! In this work, the fisher matrix is calculated for the original model. We freeze the model and calculate its gradients to get the fisher matrix.

thakral-kartik commented 8 months ago

Hi, Thanks for your reply.

I am still unable to catch up with the ewc reg. update rule, could you please clarify the below if possible ? If fisher matrix is calculated for original model, then while finetuning how A and B LORA matrices could be updated using fisher matrix, dimesnsion would be different for LORA and fisher matrix value for a particular layer (https://github.com/szxiangjn/world-model-for-language-model/blob/main/run.py#L71)

Thanks in advance for your support!

szxiangjn commented 8 months ago

You can see at (https://github.com/szxiangjn/world-model-for-language-model/blob/main/run.py#L66), we multiply A by B so that we get a matrix that has the same dimension with the model weights.

thakral-kartik commented 8 months ago

Got it, thanks!!!