CERC-AAI / multimodal

An implementation of model parallel autoregressive transformers on GPUs, based on the DeepSpeed library.
Apache License 2.0
8 stars 2 forks source link

Model-parallel version of loralib components to use with this codebase #21

Closed Rabbidon closed 1 year ago

Rabbidon commented 1 year ago

Currently there is no way to utilise the loralib library with this codebase, since loralib lacks support for model parallelism.

I would like to implement a model parallel version of the required components of loralib required to switch out the attention layer in the adapter with the corresponding lora version.

The particular thing that we care about is the type of attn._modules["query_key_value"] in the adapter constructor, before we wrap it in the adapter wrapper. In the previous codebase this had type nn.Linear, which we can naively swap out for loralib's MergedLinear with no issues: attn._modules["query_key_value"] = loralib.layers.MergedLinear(attn._modules["query_key_value"].input_size,attn._modules["query_key_value"].output_size,r=8,enable_lora=[True,False,True])

In the current codebase due to model parallelism, the type of attn._modules["query_key_value"] is ColumnParallelLinear, which doesn't even extend nn.Linear (it extends Module). Therefore when we switch it out naively with a MergedLinear, code which expects a ColumnParallelLinear breaks

We should make an analogue of the MergedLinear layer that is to the original MergedLinear as ColumnParallelLinear is to nn.Linear. This should work as a substitute.

Rabbidon commented 1 year ago

We did this