erfanzar / EasyDeL

Accelerate, Optimize performance with streamlined training and serving options with JAX.
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
191 stars 23 forks source link

AssertionError: Precision DEFAULT requested together with quantization. #147

Closed peterniu19 closed 4 months ago

peterniu19 commented 5 months ago

Hello, amazing work! I tried lora finetune using tpu on kaggle. However, when i set the bits to 4. It below error happened. Further, when I add the precision to jax.lax.Precision("default"). The error still exists. Could tell me how to use 4bit lora training with easydel?

model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path,
                                                            config_kwargs={"attn_mechanism":"sharded_vanilla",'max_position_embeddings': max_length,'bits':4},
                                                            sharding_axis_dims=(1,1,1,-1),
                                                            input_shape=(1,max_length),
                                                            precision=jax.lax.Precision("default"))
/usr/local/lib/python3.10/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html
  from .autonotebook import tqdm as notebook_tqdm
E0426 10:24:50.323884904   83398 oauth2_credentials.cc:238]            oauth_fetch: UNKNOWN:C-ares status is not ARES_SUCCESS qtype=A name=metadata.google.internal. is_balancer=0: Domain name not found {grpc_status:2, created_time:"2024-04-26T10:24:50.323861278+00:00"}
/usr/local/lib/python3.10/multiprocessing/popen_fork.py:66: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.
  self.pid = os.fork()
/usr/local/lib/python3.10/site-packages/pydantic/_internal/_fields.py:149: UserWarning: Field "model_name" has conflict with protected namespace "model_".

You may be able to resolve this warning by setting `model_config['protected_namespaces'] = ()`.
  warnings.warn(
Loading checkpoint shards: 100%|██████████| 15/15 [00:18<00:00,  1.21s/it]
---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
Cell In[1], line 28
     26 huggingface_repo_id_or_path = "/kaggle/input/yi-34b-chat"
     27 max_length = 8192
---> 28 model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path,
     29                                                             config_kwargs={"attn_mechanism":"sharded_vanilla",'max_position_embeddings': max_length,'bits':4},
     30                                                             sharding_axis_dims=(1,1,1,-1),
     31                                                             input_shape=(1,max_length),
     32                                                             precision='DEFAULT')
     35 tokenizer = AutoTokenizer.from_pretrained(
     36     huggingface_repo_id_or_path,
     37     trust_remote_code=True
     38 )
     39 tokenizer.pad_token = tokenizer.eos_token

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/auto_easydel_model.py:445, in AutoEasyDelModelForCausalLM.from_pretrained(cls, pretrained_model_name_or_path, device, dtype, param_dtype, precision, sharding_axis_dims, sharding_axis_names, query_partition_spec, generation_query_partition_spec, key_partition_spec, value_partition_spec, bias_partition_spec, generation_bias_partition_spec, attention_partition_spec, shard_attention_computation, input_shape, shard_fns, backend, config_kwargs, auto_shard_params, partition_rules, load_in_8bit, bit_targeted_params, **kwargs)
    442         setattr(cfg, k, v)
    444 logger.debug("creating easydel model")
--> 445 ed_model = module(
    446     config=cfg,
    447     _do_init=False,
    448     dtype=dtype,
    449     param_dtype=param_dtype,
    450     precision=precision,
    451     input_shape=input_shape
    452 )
    454 needs = [
    455     s.replace(".kernel", ".weight").replace(".scale", ".weight").replace(".embedding", ".weight") for s in
    456     list(flax.traverse_util.flatten_dict(ed_model.params_shape_tree, sep=".").keys())
    457 ]
    458 for k in list(state_dict.keys()):

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:597, in FlaxLlamaPreTrainedModel.__init__(self, config, input_shape, seed, dtype, _do_init, **kwargs)
    579 """
    580 The __init__ function is called when the class is instantiated.
    581 It sets up the instance of the class, and defines what happens when it's created.
   (...)
    594 
    595 """
    596 module = self.module_class(config=config, dtype=dtype, **kwargs)
--> 597 super().__init__(config, module, input_shape=input_shape,
    598                  seed=seed, dtype=dtype, _do_init=_do_init)

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/easydel_modelling_utils.py:447, in EasyDelFlaxPretrainedModel.__init__(self, config, module, input_shape, seed, dtype, param_dtype, precision, _do_init)
    436 def __init__(
    437         self,
    438         config: PretrainedConfig,
   (...)
    445         _do_init: bool = True,
    446 ):
--> 447     super().__init__(
    448         config=config,
    449         module=module,
    450         input_shape=input_shape,
    451         seed=seed,
    452         dtype=dtype,
    453         _do_init=_do_init
    454     )

File /usr/local/lib/python3.10/site-packages/transformers/modeling_flax_utils.py:224, in FlaxPreTrainedModel.__init__(self, config, module, input_shape, seed, dtype, _do_init)
    222 else:
    223     init_fn = partial(self.init_weights, input_shape=input_shape)
--> 224     params_shape_tree = jax.eval_shape(init_fn, self.key)
    226     logger.info(
    227         "Model weights are not initialized as `_do_init` is set to `False`. "
    228         f"Make sure to call `{self.__class__.__name__}.init_weights` manually to initialize the weights."
    229     )
    231 # get the shape of the parameters

    [... skipping hidden 13 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:632, in FlaxLlamaPreTrainedModel.init_weights(self, rng, input_shape, params)
    622     module_init_outputs = self.module.init(
    623         rngs,
    624         input_ids,
   (...)
    629         return_dict=False,
    630     )
    631 else:
--> 632     module_init_outputs = self.module.init(
    633         rngs, input_ids, attention_mask, position_ids, return_dict=False)
    635 random_params = module_init_outputs["params"]
    637 if params is not None:

    [... skipping hidden 9 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:1067, in FlaxLlamaForCausalLMModule.__call__(self, input_ids, attention_mask, position_ids, deterministic, init_cache, output_attentions, output_hidden_states, return_dict, extra_embedding)
   1062 if position_ids is None:
   1063     position_ids = jnp.broadcast_to(
   1064         jnp.clip(jnp.cumsum(attention_mask, axis=-1) - 1, a_min=0),
   1065         (batch_size, seq_length)
   1066     )
-> 1067 outputs = self.model(
   1068     input_ids,
   1069     attention_mask,
   1070     position_ids,
   1071     deterministic=deterministic,
   1072     init_cache=init_cache,
   1073     output_attentions=output_attentions,
   1074     output_hidden_states=output_hidden_states,
   1075     return_dict=return_dict,
   1076     extra_embedding=extra_embedding
   1077 )
   1079 hidden_states = outputs[0]
   1081 if self.config.tie_word_embeddings:

    [... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:964, in FlaxLlamaModule.__call__(self, input_ids, attention_mask, position_ids, deterministic, inputs_embeds, init_cache, output_attentions, output_hidden_states, return_dict, extra_embedding)
    959 inputs_embeds = inputs_embeds + \
    960                 extra_embedding if extra_embedding is not None else inputs_embeds
    961 hidden_states = self.dropout(
    962     inputs_embeds, deterministic=deterministic)
--> 964 outputs = self.layers(
    965     hidden_states=hidden_states,
    966     freq_cis=self.freq_cis,
    967     attention_mask=attention_mask,
    968     position_ids=position_ids,
    969     causal_mask=self.causal_mask,
    970     deterministic=deterministic,
    971     init_cache=init_cache,
    972     output_attentions=output_attentions,
    973     output_hidden_states=output_hidden_states,
    974     return_dict=return_dict,
    975 )
    977 hidden_states = outputs[0]
    978 hidden_states = self.norm(hidden_states)

    [... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:850, in FlaxLlamaBlockCollection.__call__(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, deterministic, init_cache, output_attentions, output_hidden_states, return_dict)
    847 if output_hidden_states:
    848     all_hidden_states += (hidden_states,)
--> 850 layer_outputs = block(
    851     hidden_states=hidden_states,
    852     freq_cis=freq_cis,
    853     attention_mask=attention_mask,
    854     position_ids=position_ids,
    855     causal_mask=causal_mask,
    856     deterministic=deterministic,
    857     init_cache=init_cache,
    858     output_attentions=output_attentions,
    859     fcm_mask=fcm_mask,
    860 )
    861 hidden_states = layer_outputs[0]
    863 if output_attentions:

    [... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:530, in FlaxLlamaBlock.__call__(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, segment_ids, deterministic, init_cache, output_attentions, fcm_mask)
    497 def __call__(
    498         self,
    499         hidden_states: chex.Array,
   (...)
    508         fcm_mask: Optional[jnp.ndarray] = None,
    509 ):
    510     """
    511     The __call__ function is the main function of a TransformerEncoderLayer.
    512     It takes in hidden states, frequency-domain inputs, and masks as input. It then
   (...)
    528 
    529     """
--> 530     attn_outputs = self.self_attn(
    531         self.input_layernorm(hidden_states),
    532         freq_cis,
    533         attention_mask,
    534         position_ids,
    535         causal_mask,
    536         segment_ids,
    537         deterministic,
    538         init_cache,
    539         output_attentions,
    540         fcm_mask,
    541     )
    542     attn_output = attn_outputs[0]
    543     hidden_states = hidden_states + attn_output

    [... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py:567, in core_remat_static.<locals>.inner(scope_fn, repack_fn, variable_groups, rng_groups, *args)
    564   y = fn(scope, *args)
    565   return y, repack_fn(scope)
--> 567 return rematted(variable_groups, rng_groups, *dyn_args)

    [... skipping hidden 7 frame]

File /usr/local/lib/python3.10/site-packages/flax/linen/partitioning.py:564, in core_remat_static.<locals>.inner.<locals>.rematted(variable_groups, rng_groups, *dyn_args)
    562 args = _repack_remat_args(dyn_args, static_args)
    563 scope = scope_fn(variable_groups, rng_groups)
--> 564 y = fn(scope, *args)
    565 return y, repack_fn(scope)

    [... skipping hidden 3 frame]

File /usr/local/lib/python3.10/site-packages/EasyDel/modules/llama/modelling_llama_flax.py:273, in FlaxLlamaAttention.__call__(self, hidden_states, freq_cis, attention_mask, position_ids, causal_mask, segment_ids, deterministic, init_cache, output_attentions, fcm_mask)
    248 """
    249 
    250 The __call__ function is the main function of a JAX module. It defines how the module behaves when called
   (...)
    266 
    267 """
    268 batch_size, sequence_length = hidden_states.shape[:2]
    269 (
    270     query_states,
    271     key_states,
    272     value_states
--> 273 ) = self.q_proj(hidden_states), self.k_proj(hidden_states), self.v_proj(hidden_states)
    275 query_states = query_states.reshape(
    276     batch_size, sequence_length, self.config.num_attention_heads, self.head_dim)
    277 key_states = key_states.reshape(
    278     batch_size, sequence_length, self.config.num_key_value_heads, self.head_dim)

    [... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/fjformer/linen/linen.py:235, in Linear.__call__(self, inputs)
    233 else:
    234     dot_general = lax.dot_general
--> 235 y = dot_general(
    236     inputs,
    237     kernel,
    238     (((inputs.ndim - 1,), (0,)), ((), ())),
    239     precision=self.precision,
    240 )
    241 if bias is not None:
    242     y += jnp.reshape(bias, (1,) * (y.ndim - 1) + (-1,))

    [... skipping hidden 2 frame]

File /usr/local/lib/python3.10/site-packages/fjformer/bits/q_flax.py:176, in QDotGeneral.__call__(self, lhs, rhs, dimension_numbers, precision, preferred_element_type)
    166 @nn.compact
    167 def __call__(
    168         self,
   (...)
    173         preferred_element_type=None,
    174 ):
    175     aqt_dg = self.make_aqt_dg(lhs.shape, rhs.shape, dimension_numbers)
--> 176     return aqt_dg(
    177         lhs,
    178         rhs,
    179         dimension_numbers,
    180         precision,
    181         preferred_element_type=preferred_element_type,
    182     )

File /usr/local/lib/python3.10/site-packages/fjformer/bits/q_dot_general.py:567, in make_dot_general.<locals>.ret_dg(***failed resolving arguments***)
    550 """
    551 The ret_dg function is a wrapper around the dg function.
    552 It takes in two QTensors, lhs and rhs, and returns a QTensor out.
   (...)
    563 :return: A function that returns a function
    564 """
    565 del preferred_element_type
    566 assert (
--> 567         precision is None
    568 ), f'Precision {precision} requested together with quantization.'
    570 msg = 'AQT is not yet optimized to accept quantized types directly. '
    571 msg += f'lhs.dtype: {lhs.dtype}, rhs.dtype: {rhs.dtype}'

AssertionError: Precision DEFAULT requested together with quantization.
erfanzar commented 5 months ago

hi and thanks for using EasyDeL

model, params = AutoEasyDelModelForCausalLM.from_pretrained(huggingface_repo_id_or_path,
                                                            config_kwargs={"attn_mechanism":"sharded_vanilla",
'max_position_embeddings':max_length,'bits':4},
                                                            sharding_axis_dims=(1,1,1,-1),
                                                            input_shape=(1,max_length),
                                                            precision=None)

use this one

peterniu19 commented 5 months ago

Thank you. It works. I'm confused about how changing bits will help in reducing TPU memory during training. I can train a model successfuly without adding the "bits" parameter. But if I added the bits parameter to 4 or 8, it says tpu memory exhausted. Could you please provide some insights about it?

erfanzar commented 5 months ago

yes, actually it works better in bigger scale, right now your getting out of memory error because you are trying to use gradient checkpointing, but imagine you are using much more TPUs and you have enough memory to chose to have benefit of lower precision operation instead of trying checkpointing gradients of operations, or modules.

peterniu19 commented 5 months ago

Thank you for the explanation. I reduced the batch size and did an experiment by comparing scenarios with and without using bits=8, while keeping all other settings constant. I observed that both the speed and TPU utilization (monitored via wandb with use_wandb=True) remained same. I'm wondering why these parameters did not vary between the two experiments.

erfanzar commented 5 months ago

Actually Tpu monitoring is just a tool to findout how much model and buffer taking up memory, so it doesn't record forward and backward memory monitoring, you can do that too and that's not difficult but there high chance that your training loop will crash.