davisyoshida / lorax

LoRA for arbitrary JAX models and functions
MIT License
127 stars 4 forks source link

Generating LoRA parameters with hypernetworks #12

Open MasterXiong opened 3 weeks ago

MasterXiong commented 3 weeks ago

Hi,

Thanks for open-sourcing this brilliant package! Similar to #6 , I also want to apply hypernetworks (HN) to LoRA, but my setting is a bit different. I have a base transformer model, and I want to use LoRA to adapt it to multiple different tasks. Instead of learning a separate adapter for each task, I want to use a single HN to generate different LoRA parameters for each task, by conditioning the HN on some task context that is different for each task. During the fine-tuning process, only the parameters of the HN are updated, while the base transformer is frozen. I was wondering that if you can give some suggestions or examples on how to implement this HN idea in lorax? Thanks a lot for your help!

davisyoshida commented 3 weeks ago

Yeah so basically I'd try using lorax to lorafy the parameters of your base transformer. Take a look at them in the python interpreter, and figure out where the A/B values you want to generate are. Then, during your forward pass just generate the As and Bs and populate the tree with them. If you give me some minimal example code I can give you more specific pointers.

MasterXiong commented 2 weeks ago

Thanks for your help! Here is a minimal example code of the transformer module I use:

class MlpBlock(nn.Module):
    """Transformer MLP / feed-forward block."""

    mlp_dim: int
    dtype: Dtype = jnp.float32
    out_dim: Optional[int] = None
    dropout_rate: float = 0.1
    kernel_init: Callable[
        [PRNGKey, Shape, Dtype], jax.Array
    ] = nn.initializers.xavier_uniform()
    bias_init: Callable[[PRNGKey, Shape, Dtype], jax.Array] = nn.initializers.normal(
        stddev=1e-6
    )

    @nn.compact
    def __call__(self, inputs, *, deterministic):
        """Applies Transformer MlpBlock module."""
        actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
        x = nn.Dense(
            features=self.mlp_dim,
            dtype=self.dtype,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
            name="Dense_0",
        )(inputs)
        x = nn.gelu(x)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
        output = nn.Dense(
            features=actual_out_dim,
            dtype=self.dtype,
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
            name="Dense_1",
        )(x)
        output = nn.Dropout(rate=self.dropout_rate)(output, deterministic=deterministic)
        return output

class Encoder1DBlock(nn.Module):
    """Transformer encoder layer.

    Attributes:
      inputs: input data.
      mlp_dim: dimension of the mlp on top of attention block.
      dtype: the dtype of the computation (default: float32).
      dropout_rate: dropout rate.
      attention_dropout_rate: dropout for attention heads.
      deterministic: bool, deterministic or not (to apply dropout).
      num_heads: Number of heads in nn.MultiHeadDotProductAttention
    """

    mlp_dim: int
    num_heads: int
    dtype: Dtype = jnp.float32
    dropout_rate: float = 0.1
    attention_dropout_rate: float = 0.1

    @nn.compact
    def __call__(self, inputs, attention_mask, *, deterministic):
        """Applies Encoder1DBlock module.

        Args:
          inputs: Inputs to the layer.
          deterministic: Dropout will not be applied when set to true.

        Returns:
          output after transformer encoder block.
        """

        # Attention block.
        assert inputs.ndim == 3, f"Expected (batch, seq, hidden) got {inputs.shape}"
        x = nn.LayerNorm(dtype=self.dtype)(inputs)
        x = nn.MultiHeadDotProductAttention(
            dtype=self.dtype,
            kernel_init=nn.initializers.xavier_uniform(),
            broadcast_dropout=False,
            deterministic=deterministic,
            dropout_rate=self.attention_dropout_rate,
            num_heads=self.num_heads,
        )(x, x, mask=attention_mask)
        x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
        x = x + inputs

        # MLP block.
        y = nn.LayerNorm(dtype=self.dtype)(x)
        y = MlpBlock(
            mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate
        )(y, deterministic=deterministic)

        return x + y

The transformer module is stacked for multiple layers to get the final model like below:

class Transformer(nn.Module):
    """Transformer Model Encoder for sequence to sequence translation.

    Attributes:
      num_layers: number of layers
      mlp_dim: dimension of the mlp on top of attention block
      num_heads: Number of heads in nn.MultiHeadDotProductAttention
      dropout_rate: dropout rate.
      attention_dropout_rate: dropout rate in self attention.
    """

    num_layers: int
    mlp_dim: int
    num_attention_heads: int
    dropout_rate: float = 0.1
    attention_dropout_rate: float = 0.1
    add_position_embedding: bool = False

    @nn.compact
    def __call__(self, x, attention_mask, *, train):
        """Applies Transformer model on the inputs.

        Args:
          x: Inputs to the layer.
          train: Set to `True` when training.

        Returns:
          output of a transformer encoder.
        """
        assert x.ndim == 3  # (batch, len, emb)

        if self.add_position_embedding:
            x = AddPositionEmbs(
                posemb_init=nn.initializers.normal(stddev=0.02),  # from BERT.
                name="posembed_input",
            )(x)
            x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)

        # Input Encoder
        for lyr in range(self.num_layers):
            x = Encoder1DBlock(
                mlp_dim=self.mlp_dim,
                dropout_rate=self.dropout_rate,
                attention_dropout_rate=self.attention_dropout_rate,
                name=f"encoderblock_{lyr}",
                num_heads=self.num_attention_heads,
            )(x, attention_mask, deterministic=not train)
        encoded = nn.LayerNorm(name="encoder_norm")(x)

        return encoded

My thought is to initialize the hypernetwork parameters inside class Transformer. The hypernetwork takes some context information (like task id, layer id in the model) as input, and have different output heads to generate the LoRA parameters for value, key, query, and the MLP layers respectively. Then I will pass the corresponding LoRA parameters to each Encoder1DBlock layer when calling them. Do you think this makes sense? Or maybe there are some better way to do this with lorax? Thanks!

davisyoshida commented 2 weeks ago

I think the correct implementation depends a lot on what you want the hypernetwork to take as input. Are you going to compute all the lora weights at once, before beginning execution of your model? If so, I'd recommend making the hypernetwork a separate model (or at least making a new flax module which holds both the transformer and the hypernetwork).

The reason is that it's actually a little bit annoying to interact with parameters from within the context of a flax model, since they make everything appear to be OOP-y. You'd probably have to look into using lift (https://flax.readthedocs.io/en/latest/developer_notes/lift.html).