erfanzar / EasyDeL

Accelerate, Optimize performance with streamlined training and serving options with JAX.
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
206 stars 25 forks source link

How to train in mult-node? #10

Closed Nightbringers closed 11 months ago

Nightbringers commented 1 year ago

I dont konw how to train use mult-host. Can you gave a example? thank you.

erfanzar commented 1 year ago

hello yes for sure but which backend you want to use for multi host training is it on multiple TPU pods or GPU servers?

Nightbringers commented 1 year ago

GPU servers and slurm cluster

erfanzar commented 1 year ago

can you please test this code

import jax
from EasyDel import TrainArguments, CausalLMTrainer

num_processes = 6
process_id = 0  # number between 0 and num_processes-1 that says which node is current node
coordinator_address = 'ip:port'  # for example 192.168.1.12:8600 (make sure this port is not closed by firewall)
jax.distributed.initialize(coordinator_address=coordinator_address,
                           num_processes=num_processes,
                           process_id=process_id)

train_args = TrainArguments(
    backend='gpu',
    sharding_array=(num_processes, -1, 1),
    use_wandb=True,
    use_pjit_attention_force=False
)

trainer = CausalLMTrainer(
    arguments=train_args,
    dataset_train=..., # To Be passed
    ckpt_path=... # To Be passed path to ckpt or None 
)
parameters = None  # if you want to finetune a model you can pass parameters to trainer and they should be like frozen({"params":...})
trainer.train(model_parameters=parameters or None)
Nightbringers commented 1 year ago

I encountered some issues when I tried to run EasyDeL/examples/training/causal-lm/llama.py. I use llama-13b. I first convert hf-llama to flax. Use this code:

model = AutoModelForCausalLM.from_pretrained(path) state_dict = model.state_dict() flax_params = llama_convert_hf_to_flax(state_dict, num_hidden_layers=40, num_attention_heads=40, hidden_size=5120,device = device) save_ckpt(flax_params, 'flax_param_easydel')

Is my convert code correct?

then I run llama.py, here is the error: ValueError: One of pjit outputs was given the sharding of NamedSharding(mesh={'dp': 1, 'fsdp': 8, 'mp': 1}, spec=PartitionSpec('fsdp',)), which implies that the global size of its dimension 0 should be divisible by 8, but it is equal to 40076 (full shape: (40076, 5120))

erfanzar commented 1 year ago

when you trying to run pass fully_fsdp=False in config.get_partition_rules that will fix this problem

Nightbringers commented 1 year ago

This seems to disable fully_fsdp. But I want use fully_fsdp.

erfanzar commented 1 year ago

which model you trying to use can you give me the config of model that you trying to use so i can write a custom partition rule for that

Nightbringers commented 1 year ago

llama-13b v2. But the model's vocabulary size has changed. Does this have any impact?

erfanzar commented 1 year ago

yes that's the reason that you get this error can you tell me your vocab size? like is it like EOS,BOS added version that have 32002 tokens?

erfanzar commented 1 year ago

use this

from jax.sharding import PartitionSpec as PS
from EasyDel import TrainArguments

partition_rules = (

    ("transformer/wte/embedding", PS('dp', "fsdp")),

    ("attention/(wq|wk|wv)/kernel", PS("fsdp")),
    ("attention/wo/kernel", PS("fsdp")),

    ("feed_forward/w1/kernel", PS("fsdp")),
    ("feed_forward/w2/kernel", PS("fsdp")),
    ("feed_forward/w3/kernel", PS("fsdp")),

    ("attention_norm/kernel", PS('fsdp')),
    ("ffn_norm/kernel", PS('fsdp')),

    ("transformer/ln_f/kernel", PS('fsdp')),
    ("lm_head/kernel", PS("fsdp", 'dp')),
    ('.*', PS('fsdp')),
)

train_args = TrainArguments(
    custom_rule=partition_rules,
    ...
)

this one have to work fine if you just have changed the vocab size

Nightbringers commented 1 year ago

yes, it worked!

erfanzar commented 1 year ago

if you have any other issue please let me know <3

Nightbringers commented 1 year ago

The calculation of losses has encountered an issue. The error is: File "/EasyDeL/EasyDel/trainer/fsdp_train.py", line 406, in train sharded_trainstate, loss, accuracy = self.sharded_train_step_fn(sharded_trainstate, File "/EasyDeL/EasyDel/trainer/fsdp_train.py", line 309, in fsdp_trainstep (loss, accuracy), grad = grad_fn(state.params) File "/EasyDeL/EasyDel/trainer/fsdp_train.py", line 303, in calculate_loss loss, accuracy = loss_fn( File "/anaconda3-2/envs/py3.10/lib/python3.10/site-packages/fjutils/easylm.py", line 495, in blockwise_cross_entropy logits = rearrange(logits, '(n c) d -> n c d', c=chunk_size) File "/anaconda3-2/envs/py3.10/lib/python3.10/site-packages/einops/einops.py", line 483, in rearrange return reduce(cast(Tensor, tensor), pattern, reduction='rearrange', **axes_lengths) File "/anaconda3-2/envs/py3.10/lib/python3.10/site-packages/einops/einops.py", line 420, in reduce raise EinopsError(message + '\n {}'.format(e)) einops.EinopsError: Error while processing rearrange-reduction pattern "(n c) d -> n c d". Input tensor shape: (16368, 55296). Additional info: {'c': 1024}. Shape mismatch, can't divide axis of length 16368 in chunks of 1024

my vocab size is 55296, sequence_length is 1024.

erfanzar commented 1 year ago

set loss_remat to ''

train_args = TrainArguments(
    custom_rule=partition_rules,
    loss_remat=''
)

this will work, the current error that you takin is because you trying to use blockwise crossentropy and your vocab size (55296) is not visible by 1024 so you can either change loss_remat to '' or change your loss_chunk

Nightbringers commented 1 year ago
    if self.arguments.loss_remat != '':
        blockwise_cross = functools.partial(
            blockwise_cross_entropy,
            chunk_size=self.arguments.loss_chunk,
            policy=self.arguments.loss_remat
        )
        loss_fn = blockwise_cross
    else:
        loss_fn = cross_entropy_loss_and_accuracy

What is the difference between blockwise_cross and cross_entropy_loss_and_accuracy? What is the difference between blockwise crossentropy and crossentropy? Are there any advantages to using blockwise_cross?

erfanzar commented 1 year ago

blockwise_cross use jax.lax.scan for calculating the loss so that means it's being computed chink by chunk this might be slower and for sure more efficient and use less amount of memory and might not be a useable and suitable site for every satiation

Nightbringers commented 1 year ago

how to convert a flax_params to hf_model ? I am trying to convert llama2-13b flax_params to hf_model. I first convert hf-model to flax_params. Then I try to convert back but got a error, this is the code:

flax_params = llama_convert_hf_to_flax(state_dict, num_hidden_layers=40, num_attention_heads=40, hidden_size=5120,device = device) pt_params = llama_convert_flax_to_pt(flax_params, n_layers=40, dim=5120, num_attention_heads=40)

the error is: EasyDeL/EasyDel/transform/llama.py", line 218, in llama_convert_flax_to_pt torch_params[key] = torch.from_numpy(tensor.astype(dtype=dtype)) ^^^^^^^^^^^^^ AttributeError: 'dict' object has no attribute 'astype'

erfanzar commented 1 year ago

you have to use flatten_dict in order to convert flax to pt like

from flax.traverse_util import unflatten_dict, flatten_dict
flax_params = llama_convert_hf_to_flax(state_dict, num_hidden_layers=40, num_attention_heads=40, hidden_size=5120,device = device)
flax_params = flatten_dict(flax_params)
pt_params = llama_convert_flax_to_pt(flax_params, n_layers=40, dim=5120, num_attention_heads=40)
Nightbringers commented 1 year ago

pt_params = llama_convert_flax_to_pt(flax_params, n_layers=40, dim=5120, num_attention_heads=40) that error is gone, but this got another error:

EasyDeL/EasyDel/transform/llama.py", line 229, in llama_convert_flax_to_pt torch_params[f"transformer.h.{layer_i}.attention.wq.kernel"]


KeyError: 'transformer.h.0.attention.wq.kernel'
erfanzar commented 1 year ago

can i have access to your model or can you tell me which model you are using right now?

Nightbringers commented 1 year ago

llama2-13b.
I run this: flax_params = llama_convert_hf_to_flax(state_dict, num_hidden_layers=40, num_attention_heads=40, hidden_size=5120,device = device) flax_params = flatten_dict(flax_params) print(flax_params.keys())

got:

dict_keys([('transformer', 'wte', 'embedding'), ('transformer', 'ln_f', 'kernel'), ('transformer', 'h', '0', 'attention', 'wq', 'kernel'), ('transformer', 'h', '0', 'attention', 'wk', 'kernel'), ('transformer', 'h', '0', 'attention', 'wv', 'kernel'), ('transformer', 'h', '0', 'attention', 'wo', 'kernel'), ('transformer', 'h', '0', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '0', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '0', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '0', 'attention_norm', 'kernel'), ('transformer', 'h', '0', 'ffn_norm', 'kernel'), ('transformer', 'h', '1', 'attention', 'wq', 'kernel'), ('transformer', 'h', '1', 'attention', 'wk', 'kernel'), ('transformer', 'h', '1', 'attention', 'wv', 'kernel'), ('transformer', 'h', '1', 'attention', 'wo', 'kernel'), ('transformer', 'h', '1', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '1', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '1', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '1', 'attention_norm', 'kernel'), ('transformer', 'h', '1', 'ffn_norm', 'kernel'), ('transformer', 'h', '2', 'attention', 'wq', 'kernel'), ('transformer', 'h', '2', 'attention', 'wk', 'kernel'), ('transformer', 'h', '2', 'attention', 'wv', 'kernel'), ('transformer', 'h', '2', 'attention', 'wo', 'kernel'), ('transformer', 'h', '2', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '2', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '2', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '2', 'attention_norm', 'kernel'), ('transformer', 'h', '2', 'ffn_norm', 'kernel'), ('transformer', 'h', '3', 'attention', 'wq', 'kernel'), ('transformer', 'h', '3', 'attention', 'wk', 'kernel'), ('transformer', 'h', '3', 'attention', 'wv', 'kernel'), ('transformer', 'h', '3', 'attention', 'wo', 'kernel'), ('transformer', 'h', '3', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '3', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '3', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '3', 'attention_norm', 'kernel'), ('transformer', 'h', '3', 'ffn_norm', 'kernel'), ('transformer', 'h', '4', 'attention', 'wq', 'kernel'), ('transformer', 'h', '4', 'attention', 'wk', 'kernel'), ('transformer', 'h', '4', 'attention', 'wv', 'kernel'), ('transformer', 'h', '4', 'attention', 'wo', 'kernel'), ('transformer', 'h', '4', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '4', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '4', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '4', 'attention_norm', 'kernel'), ('transformer', 'h', '4', 'ffn_norm', 'kernel'), ('transformer', 'h', '5', 'attention', 'wq', 'kernel'), ('transformer', 'h', '5', 'attention', 'wk', 'kernel'), ('transformer', 'h', '5', 'attention', 'wv', 'kernel'), ('transformer', 'h', '5', 'attention', 'wo', 'kernel'), ('transformer', 'h', '5', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '5', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '5', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '5', 'attention_norm', 'kernel'), ('transformer', 'h', '5', 'ffn_norm', 'kernel'), ('transformer', 'h', '6', 'attention', 'wq', 'kernel'), ('transformer', 'h', '6', 'attention', 'wk', 'kernel'), ('transformer', 'h', '6', 'attention', 'wv', 'kernel'), ('transformer', 'h', '6', 'attention', 'wo', 'kernel'), ('transformer', 'h', '6', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '6', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '6', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '6', 'attention_norm', 'kernel'), ('transformer', 'h', '6', 'ffn_norm', 'kernel'), ('transformer', 'h', '7', 'attention', 'wq', 'kernel'), ('transformer', 'h', '7', 'attention', 'wk', 'kernel'), ('transformer', 'h', '7', 'attention', 'wv', 'kernel'), ('transformer', 'h', '7', 'attention', 'wo', 'kernel'), ('transformer', 'h', '7', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '7', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '7', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '7', 'attention_norm', 'kernel'), ('transformer', 'h', '7', 'ffn_norm', 'kernel'), ('transformer', 'h', '8', 'attention', 'wq', 'kernel'), ('transformer', 'h', '8', 'attention', 'wk', 'kernel'), ('transformer', 'h', '8', 'attention', 'wv', 'kernel'), ('transformer', 'h', '8', 'attention', 'wo', 'kernel'), ('transformer', 'h', '8', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '8', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '8', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '8', 'attention_norm', 'kernel'), ('transformer', 'h', '8', 'ffn_norm', 'kernel'), ('transformer', 'h', '9', 'attention', 'wq', 'kernel'), ('transformer', 'h', '9', 'attention', 'wk', 'kernel'), ('transformer', 'h', '9', 'attention', 'wv', 'kernel'), ('transformer', 'h', '9', 'attention', 'wo', 'kernel'), ('transformer', 'h', '9', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '9', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '9', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '9', 'attention_norm', 'kernel'), ('transformer', 'h', '9', 'ffn_norm', 'kernel'), ('transformer', 'h', '10', 'attention', 'wq', 'kernel'), ('transformer', 'h', '10', 'attention', 'wk', 'kernel'), ('transformer', 'h', '10', 'attention', 'wv', 'kernel'), ('transformer', 'h', '10', 'attention', 'wo', 'kernel'), ('transformer', 'h', '10', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '10', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '10', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '10', 'attention_norm', 'kernel'), ('transformer', 'h', '10', 'ffn_norm', 'kernel'), ('transformer', 'h', '11', 'attention', 'wq', 'kernel'), ('transformer', 'h', '11', 'attention', 'wk', 'kernel'), ('transformer', 'h', '11', 'attention', 'wv', 'kernel'), ('transformer', 'h', '11', 'attention', 'wo', 'kernel'), ('transformer', 'h', '11', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '11', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '11', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '11', 'attention_norm', 'kernel'), ('transformer', 'h', '11', 'ffn_norm', 'kernel'), ('transformer', 'h', '12', 'attention', 'wq', 'kernel'), ('transformer', 'h', '12', 'attention', 'wk', 'kernel'), ('transformer', 'h', '12', 'attention', 'wv', 'kernel'), ('transformer', 'h', '12', 'attention', 'wo', 'kernel'), ('transformer', 'h', '12', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '12', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '12', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '12', 'attention_norm', 'kernel'), ('transformer', 'h', '12', 'ffn_norm', 'kernel'), ('transformer', 'h', '13', 'attention', 'wq', 'kernel'), ('transformer', 'h', '13', 'attention', 'wk', 'kernel'), ('transformer', 'h', '13', 'attention', 'wv', 'kernel'), ('transformer', 'h', '13', 'attention', 'wo', 'kernel'), ('transformer', 'h', '13', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '13', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '13', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '13', 'attention_norm', 'kernel'), ('transformer', 'h', '13', 'ffn_norm', 'kernel'), ('transformer', 'h', '14', 'attention', 'wq', 'kernel'), ('transformer', 'h', '14', 'attention', 'wk', 'kernel'), ('transformer', 'h', '14', 'attention', 'wv', 'kernel'), ('transformer', 'h', '14', 'attention', 'wo', 'kernel'), ('transformer', 'h', '14', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '14', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '14', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '14', 'attention_norm', 'kernel'), ('transformer', 'h', '14', 'ffn_norm', 'kernel'), ('transformer', 'h', '15', 'attention', 'wq', 'kernel'), ('transformer', 'h', '15', 'attention', 'wk', 'kernel'), ('transformer', 'h', '15', 'attention', 'wv', 'kernel'), ('transformer', 'h', '15', 'attention', 'wo', 'kernel'), ('transformer', 'h', '15', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '15', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '15', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '15', 'attention_norm', 'kernel'), ('transformer', 'h', '15', 'ffn_norm', 'kernel'), ('transformer', 'h', '16', 'attention', 'wq', 'kernel'), ('transformer', 'h', '16', 'attention', 'wk', 'kernel'), ('transformer', 'h', '16', 'attention', 'wv', 'kernel'), ('transformer', 'h', '16', 'attention', 'wo', 'kernel'), ('transformer', 'h', '16', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '16', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '16', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '16', 'attention_norm', 'kernel'), ('transformer', 'h', '16', 'ffn_norm', 'kernel'), ('transformer', 'h', '17', 'attention', 'wq', 'kernel'), ('transformer', 'h', '17', 'attention', 'wk', 'kernel'), ('transformer', 'h', '17', 'attention', 'wv', 'kernel'), ('transformer', 'h', '17', 'attention', 'wo', 'kernel'), ('transformer', 'h', '17', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '17', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '17', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '17', 'attention_norm', 'kernel'), ('transformer', 'h', '17', 'ffn_norm', 'kernel'), ('transformer', 'h', '18', 'attention', 'wq', 'kernel'), ('transformer', 'h', '18', 'attention', 'wk', 'kernel'), ('transformer', 'h', '18', 'attention', 'wv', 'kernel'), ('transformer', 'h', '18', 'attention', 'wo', 'kernel'), ('transformer', 'h', '18', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '18', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '18', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '18', 'attention_norm', 'kernel'), ('transformer', 'h', '18', 'ffn_norm', 'kernel'), ('transformer', 'h', '19', 'attention', 'wq', 'kernel'), ('transformer', 'h', '19', 'attention', 'wk', 'kernel'), ('transformer', 'h', '19', 'attention', 'wv', 'kernel'), ('transformer', 'h', '19', 'attention', 'wo', 'kernel'), ('transformer', 'h', '19', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '19', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '19', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '19', 'attention_norm', 'kernel'), ('transformer', 'h', '19', 'ffn_norm', 'kernel'), ('transformer', 'h', '20', 'attention', 'wq', 'kernel'), ('transformer', 'h', '20', 'attention', 'wk', 'kernel'), ('transformer', 'h', '20', 'attention', 'wv', 'kernel'), ('transformer', 'h', '20', 'attention', 'wo', 'kernel'), ('transformer', 'h', '20', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '20', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '20', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '20', 'attention_norm', 'kernel'), ('transformer', 'h', '20', 'ffn_norm', 'kernel'), ('transformer', 'h', '21', 'attention', 'wq', 'kernel'), ('transformer', 'h', '21', 'attention', 'wk', 'kernel'), ('transformer', 'h', '21', 'attention', 'wv', 'kernel'), ('transformer', 'h', '21', 'attention', 'wo', 'kernel'), ('transformer', 'h', '21', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '21', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '21', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '21', 'attention_norm', 'kernel'), ('transformer', 'h', '21', 'ffn_norm', 'kernel'), ('transformer', 'h', '22', 'attention', 'wq', 'kernel'), ('transformer', 'h', '22', 'attention', 'wk', 'kernel'), ('transformer', 'h', '22', 'attention', 'wv', 'kernel'), ('transformer', 'h', '22', 'attention', 'wo', 'kernel'), ('transformer', 'h', '22', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '22', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '22', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '22', 'attention_norm', 'kernel'), ('transformer', 'h', '22', 'ffn_norm', 'kernel'), ('transformer', 'h', '23', 'attention', 'wq', 'kernel'), ('transformer', 'h', '23', 'attention', 'wk', 'kernel'), ('transformer', 'h', '23', 'attention', 'wv', 'kernel'), ('transformer', 'h', '23', 'attention', 'wo', 'kernel'), ('transformer', 'h', '23', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '23', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '23', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '23', 'attention_norm', 'kernel'), ('transformer', 'h', '23', 'ffn_norm', 'kernel'), ('transformer', 'h', '24', 'attention', 'wq', 'kernel'), ('transformer', 'h', '24', 'attention', 'wk', 'kernel'), ('transformer', 'h', '24', 'attention', 'wv', 'kernel'), ('transformer', 'h', '24', 'attention', 'wo', 'kernel'), ('transformer', 'h', '24', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '24', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '24', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '24', 'attention_norm', 'kernel'), ('transformer', 'h', '24', 'ffn_norm', 'kernel'), ('transformer', 'h', '25', 'attention', 'wq', 'kernel'), ('transformer', 'h', '25', 'attention', 'wk', 'kernel'), ('transformer', 'h', '25', 'attention', 'wv', 'kernel'), ('transformer', 'h', '25', 'attention', 'wo', 'kernel'), ('transformer', 'h', '25', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '25', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '25', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '25', 'attention_norm', 'kernel'), ('transformer', 'h', '25', 'ffn_norm', 'kernel'), ('transformer', 'h', '26', 'attention', 'wq', 'kernel'), ('transformer', 'h', '26', 'attention', 'wk', 'kernel'), ('transformer', 'h', '26', 'attention', 'wv', 'kernel'), ('transformer', 'h', '26', 'attention', 'wo', 'kernel'), ('transformer', 'h', '26', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '26', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '26', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '26', 'attention_norm', 'kernel'), ('transformer', 'h', '26', 'ffn_norm', 'kernel'), ('transformer', 'h', '27', 'attention', 'wq', 'kernel'), ('transformer', 'h', '27', 'attention', 'wk', 'kernel'), ('transformer', 'h', '27', 'attention', 'wv', 'kernel'), ('transformer', 'h', '27', 'attention', 'wo', 'kernel'), ('transformer', 'h', '27', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '27', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '27', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '27', 'attention_norm', 'kernel'), ('transformer', 'h', '27', 'ffn_norm', 'kernel'), ('transformer', 'h', '28', 'attention', 'wq', 'kernel'), ('transformer', 'h', '28', 'attention', 'wk', 'kernel'), ('transformer', 'h', '28', 'attention', 'wv', 'kernel'), ('transformer', 'h', '28', 'attention', 'wo', 'kernel'), ('transformer', 'h', '28', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '28', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '28', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '28', 'attention_norm', 'kernel'), ('transformer', 'h', '28', 'ffn_norm', 'kernel'), ('transformer', 'h', '29', 'attention', 'wq', 'kernel'), ('transformer', 'h', '29', 'attention', 'wk', 'kernel'), ('transformer', 'h', '29', 'attention', 'wv', 'kernel'), ('transformer', 'h', '29', 'attention', 'wo', 'kernel'), ('transformer', 'h', '29', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '29', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '29', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '29', 'attention_norm', 'kernel'), ('transformer', 'h', '29', 'ffn_norm', 'kernel'), ('transformer', 'h', '30', 'attention', 'wq', 'kernel'), ('transformer', 'h', '30', 'attention', 'wk', 'kernel'), ('transformer', 'h', '30', 'attention', 'wv', 'kernel'), ('transformer', 'h', '30', 'attention', 'wo', 'kernel'), ('transformer', 'h', '30', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '30', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '30', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '30', 'attention_norm', 'kernel'), ('transformer', 'h', '30', 'ffn_norm', 'kernel'), ('transformer', 'h', '31', 'attention', 'wq', 'kernel'), ('transformer', 'h', '31', 'attention', 'wk', 'kernel'), ('transformer', 'h', '31', 'attention', 'wv', 'kernel'), ('transformer', 'h', '31', 'attention', 'wo', 'kernel'), ('transformer', 'h', '31', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '31', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '31', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '31', 'attention_norm', 'kernel'), ('transformer', 'h', '31', 'ffn_norm', 'kernel'), ('transformer', 'h', '32', 'attention', 'wq', 'kernel'), ('transformer', 'h', '32', 'attention', 'wk', 'kernel'), ('transformer', 'h', '32', 'attention', 'wv', 'kernel'), ('transformer', 'h', '32', 'attention', 'wo', 'kernel'), ('transformer', 'h', '32', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '32', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '32', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '32', 'attention_norm', 'kernel'), ('transformer', 'h', '32', 'ffn_norm', 'kernel'), ('transformer', 'h', '33', 'attention', 'wq', 'kernel'), ('transformer', 'h', '33', 'attention', 'wk', 'kernel'), ('transformer', 'h', '33', 'attention', 'wv', 'kernel'), ('transformer', 'h', '33', 'attention', 'wo', 'kernel'), ('transformer', 'h', '33', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '33', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '33', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '33', 'attention_norm', 'kernel'), ('transformer', 'h', '33', 'ffn_norm', 'kernel'), ('transformer', 'h', '34', 'attention', 'wq', 'kernel'), ('transformer', 'h', '34', 'attention', 'wk', 'kernel'), ('transformer', 'h', '34', 'attention', 'wv', 'kernel'), ('transformer', 'h', '34', 'attention', 'wo', 'kernel'), ('transformer', 'h', '34', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '34', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '34', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '34', 'attention_norm', 'kernel'), ('transformer', 'h', '34', 'ffn_norm', 'kernel'), ('transformer', 'h', '35', 'attention', 'wq', 'kernel'), ('transformer', 'h', '35', 'attention', 'wk', 'kernel'), ('transformer', 'h', '35', 'attention', 'wv', 'kernel'), ('transformer', 'h', '35', 'attention', 'wo', 'kernel'), ('transformer', 'h', '35', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '35', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '35', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '35', 'attention_norm', 'kernel'), ('transformer', 'h', '35', 'ffn_norm', 'kernel'), ('transformer', 'h', '36', 'attention', 'wq', 'kernel'), ('transformer', 'h', '36', 'attention', 'wk', 'kernel'), ('transformer', 'h', '36', 'attention', 'wv', 'kernel'), ('transformer', 'h', '36', 'attention', 'wo', 'kernel'), ('transformer', 'h', '36', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '36', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '36', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '36', 'attention_norm', 'kernel'), ('transformer', 'h', '36', 'ffn_norm', 'kernel'), ('transformer', 'h', '37', 'attention', 'wq', 'kernel'), ('transformer', 'h', '37', 'attention', 'wk', 'kernel'), ('transformer', 'h', '37', 'attention', 'wv', 'kernel'), ('transformer', 'h', '37', 'attention', 'wo', 'kernel'), ('transformer', 'h', '37', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '37', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '37', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '37', 'attention_norm', 'kernel'), ('transformer', 'h', '37', 'ffn_norm', 'kernel'), ('transformer', 'h', '38', 'attention', 'wq', 'kernel'), ('transformer', 'h', '38', 'attention', 'wk', 'kernel'), ('transformer', 'h', '38', 'attention', 'wv', 'kernel'), ('transformer', 'h', '38', 'attention', 'wo', 'kernel'), ('transformer', 'h', '38', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '38', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '38', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '38', 'attention_norm', 'kernel'), ('transformer', 'h', '38', 'ffn_norm', 'kernel'), ('transformer', 'h', '39', 'attention', 'wq', 'kernel'), ('transformer', 'h', '39', 'attention', 'wk', 'kernel'), ('transformer', 'h', '39', 'attention', 'wv', 'kernel'), ('transformer', 'h', '39', 'attention', 'wo', 'kernel'), ('transformer', 'h', '39', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '39', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '39', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '39', 'attention_norm', 'kernel'), ('transformer', 'h', '39', 'ffn_norm', 'kernel'), ('lm_head', 'kernel')])

erfanzar commented 1 year ago

sorry i explained a part of it wrong you should not use this

from flax.traverse_util import unflatten_dict, flatten_dict
flax_params = llama_convert_hf_to_flax(state_dict, num_hidden_layers=40, num_attention_heads=40, hidden_size=5120,device = device)
flax_params = flatten_dict(flax_params)
pt_params = llama_convert_flax_to_pt(flax_params, n_layers=40, dim=5120, num_attention_heads=40)

use this instead

from flax.traverse_util import unflatten_dict, flatten_dict
flax_params = llama_convert_hf_to_flax(state_dict, num_hidden_layers=40, num_attention_heads=40, hidden_size=5120,device = device)
flax_params = flatten_dict(flax_params,sep='.')
pt_params = llama_convert_flax_to_pt(flax_params, n_layers=40, dim=5120, num_attention_heads=40)
erfanzar commented 1 year ago

Docs Are available at https://erfanzar.github.io/EasyDeL/docs/

Nightbringers commented 1 year ago

thanks, I will keep testing when I have time.

Nightbringers commented 1 year ago

import error

Traceback (most recent call last): File "EasyDeL/test_llama.py", line 1, in from EasyDel import TrainArguments, CausalLMTrainer File "EasyDeL/EasyDel/init.py", line 2, in from .modules import FlaxLlamaModel, FlaxLlamaForCausalLM, LlamaConfig, \ File "/EasyDeL/EasyDel/modules/init.py", line 5, in from .falcon import FalconConfig, FlaxFalconModel, FlaxFalconForCausalLM File "/EasyDeL/EasyDel/modules/falcon/init.py", line 1, in from .modelling_falcon_flax import FlaxFalconForCausalLM, FlaxFalconModel, FalconConfig File "/EasyDeL/EasyDel/modules/falcon/modelling_falcon_flax.py", line 519, in class FlaxFalconPretrainedModel(FlaxPreTrainedModel): File "/EasyDeL/EasyDel/modules/falcon/modelling_falcon_flax.py", line 528, in FlaxFalconPretrainedModel precision: Optional[None, jax.lax.Precision] = jax.lax.Precision('fastest')


  File "/anaconda3/envs/easy/lib/python3.11/typing.py", line 355, in inner
    return func(*args, **kwds)
           ^^^^^^^^^^^^^^^^^^^
  File "/anaconda3/envs/easy/lib/python3.11/typing.py", line 478, in __getitem__
    return self._getitem(self, parameters)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda3/envs/easy/lib/python3.11/typing.py", line 700, in Optional
    arg = _type_check(parameters, f"{self} requires a single type.")
          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/anaconda3/envs/easy/lib/python3.11/typing.py", line 197, in _type_check
    raise TypeError(f"{msg} Got {arg!r:.100}.")
TypeError: typing.Optional requires a single type. Got (None, <class 'jax._src.lax.lax.Precision'>).
erfanzar commented 1 year ago

fixed im sorry for such error :)

Nightbringers commented 1 year ago

if train use mult-host, does the dataset need any additional processing?
And total_batch_size=FLAGS.batch_size, this total_batch_size is sum of per host batch_size? for example, total_batch_size = 16, and there are 2 nodes all have 8 gpus. Then per gpu batch_size is 1? per node batch_size is 8? Is this right?

erfanzar commented 1 year ago

yes for using easydel you should preprocess you dataset you should pass the tokenized dataset that contains input_ids and attention mask

and for batch size you pass the batch size for each step being multiplied to number of gradient accumulation steps for example imagine that you have passed batch size of 8 to trainer with gradient accumulation 8 the total batch size for data loader become 64 and if you have 2 hosts this will become 32 batch size for each host and if you have 8 GPUs per each machine this will become 4 batch per-each GPU

Nightbringers commented 1 year ago

warning: Linking two modules of different target triples: 'LLVMDialectModule' is 'nvptx64-nvidia-gpulibs' whereas '' is 'nvptx64-nvidia-cuda' Does this warning have any impact?

what is use_pjit_attention_force ? What is the difference between use_pjit_attention_force=false and use_pjit_attention_force=true?

Nightbringers commented 1 year ago

and use_flash_attention seems not work, I found that their speed is the same whether set to true or false.

Nightbringers commented 1 year ago

I set max_sequence_length = 10240, had this error: ValueError: Incompatible shapes for broadcasting: (1, 1, 1, 10240) and requested shape (1, 1, 8192, 8192) I suppose because my model max_position_embeddings=8192?

And when use large sequence_length occurs loss=nan.

erfanzar commented 1 year ago

Yes you are right you should change you model max length

Nightbringers commented 1 year ago

use_flash_attention seems not work, I found that their speed is the same whether set to true or false. And what is use_pjit_attention_force ?

Nightbringers commented 1 year ago

I'm really looking forward to Mojo version. Is Mojo a replacement for Jax, or can they work together?

erfanzar commented 1 year ago

Actually mojo is more native at least the version that im creating and its works without any imported libraries in mojo since mojo is fast, native and compiled language i coded everything unique and the only library in use from python is os library only to read the size of check points (mojo don't have any built in I/O library)

Nightbringers commented 1 year ago

that was awesome. you want make a framework like tensorflow or pytorch based on mojo. This requires a significant amount of work.