Closed JinSeoungwoo closed 8 months ago
yes you can use JaxServer For that
Additionaly, can you let me know how to convet flax model into pytorch_model_bin?
yes you can use JaxServer For that
I got error with below code
config = MistralConfig(rotary_type="complex")
model = FlaxMistralForCausalLM(config, _do_init=False)
tokenizer = AutoTokenizer.from_pretrained('mistralai/Mistral-7B-v0.1',model_max_length=4096,padding_side="left",add_eos_token=True)
tokenizer.pad_token = tokenizer.eos_token
server = JAXServer.load(
path='/my/ckpt-path/Mistral-Test',
model=model,
tokenizer=tokenizer,
config_model=config,
add_params_field=True,
config=None,
init_shape=(1, 1)
)
Error:
Traceback (most recent call last):
File "test.py", line 12, in <module>
server = JAXServer.load(
File "/usr/local/lib/python3.8/dist-packages/EasyDel/serve/serve_utils.py", line 414, in load
server.compile(verbose=verbose)
File "/usr/local/lib/python3.8/dist-packages/EasyDel/serve/serve_utils.py", line 475, in compile
for r, a in self.process(
File "/usr/local/lib/python3.8/dist-packages/EasyDel/serve/serve_utils.py", line 615, in process
predicted_token = self.greedy_generate(
File "/usr/local/lib/python3.8/dist-packages/EasyDel/serve/serve_utils.py", line 506, in greedy_generate
return self._greedy_generate(
File "/usr/local/lib/python3.8/dist-packages/EasyDel/serve/serve_utils.py", line 286, in greedy_generate
predict = model.generate(
File "/usr/local/lib/python3.8/dist-packages/transformers/generation/flax_utils.py", line 417, in generate
return self._greedy_search(
File "/usr/local/lib/python3.8/dist-packages/transformers/generation/flax_utils.py", line 636, in _greedy_search
state = greedy_search_body_fn(state)
File "/usr/local/lib/python3.8/dist-packages/transformers/generation/flax_utils.py", line 612, in greedy_search_body_fn
model_outputs = model(state.running_token, params=params, **state.model_kwargs)
File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 561, in __call__
outputs = self.module.apply(
File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 784, in __call__
outputs = self.model(
File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 708, in __call__
outputs = self.layers(
File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 638, in __call__
output = layer(
File "/usr/local/lib/python3.8/dist-packages/flax/linen/partitioning.py", line 553, in inner
return rematted(variable_groups, rng_groups, *dyn_args)
File "/usr/local/lib/python3.8/dist-packages/flax/linen/partitioning.py", line 550, in rematted
y = fn(scope, *args)
File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 432, in __call__
attention_output = self.self_attn(
File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 347, in __call__
q, k, v, attention_mask = self.concatenate_to_cache_(q, k, v, attention_mask)
File "/usr/local/lib/python3.8/dist-packages/EasyDel/modules/mistral/modelling_mistral_flax.py", line 316, in concatenate_to_cache_
attention_mask = nn.combine_masks(pad_mask, attention_mask)
File "/usr/local/lib/python3.8/dist-packages/flax/linen/attention.py", line 506, in combine_masks
assert all(
AssertionError: masks must have same rank: (4, 2)
Additionally, can you let me know how to convert the flax model into pytorch_model_bin?
use mistral_flax_to_pt in transform functions (I don't know if I said the right name for that func ;\ ) right now mistral models have a computing problem that I'm trying to fix them as soon as I can
Additionally, can you let me know how to convert the flax model into pytorch_model_bin?
use mistral_flax_to_pt in transform functions (I don't know if I said the right name for that func ;\ ) right now mistral models have a computing problem that I'm trying to fix them as soon as I can
By any chance, can you show an example of loading and applying a flax model to use the mistral_flax_to_pt function? I'm having trouble as I'm not familiar with flax... Sorry for the many requests.
Additionally, can you let me know how to convert the flax model into pytorch_model_bin?
use mistral_flax_to_pt in transform functions (I don't know if I said the right name for that func ;\ ) right now mistral models have a computing problem that I'm trying to fix them as soon as I can
By any chance, can you show an example of loading and applying a flax model to use the mistral_flax_to_pt function? I'm having trouble as I'm not familiar with flax... Sorry for the many requests.
use mistral_convert_flax_to_pt
and that's fine you can ask any question that you want I'm here to help <3
Is this code right to convert ckpt to pytorch_model.bin?
_, flax_params = StreamingCheckpointer.load_trainstate_checkpoint(path)
flax_params = flatten_dict(flax_params['params'], sep='.')
pytorch_state_dict = mistral_convert_flax_to_pt(flax_params, MistralConfig())
torch.save(pytorch_state_dict, 'pytorch_model.bin')
I'm wondering if it's ok to use MistralConfig as is as a config for mistral_convert_flax_to_pt
yes the code is correct but use the config of the model you want to convert from EasyDel to torch or hf to cause the number of elements like num_hidden_layers or ... to be taken from the given config
yes the code is correct but use the config of the model you want to convert from EasyDel to torch or hf to cause the number of elements like num_hidden_layers or ... to be taken from the given config
Hmm... so I can't just use the mistralconfig class? Also, if I have to use a custom config, I wonder if I can use the one I used for train.
yes you should use the config that used to train mode (saved in W&B project if you don't remember that)
Is the only finetune method currently supported is full-fintune?
Also, are there any plans to create a linear scheduler with a warm up step?
Also, are there any plans to create a linear scheduler with a warm up step?
I just made a warmup_linear scheduler function
def get_adamw_with_warmup_linear_scheduler(
steps: int,
learning_rate_start: float = 5e-5,
learning_rate_end: float = 1e-5,
b1: float = 0.9,
b2: float = 0.999,
eps: float = 1e-8,
eps_root: float = 0.0,
weight_decay: float = 1e-1,
gradient_accumulation_steps: int = 1,
mu_dtype: Optional[chex.ArrayDType] = None,
warmup_steps: int = 500
):
"""
:param gradient_accumulation_steps:
:param steps:
:param learning_rate_start:
:param learning_rate_end:
:param b1:
:param b2:
:param eps:
:param eps_root:
:param weight_decay:
:param mu_dtype:
# New parameter for warmup
@warmup_steps (int): Number of steps for the warmup phase
# return Optimizer and Scheduler with WarmUp feature
"""
scheduler_warmup= optax.linear_schedule(init_value=5e-8, end_value=learning_rate_start, transition_steps=warmup_steps)
scheduler_decay= optax.linear_schedule(init_value=learning_rate_start, end_value=learning_rate_end, transition_steps=steps-warmup_steps)
scheduler_combined= optax.join_schedules(schedules=[scheduler_warmup, scheduler_decay], boundaries=[warmup_steps])
tx = optax.chain(
optax.scale_by_adam(
b1=b1,
b2=b2,
eps=eps,
eps_root=eps_root,
mu_dtype=mu_dtype
),
optax.add_decayed_weights(
weight_decay=weight_decay
),
optax.scale_by_schedule(scheduler_combined),
optax.scale(-1)
)
if gradient_accumulation_steps > 1:
tx = optax.MultiSteps(
tx, gradient_accumulation_steps
)
return tx, scheduler_combined
Is the only finetune method currently supported is full-fintune?
RLHF is supported too for finetuning models but only for llama1 and falcon and mpt models right now
Also, are there any plans to create a linear scheduler with a warm up step?
ill create that for you in next update on main branch
Also, are there any plans to create a linear scheduler with a warm up step?
I just made a warmup_linear scheduler function
def get_adamw_with_warmup_linear_scheduler( steps: int, learning_rate_start: float = 5e-5, learning_rate_end: float = 1e-5, b1: float = 0.9, b2: float = 0.999, eps: float = 1e-8, eps_root: float = 0.0, weight_decay: float = 1e-1, gradient_accumulation_steps: int = 1, mu_dtype: Optional[chex.ArrayDType] = None, warmup_steps: int = 500 ): """ :param gradient_accumulation_steps: :param steps: :param learning_rate_start: :param learning_rate_end: :param b1: :param b2: :param eps: :param eps_root: :param weight_decay: :param mu_dtype: # New parameter for warmup @warmup_steps (int): Number of steps for the warmup phase # return Optimizer and Scheduler with WarmUp feature """ scheduler_warmup= optax.linear_schedule(init_value=5e-8, end_value=learning_rate_start, transition_steps=warmup_steps) scheduler_decay= optax.linear_schedule(init_value=learning_rate_start, end_value=learning_rate_end, transition_steps=steps-warmup_steps) scheduler_combined= optax.join_schedules(schedules=[scheduler_warmup, scheduler_decay], boundaries=[warmup_steps]) tx = optax.chain( optax.scale_by_adam( b1=b1, b2=b2, eps=eps, eps_root=eps_root, mu_dtype=mu_dtype ), optax.add_decayed_weights( weight_decay=weight_decay ), optax.scale_by_schedule(scheduler_combined), optax.scale(-1) ) if gradient_accumulation_steps > 1: tx = optax.MultiSteps( tx, gradient_accumulation_steps ) return tx, scheduler_combined
thank!
update Fjtuils to 0.0.20 and reinstall EasyDel and you then you can use warm_up_linear
I received the file Mistral-1.566301941871643-69 at the end of the model's training, and I was wondering if there is a way to convert this save file to model.bin or load it to tpu to see if it works.
Thank you for the support!