Open gau-nernst opened 1 year ago
Good catch! I had Conv2d in mind when I first wrote it.
Looks like we just need to instantiate lora_A and lora_B differently depending on the kind of convolution.
Happy to review and merge it if someone wants to implement and test it. Otherwise, I'll do it in the near future.
I have this idea but it will change the way current Conv2d LoRA works. We can treat convolution as a matmul with the input as a flattened "window". For example, for Conv2d, the input is a window with (kernel_size, kernel_size)
spatial dimensions, and the flattened input dim is in_channels * kernel_size * kernel_size
. This will naturally extend to Conv1d and Conv3d
There are two benefits with the above implementation. (1) kernel size doesn't need to be the same in all spatial dimensions, and (2) we can use convolution in the LoRA branch in the forward pass instead of merging weights, similar to Linear implementation (relevant issue - #54). The first convolution (with lora_A) is normal convolution, with the same kernel size, but the second convolution (with lora_B) will be point-wise (aka 1x1) convolution. I haven't tested it but from what I understand, it should work.
The situation becomes slightly complicated when grouped convolution is involved (groups > 1
). I'm thinking of accounting for groups in the input channels of lora_A (so lora_A becomes (rank, in_channels / groups kernel_size kernel_size)). We can still implement forward pass of LoRA branch as two convolutions with lora_A and lora_B, where we will use grouped convolution for lora_A, similar to the original convolution branch. A problem might arise when we try to merge weights though. Due to how groped convolution works, I think the merged weights might not be lora_B @ lora_A (I will need to test this). If that's the case, we need to use a different calculation to merge weights.
Another way of using groups > 1
is to follow your current implementation, which puts groups in the output of lora_B (out_channels / groups, rank). However, this would sacrifice the ability to use convolution for forward pass in LoRA branch, but maintains the ability to merge weights with simple lora_B @ lora_A.
Let me know what you think @edwardjhu. Thank you!
Hello, I also encountered the same problem. Is this improvement feasible in your subsequent [experiments? @gau-nernst @edwardjhu
I have changed the initialization of lora_B parameters so that the new implementation works for more than 2d cases (Pull Request #157 ). I have tested it, and it works for 1d to 3d. Also, the Lora parameter's shape is the same as before in the 2d case. I didn't test the group case. Please let me know if the group case needs to be fixed. @gau-nernst @edwardjhu
Class
ConvLoRA
currently only works for Conv2d. By inspecting the shape ofB @ A
, which is (out_channels // groups kernel_size, in_channels kernel_size), we can see that it is only compatible with Conv2d.For reference, weight shape for