davisyoshida / lorax

LoRA for arbitrary JAX models and functions
MIT License
132 stars 5 forks source link

Predicting LoRA weights #6

Closed PuR3Luck closed 10 months ago

PuR3Luck commented 11 months ago

I would like to use a separate neural network to predict LoRA weights for a main neural network, while training both neural networks at the same time. How can I manipulate the pytrees or to achieve this if it is possible at all?

PuR3Luck commented 11 months ago

I am experimenting with using one transformer block and then switching out loras for each application of the transformer block, hence I would like to train both neural networks at the same time

davisyoshida commented 11 months ago

@PuR3Luck The LoraWeight class is a pytree, so you can return it from jitted functions. In this case even that shouldn't be necessary, I think the easiest way would be something like this:

def the_lora_weights_model(params, inputs):
    # do whatever you're doing here to produce A and B
    return A, B

def original_model(W, inputs):
    # actual model logic
    return output

lora_model = lorax.lora(original_model) # this makes it so that the model can handle LoraWeight arguments

def combined_model(W, inner_params, inputs):
    A, B = the_lora_weights_model(...) # invoke this however it's supposed to be called

    lora_weight = LoraWeight(w=W, a=A, b=B)
    output = lora_model(lora_weight, ...) # invoke the model like normal except pass the lora weight

If you give me some more details or code for your setup I can give you something more customized.

PuR3Luck commented 11 months ago

Thanks for the help!

What I am planning to do is to replace all of the transformer blocks of an autoregressive transformer written in Flax with just one transformer block, that is updated with LoRAs parameterised by a neural hypernetwork. The hypernetwork is given a signal for the depth of the layer and possibly the most recent token generated. This may be able to let me perform inference time quality-throughput tradeoffs.

I think this might be viable as I think One Wide Feedforward is All You Need (https://arxiv.org/pdf/2309.01826.pdf) showed that replacing all ffns with just one common ffn can work. I hypothesized combining that combining this approach (common ffns and maybe self attention) with LoRAs might allow me to regain the lost accuracy while potentially overfitting less due to the lower parameter count.

Maybe it might even be faster compared to larger models due to the smaller parameter count and being more compute bound and less memory bound compared to larger models

PuR3Luck commented 11 months ago

On a separate note, if I want to train both the transformer block and the hypernetwork at the same time, I should not use the wrap_optimiser() function right?

davisyoshida commented 11 months ago

I tried the "shared weights but different LoRAs" a few months ago. I found that the parameter efficiency of dense weights was better (I had to kick the lora dimension up so high it wasn't actually reducing the number of parameters). Your hypernetwork idea is interesting, maybe it will help.

I should not use the wrap_optimiser() function right?

In your case it sounds like all the parameters are trainable, so it should be unecessary.

PuR3Luck commented 11 months ago

In your case did you pretrain from scratch or factorise out LoRA matrices from a pretrained model?

davisyoshida commented 11 months ago

I trained from scratch. I was doing language modeling on penn treebank. I definitely can't say I did a thorough enough experiment to be sure that it doesn't work, since I only spent a couple hours on it.

PuR3Luck commented 11 months ago

Would it be possible to still do it if I wanted to write all my networks in Flax (as I see you are manually passing in all the parameters)? Also, do you see an elegant way of using the pytree such that my hyper network knows the dimensions of its outputs on initialisation such that I can easily pass the output (As and Bs for all parts of the main network) of the hypernetwork into LoRAWeight?

davisyoshida commented 10 months ago

Here's some example code showing how I'd handle it. The rough outline of the strategy is:

  1. Initialize weights for both your model and the hypernetwork (using nn.compact lets the hypernetwork do input-shape-dependent initialization)
  2. Delete the unneeded weights for the outer model (if this is costing too much GPU memory, you can do some workarounds to avoid actually initializing these weights that get created then deleted)
  3. Define a joint_call function which re-populates the weights for the model using a combination of weight sharing and calls to the hypernetwork, then calls the model using these parameters. This function will be compatible with JAX transformations like grad, vmap, and jit.
import flax.linen as nn
import lorax
import jax
import jax.numpy as jnp

class MyFlaxNetwork(nn.Module):
    """Example network"""
    def setup(self):
        self.blocks = [{'a': nn.Dense(64), 'b': nn.Dense(64)} for _ in range(5)]
        self.out_proj = nn.Dense(1)

    def __call__(self, x):
        for block in self.blocks:
            x = jax.nn.relu(block['a'](x) + block['b'](x))

        return self.out_proj(x)

class LoraMakerNetwork(nn.Module):
    @nn.compact # have to use this so our parameters can depend on the input shape
    def __call__(self, W, *whatever_other_inputs_you_want):
        M, N = W.shape
        k = 4
        a = jax.random.normal(jax.random.PRNGKey(0), (k, N))
        b = jax.random.normal(jax.random.PRNGKey(1), (M, k))

        some_param = self.param('some_param', lambda rng_key: jnp.ones(()))
        return a, b

def main():
    model = MyFlaxNetwork()
    params = model.init(jax.random.PRNGKey(0), jnp.ones(64))

    for i in range(1, 5):
        # Delete the extra params we don't care about
        # This step might need to change quite a bit depending on
        # what exactly you want to share or keep separate
        del params['params'][f'blocks_{i}_a']['kernel']
        del params['params'][f'blocks_{i}_b']['kernel']

    lora_model = LoraMakerNetwork()
    lora_model_params = lora_model.init(jax.random.PRNGKey(0), params['params']['blocks_0_a']['kernel'])

    @jax.jit
    def joint_call(params, lora_model_params, input_data):
        # Copy the params tree so we can mutate it
        # This doesn't actually copy the data on the GPU, it just copies the pytree at tracing time
        modified_params = jax.tree_map(lambda x: x, params)

        # Step 1: Overwrite the original param tree with LoraWeight instances
        for i in range(5):
            for k in ['a', 'b']:
                shared_param_name = f'blocks_0_{k}'
                write_params_name = f'blocks_{i}_{k}'
                w_shared = params['params'][shared_param_name]['kernel']

                # This will do the same thing for every layer, but presumably you'll be passing some other inputs
                a, b = lora_model.apply(lora_model_params, w_shared)
                lora_weight = lorax.LoraWeight(w=w_shared, a=a, b=b)
                modified_params['params'][write_params_name]['kernel'] = lora_weight

        # Step 2: Run model using modified params tree
        wrapped_model = lorax.lora(model.apply)
        return wrapped_model(modified_params, input_data)

    inp = jax.random.normal(jax.random.PRNGKey(0), (64,))

    print(joint_call(params, lora_model_params, inp))

if __name__ == '__main__':
    main()

Feel free to let me know if you have any questions or if this won't work.

PuR3Luck commented 10 months ago

Thanks for the prompt reply!

Do you think that there is a way to wrap the main function in a flax module? I think that wrapping it in a flax module would make testing and initialising the model more convenient. My only worry is that in the flax module since I dont think the init method is not called, I am not sure of a way to del the parameters

PuR3Luck commented 10 months ago

I am leaving a simplified version of my model code below (i am not fully certain on the implementation of the LoRATransformer class without the init method)

class TransformerBlock(nn.Module):
  d_model: int
  num_heads: int
  @nn.compact
  def __call__(self,x):
    attn = nn.SelfAttention(num_heads=self.num_heads,name="self_attn")(x)
    x = attn + x
    x = nn.LayerNorm()(x)
    x = nn.Sequential([
      nn.Dense(features=4*self.d_model,name="ffn_1"),
      nn.GELU(),
      nn.Dense(features=self.d_model,name="ffn_2"),
    ])(x)
    x = nn.Dropout(rate=0.1)(x)
    x = x + attn
    x = nn.LayerNorm()(x)
    return x

class Transformer(nn.Module):
  depth: int
  d_model: int
  num_heads: int

  def setup(self):
    self.blocks = [TransformerBlock(d_model=self.d_model,num_heads=self.num_heads) for _ in range(self.depth)]
    self.output = nn.Dense(features=1)

  def __call__(self,x):
    for block in self.blocks:
      x = block(x)
    x = self.output(x)
    return x

class LoRATransformer(nn.Module):
  depth: int
  lora_rank: int
  d_model: int
  num_heads: int

  def setup(self):
    network = Transformer(depth=self.depth,d_model=self.d_model,num_heads=self.num_heads)
    lora_hypernetwork = LoRA_Hypernetwork(lora_rank=self.lora_rank)

    for i in range(1,self.depth+1):

  def __call__(self,x):
    raise NotImplementedError
davisyoshida commented 10 months ago

So I think you probably don't want the LoRATransformer to be a Module, since you want to be able to manipulate the parameters of the Transformer module. Inside the context of a flax module, the parameters are hidden from you. That's why I implemented the joint_call function.

PuR3Luck commented 10 months ago

So am I right to interpret your response as I cannot implement LoRATransformer as a flax module?

davisyoshida commented 10 months ago

You probably can get it working that way, but it will be harder.

davisyoshida commented 10 months ago

Oh whoops didn't mean to close.

PuR3Luck commented 10 months ago

@davisyoshida Just to confirm, but I use the joint call method to perform the training right, and this applying a loss through the joint call function would optimise both the transformer block and hypernetwork parameters right?

PuR3Luck commented 10 months ago

Also do I have to initialize a separate hypernetwork for each matrix that has a unique shape? I presume so

PuR3Luck commented 10 months ago

And also how should I modify your code if I want to add dropout, espescially the line

wrapped_model = lorax.lora(model.apply) 
davisyoshida commented 10 months ago

I use the joint call method to perform the training right

Yes. To differentiate your loss function with respect to both sets of parameters, you can either pass them together in a tuple, or use the argnums param to jax.grad.

Also do I have to initialize a separate hypernetwork for each matrix that has a unique shape? I presume so

I think this depends on your architecture, not the particular implementation. If you have some architecture that can handle multiple output shapes, it should be fine to use that.

how should I modify your code if I want to add dropout

Usually dropout goes in the model code, you shouldn't need to add anything extra to the code I supplied. If you have something else in mind give me a few more details on how you want to apply dropout and I can give some advice.

PuR3Luck commented 10 months ago

Following the implementation of dropout from the flax linen docs, dropout is enabled by passing a training boolean, combined with a prng key for the dropout in the apply method, so I was wondering how to modify the model.apply method as the lorax.lora function wraps model.apply. I suspect I should modify the code to be wrapped_model = lorax.lora(model) then follow the dropout guide, and use wrapped_model.apply instead, with the training flag and prng key

https://flax.readthedocs.io/en/latest/guides/training_techniques/dropout.html

davisyoshida commented 10 months ago

@PuR3Luck The lora() transform knows about functions, not models, so you won't be able to access attributes like .apply after using it. If it doesn't work to call wrapped_model(modified_params, input_data, training=True), then you could solve that with a partial function application:

from functools import partial

apply_with_dropout = partial(model.apply, training=True)
wrapped_model = lorax.lora(apply_with_dropout)
PuR3Luck commented 10 months ago

Just checking my final code should look something like this right?

import flax.linen as nn
import lorax
import jax
import jax.numpy as jnp
from functools import partial

class MyFlaxNetwork(nn.Module):
    """Example network"""
    def setup(self):
        self.blocks = [{'a': nn.Dense(64), 'b': nn.Dense(64)} for _ in range(5)]
        self.out_proj = nn.Dense(1)

    def __call__(self, x):
        for block in self.blocks:
            x = jax.nn.relu(block['a'](x) + block['b'](x))

        return self.out_proj(x)

class LoraMakerNetwork(nn.Module):
    @nn.compact # have to use this so our parameters can depend on the input shape
    def __call__(self, W, *whatever_other_inputs_you_want):
        M, N = W.shape
        k = 4
        a = jax.random.normal(jax.random.PRNGKey(0), (k, N))
        b = jax.random.normal(jax.random.PRNGKey(1), (M, k))

        some_param = self.param('some_param', lambda rng_key: jnp.ones(()))
        return a, b

def main():
    model = MyFlaxNetwork()
    params = model.init(jax.random.PRNGKey(0), jnp.ones(64))

    for i in range(1, 5):
        # Delete the extra params we don't care about
        # This step might need to change quite a bit depending on
        # what exactly you want to share or keep separate
        del params['params'][f'blocks_{i}_a']['kernel']
        del params['params'][f'blocks_{i}_b']['kernel']

    lora_model = LoraMakerNetwork()
    lora_model_params = lora_model.init(jax.random.PRNGKey(0), params['params']['blocks_0_a']['kernel'])

    @jax.jit
    def joint_call(params, lora_model_params, input_data):
        # Copy the params tree so we can mutate it
        # This doesn't actually copy the data on the GPU, it just copies the pytree at tracing time
        modified_params = jax.tree_map(lambda x: x, params)

        # Step 1: Overwrite the original param tree with LoraWeight instances
        for i in range(5):
            for k in ['a', 'b']:
                shared_param_name = f'blocks_0_{k}'
                write_params_name = f'blocks_{i}_{k}'
                w_shared = params['params'][shared_param_name]['kernel']

                # This will do the same thing for every layer, but presumably you'll be passing some other inputs
                a, b = lora_model.apply(lora_model_params, w_shared)
                lora_weight = lorax.LoraWeight(w=w_shared, a=a, b=b)
                modified_params['params'][write_params_name]['kernel'] = lora_weight

        apply_fn  = partial(model.apply, training = True, rngs = {"dropout":jax.random.PRNG(0)})

        # Step 2: Run model using modified params tree
        wrapped_model = lorax.lora(apply_fn)
        return wrapped_model(modified_params, input_data)

    all_params = (params_0_a, params_0_b, lora_model_params)

    optimiser = optax.adam(1e-3)

    opt_state = optimiser.init(all_params)

    @jax.jit
    def update_fn(params, opt_state, joint_call, x):
        grad_fn = jax.value_and_grad(joint_call)
        loss, grad = grad_fn(params, x)
        updates, new_opt_state = optimizer.update(grad, opt_state, params=params)
        updated_params = optax.apply_updates(params, updates)
        return loss, new_opt_state, updated_params

    for i in range(epochs):
        loss, opt_state, params = update_fn(loss, opt_state, params)

if __name__ == '__main__':
    main()
davisyoshida commented 10 months ago

Well this doesn't actually have a loss function or targets to use for joint_call, but other than that the layout seems roughly correct. You don't need to pass joint_call as an argument to the update function, it's already accessible because they're in the same scope.

PuR3Luck commented 10 months ago

I thought that functions being jax.jit must not use global variables and functions though?

davisyoshida commented 10 months ago

You only need to avoid that if their values will change. For stuff like a function you're only going to assign once there's no issue.

PuR3Luck commented 10 months ago

I didn't know that. Thanks for sharing!

PuR3Luck commented 10 months ago

I resolved the missing loss function and joint call targets, so the code should look like this right?

import flax.linen as nn
import lorax
import jax
import jax.numpy as jnp
from functools import partial

class MyFlaxNetwork(nn.Module):
    """Example network"""
    def setup(self):
        self.blocks = [{'a': nn.Dense(64), 'b': nn.Dense(64)} for _ in range(5)]
        self.out_proj = nn.Dense(1)

    def __call__(self, x):
        for block in self.blocks:
            x = jax.nn.relu(block['a'](x) + block['b'](x))

        return self.out_proj(x)

class LoraMakerNetwork(nn.Module):
    @nn.compact # have to use this so our parameters can depend on the input shape
    def __call__(self, W, *whatever_other_inputs_you_want):
        M, N = W.shape
        k = 4
        a = jax.random.normal(jax.random.PRNGKey(0), (k, N))
        b = jax.random.normal(jax.random.PRNGKey(1), (M, k))

        some_param = self.param('some_param', lambda rng_key: jnp.ones(()))
        return a, b

def main():
    model = MyFlaxNetwork()
    params = model.init(jax.random.PRNGKey(0), jnp.ones(64))

    for i in range(1, 5):
        # Delete the extra params we don't care about
        # This step might need to change quite a bit depending on
        # what exactly you want to share or keep separate
        del params['params'][f'blocks_{i}_a']['kernel']
        del params['params'][f'blocks_{i}_b']['kernel']

    lora_model = LoraMakerNetwork()
    lora_model_params = lora_model.init(jax.random.PRNGKey(0), params['params']['blocks_0_a']['kernel'])

    @jax.jit
    def joint_call(params, lora_model_params, input_data):
        # Copy the params tree so we can mutate it
        # This doesn't actually copy the data on the GPU, it just copies the pytree at tracing time
        modified_params = jax.tree_map(lambda x: x, params)

        # Step 1: Overwrite the original param tree with LoraWeight instances
        for i in range(5):
            for k in ['a', 'b']:
                shared_param_name = f'blocks_0_{k}'
                write_params_name = f'blocks_{i}_{k}'
                w_shared = params['params'][shared_param_name]['kernel']

                # This will do the same thing for every layer, but presumably you'll be passing some other inputs
                a, b = lora_model.apply(lora_model_params, w_shared)
                lora_weight = lorax.LoraWeight(w=w_shared, a=a, b=b)
                modified_params['params'][write_params_name]['kernel'] = lora_weight

        apply_fn  = partial(model.apply, training = True, rngs = {"dropout":jax.random.PRNG(0)})

        # Step 2: Run model using modified params tree
        wrapped_model = lorax.lora(apply_fn)
        return wrapped_model(modified_params, input_data)

    all_params = (params, lora_model_params)

    optimiser = optax.adam(1e-3)

    opt_state = optimiser.init(all_params)

    def loss_fn(params,data):
        output = joint_call(
            params[0],
            params[1],
            data
            )
        # Calculate loss
        return loss

    @jax.jit
    def update_fn(params, opt_state, data):
        grad_fn = jax.value_and_grad(loss_fn)
        loss, grad = grad_fn(params, data)
        updates, new_opt_state = optimizer.update(grad, opt_state, params=params)
        updated_params = optax.apply_updates(params, updates)
        return loss, new_opt_state, updated_params

    for i in range(epochs):
        loss, opt_state, params = update_fn(params, opt_state, data)
        print(f"Loss:{loss}")

if __name__ == '__main__':
    main()
davisyoshida commented 10 months ago

Looks like it should work yeah

davisyoshida commented 10 months ago

Closing this, feel free to let me know if there are any more problems though.

PuR3Luck commented 10 months ago

I am currently playing around with training the Transformer. I am also now wondering how I could apply this to convolutions as the kernel for the convolution has 3 dimensions?

PuR3Luck commented 10 months ago

I am applying a 1D convolution over a sequence with dimensions (seq_len, dimension) that has multiple out_features so the kernel has a shape with 3 dimensions. I saw in the code for lorax that loraweight has a mechanism to handle convolutions. Could you provide some advice?

PuR3Luck commented 10 months ago

@davisyoshida The network seems to very frequently output NaN values for "deep networks" around 16 or so layers as in the output of the network is NaN, this behaviour is weird to me as I didnt really observe this behaviour when using normal dense models but seems also have this issue if I scale model width, do you have some suggestions to fix this behaviour? In general it seems extremely sensitive to the hyperparameters. I have tried different weight initialisation but it does not seem to help much.

davisyoshida commented 10 months ago

@PuR3Luck This sounds like something more to do with the hypernetwork part than the Lora part. I don't have any advice for you since I haven't used hypernetworks.

PuR3Luck commented 10 months ago

Ok thanks for all the help!