Closed Nightbringers closed 11 months ago
hello yes for sure but which backend you want to use for multi host training is it on multiple TPU pods or GPU servers?
GPU servers and slurm cluster
can you please test this code
import jax
from EasyDel import TrainArguments, CausalLMTrainer
num_processes = 6
process_id = 0 # number between 0 and num_processes-1 that says which node is current node
coordinator_address = 'ip:port' # for example 192.168.1.12:8600 (make sure this port is not closed by firewall)
jax.distributed.initialize(coordinator_address=coordinator_address,
num_processes=num_processes,
process_id=process_id)
train_args = TrainArguments(
backend='gpu',
sharding_array=(num_processes, -1, 1),
use_wandb=True,
use_pjit_attention_force=False
)
trainer = CausalLMTrainer(
arguments=train_args,
dataset_train=..., # To Be passed
ckpt_path=... # To Be passed path to ckpt or None
)
parameters = None # if you want to finetune a model you can pass parameters to trainer and they should be like frozen({"params":...})
trainer.train(model_parameters=parameters or None)
I encountered some issues when I tried to run EasyDeL/examples/training/causal-lm/llama.py. I use llama-13b. I first convert hf-llama to flax. Use this code:
model = AutoModelForCausalLM.from_pretrained(path) state_dict = model.state_dict() flax_params = llama_convert_hf_to_flax(state_dict, num_hidden_layers=40, num_attention_heads=40, hidden_size=5120,device = device) save_ckpt(flax_params, 'flax_param_easydel')
Is my convert code correct?
then I run llama.py, here is the error: ValueError: One of pjit outputs was given the sharding of NamedSharding(mesh={'dp': 1, 'fsdp': 8, 'mp': 1}, spec=PartitionSpec('fsdp',)), which implies that the global size of its dimension 0 should be divisible by 8, but it is equal to 40076 (full shape: (40076, 5120))
when you trying to run pass fully_fsdp=False in config.get_partition_rules that will fix this problem
This seems to disable fully_fsdp. But I want use fully_fsdp.
which model you trying to use can you give me the config of model that you trying to use so i can write a custom partition rule for that
llama-13b v2. But the model's vocabulary size has changed. Does this have any impact?
yes that's the reason that you get this error can you tell me your vocab size? like is it like EOS,BOS added version that have 32002 tokens?
use this
from jax.sharding import PartitionSpec as PS
from EasyDel import TrainArguments
partition_rules = (
("transformer/wte/embedding", PS('dp', "fsdp")),
("attention/(wq|wk|wv)/kernel", PS("fsdp")),
("attention/wo/kernel", PS("fsdp")),
("feed_forward/w1/kernel", PS("fsdp")),
("feed_forward/w2/kernel", PS("fsdp")),
("feed_forward/w3/kernel", PS("fsdp")),
("attention_norm/kernel", PS('fsdp')),
("ffn_norm/kernel", PS('fsdp')),
("transformer/ln_f/kernel", PS('fsdp')),
("lm_head/kernel", PS("fsdp", 'dp')),
('.*', PS('fsdp')),
)
train_args = TrainArguments(
custom_rule=partition_rules,
...
)
this one have to work fine if you just have changed the vocab size
yes, it worked!
if you have any other issue please let me know <3
The calculation of losses has encountered an issue. The error is: File "/EasyDeL/EasyDel/trainer/fsdp_train.py", line 406, in train sharded_trainstate, loss, accuracy = self.sharded_train_step_fn(sharded_trainstate, File "/EasyDeL/EasyDel/trainer/fsdp_train.py", line 309, in fsdp_trainstep (loss, accuracy), grad = grad_fn(state.params) File "/EasyDeL/EasyDel/trainer/fsdp_train.py", line 303, in calculate_loss loss, accuracy = loss_fn( File "/anaconda3-2/envs/py3.10/lib/python3.10/site-packages/fjutils/easylm.py", line 495, in blockwise_cross_entropy logits = rearrange(logits, '(n c) d -> n c d', c=chunk_size) File "/anaconda3-2/envs/py3.10/lib/python3.10/site-packages/einops/einops.py", line 483, in rearrange return reduce(cast(Tensor, tensor), pattern, reduction='rearrange', **axes_lengths) File "/anaconda3-2/envs/py3.10/lib/python3.10/site-packages/einops/einops.py", line 420, in reduce raise EinopsError(message + '\n {}'.format(e)) einops.EinopsError: Error while processing rearrange-reduction pattern "(n c) d -> n c d". Input tensor shape: (16368, 55296). Additional info: {'c': 1024}. Shape mismatch, can't divide axis of length 16368 in chunks of 1024
my vocab size is 55296, sequence_length is 1024.
set loss_remat to ''
train_args = TrainArguments(
custom_rule=partition_rules,
loss_remat=''
)
this will work, the current error that you takin is because you trying to use blockwise crossentropy and your vocab size (55296) is not visible by 1024 so you can either change loss_remat to '' or change your loss_chunk
if self.arguments.loss_remat != '':
blockwise_cross = functools.partial(
blockwise_cross_entropy,
chunk_size=self.arguments.loss_chunk,
policy=self.arguments.loss_remat
)
loss_fn = blockwise_cross
else:
loss_fn = cross_entropy_loss_and_accuracy
What is the difference between blockwise_cross and cross_entropy_loss_and_accuracy? What is the difference between blockwise crossentropy and crossentropy? Are there any advantages to using blockwise_cross?
blockwise_cross use jax.lax.scan for calculating the loss so that means it's being computed chink by chunk this might be slower and for sure more efficient and use less amount of memory and might not be a useable and suitable site for every satiation
how to convert a flax_params to hf_model ? I am trying to convert llama2-13b flax_params to hf_model. I first convert hf-model to flax_params. Then I try to convert back but got a error, this is the code:
flax_params = llama_convert_hf_to_flax(state_dict, num_hidden_layers=40, num_attention_heads=40, hidden_size=5120,device = device) pt_params = llama_convert_flax_to_pt(flax_params, n_layers=40, dim=5120, num_attention_heads=40)
the error is: EasyDeL/EasyDel/transform/llama.py", line 218, in llama_convert_flax_to_pt torch_params[key] = torch.from_numpy(tensor.astype(dtype=dtype)) ^^^^^^^^^^^^^ AttributeError: 'dict' object has no attribute 'astype'
you have to use flatten_dict in order to convert flax to pt like
from flax.traverse_util import unflatten_dict, flatten_dict
flax_params = llama_convert_hf_to_flax(state_dict, num_hidden_layers=40, num_attention_heads=40, hidden_size=5120,device = device)
flax_params = flatten_dict(flax_params)
pt_params = llama_convert_flax_to_pt(flax_params, n_layers=40, dim=5120, num_attention_heads=40)
pt_params = llama_convert_flax_to_pt(flax_params, n_layers=40, dim=5120, num_attention_heads=40) that error is gone, but this got another error:
EasyDeL/EasyDel/transform/llama.py", line 229, in llama_convert_flax_to_pt torch_params[f"transformer.h.{layer_i}.attention.wq.kernel"]
KeyError: 'transformer.h.0.attention.wq.kernel'
can i have access to your model or can you tell me which model you are using right now?
llama2-13b.
I run this:
flax_params = llama_convert_hf_to_flax(state_dict, num_hidden_layers=40, num_attention_heads=40, hidden_size=5120,device = device)
flax_params = flatten_dict(flax_params)
print(flax_params.keys())
got:
dict_keys([('transformer', 'wte', 'embedding'), ('transformer', 'ln_f', 'kernel'), ('transformer', 'h', '0', 'attention', 'wq', 'kernel'), ('transformer', 'h', '0', 'attention', 'wk', 'kernel'), ('transformer', 'h', '0', 'attention', 'wv', 'kernel'), ('transformer', 'h', '0', 'attention', 'wo', 'kernel'), ('transformer', 'h', '0', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '0', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '0', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '0', 'attention_norm', 'kernel'), ('transformer', 'h', '0', 'ffn_norm', 'kernel'), ('transformer', 'h', '1', 'attention', 'wq', 'kernel'), ('transformer', 'h', '1', 'attention', 'wk', 'kernel'), ('transformer', 'h', '1', 'attention', 'wv', 'kernel'), ('transformer', 'h', '1', 'attention', 'wo', 'kernel'), ('transformer', 'h', '1', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '1', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '1', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '1', 'attention_norm', 'kernel'), ('transformer', 'h', '1', 'ffn_norm', 'kernel'), ('transformer', 'h', '2', 'attention', 'wq', 'kernel'), ('transformer', 'h', '2', 'attention', 'wk', 'kernel'), ('transformer', 'h', '2', 'attention', 'wv', 'kernel'), ('transformer', 'h', '2', 'attention', 'wo', 'kernel'), ('transformer', 'h', '2', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '2', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '2', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '2', 'attention_norm', 'kernel'), ('transformer', 'h', '2', 'ffn_norm', 'kernel'), ('transformer', 'h', '3', 'attention', 'wq', 'kernel'), ('transformer', 'h', '3', 'attention', 'wk', 'kernel'), ('transformer', 'h', '3', 'attention', 'wv', 'kernel'), ('transformer', 'h', '3', 'attention', 'wo', 'kernel'), ('transformer', 'h', '3', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '3', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '3', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '3', 'attention_norm', 'kernel'), ('transformer', 'h', '3', 'ffn_norm', 'kernel'), ('transformer', 'h', '4', 'attention', 'wq', 'kernel'), ('transformer', 'h', '4', 'attention', 'wk', 'kernel'), ('transformer', 'h', '4', 'attention', 'wv', 'kernel'), ('transformer', 'h', '4', 'attention', 'wo', 'kernel'), ('transformer', 'h', '4', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '4', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '4', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '4', 'attention_norm', 'kernel'), ('transformer', 'h', '4', 'ffn_norm', 'kernel'), ('transformer', 'h', '5', 'attention', 'wq', 'kernel'), ('transformer', 'h', '5', 'attention', 'wk', 'kernel'), ('transformer', 'h', '5', 'attention', 'wv', 'kernel'), ('transformer', 'h', '5', 'attention', 'wo', 'kernel'), ('transformer', 'h', '5', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '5', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '5', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '5', 'attention_norm', 'kernel'), ('transformer', 'h', '5', 'ffn_norm', 'kernel'), ('transformer', 'h', '6', 'attention', 'wq', 'kernel'), ('transformer', 'h', '6', 'attention', 'wk', 'kernel'), ('transformer', 'h', '6', 'attention', 'wv', 'kernel'), ('transformer', 'h', '6', 'attention', 'wo', 'kernel'), ('transformer', 'h', '6', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '6', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '6', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '6', 'attention_norm', 'kernel'), ('transformer', 'h', '6', 'ffn_norm', 'kernel'), ('transformer', 'h', '7', 'attention', 'wq', 'kernel'), ('transformer', 'h', '7', 'attention', 'wk', 'kernel'), ('transformer', 'h', '7', 'attention', 'wv', 'kernel'), ('transformer', 'h', '7', 'attention', 'wo', 'kernel'), ('transformer', 'h', '7', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '7', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '7', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '7', 'attention_norm', 'kernel'), ('transformer', 'h', '7', 'ffn_norm', 'kernel'), ('transformer', 'h', '8', 'attention', 'wq', 'kernel'), ('transformer', 'h', '8', 'attention', 'wk', 'kernel'), ('transformer', 'h', '8', 'attention', 'wv', 'kernel'), ('transformer', 'h', '8', 'attention', 'wo', 'kernel'), ('transformer', 'h', '8', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '8', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '8', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '8', 'attention_norm', 'kernel'), ('transformer', 'h', '8', 'ffn_norm', 'kernel'), ('transformer', 'h', '9', 'attention', 'wq', 'kernel'), ('transformer', 'h', '9', 'attention', 'wk', 'kernel'), ('transformer', 'h', '9', 'attention', 'wv', 'kernel'), ('transformer', 'h', '9', 'attention', 'wo', 'kernel'), ('transformer', 'h', '9', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '9', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '9', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '9', 'attention_norm', 'kernel'), ('transformer', 'h', '9', 'ffn_norm', 'kernel'), ('transformer', 'h', '10', 'attention', 'wq', 'kernel'), ('transformer', 'h', '10', 'attention', 'wk', 'kernel'), ('transformer', 'h', '10', 'attention', 'wv', 'kernel'), ('transformer', 'h', '10', 'attention', 'wo', 'kernel'), ('transformer', 'h', '10', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '10', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '10', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '10', 'attention_norm', 'kernel'), ('transformer', 'h', '10', 'ffn_norm', 'kernel'), ('transformer', 'h', '11', 'attention', 'wq', 'kernel'), ('transformer', 'h', '11', 'attention', 'wk', 'kernel'), ('transformer', 'h', '11', 'attention', 'wv', 'kernel'), ('transformer', 'h', '11', 'attention', 'wo', 'kernel'), ('transformer', 'h', '11', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '11', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '11', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '11', 'attention_norm', 'kernel'), ('transformer', 'h', '11', 'ffn_norm', 'kernel'), ('transformer', 'h', '12', 'attention', 'wq', 'kernel'), ('transformer', 'h', '12', 'attention', 'wk', 'kernel'), ('transformer', 'h', '12', 'attention', 'wv', 'kernel'), ('transformer', 'h', '12', 'attention', 'wo', 'kernel'), ('transformer', 'h', '12', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '12', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '12', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '12', 'attention_norm', 'kernel'), ('transformer', 'h', '12', 'ffn_norm', 'kernel'), ('transformer', 'h', '13', 'attention', 'wq', 'kernel'), ('transformer', 'h', '13', 'attention', 'wk', 'kernel'), ('transformer', 'h', '13', 'attention', 'wv', 'kernel'), ('transformer', 'h', '13', 'attention', 'wo', 'kernel'), ('transformer', 'h', '13', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '13', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '13', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '13', 'attention_norm', 'kernel'), ('transformer', 'h', '13', 'ffn_norm', 'kernel'), ('transformer', 'h', '14', 'attention', 'wq', 'kernel'), ('transformer', 'h', '14', 'attention', 'wk', 'kernel'), ('transformer', 'h', '14', 'attention', 'wv', 'kernel'), ('transformer', 'h', '14', 'attention', 'wo', 'kernel'), ('transformer', 'h', '14', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '14', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '14', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '14', 'attention_norm', 'kernel'), ('transformer', 'h', '14', 'ffn_norm', 'kernel'), ('transformer', 'h', '15', 'attention', 'wq', 'kernel'), ('transformer', 'h', '15', 'attention', 'wk', 'kernel'), ('transformer', 'h', '15', 'attention', 'wv', 'kernel'), ('transformer', 'h', '15', 'attention', 'wo', 'kernel'), ('transformer', 'h', '15', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '15', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '15', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '15', 'attention_norm', 'kernel'), ('transformer', 'h', '15', 'ffn_norm', 'kernel'), ('transformer', 'h', '16', 'attention', 'wq', 'kernel'), ('transformer', 'h', '16', 'attention', 'wk', 'kernel'), ('transformer', 'h', '16', 'attention', 'wv', 'kernel'), ('transformer', 'h', '16', 'attention', 'wo', 'kernel'), ('transformer', 'h', '16', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '16', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '16', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '16', 'attention_norm', 'kernel'), ('transformer', 'h', '16', 'ffn_norm', 'kernel'), ('transformer', 'h', '17', 'attention', 'wq', 'kernel'), ('transformer', 'h', '17', 'attention', 'wk', 'kernel'), ('transformer', 'h', '17', 'attention', 'wv', 'kernel'), ('transformer', 'h', '17', 'attention', 'wo', 'kernel'), ('transformer', 'h', '17', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '17', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '17', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '17', 'attention_norm', 'kernel'), ('transformer', 'h', '17', 'ffn_norm', 'kernel'), ('transformer', 'h', '18', 'attention', 'wq', 'kernel'), ('transformer', 'h', '18', 'attention', 'wk', 'kernel'), ('transformer', 'h', '18', 'attention', 'wv', 'kernel'), ('transformer', 'h', '18', 'attention', 'wo', 'kernel'), ('transformer', 'h', '18', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '18', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '18', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '18', 'attention_norm', 'kernel'), ('transformer', 'h', '18', 'ffn_norm', 'kernel'), ('transformer', 'h', '19', 'attention', 'wq', 'kernel'), ('transformer', 'h', '19', 'attention', 'wk', 'kernel'), ('transformer', 'h', '19', 'attention', 'wv', 'kernel'), ('transformer', 'h', '19', 'attention', 'wo', 'kernel'), ('transformer', 'h', '19', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '19', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '19', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '19', 'attention_norm', 'kernel'), ('transformer', 'h', '19', 'ffn_norm', 'kernel'), ('transformer', 'h', '20', 'attention', 'wq', 'kernel'), ('transformer', 'h', '20', 'attention', 'wk', 'kernel'), ('transformer', 'h', '20', 'attention', 'wv', 'kernel'), ('transformer', 'h', '20', 'attention', 'wo', 'kernel'), ('transformer', 'h', '20', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '20', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '20', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '20', 'attention_norm', 'kernel'), ('transformer', 'h', '20', 'ffn_norm', 'kernel'), ('transformer', 'h', '21', 'attention', 'wq', 'kernel'), ('transformer', 'h', '21', 'attention', 'wk', 'kernel'), ('transformer', 'h', '21', 'attention', 'wv', 'kernel'), ('transformer', 'h', '21', 'attention', 'wo', 'kernel'), ('transformer', 'h', '21', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '21', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '21', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '21', 'attention_norm', 'kernel'), ('transformer', 'h', '21', 'ffn_norm', 'kernel'), ('transformer', 'h', '22', 'attention', 'wq', 'kernel'), ('transformer', 'h', '22', 'attention', 'wk', 'kernel'), ('transformer', 'h', '22', 'attention', 'wv', 'kernel'), ('transformer', 'h', '22', 'attention', 'wo', 'kernel'), ('transformer', 'h', '22', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '22', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '22', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '22', 'attention_norm', 'kernel'), ('transformer', 'h', '22', 'ffn_norm', 'kernel'), ('transformer', 'h', '23', 'attention', 'wq', 'kernel'), ('transformer', 'h', '23', 'attention', 'wk', 'kernel'), ('transformer', 'h', '23', 'attention', 'wv', 'kernel'), ('transformer', 'h', '23', 'attention', 'wo', 'kernel'), ('transformer', 'h', '23', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '23', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '23', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '23', 'attention_norm', 'kernel'), ('transformer', 'h', '23', 'ffn_norm', 'kernel'), ('transformer', 'h', '24', 'attention', 'wq', 'kernel'), ('transformer', 'h', '24', 'attention', 'wk', 'kernel'), ('transformer', 'h', '24', 'attention', 'wv', 'kernel'), ('transformer', 'h', '24', 'attention', 'wo', 'kernel'), ('transformer', 'h', '24', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '24', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '24', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '24', 'attention_norm', 'kernel'), ('transformer', 'h', '24', 'ffn_norm', 'kernel'), ('transformer', 'h', '25', 'attention', 'wq', 'kernel'), ('transformer', 'h', '25', 'attention', 'wk', 'kernel'), ('transformer', 'h', '25', 'attention', 'wv', 'kernel'), ('transformer', 'h', '25', 'attention', 'wo', 'kernel'), ('transformer', 'h', '25', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '25', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '25', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '25', 'attention_norm', 'kernel'), ('transformer', 'h', '25', 'ffn_norm', 'kernel'), ('transformer', 'h', '26', 'attention', 'wq', 'kernel'), ('transformer', 'h', '26', 'attention', 'wk', 'kernel'), ('transformer', 'h', '26', 'attention', 'wv', 'kernel'), ('transformer', 'h', '26', 'attention', 'wo', 'kernel'), ('transformer', 'h', '26', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '26', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '26', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '26', 'attention_norm', 'kernel'), ('transformer', 'h', '26', 'ffn_norm', 'kernel'), ('transformer', 'h', '27', 'attention', 'wq', 'kernel'), ('transformer', 'h', '27', 'attention', 'wk', 'kernel'), ('transformer', 'h', '27', 'attention', 'wv', 'kernel'), ('transformer', 'h', '27', 'attention', 'wo', 'kernel'), ('transformer', 'h', '27', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '27', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '27', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '27', 'attention_norm', 'kernel'), ('transformer', 'h', '27', 'ffn_norm', 'kernel'), ('transformer', 'h', '28', 'attention', 'wq', 'kernel'), ('transformer', 'h', '28', 'attention', 'wk', 'kernel'), ('transformer', 'h', '28', 'attention', 'wv', 'kernel'), ('transformer', 'h', '28', 'attention', 'wo', 'kernel'), ('transformer', 'h', '28', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '28', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '28', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '28', 'attention_norm', 'kernel'), ('transformer', 'h', '28', 'ffn_norm', 'kernel'), ('transformer', 'h', '29', 'attention', 'wq', 'kernel'), ('transformer', 'h', '29', 'attention', 'wk', 'kernel'), ('transformer', 'h', '29', 'attention', 'wv', 'kernel'), ('transformer', 'h', '29', 'attention', 'wo', 'kernel'), ('transformer', 'h', '29', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '29', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '29', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '29', 'attention_norm', 'kernel'), ('transformer', 'h', '29', 'ffn_norm', 'kernel'), ('transformer', 'h', '30', 'attention', 'wq', 'kernel'), ('transformer', 'h', '30', 'attention', 'wk', 'kernel'), ('transformer', 'h', '30', 'attention', 'wv', 'kernel'), ('transformer', 'h', '30', 'attention', 'wo', 'kernel'), ('transformer', 'h', '30', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '30', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '30', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '30', 'attention_norm', 'kernel'), ('transformer', 'h', '30', 'ffn_norm', 'kernel'), ('transformer', 'h', '31', 'attention', 'wq', 'kernel'), ('transformer', 'h', '31', 'attention', 'wk', 'kernel'), ('transformer', 'h', '31', 'attention', 'wv', 'kernel'), ('transformer', 'h', '31', 'attention', 'wo', 'kernel'), ('transformer', 'h', '31', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '31', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '31', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '31', 'attention_norm', 'kernel'), ('transformer', 'h', '31', 'ffn_norm', 'kernel'), ('transformer', 'h', '32', 'attention', 'wq', 'kernel'), ('transformer', 'h', '32', 'attention', 'wk', 'kernel'), ('transformer', 'h', '32', 'attention', 'wv', 'kernel'), ('transformer', 'h', '32', 'attention', 'wo', 'kernel'), ('transformer', 'h', '32', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '32', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '32', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '32', 'attention_norm', 'kernel'), ('transformer', 'h', '32', 'ffn_norm', 'kernel'), ('transformer', 'h', '33', 'attention', 'wq', 'kernel'), ('transformer', 'h', '33', 'attention', 'wk', 'kernel'), ('transformer', 'h', '33', 'attention', 'wv', 'kernel'), ('transformer', 'h', '33', 'attention', 'wo', 'kernel'), ('transformer', 'h', '33', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '33', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '33', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '33', 'attention_norm', 'kernel'), ('transformer', 'h', '33', 'ffn_norm', 'kernel'), ('transformer', 'h', '34', 'attention', 'wq', 'kernel'), ('transformer', 'h', '34', 'attention', 'wk', 'kernel'), ('transformer', 'h', '34', 'attention', 'wv', 'kernel'), ('transformer', 'h', '34', 'attention', 'wo', 'kernel'), ('transformer', 'h', '34', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '34', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '34', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '34', 'attention_norm', 'kernel'), ('transformer', 'h', '34', 'ffn_norm', 'kernel'), ('transformer', 'h', '35', 'attention', 'wq', 'kernel'), ('transformer', 'h', '35', 'attention', 'wk', 'kernel'), ('transformer', 'h', '35', 'attention', 'wv', 'kernel'), ('transformer', 'h', '35', 'attention', 'wo', 'kernel'), ('transformer', 'h', '35', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '35', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '35', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '35', 'attention_norm', 'kernel'), ('transformer', 'h', '35', 'ffn_norm', 'kernel'), ('transformer', 'h', '36', 'attention', 'wq', 'kernel'), ('transformer', 'h', '36', 'attention', 'wk', 'kernel'), ('transformer', 'h', '36', 'attention', 'wv', 'kernel'), ('transformer', 'h', '36', 'attention', 'wo', 'kernel'), ('transformer', 'h', '36', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '36', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '36', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '36', 'attention_norm', 'kernel'), ('transformer', 'h', '36', 'ffn_norm', 'kernel'), ('transformer', 'h', '37', 'attention', 'wq', 'kernel'), ('transformer', 'h', '37', 'attention', 'wk', 'kernel'), ('transformer', 'h', '37', 'attention', 'wv', 'kernel'), ('transformer', 'h', '37', 'attention', 'wo', 'kernel'), ('transformer', 'h', '37', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '37', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '37', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '37', 'attention_norm', 'kernel'), ('transformer', 'h', '37', 'ffn_norm', 'kernel'), ('transformer', 'h', '38', 'attention', 'wq', 'kernel'), ('transformer', 'h', '38', 'attention', 'wk', 'kernel'), ('transformer', 'h', '38', 'attention', 'wv', 'kernel'), ('transformer', 'h', '38', 'attention', 'wo', 'kernel'), ('transformer', 'h', '38', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '38', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '38', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '38', 'attention_norm', 'kernel'), ('transformer', 'h', '38', 'ffn_norm', 'kernel'), ('transformer', 'h', '39', 'attention', 'wq', 'kernel'), ('transformer', 'h', '39', 'attention', 'wk', 'kernel'), ('transformer', 'h', '39', 'attention', 'wv', 'kernel'), ('transformer', 'h', '39', 'attention', 'wo', 'kernel'), ('transformer', 'h', '39', 'feed_forward', 'w1', 'kernel'), ('transformer', 'h', '39', 'feed_forward', 'w2', 'kernel'), ('transformer', 'h', '39', 'feed_forward', 'w3', 'kernel'), ('transformer', 'h', '39', 'attention_norm', 'kernel'), ('transformer', 'h', '39', 'ffn_norm', 'kernel'), ('lm_head', 'kernel')])
sorry i explained a part of it wrong you should not use this
from flax.traverse_util import unflatten_dict, flatten_dict
flax_params = llama_convert_hf_to_flax(state_dict, num_hidden_layers=40, num_attention_heads=40, hidden_size=5120,device = device)
flax_params = flatten_dict(flax_params)
pt_params = llama_convert_flax_to_pt(flax_params, n_layers=40, dim=5120, num_attention_heads=40)
use this instead
from flax.traverse_util import unflatten_dict, flatten_dict
flax_params = llama_convert_hf_to_flax(state_dict, num_hidden_layers=40, num_attention_heads=40, hidden_size=5120,device = device)
flax_params = flatten_dict(flax_params,sep='.')
pt_params = llama_convert_flax_to_pt(flax_params, n_layers=40, dim=5120, num_attention_heads=40)
Docs Are available at https://erfanzar.github.io/EasyDeL/docs/
thanks, I will keep testing when I have time.
import error
Traceback (most recent call last):
File "EasyDeL/test_llama.py", line 1, in
File "/anaconda3/envs/easy/lib/python3.11/typing.py", line 355, in inner
return func(*args, **kwds)
^^^^^^^^^^^^^^^^^^^
File "/anaconda3/envs/easy/lib/python3.11/typing.py", line 478, in __getitem__
return self._getitem(self, parameters)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/anaconda3/envs/easy/lib/python3.11/typing.py", line 700, in Optional
arg = _type_check(parameters, f"{self} requires a single type.")
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/anaconda3/envs/easy/lib/python3.11/typing.py", line 197, in _type_check
raise TypeError(f"{msg} Got {arg!r:.100}.")
TypeError: typing.Optional requires a single type. Got (None, <class 'jax._src.lax.lax.Precision'>).
fixed im sorry for such error :)
if train use mult-host, does the dataset need any additional processing?
And
total_batch_size=FLAGS.batch_size,
this total_batch_size is sum of per host batch_size? for example, total_batch_size = 16, and there are 2 nodes all have 8 gpus. Then per gpu batch_size is 1? per node batch_size is 8?
Is this right?
yes for using easydel you should preprocess you dataset you should pass the tokenized dataset that contains input_ids and attention mask
and for batch size you pass the batch size for each step being multiplied to number of gradient accumulation steps for example imagine that you have passed batch size of 8 to trainer with gradient accumulation 8 the total batch size for data loader become 64 and if you have 2 hosts this will become 32 batch size for each host and if you have 8 GPUs per each machine this will become 4 batch per-each GPU
warning: Linking two modules of different target triples: 'LLVMDialectModule' is 'nvptx64-nvidia-gpulibs' whereas '' is 'nvptx64-nvidia-cuda' Does this warning have any impact?
what is use_pjit_attention_force ? What is the difference between use_pjit_attention_force=false and use_pjit_attention_force=true?
and use_flash_attention seems not work, I found that their speed is the same whether set to true or false.
I set max_sequence_length = 10240, had this error: ValueError: Incompatible shapes for broadcasting: (1, 1, 1, 10240) and requested shape (1, 1, 8192, 8192) I suppose because my model max_position_embeddings=8192?
And when use large sequence_length occurs loss=nan.
Yes you are right you should change you model max length
use_flash_attention seems not work, I found that their speed is the same whether set to true or false. And what is use_pjit_attention_force ?
I'm really looking forward to Mojo version. Is Mojo a replacement for Jax, or can they work together?
Actually mojo is more native at least the version that im creating and its works without any imported libraries in mojo since mojo is fast, native and compiled language i coded everything unique and the only library in use from python is os library only to read the size of check points (mojo don't have any built in I/O library)
that was awesome. you want make a framework like tensorflow or pytorch based on mojo. This requires a significant amount of work.
I dont konw how to train use mult-host. Can you gave a example? thank you.