Currently there is no way to utilise the loralib library with this codebase, since loralib lacks support for model parallelism.
I would like to implement a model parallel version of the required components of loralib required to switch out the attention layer in the adapter with the corresponding lora version.
The particular thing that we care about is the type of attn._modules["query_key_value"] in the adapter constructor, before we wrap it in the adapter wrapper.
In the previous codebase this had type nn.Linear, which we can naively swap out for loralib's MergedLinear with no issues:
attn._modules["query_key_value"] = loralib.layers.MergedLinear(attn._modules["query_key_value"].input_size,attn._modules["query_key_value"].output_size,r=8,enable_lora=[True,False,True])
In the current codebase due to model parallelism, the type of attn._modules["query_key_value"] is ColumnParallelLinear, which doesn't even extend nn.Linear (it extends Module). Therefore when we switch it out naively with a MergedLinear, code which expects a ColumnParallelLinear breaks
We should make an analogue of the MergedLinear layer that is to the original MergedLinear as ColumnParallelLinear is to nn.Linear. This should work as a substitute.
Currently there is no way to utilise the loralib library with this codebase, since loralib lacks support for model parallelism.
I would like to implement a model parallel version of the required components of loralib required to switch out the attention layer in the adapter with the corresponding lora version.
The particular thing that we care about is the type of
attn._modules["query_key_value"]
in the adapter constructor, before we wrap it in the adapter wrapper. In the previous codebase this had type nn.Linear, which we can naively swap out for loralib's MergedLinear with no issues:attn._modules["query_key_value"] = loralib.layers.MergedLinear(attn._modules["query_key_value"].input_size,attn._modules["query_key_value"].output_size,r=8,enable_lora=[True,False,True])
In the current codebase due to model parallelism, the type of
attn._modules["query_key_value"]
is ColumnParallelLinear, which doesn't even extend nn.Linear (it extends Module). Therefore when we switch it out naively with a MergedLinear, code which expects a ColumnParallelLinear breaksWe should make an analogue of the MergedLinear layer that is to the original MergedLinear as ColumnParallelLinear is to nn.Linear. This should work as a substitute.