Open MasterXiong opened 3 weeks ago
Yeah so basically I'd try using lorax to lorafy the parameters of your base transformer. Take a look at them in the python interpreter, and figure out where the A/B values you want to generate are. Then, during your forward pass just generate the As and Bs and populate the tree with them. If you give me some minimal example code I can give you more specific pointers.
Thanks for your help! Here is a minimal example code of the transformer module I use:
class MlpBlock(nn.Module):
"""Transformer MLP / feed-forward block."""
mlp_dim: int
dtype: Dtype = jnp.float32
out_dim: Optional[int] = None
dropout_rate: float = 0.1
kernel_init: Callable[
[PRNGKey, Shape, Dtype], jax.Array
] = nn.initializers.xavier_uniform()
bias_init: Callable[[PRNGKey, Shape, Dtype], jax.Array] = nn.initializers.normal(
stddev=1e-6
)
@nn.compact
def __call__(self, inputs, *, deterministic):
"""Applies Transformer MlpBlock module."""
actual_out_dim = inputs.shape[-1] if self.out_dim is None else self.out_dim
x = nn.Dense(
features=self.mlp_dim,
dtype=self.dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
name="Dense_0",
)(inputs)
x = nn.gelu(x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
output = nn.Dense(
features=actual_out_dim,
dtype=self.dtype,
kernel_init=self.kernel_init,
bias_init=self.bias_init,
name="Dense_1",
)(x)
output = nn.Dropout(rate=self.dropout_rate)(output, deterministic=deterministic)
return output
class Encoder1DBlock(nn.Module):
"""Transformer encoder layer.
Attributes:
inputs: input data.
mlp_dim: dimension of the mlp on top of attention block.
dtype: the dtype of the computation (default: float32).
dropout_rate: dropout rate.
attention_dropout_rate: dropout for attention heads.
deterministic: bool, deterministic or not (to apply dropout).
num_heads: Number of heads in nn.MultiHeadDotProductAttention
"""
mlp_dim: int
num_heads: int
dtype: Dtype = jnp.float32
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
@nn.compact
def __call__(self, inputs, attention_mask, *, deterministic):
"""Applies Encoder1DBlock module.
Args:
inputs: Inputs to the layer.
deterministic: Dropout will not be applied when set to true.
Returns:
output after transformer encoder block.
"""
# Attention block.
assert inputs.ndim == 3, f"Expected (batch, seq, hidden) got {inputs.shape}"
x = nn.LayerNorm(dtype=self.dtype)(inputs)
x = nn.MultiHeadDotProductAttention(
dtype=self.dtype,
kernel_init=nn.initializers.xavier_uniform(),
broadcast_dropout=False,
deterministic=deterministic,
dropout_rate=self.attention_dropout_rate,
num_heads=self.num_heads,
)(x, x, mask=attention_mask)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=deterministic)
x = x + inputs
# MLP block.
y = nn.LayerNorm(dtype=self.dtype)(x)
y = MlpBlock(
mlp_dim=self.mlp_dim, dtype=self.dtype, dropout_rate=self.dropout_rate
)(y, deterministic=deterministic)
return x + y
The transformer module is stacked for multiple layers to get the final model like below:
class Transformer(nn.Module):
"""Transformer Model Encoder for sequence to sequence translation.
Attributes:
num_layers: number of layers
mlp_dim: dimension of the mlp on top of attention block
num_heads: Number of heads in nn.MultiHeadDotProductAttention
dropout_rate: dropout rate.
attention_dropout_rate: dropout rate in self attention.
"""
num_layers: int
mlp_dim: int
num_attention_heads: int
dropout_rate: float = 0.1
attention_dropout_rate: float = 0.1
add_position_embedding: bool = False
@nn.compact
def __call__(self, x, attention_mask, *, train):
"""Applies Transformer model on the inputs.
Args:
x: Inputs to the layer.
train: Set to `True` when training.
Returns:
output of a transformer encoder.
"""
assert x.ndim == 3 # (batch, len, emb)
if self.add_position_embedding:
x = AddPositionEmbs(
posemb_init=nn.initializers.normal(stddev=0.02), # from BERT.
name="posembed_input",
)(x)
x = nn.Dropout(rate=self.dropout_rate)(x, deterministic=not train)
# Input Encoder
for lyr in range(self.num_layers):
x = Encoder1DBlock(
mlp_dim=self.mlp_dim,
dropout_rate=self.dropout_rate,
attention_dropout_rate=self.attention_dropout_rate,
name=f"encoderblock_{lyr}",
num_heads=self.num_attention_heads,
)(x, attention_mask, deterministic=not train)
encoded = nn.LayerNorm(name="encoder_norm")(x)
return encoded
My thought is to initialize the hypernetwork parameters inside class Transformer
. The hypernetwork takes some context information (like task id, layer id in the model) as input, and have different output heads to generate the LoRA parameters for value, key, query, and the MLP layers respectively. Then I will pass the corresponding LoRA parameters to each Encoder1DBlock
layer when calling them. Do you think this makes sense? Or maybe there are some better way to do this with lorax
? Thanks!
I think the correct implementation depends a lot on what you want the hypernetwork to take as input. Are you going to compute all the lora weights at once, before beginning execution of your model? If so, I'd recommend making the hypernetwork a separate model (or at least making a new flax module which holds both the transformer and the hypernetwork).
The reason is that it's actually a little bit annoying to interact with parameters from within the context of a flax model, since they make everything appear to be OOP-y. You'd probably have to look into using lift
(https://flax.readthedocs.io/en/latest/developer_notes/lift.html).
Hi,
Thanks for open-sourcing this brilliant package! Similar to #6 , I also want to apply hypernetworks (HN) to LoRA, but my setting is a bit different. I have a base transformer model, and I want to use LoRA to adapt it to multiple different tasks. Instead of learning a separate adapter for each task, I want to use a single HN to generate different LoRA parameters for each task, by conditioning the HN on some task context that is different for each task. During the fine-tuning process, only the parameters of the HN are updated, while the base transformer is frozen. I was wondering that if you can give some suggestions or examples on how to implement this HN idea in
lorax
? Thanks a lot for your help!