Closed buttercutter closed 1 year ago
(QLoRA uses a data free quantization method. This is using GPT-Q)
params = model.parameters() # Returns an iterable over the parameters
I'm guessing the issue is that params is a generator. It looks like the issue is not related to this repo.
Just a side comment.
https://github.com/davisyoshida/lorax#minimal-example is using params = [jax.random.normal(jax.random.PRNGKey(i), (dim, dim)) / (dim ** 0.5) for i in range(30)]
which does not make sense for a pre-trained model such as longT5
@buttercutter Yeah that's just an example for showing the API, I have an example of how to apply it to a pretrained HuggingFace model here. There's nothing really HuggingFace specific in there though.
I tried to change your example to use LongT5 instead.
However, I got into the following runtime error:
Cell In[20], line 46, in main()
39 lora_spec = simple_spec(params, decision_fn=decision_fn, tune_vectors=True)
41 # Cast input parameters to np.float32
42 #params = lax.convert_element_type(params, np.float32, params.dtype)
43
44 # Split the parameters up into tunable and frozen ones, and initialize a pair of LoRA matrices for each parameter
45 # which had a spec value other than LORA_FULL or LORA_FREEZE
---> 46 freeze_params, tune_params = init_lora(params, lora_spec, jax.random.PRNGKey(0))
48 optimizer = optax.adamw(learning_rate=1e-4, weight_decay=1e-4)
50 # Make sure to only pass the tunable parameters to the optimizer
File ~/anaconda3/envs/arc/lib/python3.11/site-packages/lorax/helpers.py:42, in init_lora(param_tree, spec, rng, stddev, dtype, alpha, is_leaf)
37 b = jax.random.normal(rng, (*window_shape, in_channels, spec_val), dtype=param.dtype) * stddev
38 return LoraNode(a, b, alpha=alpha)
40 return (
41 jax.tree_map(freeze_getter, param_tree, spec, is_leaf=is_leaf),
---> 42 jax.tree_util.tree_map_with_path(tune_getter, param_tree, spec, is_leaf=is_leaf)
43 )
File ~/anaconda3/envs/arc/lib/python3.11/site-packages/jax/_src/tree_util.py:788, in tree_map_with_path(f, tree, is_leaf, *rest)
786 keypath_leaves = list(zip(*keypath_leaves))
787 all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
--> 788 return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))
File ~/anaconda3/envs/arc/lib/python3.11/site-packages/jax/_src/tree_util.py:788, in <genexpr>(.0)
786 keypath_leaves = list(zip(*keypath_leaves))
787 all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
--> 788 return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))
File ~/anaconda3/envs/arc/lib/python3.11/site-packages/lorax/helpers.py:25, in init_lora.<locals>.tune_getter(path, param, spec_val)
22 if len(param.shape) == 2:
23 b_dim, a_dim = param.shape
---> 25 b = jnp.zeros((b_dim, spec_val), dtype=param.dtype)
26 a = jax.random.normal(rng, (spec_val, a_dim), dtype=param.dtype) * stddev
27 return LoraNode(a, b, alpha=alpha)
File ~/anaconda3/envs/arc/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2151, in zeros(shape, dtype)
2149 if isinstance(shape, types.GeneratorType):
2150 raise TypeError("expected sequence object with len >= 0 or a single integer")
-> 2151 dtypes.check_user_dtype_supported(dtype, "zeros")
2152 shape = canonicalize_shape(shape)
2153 return lax.full(shape, 0, _jnp_dtype(dtype))
File ~/anaconda3/envs/arc/lib/python3.11/site-packages/jax/_src/dtypes.py:596, in check_user_dtype_supported(dtype, fun_name)
594 if isinstance(dtype, type) and dtype in {bool, int, float, builtins.complex}:
595 return
--> 596 np_dtype = np.dtype(dtype)
597 if int4 is not None:
598 is_custom_dtype = np_dtype.type in [*_custom_float_scalar_types, int4, uint4]
TypeError: Cannot interpret 'torch.float32' as a data type
It sounds like you're trying to pass in torch tensors, but this is for JAX.
For longT5 model, I am not really sure how to make use of (frozen_params, tunable_params)
tuple and input_ids
without any runtime errors.
Cell In[8], line 21, in main.<locals>.lora_forward(params, input_ids)
19 @lora
20 def lora_forward(params, input_ids):
---> 21 return model(input_ids)
File ~/anaconda3/envs/arc/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/anaconda3/envs/arc/lib/python3.11/site-packages/transformers/models/longt5/modeling_longt5.py:2030, in LongT5ForConditionalGeneration.forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
2027 # Encode if needed (training, first prediction pass)
2028 if encoder_outputs is None:
2029 # Convert encoder inputs in embeddings if needed
-> 2030 encoder_outputs = self.encoder(
2031 input_ids=input_ids,
2032 attention_mask=attention_mask,
2033 inputs_embeds=inputs_embeds,
2034 head_mask=head_mask,
2035 output_attentions=output_attentions,
2036 output_hidden_states=output_hidden_states,
2037 return_dict=return_dict,
2038 )
2039 elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
2040 encoder_outputs = BaseModelOutput(
2041 last_hidden_state=encoder_outputs[0],
2042 hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
2043 attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
2044 )
File ~/anaconda3/envs/arc/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
1496 # If we don't have any hooks, we want to skip the rest of the logic in
1497 # this function, and just call forward.
1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
1499 or _global_backward_pre_hooks or _global_backward_hooks
1500 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501 return forward_call(*args, **kwargs)
1502 # Do not call functions when jit is used
1503 full_backward_hooks, non_full_backward_hooks = [], []
File ~/anaconda3/envs/arc/lib/python3.11/site-packages/transformers/models/longt5/modeling_longt5.py:1436, in LongT5Stack.forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
1432 raise ValueError(
1433 f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
1434 )
1435 elif input_ids is not None:
-> 1436 input_shape = input_ids.size()
1437 input_ids = input_ids.view(-1, input_shape[-1])
1438 elif inputs_embeds is not None:
TypeError: 'int' object is not callable
I am not particularly interested in reading stack traces for you, sorry.
oh wait, I am not posting this for help with traceback.
For https://chat.openai.com/share/2d026c56-7b00-490e-b7fd-62a76c7e49c2 and https://github.com/davisyoshida/lorax/blob/13034a29d0354d69da87ed26b7adf0627548c930/lorax/helpers.py#L8 , is there a way to adapt init_lora()
coding properly for models such as longT5 which does not require both frozen_param
and tuneable_param
?
Note: https://github.com/microsoft/LoRA/blob/a0d5efec36d74b5dce257492cc6943402573c4f3/loralib/layers.py#L48-L49 still splits into two separate params
The point of the lorax transform is to convert a function of the form f(params, x)
into a function of the form f(arg_tuple, x)
where arg_tuple = (frozen_params, tunable_params)
. So the point is you start with a model which doesn't expect the parameter tuple, and it gets transformed into one which does.
I am trying to get a longT5 version of qLORA to run.
I have an error in the notebook, could you help to check ? The error traces back to the line of code :
optimizer_state = optimizer.init(params)