davisyoshida / easy-lora-and-gptq

JAX notebook showing how to LoRA + GPTQ arbitrary models
MIT License
10 stars 0 forks source link

qLoRA notebook for LongT5 #1

Closed buttercutter closed 1 year ago

buttercutter commented 1 year ago

I am trying to get a longT5 version of qLORA to run.

TypeError: zeros_like requires ndarray or scalar arguments, got <class 'generator'> at position 0.

I have an error in the notebook, could you help to check ? The error traces back to the line of code : optimizer_state = optimizer.init(params)

davisyoshida commented 1 year ago

(QLoRA uses a data free quantization method. This is using GPT-Q)

params = model.parameters()  # Returns an iterable over the parameters

I'm guessing the issue is that params is a generator. It looks like the issue is not related to this repo.

buttercutter commented 1 year ago

Just a side comment.

https://github.com/davisyoshida/lorax#minimal-example is using params = [jax.random.normal(jax.random.PRNGKey(i), (dim, dim)) / (dim ** 0.5) for i in range(30)] which does not make sense for a pre-trained model such as longT5

davisyoshida commented 1 year ago

@buttercutter Yeah that's just an example for showing the API, I have an example of how to apply it to a pretrained HuggingFace model here. There's nothing really HuggingFace specific in there though.

buttercutter commented 1 year ago

I tried to change your example to use LongT5 instead.

However, I got into the following runtime error:

Cell In[20], line 46, in main()
     39 lora_spec = simple_spec(params, decision_fn=decision_fn, tune_vectors=True)
     41 # Cast input parameters to np.float32
     42 #params = lax.convert_element_type(params, np.float32, params.dtype)
     43 
     44 # Split the parameters up into tunable and frozen ones, and initialize a pair of LoRA matrices for each parameter
     45 # which had a spec value other than LORA_FULL or LORA_FREEZE
---> 46 freeze_params, tune_params = init_lora(params, lora_spec, jax.random.PRNGKey(0))
     48 optimizer = optax.adamw(learning_rate=1e-4, weight_decay=1e-4)
     50 # Make sure to only pass the tunable parameters to the optimizer

File ~/anaconda3/envs/arc/lib/python3.11/site-packages/lorax/helpers.py:42, in init_lora(param_tree, spec, rng, stddev, dtype, alpha, is_leaf)
     37     b = jax.random.normal(rng, (*window_shape, in_channels, spec_val), dtype=param.dtype) * stddev
     38     return LoraNode(a, b, alpha=alpha)
     40 return (
     41     jax.tree_map(freeze_getter, param_tree, spec, is_leaf=is_leaf),
---> 42     jax.tree_util.tree_map_with_path(tune_getter, param_tree, spec, is_leaf=is_leaf)
     43 )

File ~/anaconda3/envs/arc/lib/python3.11/site-packages/jax/_src/tree_util.py:788, in tree_map_with_path(f, tree, is_leaf, *rest)
    786 keypath_leaves = list(zip(*keypath_leaves))
    787 all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
--> 788 return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))

File ~/anaconda3/envs/arc/lib/python3.11/site-packages/jax/_src/tree_util.py:788, in <genexpr>(.0)
    786 keypath_leaves = list(zip(*keypath_leaves))
    787 all_keypath_leaves = keypath_leaves + [treedef.flatten_up_to(r) for r in rest]
--> 788 return treedef.unflatten(f(*xs) for xs in zip(*all_keypath_leaves))

File ~/anaconda3/envs/arc/lib/python3.11/site-packages/lorax/helpers.py:25, in init_lora.<locals>.tune_getter(path, param, spec_val)
     22 if len(param.shape) == 2:
     23     b_dim, a_dim = param.shape
---> 25     b = jnp.zeros((b_dim, spec_val), dtype=param.dtype)
     26     a = jax.random.normal(rng, (spec_val, a_dim), dtype=param.dtype) * stddev
     27     return LoraNode(a, b, alpha=alpha)

File ~/anaconda3/envs/arc/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:2151, in zeros(shape, dtype)
   2149 if isinstance(shape, types.GeneratorType):
   2150   raise TypeError("expected sequence object with len >= 0 or a single integer")
-> 2151 dtypes.check_user_dtype_supported(dtype, "zeros")
   2152 shape = canonicalize_shape(shape)
   2153 return lax.full(shape, 0, _jnp_dtype(dtype))

File ~/anaconda3/envs/arc/lib/python3.11/site-packages/jax/_src/dtypes.py:596, in check_user_dtype_supported(dtype, fun_name)
    594 if isinstance(dtype, type) and dtype in {bool, int, float, builtins.complex}:
    595   return
--> 596 np_dtype = np.dtype(dtype)
    597 if int4 is not None:
    598   is_custom_dtype = np_dtype.type in [*_custom_float_scalar_types, int4, uint4]

TypeError: Cannot interpret 'torch.float32' as a data type
davisyoshida commented 1 year ago

It sounds like you're trying to pass in torch tensors, but this is for JAX.

buttercutter commented 1 year ago

For longT5 model, I am not really sure how to make use of (frozen_params, tunable_params) tuple and input_ids without any runtime errors.

Cell In[8], line 21, in main.<locals>.lora_forward(params, input_ids)
     19 @lora
     20 def lora_forward(params, input_ids):
---> 21     return model(input_ids)

File ~/anaconda3/envs/arc/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/arc/lib/python3.11/site-packages/transformers/models/longt5/modeling_longt5.py:2030, in LongT5ForConditionalGeneration.forward(self, input_ids, attention_mask, decoder_input_ids, decoder_attention_mask, head_mask, decoder_head_mask, cross_attn_head_mask, encoder_outputs, past_key_values, inputs_embeds, decoder_inputs_embeds, labels, use_cache, output_attentions, output_hidden_states, return_dict)
   2027 # Encode if needed (training, first prediction pass)
   2028 if encoder_outputs is None:
   2029     # Convert encoder inputs in embeddings if needed
-> 2030     encoder_outputs = self.encoder(
   2031         input_ids=input_ids,
   2032         attention_mask=attention_mask,
   2033         inputs_embeds=inputs_embeds,
   2034         head_mask=head_mask,
   2035         output_attentions=output_attentions,
   2036         output_hidden_states=output_hidden_states,
   2037         return_dict=return_dict,
   2038     )
   2039 elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
   2040     encoder_outputs = BaseModelOutput(
   2041         last_hidden_state=encoder_outputs[0],
   2042         hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
   2043         attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
   2044     )

File ~/anaconda3/envs/arc/lib/python3.11/site-packages/torch/nn/modules/module.py:1501, in Module._call_impl(self, *args, **kwargs)
   1496 # If we don't have any hooks, we want to skip the rest of the logic in
   1497 # this function, and just call forward.
   1498 if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks or self._forward_pre_hooks
   1499         or _global_backward_pre_hooks or _global_backward_hooks
   1500         or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501     return forward_call(*args, **kwargs)
   1502 # Do not call functions when jit is used
   1503 full_backward_hooks, non_full_backward_hooks = [], []

File ~/anaconda3/envs/arc/lib/python3.11/site-packages/transformers/models/longt5/modeling_longt5.py:1436, in LongT5Stack.forward(self, input_ids, attention_mask, encoder_hidden_states, encoder_attention_mask, inputs_embeds, head_mask, cross_attn_head_mask, past_key_values, use_cache, output_attentions, output_hidden_states, return_dict)
   1432     raise ValueError(
   1433         f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
   1434     )
   1435 elif input_ids is not None:
-> 1436     input_shape = input_ids.size()
   1437     input_ids = input_ids.view(-1, input_shape[-1])
   1438 elif inputs_embeds is not None:

TypeError: 'int' object is not callable   
davisyoshida commented 1 year ago

I am not particularly interested in reading stack traces for you, sorry.

buttercutter commented 1 year ago

oh wait, I am not posting this for help with traceback.

For https://chat.openai.com/share/2d026c56-7b00-490e-b7fd-62a76c7e49c2 and https://github.com/davisyoshida/lorax/blob/13034a29d0354d69da87ed26b7adf0627548c930/lorax/helpers.py#L8 , is there a way to adapt init_lora() coding properly for models such as longT5 which does not require both frozen_param and tuneable_param ?

Note: https://github.com/microsoft/LoRA/blob/a0d5efec36d74b5dce257492cc6943402573c4f3/loralib/layers.py#L48-L49 still splits into two separate params

davisyoshida commented 1 year ago

The point of the lorax transform is to convert a function of the form f(params, x) into a function of the form f(arg_tuple, x) where arg_tuple = (frozen_params, tunable_params). So the point is you start with a model which doesn't expect the parameter tuple, and it gets transformed into one which does.