erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
168 stars 19 forks source link

model loading #32

Closed JinSeoungwoo closed 8 months ago

JinSeoungwoo commented 9 months ago

I received the file Mistral-1.566301941871643-69 at the end of the model's training, and I was wondering if there is a way to convert this save file to model.bin or load it to tpu to see if it works.

Thank you for the support!

erfanzar commented 9 months ago

yes you can use JaxServer For that

erfanzar commented 9 months ago

link to JaxServer Docs

JinSeoungwoo commented 9 months ago

Additionaly, can you let me know how to convet flax model into pytorch_model_bin?

JinSeoungwoo commented 9 months ago

yes you can use JaxServer For that

I got error with below code

config = MistralConfig(rotary_type="complex")
model = FlaxMistralForCausalLM(config, _do_init=False)

tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-v0.1',model_max_length=4096,padding_side="left",add_eos_token=True)
tokenizer.pad_token = tokenizer.eos_token

server = JAXServer.load(
    path='/my/ckpt-path/Mistral-Test',
    model=model,
    tokenizer=tokenizer,
    config_model=config,
    add_params_field=True,
    config=None,
    init_shape=(1, 1)
)

Error:

Traceback (most recent call last):
  File "test.py", line 12, in <module>
    server = JAXServer.load(
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/serve/serve_utils.py", line 414, in load
    server.compile(verbose=verbose)
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/serve/serve_utils.py", line 475, in compile
    for r, a in self.process(
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/serve/serve_utils.py", line 615, in process
    predicted_token = self.greedy_generate(
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/serve/serve_utils.py", line 506, in greedy_generate
    return self._greedy_generate(
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/serve/serve_utils.py", line 286, in greedy_generate
    predict = model.generate(
  File "/usr/local/lib/python3.8/dist-packages/transformers/generation/flax_utils.py", line 417, in generate
    return self._greedy_search(
  File "/usr/local/lib/python3.8/dist-packages/transformers/generation/flax_utils.py", line 636, in _greedy_search
    state = greedy_search_body_fn(state)
  File "/usr/local/lib/python3.8/dist-packages/transformers/generation/flax_utils.py", line 612, in greedy_search_body_fn
    model_outputs = model(state.running_token, params=params, **state.model_kwargs)
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 561, in __call__
    outputs = self.module.apply(
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 784, in __call__
    outputs = self.model(
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 708, in __call__
    outputs = self.layers(
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 638, in __call__
    output = layer(
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/partitioning.py", line 553, in inner
    return rematted(variable_groups, rng_groups, *dyn_args)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/partitioning.py", line 550, in rematted
    y = fn(scope, *args)
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 432, in __call__
    attention_output = self.self_attn(
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 347, in __call__
    q, k, v, attention_mask = self.concatenate_to_cache_(q, k, v, attention_mask)
  File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 316, in concatenate_to_cache_
    attention_mask = nn.combine_masks(pad_mask, attention_mask)
  File "/usr/local/lib/python3.8/dist-packages/flax/linen/attention.py", line 506, in combine_masks
    assert all(
AssertionError: masks must have same rank: (4, 2)
erfanzar commented 9 months ago

Additionally, can you let me know how to convert the flax model into pytorch_model_bin?

use mistral_flax_to_pt in transform functions (I don't know if I said the right name for that func ;\ ) right now mistral models have a computing problem that I'm trying to fix them as soon as I can

JinSeoungwoo commented 9 months ago

Additionally, can you let me know how to convert the flax model into pytorch_model_bin?

use mistral_flax_to_pt in transform functions (I don't know if I said the right name for that func ;\ ) right now mistral models have a computing problem that I'm trying to fix them as soon as I can

By any chance, can you show an example of loading and applying a flax model to use the mistral_flax_to_pt function? I'm having trouble as I'm not familiar with flax... Sorry for the many requests.

erfanzar commented 9 months ago

Additionally, can you let me know how to convert the flax model into pytorch_model_bin?

use mistral_flax_to_pt in transform functions (I don't know if I said the right name for that func ;\ ) right now mistral models have a computing problem that I'm trying to fix them as soon as I can

By any chance, can you show an example of loading and applying a flax model to use the mistral_flax_to_pt function? I'm having trouble as I'm not familiar with flax... Sorry for the many requests.

use mistral_convert_flax_to_pt and that's fine you can ask any question that you want I'm here to help <3

JinSeoungwoo commented 9 months ago

Is this code right to convert ckpt to pytorch_model.bin?

 _, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path)
flax_params = flatten_dict(flax_params['params'], sep='.')

pytorch_state_dict = mistral_convert_flax_to_pt(flax_params, MistralConfig())

torch.save(pytorch_state_dict, 'pytorch_model.bin')

I'm wondering if it's ok to use MistralConfig as is as a config for mistral_convert_flax_to_pt

erfanzar commented 8 months ago

yes the code is correct but use the config of the model you want to convert from EasyDel to torch or hf to cause the number of elements like num_hidden_layers or ... to be taken from the given config

JinSeoungwoo commented 8 months ago

yes the code is correct but use the config of the model you want to convert from EasyDel to torch or hf to cause the number of elements like num_hidden_layers or ... to be taken from the given config

Hmm... so I can't just use the mistralconfig class? Also, if I have to use a custom config, I wonder if I can use the one I used for train.

erfanzar commented 8 months ago

yes you should use the config that used to train mode (saved in W&B project if you don't remember that)

JinSeoungwoo commented 8 months ago

Is the only finetune method currently supported is full-fintune?

JinSeoungwoo commented 8 months ago

Also, are there any plans to create a linear scheduler with a warm up step?

JinSeoungwoo commented 8 months ago

Also, are there any plans to create a linear scheduler with a warm up step?

I just made a warmup_linear scheduler function

def get_adamw_with_warmup_linear_scheduler(
        steps: int,
        learning_rate_start: float = 5e-5,
        learning_rate_end: float = 1e-5,
        b1: float = 0.9,
        b2: float = 0.999,
        eps: float = 1e-8,
        eps_root: float = 0.0,
        weight_decay: float = 1e-1,
        gradient_accumulation_steps: int = 1,
        mu_dtype: Optional[chex.ArrayDType] = None,

        warmup_steps: int = 500

):
    """

    :param gradient_accumulation_steps:
    :param steps:
    :param learning_rate_start:
    :param learning_rate_end:
    :param b1:
    :param b2:
    :param eps:
    :param eps_root:
    :param weight_decay:
    :param mu_dtype:

     # New parameter for warmup
     @warmup_steps (int): Number of steps for the warmup phase

     # return Optimizer and Scheduler with WarmUp feature
   """

    scheduler_warmup= optax.linear_schedule(init_value=5e-8, end_value=learning_rate_start, transition_steps=warmup_steps)
    scheduler_decay= optax.linear_schedule(init_value=learning_rate_start, end_value=learning_rate_end, transition_steps=steps-warmup_steps)

    scheduler_combined= optax.join_schedules(schedules=[scheduler_warmup, scheduler_decay], boundaries=[warmup_steps])

    tx = optax.chain(
        optax.scale_by_adam(
            b1=b1,
            b2=b2,
            eps=eps,
            eps_root=eps_root,
            mu_dtype=mu_dtype
        ),
        optax.add_decayed_weights(
            weight_decay=weight_decay
        ),
        optax.scale_by_schedule(scheduler_combined),
        optax.scale(-1)
    )
    if gradient_accumulation_steps > 1:
        tx = optax.MultiSteps(
            tx, gradient_accumulation_steps
        )
    return tx, scheduler_combined
erfanzar commented 8 months ago

Is the only finetune method currently supported is full-fintune?

RLHF is supported too for finetuning models but only for llama1 and falcon and mpt models right now

erfanzar commented 8 months ago

Also, are there any plans to create a linear scheduler with a warm up step?

ill create that for you in next update on main branch

erfanzar commented 8 months ago

Also, are there any plans to create a linear scheduler with a warm up step?

I just made a warmup_linear scheduler function

def get_adamw_with_warmup_linear_scheduler(
        steps: int,
        learning_rate_start: float = 5e-5,
        learning_rate_end: float = 1e-5,
        b1: float = 0.9,
        b2: float = 0.999,
        eps: float = 1e-8,
        eps_root: float = 0.0,
        weight_decay: float = 1e-1,
        gradient_accumulation_steps: int = 1,
        mu_dtype: Optional[chex.ArrayDType] = None,

        warmup_steps: int = 500

):
    """

    :param gradient_accumulation_steps:
    :param steps:
    :param learning_rate_start:
    :param learning_rate_end:
    :param b1:
    :param b2:
    :param eps:
    :param eps_root:
    :param weight_decay:
    :param mu_dtype:

     # New parameter for warmup
     @warmup_steps (int): Number of steps for the warmup phase

     # return Optimizer and Scheduler with WarmUp feature
   """

    scheduler_warmup= optax.linear_schedule(init_value=5e-8, end_value=learning_rate_start, transition_steps=warmup_steps)
    scheduler_decay= optax.linear_schedule(init_value=learning_rate_start, end_value=learning_rate_end, transition_steps=steps-warmup_steps)

    scheduler_combined= optax.join_schedules(schedules=[scheduler_warmup, scheduler_decay], boundaries=[warmup_steps])

    tx = optax.chain(
        optax.scale_by_adam(
            b1=b1,
            b2=b2,
            eps=eps,
            eps_root=eps_root,
            mu_dtype=mu_dtype
        ),
        optax.add_decayed_weights(
            weight_decay=weight_decay
        ),
        optax.scale_by_schedule(scheduler_combined),
        optax.scale(-1)
    )
    if gradient_accumulation_steps > 1:
        tx = optax.MultiSteps(
            tx, gradient_accumulation_steps
        )
    return tx, scheduler_combined

thank! update Fjtuils to 0.0.20 and reinstall EasyDel and you then you can use warm_up_linear