Closed PuR3Luck closed 10 months ago
I am experimenting with using one transformer block and then switching out loras for each application of the transformer block, hence I would like to train both neural networks at the same time
@PuR3Luck The LoraWeight
class is a pytree, so you can return it from jitted functions. In this case even that shouldn't be necessary, I think the easiest way would be something like this:
def the_lora_weights_model(params, inputs):
# do whatever you're doing here to produce A and B
return A, B
def original_model(W, inputs):
# actual model logic
return output
lora_model = lorax.lora(original_model) # this makes it so that the model can handle LoraWeight arguments
def combined_model(W, inner_params, inputs):
A, B = the_lora_weights_model(...) # invoke this however it's supposed to be called
lora_weight = LoraWeight(w=W, a=A, b=B)
output = lora_model(lora_weight, ...) # invoke the model like normal except pass the lora weight
If you give me some more details or code for your setup I can give you something more customized.
Thanks for the help!
What I am planning to do is to replace all of the transformer blocks of an autoregressive transformer written in Flax with just one transformer block, that is updated with LoRAs parameterised by a neural hypernetwork. The hypernetwork is given a signal for the depth of the layer and possibly the most recent token generated. This may be able to let me perform inference time quality-throughput tradeoffs.
I think this might be viable as I think One Wide Feedforward is All You Need (https://arxiv.org/pdf/2309.01826.pdf) showed that replacing all ffns with just one common ffn can work. I hypothesized combining that combining this approach (common ffns and maybe self attention) with LoRAs might allow me to regain the lost accuracy while potentially overfitting less due to the lower parameter count.
Maybe it might even be faster compared to larger models due to the smaller parameter count and being more compute bound and less memory bound compared to larger models
On a separate note, if I want to train both the transformer block and the hypernetwork at the same time, I should not use the wrap_optimiser() function right?
I tried the "shared weights but different LoRAs" a few months ago. I found that the parameter efficiency of dense weights was better (I had to kick the lora dimension up so high it wasn't actually reducing the number of parameters). Your hypernetwork idea is interesting, maybe it will help.
I should not use the wrap_optimiser() function right?
In your case it sounds like all the parameters are trainable, so it should be unecessary.
In your case did you pretrain from scratch or factorise out LoRA matrices from a pretrained model?
I trained from scratch. I was doing language modeling on penn treebank. I definitely can't say I did a thorough enough experiment to be sure that it doesn't work, since I only spent a couple hours on it.
Would it be possible to still do it if I wanted to write all my networks in Flax (as I see you are manually passing in all the parameters)? Also, do you see an elegant way of using the pytree such that my hyper network knows the dimensions of its outputs on initialisation such that I can easily pass the output (As and Bs for all parts of the main network) of the hypernetwork into LoRAWeight?
Here's some example code showing how I'd handle it. The rough outline of the strategy is:
nn.compact
lets the hypernetwork do input-shape-dependent initialization)joint_call
function which re-populates the weights for the model using a combination of weight sharing and calls to the hypernetwork, then calls the model using these parameters. This function will be compatible with JAX transformations like grad
, vmap
, and jit
.import flax.linen as nn
import lorax
import jax
import jax.numpy as jnp
class MyFlaxNetwork(nn.Module):
"""Example network"""
def setup(self):
self.blocks = [{'a': nn.Dense(64), 'b': nn.Dense(64)} for _ in range(5)]
self.out_proj = nn.Dense(1)
def __call__(self, x):
for block in self.blocks:
x = jax.nn.relu(block['a'](x) + block['b'](x))
return self.out_proj(x)
class LoraMakerNetwork(nn.Module):
@nn.compact # have to use this so our parameters can depend on the input shape
def __call__(self, W, *whatever_other_inputs_you_want):
M, N = W.shape
k = 4
a = jax.random.normal(jax.random.PRNGKey(0), (k, N))
b = jax.random.normal(jax.random.PRNGKey(1), (M, k))
some_param = self.param('some_param', lambda rng_key: jnp.ones(()))
return a, b
def main():
model = MyFlaxNetwork()
params = model.init(jax.random.PRNGKey(0), jnp.ones(64))
for i in range(1, 5):
# Delete the extra params we don't care about
# This step might need to change quite a bit depending on
# what exactly you want to share or keep separate
del params['params'][f'blocks_{i}_a']['kernel']
del params['params'][f'blocks_{i}_b']['kernel']
lora_model = LoraMakerNetwork()
lora_model_params = lora_model.init(jax.random.PRNGKey(0), params['params']['blocks_0_a']['kernel'])
@jax.jit
def joint_call(params, lora_model_params, input_data):
# Copy the params tree so we can mutate it
# This doesn't actually copy the data on the GPU, it just copies the pytree at tracing time
modified_params = jax.tree_map(lambda x: x, params)
# Step 1: Overwrite the original param tree with LoraWeight instances
for i in range(5):
for k in ['a', 'b']:
shared_param_name = f'blocks_0_{k}'
write_params_name = f'blocks_{i}_{k}'
w_shared = params['params'][shared_param_name]['kernel']
# This will do the same thing for every layer, but presumably you'll be passing some other inputs
a, b = lora_model.apply(lora_model_params, w_shared)
lora_weight = lorax.LoraWeight(w=w_shared, a=a, b=b)
modified_params['params'][write_params_name]['kernel'] = lora_weight
# Step 2: Run model using modified params tree
wrapped_model = lorax.lora(model.apply)
return wrapped_model(modified_params, input_data)
inp = jax.random.normal(jax.random.PRNGKey(0), (64,))
print(joint_call(params, lora_model_params, inp))
if __name__ == '__main__':
main()
Feel free to let me know if you have any questions or if this won't work.
Thanks for the prompt reply!
Do you think that there is a way to wrap the main function in a flax module? I think that wrapping it in a flax module would make testing and initialising the model more convenient. My only worry is that in the flax module since I dont think the init method is not called, I am not sure of a way to del the parameters
I am leaving a simplified version of my model code below (i am not fully certain on the implementation of the LoRATransformer class without the init method)
class TransformerBlock(nn.Module):
d_model: int
num_heads: int
@nn.compact
def __call__(self,x):
attn = nn.SelfAttention(num_heads=self.num_heads,name="self_attn")(x)
x = attn + x
x = nn.LayerNorm()(x)
x = nn.Sequential([
nn.Dense(features=4*self.d_model,name="ffn_1"),
nn.GELU(),
nn.Dense(features=self.d_model,name="ffn_2"),
])(x)
x = nn.Dropout(rate=0.1)(x)
x = x + attn
x = nn.LayerNorm()(x)
return x
class Transformer(nn.Module):
depth: int
d_model: int
num_heads: int
def setup(self):
self.blocks = [TransformerBlock(d_model=self.d_model,num_heads=self.num_heads) for _ in range(self.depth)]
self.output = nn.Dense(features=1)
def __call__(self,x):
for block in self.blocks:
x = block(x)
x = self.output(x)
return x
class LoRATransformer(nn.Module):
depth: int
lora_rank: int
d_model: int
num_heads: int
def setup(self):
network = Transformer(depth=self.depth,d_model=self.d_model,num_heads=self.num_heads)
lora_hypernetwork = LoRA_Hypernetwork(lora_rank=self.lora_rank)
for i in range(1,self.depth+1):
def __call__(self,x):
raise NotImplementedError
So I think you probably don't want the LoRATransformer to be a Module, since you want to be able to manipulate the parameters of the Transformer
module. Inside the context of a flax module, the parameters are hidden from you. That's why I implemented the joint_call
function.
So am I right to interpret your response as I cannot implement LoRATransformer as a flax module?
You probably can get it working that way, but it will be harder.
Oh whoops didn't mean to close.
@davisyoshida Just to confirm, but I use the joint call method to perform the training right, and this applying a loss through the joint call function would optimise both the transformer block and hypernetwork parameters right?
Also do I have to initialize a separate hypernetwork for each matrix that has a unique shape? I presume so
And also how should I modify your code if I want to add dropout, espescially the line
wrapped_model = lorax.lora(model.apply)
I use the joint call method to perform the training right
Yes. To differentiate your loss function with respect to both sets of parameters, you can either pass them together in a tuple, or use the argnums
param to jax.grad
.
Also do I have to initialize a separate hypernetwork for each matrix that has a unique shape? I presume so
I think this depends on your architecture, not the particular implementation. If you have some architecture that can handle multiple output shapes, it should be fine to use that.
how should I modify your code if I want to add dropout
Usually dropout goes in the model code, you shouldn't need to add anything extra to the code I supplied. If you have something else in mind give me a few more details on how you want to apply dropout and I can give some advice.
Following the implementation of dropout from the flax linen docs, dropout is enabled by passing a training boolean, combined with a prng key for the dropout in the apply method, so I was wondering how to modify the model.apply method as the lorax.lora function wraps model.apply. I suspect I should modify the code to be wrapped_model = lorax.lora(model) then follow the dropout guide, and use wrapped_model.apply instead, with the training flag and prng key
https://flax.readthedocs.io/en/latest/guides/training_techniques/dropout.html
@PuR3Luck The lora()
transform knows about functions, not models, so you won't be able to access attributes like .apply
after using it. If it doesn't work to call wrapped_model(modified_params, input_data, training=True)
, then you could solve that with a partial function application:
from functools import partial
apply_with_dropout = partial(model.apply, training=True)
wrapped_model = lorax.lora(apply_with_dropout)
Just checking my final code should look something like this right?
import flax.linen as nn
import lorax
import jax
import jax.numpy as jnp
from functools import partial
class MyFlaxNetwork(nn.Module):
"""Example network"""
def setup(self):
self.blocks = [{'a': nn.Dense(64), 'b': nn.Dense(64)} for _ in range(5)]
self.out_proj = nn.Dense(1)
def __call__(self, x):
for block in self.blocks:
x = jax.nn.relu(block['a'](x) + block['b'](x))
return self.out_proj(x)
class LoraMakerNetwork(nn.Module):
@nn.compact # have to use this so our parameters can depend on the input shape
def __call__(self, W, *whatever_other_inputs_you_want):
M, N = W.shape
k = 4
a = jax.random.normal(jax.random.PRNGKey(0), (k, N))
b = jax.random.normal(jax.random.PRNGKey(1), (M, k))
some_param = self.param('some_param', lambda rng_key: jnp.ones(()))
return a, b
def main():
model = MyFlaxNetwork()
params = model.init(jax.random.PRNGKey(0), jnp.ones(64))
for i in range(1, 5):
# Delete the extra params we don't care about
# This step might need to change quite a bit depending on
# what exactly you want to share or keep separate
del params['params'][f'blocks_{i}_a']['kernel']
del params['params'][f'blocks_{i}_b']['kernel']
lora_model = LoraMakerNetwork()
lora_model_params = lora_model.init(jax.random.PRNGKey(0), params['params']['blocks_0_a']['kernel'])
@jax.jit
def joint_call(params, lora_model_params, input_data):
# Copy the params tree so we can mutate it
# This doesn't actually copy the data on the GPU, it just copies the pytree at tracing time
modified_params = jax.tree_map(lambda x: x, params)
# Step 1: Overwrite the original param tree with LoraWeight instances
for i in range(5):
for k in ['a', 'b']:
shared_param_name = f'blocks_0_{k}'
write_params_name = f'blocks_{i}_{k}'
w_shared = params['params'][shared_param_name]['kernel']
# This will do the same thing for every layer, but presumably you'll be passing some other inputs
a, b = lora_model.apply(lora_model_params, w_shared)
lora_weight = lorax.LoraWeight(w=w_shared, a=a, b=b)
modified_params['params'][write_params_name]['kernel'] = lora_weight
apply_fn = partial(model.apply, training = True, rngs = {"dropout":jax.random.PRNG(0)})
# Step 2: Run model using modified params tree
wrapped_model = lorax.lora(apply_fn)
return wrapped_model(modified_params, input_data)
all_params = (params_0_a, params_0_b, lora_model_params)
optimiser = optax.adam(1e-3)
opt_state = optimiser.init(all_params)
@jax.jit
def update_fn(params, opt_state, joint_call, x):
grad_fn = jax.value_and_grad(joint_call)
loss, grad = grad_fn(params, x)
updates, new_opt_state = optimizer.update(grad, opt_state, params=params)
updated_params = optax.apply_updates(params, updates)
return loss, new_opt_state, updated_params
for i in range(epochs):
loss, opt_state, params = update_fn(loss, opt_state, params)
if __name__ == '__main__':
main()
Well this doesn't actually have a loss function or targets to use for joint_call, but other than that the layout seems roughly correct. You don't need to pass joint_call as an argument to the update function, it's already accessible because they're in the same scope.
I thought that functions being jax.jit must not use global variables and functions though?
You only need to avoid that if their values will change. For stuff like a function you're only going to assign once there's no issue.
I didn't know that. Thanks for sharing!
I resolved the missing loss function and joint call targets, so the code should look like this right?
import flax.linen as nn
import lorax
import jax
import jax.numpy as jnp
from functools import partial
class MyFlaxNetwork(nn.Module):
"""Example network"""
def setup(self):
self.blocks = [{'a': nn.Dense(64), 'b': nn.Dense(64)} for _ in range(5)]
self.out_proj = nn.Dense(1)
def __call__(self, x):
for block in self.blocks:
x = jax.nn.relu(block['a'](x) + block['b'](x))
return self.out_proj(x)
class LoraMakerNetwork(nn.Module):
@nn.compact # have to use this so our parameters can depend on the input shape
def __call__(self, W, *whatever_other_inputs_you_want):
M, N = W.shape
k = 4
a = jax.random.normal(jax.random.PRNGKey(0), (k, N))
b = jax.random.normal(jax.random.PRNGKey(1), (M, k))
some_param = self.param('some_param', lambda rng_key: jnp.ones(()))
return a, b
def main():
model = MyFlaxNetwork()
params = model.init(jax.random.PRNGKey(0), jnp.ones(64))
for i in range(1, 5):
# Delete the extra params we don't care about
# This step might need to change quite a bit depending on
# what exactly you want to share or keep separate
del params['params'][f'blocks_{i}_a']['kernel']
del params['params'][f'blocks_{i}_b']['kernel']
lora_model = LoraMakerNetwork()
lora_model_params = lora_model.init(jax.random.PRNGKey(0), params['params']['blocks_0_a']['kernel'])
@jax.jit
def joint_call(params, lora_model_params, input_data):
# Copy the params tree so we can mutate it
# This doesn't actually copy the data on the GPU, it just copies the pytree at tracing time
modified_params = jax.tree_map(lambda x: x, params)
# Step 1: Overwrite the original param tree with LoraWeight instances
for i in range(5):
for k in ['a', 'b']:
shared_param_name = f'blocks_0_{k}'
write_params_name = f'blocks_{i}_{k}'
w_shared = params['params'][shared_param_name]['kernel']
# This will do the same thing for every layer, but presumably you'll be passing some other inputs
a, b = lora_model.apply(lora_model_params, w_shared)
lora_weight = lorax.LoraWeight(w=w_shared, a=a, b=b)
modified_params['params'][write_params_name]['kernel'] = lora_weight
apply_fn = partial(model.apply, training = True, rngs = {"dropout":jax.random.PRNG(0)})
# Step 2: Run model using modified params tree
wrapped_model = lorax.lora(apply_fn)
return wrapped_model(modified_params, input_data)
all_params = (params, lora_model_params)
optimiser = optax.adam(1e-3)
opt_state = optimiser.init(all_params)
def loss_fn(params,data):
output = joint_call(
params[0],
params[1],
data
)
# Calculate loss
return loss
@jax.jit
def update_fn(params, opt_state, data):
grad_fn = jax.value_and_grad(loss_fn)
loss, grad = grad_fn(params, data)
updates, new_opt_state = optimizer.update(grad, opt_state, params=params)
updated_params = optax.apply_updates(params, updates)
return loss, new_opt_state, updated_params
for i in range(epochs):
loss, opt_state, params = update_fn(params, opt_state, data)
print(f"Loss:{loss}")
if __name__ == '__main__':
main()
Looks like it should work yeah
Closing this, feel free to let me know if there are any more problems though.
I am currently playing around with training the Transformer. I am also now wondering how I could apply this to convolutions as the kernel for the convolution has 3 dimensions?
I am applying a 1D convolution over a sequence with dimensions (seq_len, dimension) that has multiple out_features so the kernel has a shape with 3 dimensions. I saw in the code for lorax that loraweight has a mechanism to handle convolutions. Could you provide some advice?
@davisyoshida The network seems to very frequently output NaN values for "deep networks" around 16 or so layers as in the output of the network is NaN, this behaviour is weird to me as I didnt really observe this behaviour when using normal dense models but seems also have this issue if I scale model width, do you have some suggestions to fix this behaviour? In general it seems extremely sensitive to the hyperparameters. I have tried different weight initialisation but it does not seem to help much.
@PuR3Luck This sounds like something more to do with the hypernetwork part than the Lora part. I don't have any advice for you since I haven't used hypernetworks.
Ok thanks for all the help!
I would like to use a separate neural network to predict LoRA weights for a main neural network, while training both neural networks at the same time. How can I manipulate the pytrees or to achieve this if it is possible at all?