Closed s-smits closed 1 week ago
Hi, you can not use the flash attention mechanism with sequence sharding strategies and it will crash make sure that you are using FSDP sharding instead of SP
would this mean I don't have to state the PartitionSpec set-up explicitly again because it's already defined through the EasyDeL library support?
Actually yes, but you can change this nature also by just using ed.PartitionAxis and pass that to the model you're trying to load or append that to the module config
at the moment I'm trying to figure out some bugs on the NNX version of the project, ill try to run your code today or tomorrow.
Thank you for your quick reply. I've simplified my code a bit:
import os
import sys
import logging
# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logging.info("Script started")
# Set necessary environment variables
os.environ['JAX_PLATFORMS'] = ''
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'
logging.info("Environment variables set")
try:
import jax
logging.info(f"JAX imported, version: {jax.__version__}")
# Initialize JAX's distributed system
jax.distributed.initialize()
logging.info("JAX distributed initialized")
logging.info(f"Number of devices: {jax.device_count()}")
logging.info(f"Devices: {jax.devices()}")
# Rest of your imports
import jax.numpy as jnp
import easydel as ed
from easydel import (
AutoEasyDeLModelForCausalLM,
TrainArguments,
CausalLanguageModelTrainer,
EasyDeLOptimizers,
EasyDeLSchedulers,
EasyDeLGradientCheckPointers,
get_modules_by_type,
AttentionMechanisms
)
from datasets import load_dataset
from flax.core import FrozenDict
logging.info("All modules imported successfully")
jax.print_environment_info()
pretrained_model_name_or_path = "ssmits/Falcon2-5.5B-Dutch"
# pretrained_model_name_or_path = "Qwen/Qwen2-7B"
max_length = 2048
sharding_axis_dims = (1, -1, 1, 1)
partition_axis = ed.PartitionAxis()
input_shape = (1, max_length)
attn_mechanism = AttentionMechanisms.sharded_vanilla
dtype = jnp.bfloat16
# Load and split the dataset
#import multiprocessing
#num_cpus = multiprocessing.cpu_count()
logging.info("Loading dataset...")
train_dataset = load_dataset("BramVanroy/occiglot-fineweb-v0.5-nl", split="train", streaming=True)
logging.info("Dataset loaded successfully")
logging.info("Loading model...")
model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
device=jax.devices()[0], # USE TPU0
# device=jax.devices('cpu')[0], #USE CPU0
input_shape=input_shape,
# device_map = "auto",
# auto_shard_params=True,
sharding_axis_dims=sharding_axis_dims,
verbose_params=True,
config_kwargs=dict(
use_scan_mlp=False,
attn_mechanism=attn_mechanism,
partition_axis=partition_axis
),
partition_axis=partition_axis,
param_dtype=dtype,
)
logging.info("Model loaded successfully")
config = model.config
config.freq_max_position_embeddings = config.max_position_embeddings
config.max_position_embeddings = max_length
config.c_max_position_embeddings = config.max_position_embeddings
model.config.add_basic_configurations(
attn_mechanism=attn_mechanism, shard_attention_computation=True,
)
# model.config.add_basic_configurations(
# attn_mechanism="flash", # Using Flash Attention here you can simply just set this to normal or ring
# block_b=1,
# block_q=128,
# block_k=128,
# block_k_major=128,
# )
# First, define the basic parameters we know
total_samples = 16146000 # Replace with the actual number of samples in your dataset
total_batch_size = 32
num_train_epochs = 1
# Calculate the number of training steps
steps_per_epoch = total_samples // total_batch_size
max_training_steps = steps_per_epoch * num_train_epochs
logging.info("Setting up training arguments...")
# Now define TrainArguments with the calculated max_steps
train_args = TrainArguments(
model_class=get_modules_by_type(model.config.model_type)[1],
configs_to_initialize_model_class={
"config": model.config,
"dtype": dtype,
"param_dtype": dtype,
"input_shape": input_shape
},
init_input_shape=input_shape,
dtype=dtype,
param_dtype=dtype,
custom_rule=model.config.get_partition_rules(True),
sharding_array=sharding_axis_dims,
do_shard_fns=True,
backend="tpu",
model_name="Qwen-Tune",
num_train_epochs=num_train_epochs,
learning_rate=5e-5,
learning_rate_end=7e-6,
warmup_steps=1000,
max_training_steps=max_training_steps,
optimizer=EasyDeLOptimizers.ADAMW,
scheduler=EasyDeLSchedulers.WARM_UP_LINEAR,
weight_decay=0.1,
z_loss=0.0001,
label_smoothing_factor=float(0),
total_batch_size=total_batch_size,
save_steps=2000,
save_total_limit=1,
do_last_save=True,
max_sequence_length=max_length,
gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
gradient_accumulation_steps=4,
loss_re_mat="",
force_batch_and_gradient_accumulation_steps_calculation=False,
step_start_point=0,
wandb_entity=None
)
logging.info("Training arguments set up successfully")
logging.info("Creating trainer...")
# Create the trainer
trainer = CausalLanguageModelTrainer(
train_args,
train_dataset.shuffle(),
checkpoint_path=None
)
logging.info("Trainer created successfully")
model_parameters = FrozenDict({"params": params})
logging.info("Starting training...")
output = trainer.train(
model_parameters=model_parameters, # pass this as none in case of resuming from last checkpoint
state=None
)
logging.info("Training completed successfully")
saved_model_location = f"{str(train_args.get_path())}/{output.last_save_file_name}"
logging.info(f"Model saved at: {saved_model_location}")
except Exception as e:
logging.exception(f"An error occurred: {str(e)}")
sys.exit(1)
logging.info("Script completed successfully")
Converting Model: 0%| | 0/172 [00:00<?, ?it/s]Traceback (most recent call last): File "/nfs_share/tpu-training-dutch/train_shard.py", line 58, in <module> model, params = AutoEasyDeLModelForCausalLM.from_pretrained( File "/home/air/.local/lib/python3.10/site-packages/easydel/modules/auto_easydel_model.py", line 588, in from_pretrained return cls._from_torch( File "/home/air/.local/lib/python3.10/site-packages/easydel/modules/auto_easydel_model.py", line 770, in _from_torch params = trf( File "/home/air/.local/lib/python3.10/site-packages/easydel/transform/easydel_transform.py", line 184, in huggingface_to_easydel pt2jax(tensor), dtype File "/home/air/.local/lib/python3.10/site-packages/easydel/transform/utils.py", line 107, in pt2jax return jax.numpy.asarray(x.detach().cpu().numpy()) File "/home/air/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3289, in asarray return array(a, dtype=dtype, copy=bool(copy), order=order) File "/home/air/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3133, in array return jax.device_put(object) File "/home/air/.local/lib/python3.10/site-packages/jax/_src/api.py", line 2471, in device_put out_flat = dispatch.device_put_p.bind( File "/home/air/.local/lib/python3.10/site-packages/jax/_src/core.py", line 416, in bind return self.bind_with_trace(find_top_trace(args), args, params) File "/home/air/.local/lib/python3.10/site-packages/jax/_src/core.py", line 420, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/air/.local/lib/python3.10/site-packages/jax/_src/core.py", line 921, in process_primitive return primitive.impl(*tracers, **params) File "/home/air/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 496, in _batched_device_put_impl shard_arg_results = pxla.shard_args(shard_arg_shardings, shard_arg_xs) File "/home/air/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper return func(*args, **kwargs) File "/home/air/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 119, in shard_args return shard_arg_handlers[type(arg)]([arg], shardings) File "/home/air/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 170, in _shard_array results.append(batched_device_put(aval, sharding, shards, devices)) File "/home/air/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 195, in batched_device_put return xc.batched_device_put(aval, sharding, xs, list(devices), committed) jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Cannot copy array to non-addressable device TFRT_CPU_0
I think this is a JAX bug, but would like to know if I can prevent this somehow. It does not load after loading the wandb, there is some CPU activity but now much.
Using
echo 'python3 /nfs_share/tpu-training-dutch/train_shard.py' | podrun -iw
do this at the start of importing
import os
os.environ["EASYDEL_AUTO"]="false"
import jax
jax.print_environment_info()
and check if it fix that
After the weekend, thank you!
Now it works for Qwen with:
import jax
import easydel as ed
import jax.numpy as jnp
from easydel import (
AutoEasyDeLModelForCausalLM,
TrainArguments,
CausalLanguageModelTrainer,
EasyDeLOptimizers,
EasyDeLSchedulers,
EasyDeLGradientCheckPointers,
get_modules_by_type,
AttentionMechanisms
)
from flax.core import FrozenDict
import wandb
import numpy as np
from dataset_utils import load_and_process_dataset
jax.print_environment_info()
# os.environ['WANDB_DISABLED'] = 'true'
# wandb.init(project="EasyDeL-Qwen-Tune", entity="safemantic")
#pretrained_model_name_or_path = "ssmits/Falcon2-5.5B-Dutch"
pretrained_model_name_or_path = "Qwen/Qwen2-7B-Instruct"
max_length = 4096
sharding_axis_dims = (1, -1, 1, 1)
partition_axis = ed.PartitionAxis()
input_shape = (1, max_length)
attn_mechanism = AttentionMechanisms.sharded_vanilla
dtype = jnp.bfloat16
# ed.AttentionModule.test_attentions(axis_dims=sharding_axis_dims) # you can test the attention modules to find the best one which works for you
# Use the new function to load and process the dataset
tokenized_dataset = load_and_process_dataset("ssmits/processed-falcon-dutch-dataset", max_length=max_length)
model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
pretrained_model_name_or_path,
device=jax.devices('cpu')[0],
input_shape=input_shape,
# device_map = "auto",
# auto_shard_params=True,
sharding_axis_dims=sharding_axis_dims,
verbose_params=True,
config_kwargs=dict(
use_scan_mlp=False,
attn_mechanism=attn_mechanism,
partition_axis=partition_axis
),
partition_axis=partition_axis,
param_dtype=dtype,
)
config = model.config
config.freq_max_position_embeddings = config.max_position_embeddings
config.max_position_embeddings = max_length
config.c_max_position_embeddings = config.max_position_embeddings
model.config.add_basic_configurations(
attn_mechanism=attn_mechanism, shard_attention_computation=True,
)
# model.config.add_basic_configurations(
# attn_mechanism="flash", # Using Flash Attention here you can simply just set this to normal or ring
# block_b=1,
# block_q=128,
# block_k=128,
# block_k_major=128,
# )
# Add the monkey patch
def patched_log_metrics(self, metrics, step):
wandb_metrics = {}
for key, value in metrics.items():
if isinstance(value, (list, tuple, np.ndarray, jnp.ndarray)):
wandb_metrics[key] = wandb.Histogram(np.array(value))
else:
wandb_metrics[key] = value
wandb.log(wandb_metrics, step=step)
# Apply the monkey patch
TrainArguments.log_metrics = patched_log_metrics
train_args = TrainArguments(
model_class=get_modules_by_type(model.config.model_type)[1],
configs_to_initialize_model_class={
"config": model.config,
"dtype": dtype,
"param_dtype": dtype,
"input_shape": input_shape
},
init_input_shape=input_shape,
dtype=dtype,
param_dtype=dtype,
custom_rule=model.config.get_partition_rules(True),
sharding_array=sharding_axis_dims,
do_shard_fns=True,
backend="tpu",
model_name="Falcon-Tune",
num_train_epochs=1,
learning_rate=5e-5,
learning_rate_end=7e-6,
warmup_steps=2000,
optimizer=EasyDeLOptimizers.ADAMW,
scheduler=EasyDeLSchedulers.WARM_UP_LINEAR,
weight_decay=0.1,
z_loss=0.0001,
label_smoothing_factor=float(0),
total_batch_size=8,
save_steps=2000,
max_training_steps=100000,
save_total_limit=1,
do_last_save=True,
max_sequence_length=max_length,
gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
gradient_accumulation_steps=4,
loss_re_mat="",
force_batch_and_gradient_accumulation_steps_calculation=False,
step_start_point=0,
# training_time=None,"180Min", # Set training limit time to 10 hours you can set this to None
wandb_entity="safemantic",
# Read docs for more and better understanding of options
)
# Print 5 most important training arguments
print(f"1. Learning Rate: {train_args.learning_rate}")
print(f"2. Number of Training Epochs: {train_args.num_train_epochs}")
print(f"3. Total Batch Size: {train_args.total_batch_size}")
print(f"4. Max Training Steps: {train_args.max_training_steps}")
print(f"5. Gradient Accumulation Steps: {train_args.gradient_accumulation_steps}")
trainer = CausalLanguageModelTrainer(
train_args,
tokenized_dataset.shuffle().shuffle(),
checkpoint_path=None # In Case of resuming from a checkpoint you can pass checkpoint path here and simply just
# don't create and run model and params steps above.
)
model_parameters = FrozenDict({"params": params})
output = trainer.train(
model_parameters=model_parameters, # pass this as none in case of resuming from last checkpoint
state=None
)
saved_model_location = f"{str(train_args.get_path())}/{output.last_save_file_name}"
print("Hey im Here in case you want to load me :", saved_model_location)
However, for Falcon, it still does not work, even with batch size 4 or 1. My suspicion is that Falcon2 is not correctly integrated in EasyDeL or that something goes wrong with sharded_vanilla. I'll try with Falcon-7B and even Falcon-40B to see if Falcon-11B with a newer architecture could be the problem.
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/home/air/workspace/train.py", line 141, in <module>
output = trainer.train(
File "/home/air/workspace/EasyDel/src/easydel/trainers/causal_language_model_trainer/causal_language_model_trainer.py", line 826, in train
) = self.sharded_train_step_function(sharded_state, batch)
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 63.80G of 30.75G hbm. Exceeded hbm capacity by 33.05G.
Total hbm usage >= 65.05G:
reserved 1.25G
program 63.80G
arguments 0B
Output size 0B; shares 0B with arguments.
Program hbm requirement 63.80G:
global 1.14M
scoped 4.56M
HLO temp 63.79G (100.0% utilization: Unpadded (63.58G) Padded (63.58G), 0.3% fragmentation (210.89M))
Largest program allocations in hbm:
1. Size: 2.00G
Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/0/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
Unpadded size: 2.00G
XLA label: fusion.1265.remat3 = fusion(custom-call.224, custom-call.241, custom-call.22), kind=kOutput, calls=fused_computation.957.clone.clone.clone
Allocation type: HLO temp
==========================
2. Size: 2.00G
Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/27/self_attention/sub" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1069
Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
Unpadded size: 2.00G
XLA label: fusion.6346.remat2 = fusion(fusion.1211.remat), kind=kOutput, calls=fused_computation.5587.clone.clone
Allocation type: HLO temp
==========================
3. Size: 2.00G
Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/27/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
Unpadded size: 2.00G
XLA label: fusion.1211.remat = fusion(custom-call.240, all-reduce.366, custom-call.38), kind=kOutput, calls=fused_computation.903.clone
Allocation type: HLO temp
==========================
4. Size: 2.00G
Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/25/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
Unpadded size: 2.00G
XLA label: fusion.1215.remat3 = fusion(custom-call.238, custom-call.255, custom-call.36), kind=kOutput, calls=fused_computation.907.clone.clone.clone
Allocation type: HLO temp
==========================
5. Size: 2.00G
Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/24/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
Unpadded size: 2.00G
XLA label: fusion.1217.remat3 = fusion(custom-call.237, custom-call.254, custom-call.35), kind=kOutput, calls=fused_computation.909.clone.clone.clone
Allocation type: HLO temp
==========================
6. Size: 2.00G
Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/23/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
Unpadded size: 2.00G
XLA label: fusion.1219.remat3 = fusion(custom-call.236, custom-call.253, custom-call.34), kind=kOutput, calls=fused_computation.911.clone.clone.clone
Allocation type: HLO temp
==========================
7. Size: 2.00G
Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/22/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
Unpadded size: 2.00G
XLA label: fusion.1221.remat3 = fusion(custom-call.235, custom-call.252, custom-call.33), kind=kOutput, calls=fused_computation.913.clone.clone.clone
Allocation type: HLO temp
==========================
8. Size: 2.00G
Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/21/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
Unpadded size: 2.00G
XLA label: fusion.1223.remat3 = fusion(custom-call.234, custom-call.251, custom-call.32), kind=kOutput, calls=fused_computation.915.clone.clone.clone
Allocation type: HLO temp
==========================
9. Size: 2.00G
Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/20/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
Unpadded size: 2.00G
XLA label: fusion.1225.remat3 = fusion(custom-call.233, custom-call.250, custom-call.31), kind=kOutput, calls=fused_computation.917.clone.clone.clone
Allocation type: HLO temp
==========================
10. Size: 2.00G
Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/19/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
Unpadded size: 2.00G
XLA label: fusion.1227.remat3 = fusion(custom-call.232, custom-call.249, custom-call.30), kind=kOutput, calls=fused_computation.919.clone.clone.clone
Allocation type: HLO temp
==========================
11. Size: 2.00G
Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/18/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
Unpadded size: 2.00G
XLA label: fusion.1229.remat3 = fusion(custom-call.231, custom-call.248, custom-call.29), kind=kOutput, calls=fused_computation.921.clone.clone.clone
Allocation type: HLO temp
==========================
12. Size: 2.00G
Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/17/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
Unpadded size: 2.00G
XLA label: fusion.1231.remat3 = fusion(custom-call.230, custom-call.247, custom-call.28), kind=kOutput, calls=fused_computation.923.clone.clone.clone
Allocation type: HLO temp
==========================
13. Size: 2.00G
Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/16/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
Unpadded size: 2.00G
XLA label: fusion.1233.remat3 = fusion(custom-call.229, custom-call.246, custom-call.27), kind=kOutput, calls=fused_computation.925.clone.clone.clone
Allocation type: HLO temp
==========================
14. Size: 2.00G
Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/14/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
Unpadded size: 2.00G
XLA label: fusion.1237.remat3 = fusion(custom-call.227, custom-call.245, gte.remat.76), kind=kOutput, calls=fused_computation.929.clone.clone.clone
Allocation type: HLO temp
==========================
15. Size: 2.00G
Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/13/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
Unpadded size: 2.00G
XLA label: fusion.1239.remat3 = fusion(copy-done.108, copy-done.125, custom-call.25), kind=kOutput, calls=fused_computation.931.clone.clone.clone
Allocation type: HLO temp
==========================
16. Size: 2.00G
Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/12/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
Unpadded size: 2.00G
XLA label: fusion.1241.remat3 = fusion(copy-done.106, all-reduce.171, custom-call.25), kind=kOutput, calls=fused_computation.933.clone.clone.clone
Allocation type: HLO temp
==========================
17. Size: 2.00G
Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/11/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
Unpadded size: 2.00G
XLA label: fusion.1243.remat3 = fusion(get-tuple-element.3954, custom-call.244, custom-call.25), kind=kOutput, calls=fused_computation.935.clone.clone.clone
Allocation type: HLO temp
==========================
18. Size: 2.00G
Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/10/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
Unpadded size: 2.00G
XLA label: fusion.1245.remat3 = fusion(get-tuple-element.3951, copy-done.121, custom-call.25), kind=kOutput, calls=fused_computation.937.clone.clone.clone
Allocation type: HLO temp
==========================
19. Size: 2.00G
Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/9/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
Unpadded size: 2.00G
XLA label: fusion.1247.remat3 = fusion(copy-done.105, copy-done.120, custom-call.25), kind=kOutput, calls=fused_computation.939.clone.clone.clone
Allocation type: HLO temp
==========================
20. Size: 2.00G
Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/8/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
Unpadded size: 2.00G
XLA label: fusion.1249.remat3 = fusion(copy-done.104, copy-done.119, custom-call.25), kind=kOutput, calls=fused_computation.941.clone.clone.clone
Allocation type: HLO temp
==========================
Falcon-7B also does not work, but the focus for me lies on 11B. Could you take a look at it?
sure, im working on that.
Thank you. Did you manage to find the bug?
Yes actually that's fixed but there are still some other issues from new experimental features... they all will be fixed soon but in case that your are not in discord server sorry that i forgot to tell you it's fixed. (Pypi version should work fine)
Your partition specs are wrong Replace weight with kernel and dot separator with /
Will these fixes be implemented in the next version of EasyDeL?
actually they are fixed right now
@s-smits is it working?
Describe the bug Can't train with multiple VM's; TPU v-4-32 It stops after loading the model, won't even load the data Been trying for two days, maybe my set-up is wrong. Really want to know when to use (1, context_window) and when to use (num_devices, context_window) as input_shape. Using tpux with correct IP addresses etc and podrun train.py for distributed training.
UPDATE: The main problem is probably the PartitionSpec / Flash Attention variable/name scheming which has to be exactly correct. If flash attention is working for Falcon-11B, would this mean I don't have to state the PartitionSpec set-up explicitly again because it's already defined through the EasyDeL library support?
To Reproduce