erfanzar / EasyDeL

Accelerate, Optimize performance with streamlined training and serving options with JAX.
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
181 stars 23 forks source link

TPU v4-32 set-up not working #166

Closed s-smits closed 1 week ago

s-smits commented 1 month ago

Describe the bug Can't train with multiple VM's; TPU v-4-32 It stops after loading the model, won't even load the data Been trying for two days, maybe my set-up is wrong. Really want to know when to use (1, context_window) and when to use (num_devices, context_window) as input_shape. Using tpux with correct IP addresses etc and podrun train.py for distributed training.

UPDATE: The main problem is probably the PartitionSpec / Flash Attention variable/name scheming which has to be exactly correct. If flash attention is working for Falcon-11B, would this mean I don't have to state the PartitionSpec set-up explicitly again because it's already defined through the EasyDeL library support?

To Reproduce

# MULTI TPU WORKER TRAINING 4x8 TPU'S VM

import os
import time
import json
import numpy as np
import jax
import jax.numpy as jnp
from jax.experimental import enable_x64
from jax.sharding import PartitionSpec
from huggingface_hub import HfApi
from huggingface_hub.hf_api import HfFolder
from datasets import load_dataset, Dataset
from transformers import AutoTokenizer
from flax.core import FrozenDict
from easydel import (
    AutoEasyDeLModelForCausalLM,
    TrainArguments,
    CausalLanguageModelTrainer,
    EasyDeLOptimizers,
    EasyDeLSchedulers,
    EasyDeLGradientCheckPointers,
    get_modules_by_type
)
from fjformer import GenerateRNG
import wandb

# Enable bfloat16
enable_x64()

# Constants and configurations
HF_TOKEN = "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"
WANDB_TOKEN = "23863bbe34b85b92c7cf348d4969a5c68d025731"
PRETRAINED_MODEL_NAME_OR_PATH = "ssmits/Falcon2-nano-test"

# Save Hugging Face token and login to Weights & Biases
HfFolder.save_token(HF_TOKEN)
wandb.login(key=WANDB_TOKEN)

# Set environment variable to disable tokenizer parallelism
os.environ["TOKENIZERS_PARALLELISM"] = "false"

rng_g = GenerateRNG()
api = HfApi()

def load_and_convert_model(pretrained_model_name_or_path):
    num_workers = len(jax.devices())
    input_shape = (num_workers, 4096)  

    model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path,
        device=jax.devices('cpu')[0],
        input_shape=input_shape,
        device_map="auto",
        dtype=jnp.bfloat16,
        param_dtype=jnp.bfloat16
    )

    config = model.config

    # Assuming params are already in bfloat16
    model_parameters = FrozenDict({"params": params})

    config.add_basic_configurations(
        attn_mechanism="flash",
        block_b=1,
        block_q=128,
        block_k=128,
        block_k_major=128,
        block_q_major_dkv=128,
        block_k_major_dkv=128,
        block_k_major_dq=128,
        block_k_dkv=128,
        block_q_dkv=128,
        block_k_dq=128,
        block_q_dq=128
    )

    config.attn_dtype = jnp.bfloat16
    config.use_flash_attention = True
    config.parallel_attn = True

    config.freq_max_position_embeddings = config.max_position_embeddings
    config.max_position_embeddings = 4096
    config.c_max_position_embeddings = config.max_position_embeddings

    tokenizer = AutoTokenizer.from_pretrained(
        pretrained_model_name_or_path,
        trust_remote_code=True
    )

    if tokenizer.pad_token is None:
        tokenizer.pad_token = tokenizer.eos_token

    config.get_partition_rules = lambda use_pjit: get_partition_rules(config, use_pjit)

    return model, model_parameters, config, tokenizer

def get_partition_rules(config, use_pjit):
    return (
        ("transformer.word_embeddings.weight", PartitionSpec("tp", "fsdp")),
        ("transformer.h.*.input_layernorm.(weight|bias)", PartitionSpec(None)),
        ("transformer.h.*.self_attention.query_key_value.weight", PartitionSpec("tp", "fsdp")),
        ("transformer.h.*.self_attention.dense.weight", PartitionSpec("fsdp", "tp")),
        ("transformer.h.*.mlp.dense_h_to_4h.weight", PartitionSpec("tp", "fsdp")),
        ("transformer.h.*.mlp.dense_4h_to_h.weight", PartitionSpec("fsdp", "tp")),
        ("transformer.ln_f.(weight|bias)", PartitionSpec(None)),
        ("lm_head.weight", PartitionSpec("fsdp", "tp")),
        (".*", PartitionSpec(None)),
    )

def custom_json_encoder(obj):
    if hasattr(obj, '__name__'):
        return obj.__name__
    if hasattr(obj, '__dict__'):
        return {k: custom_json_encoder(v) for k, v in obj.__dict__.items()}
    return str(obj)

def print_flash_attention_config(config):
    flash_attention_params = {k: v for k, v in vars(config).items() 
                              if 'flash' in k.lower() or 'block' in k.lower() or 'attn' in k.lower()}

    print(json.dumps(flash_attention_params, indent=2, default=custom_json_encoder))

# Set environment variable to disable tokenizer parallelism
os.environ["TOKENIZERS_PARALLELISM"] = "false"

testing_size = 1000
testing = False
batch_size = 1000

def preprocess_function(text, tokenizer, max_sequence_length):
    try:
        if not isinstance(text, str):
            print(f"Warning: 'text' is not a string. Text: {text}")
            text = str(text)

        tokens = tokenizer.encode(text, add_special_tokens=False, truncation=True, max_length=max_sequence_length*2)

        packed_input_ids = []
        packed_attention_mask = []

        for i in range(0, len(tokens), max_sequence_length):
            chunk = tokens[i:i+max_sequence_length]
            chunk_length = len(chunk)

            if chunk_length < max_sequence_length:
                chunk += [tokenizer.pad_token_id] * (max_sequence_length - chunk_length)

            packed_input_ids.append(chunk)
            packed_attention_mask.append([1] * chunk_length + [0] * (max_sequence_length - chunk_length))

        return packed_input_ids, packed_attention_mask
    except Exception as e:
        print(f"Error in preprocess_function: {str(e)}")
        print(f"Text: {text[:100]}...")  # Print first 100 characters of the text
        return [], []

def preprocess_dataset(tokenizer, max_sequence_length, testing=testing):
    try:
        dataset = load_dataset('occiglot/occiglot-fineweb-v0.5', data_dir='nl', split='train', streaming=True, verification_mode="no_checks", token=HF_TOKEN)

        all_input_ids = []
        all_attention_masks = []

        total_processed = 0
        start_time = time.time()

        for batch in dataset.iter(batch_size=batch_size):
            for i in range(0, len(batch['text']), 3):  # Process in groups of 3 (text, id, metadata)
                text = batch['text'][i]
                packed_ids, packed_mask = preprocess_function(text, tokenizer, max_sequence_length)
                all_input_ids.extend(packed_ids)
                all_attention_masks.extend(packed_mask)

            total_processed += len(batch['text']) // 3
            elapsed_time = time.time() - start_time
            print(f"Processed {total_processed} examples in {elapsed_time:.2f} seconds. Current dataset size: {len(all_input_ids)}")

            if testing and len(all_input_ids) >= testing_size:
                break

        # Trim to exactly testing_size if testing
        if testing:
            all_input_ids = all_input_ids[:testing_size]
            all_attention_masks = all_attention_masks[:testing_size]

        # Convert to numpy arrays
        input_ids = np.array(all_input_ids)
        attention_mask = np.array(all_attention_masks)

        print(f"Input IDs shape: {input_ids.shape}, dtype: {input_ids.dtype}")
        print(f"Attention Mask shape: {attention_mask.shape}, dtype: {attention_mask.dtype}")

        # Create a Dataset object
        return Dataset.from_dict({
            'input_ids': input_ids,
            'attention_mask': attention_mask
        })

    except Exception as e:
        print(f"An error occurred while processing the dataset: {str(e)}")
        return None

def patched_log_metrics(self, metrics, step):
    wandb_metrics = {}
    for key, value in metrics.items():
        if isinstance(value, (list, tuple, np.ndarray, jnp.ndarray)):
            wandb_metrics[key] = wandb.Histogram(np.array(value))
        else:
            wandb_metrics[key] = value
    wandb.log(wandb_metrics, step=step)

# Apply the monkey patch
TrainArguments.log_metrics = patched_log_metrics

def create_train_args(learning_rate, num_workers, max_sequence_length, config):
    return TrainArguments(
        model_class=get_modules_by_type(config.model_type)[1],
        configs_to_initialize_model_class={
            'config': config,
            'dtype': jnp.bfloat16,
            'param_dtype': jnp.bfloat16,
            'input_shape': (num_workers, max_sequence_length),
        },
        custom_rule=config.get_partition_rules(True),
        model_name="FlashAttentionTest",
        num_train_epochs=1,
        warmup_steps=100,
        optimizer=EasyDeLOptimizers.ADAMW,
        scheduler=EasyDeLSchedulers.LINEAR,
        weight_decay=0.05,
        max_sequence_length=max_sequence_length,
        total_batch_size=32,
        gradient_accumulation_steps=1,
        init_input_shape=(num_workers, max_sequence_length),
        sharding_array=(1, jax.device_count(), 1, 1),
        gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
        dtype=jnp.bfloat16,
        param_dtype=jnp.bfloat16,
        step_start_point=0,
        backend="tpu",
        learning_rate=learning_rate,
        learning_rate_end=learning_rate / 10,
    )

def upload_model(output):
    try:
        api.create_repo('Falcon2-5.5B-Dutch-pretrained-test', token=HF_TOKEN, private=True, exist_ok=True)
        api.upload_file(
            path_or_fileobj=output.checkpoint_path,
            repo_id="ssmits/Falcon2-5.5B-Dutch-pretrained-test",
            repo_type='model',
            path_in_repo=output.last_save_file_name,
            token=HF_TOKEN
        )
    except Exception as e:
        print(f"Failed to upload model: {e}")

def main():
    # Initialize model and tokenizer
    model, model_parameters, config, tokenizer = load_and_convert_model(PRETRAINED_MODEL_NAME_OR_PATH)
    print("Model loaded and converted successfully")
    print_flash_attention_config(config)

    # Process full dataset
    print("Processing full dataset...")
    full_dataset = preprocess_dataset(tokenizer, 4096, testing=False)
    if full_dataset is None:
        print("Failed to process the full dataset. Exiting.")
        return
    print("Dataset processed successfully")

    # Set fixed learning rate
    learning_rate = 1e-5

    # Train model
    print("Creating training arguments...")
    train_args = create_train_args(learning_rate, len(jax.devices()), 4096, config)
    train_args.training_time = "7H"
    print("Training arguments created")

    print("Initializing trainer...")
    trainer = CausalLanguageModelTrainer(train_args, full_dataset, checkpoint_path=None)
    print("Trainer initialized")

    print("Starting training...")
    output = trainer.train(model_parameters=model_parameters, state=None)

    # Upload model if this is the main process
    if jax.process_index() == 0:
        upload_model(output)

    print("Training completed.")

if __name__ == "__main__":
    # Ensure only one worker initializes wandb
    if jax.process_index() == 0:
        wandb.init(project="TPU_TEST")
    main()
erfanzar commented 1 month ago

Hi, you can not use the flash attention mechanism with sequence sharding strategies and it will crash make sure that you are using FSDP sharding instead of SP

would this mean I don't have to state the PartitionSpec set-up explicitly again because it's already defined through the EasyDeL library support?

Actually yes, but you can change this nature also by just using ed.PartitionAxis and pass that to the model you're trying to load or append that to the module config

at the moment I'm trying to figure out some bugs on the NNX version of the project, ill try to run your code today or tomorrow.

s-smits commented 1 month ago

Thank you for your quick reply. I've simplified my code a bit:

import os
import sys
import logging

# Configure logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

logging.info("Script started")

# Set necessary environment variables
os.environ['JAX_PLATFORMS'] = ''
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'
os.environ['HF_HUB_ENABLE_HF_TRANSFER'] = '1'

logging.info("Environment variables set")

try:
    import jax
    logging.info(f"JAX imported, version: {jax.__version__}")

    # Initialize JAX's distributed system
    jax.distributed.initialize()
    logging.info("JAX distributed initialized")

    logging.info(f"Number of devices: {jax.device_count()}")
    logging.info(f"Devices: {jax.devices()}")

    # Rest of your imports
    import jax.numpy as jnp
    import easydel as ed
    from easydel import (
        AutoEasyDeLModelForCausalLM,
        TrainArguments,
        CausalLanguageModelTrainer,
        EasyDeLOptimizers,
        EasyDeLSchedulers,
        EasyDeLGradientCheckPointers,
        get_modules_by_type,
        AttentionMechanisms
    )
    from datasets import load_dataset
    from flax.core import FrozenDict

    logging.info("All modules imported successfully")

    jax.print_environment_info()

    pretrained_model_name_or_path = "ssmits/Falcon2-5.5B-Dutch"
    # pretrained_model_name_or_path = "Qwen/Qwen2-7B"

    max_length = 2048
    sharding_axis_dims = (1, -1, 1, 1)
    partition_axis = ed.PartitionAxis()
    input_shape = (1, max_length)
    attn_mechanism = AttentionMechanisms.sharded_vanilla
    dtype = jnp.bfloat16

    # Load and split the dataset
    #import multiprocessing
    #num_cpus = multiprocessing.cpu_count()
    logging.info("Loading dataset...")
    train_dataset = load_dataset("BramVanroy/occiglot-fineweb-v0.5-nl", split="train", streaming=True)
    logging.info("Dataset loaded successfully")

    logging.info("Loading model...")
    model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
        pretrained_model_name_or_path,
        device=jax.devices()[0],  # USE TPU0
        # device=jax.devices('cpu')[0], #USE CPU0
        input_shape=input_shape,
        # device_map = "auto",
        # auto_shard_params=True,
        sharding_axis_dims=sharding_axis_dims,
        verbose_params=True,
        config_kwargs=dict(
            use_scan_mlp=False,
            attn_mechanism=attn_mechanism,
            partition_axis=partition_axis
        ),
        partition_axis=partition_axis,
        param_dtype=dtype,
    )
    logging.info("Model loaded successfully")

    config = model.config
    config.freq_max_position_embeddings = config.max_position_embeddings
    config.max_position_embeddings = max_length
    config.c_max_position_embeddings = config.max_position_embeddings

    model.config.add_basic_configurations(
        attn_mechanism=attn_mechanism, shard_attention_computation=True,
    )

    # model.config.add_basic_configurations(
    #     attn_mechanism="flash",  # Using Flash Attention here you can simply just set this to normal or ring
    #     block_b=1,
    #     block_q=128,
    #     block_k=128,
    #     block_k_major=128,
    # )

    # First, define the basic parameters we know
    total_samples = 16146000  # Replace with the actual number of samples in your dataset
    total_batch_size = 32
    num_train_epochs = 1

    # Calculate the number of training steps
    steps_per_epoch = total_samples // total_batch_size
    max_training_steps = steps_per_epoch * num_train_epochs

    logging.info("Setting up training arguments...")
    # Now define TrainArguments with the calculated max_steps
    train_args = TrainArguments(
        model_class=get_modules_by_type(model.config.model_type)[1],
        configs_to_initialize_model_class={
            "config": model.config,
            "dtype": dtype,
            "param_dtype": dtype,
            "input_shape": input_shape
        },
        init_input_shape=input_shape,
        dtype=dtype,
        param_dtype=dtype,
        custom_rule=model.config.get_partition_rules(True),
        sharding_array=sharding_axis_dims,
        do_shard_fns=True,
        backend="tpu",
        model_name="Qwen-Tune",
        num_train_epochs=num_train_epochs,
        learning_rate=5e-5,
        learning_rate_end=7e-6,
        warmup_steps=1000,
        max_training_steps=max_training_steps,
        optimizer=EasyDeLOptimizers.ADAMW,
        scheduler=EasyDeLSchedulers.WARM_UP_LINEAR,
        weight_decay=0.1,
        z_loss=0.0001,
        label_smoothing_factor=float(0),
        total_batch_size=total_batch_size,
        save_steps=2000,
        save_total_limit=1,
        do_last_save=True,
        max_sequence_length=max_length,
        gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
        gradient_accumulation_steps=4,
        loss_re_mat="",
        force_batch_and_gradient_accumulation_steps_calculation=False,
        step_start_point=0,
        wandb_entity=None
    )
    logging.info("Training arguments set up successfully")

    logging.info("Creating trainer...")
    # Create the trainer
    trainer = CausalLanguageModelTrainer(
        train_args,
        train_dataset.shuffle(),
        checkpoint_path=None
    )
    logging.info("Trainer created successfully")

    model_parameters = FrozenDict({"params": params})

    logging.info("Starting training...")
    output = trainer.train(
        model_parameters=model_parameters,  # pass this as none in case of resuming from last checkpoint
        state=None
    )
    logging.info("Training completed successfully")

    saved_model_location = f"{str(train_args.get_path())}/{output.last_save_file_name}"
    logging.info(f"Model saved at: {saved_model_location}")

except Exception as e:
    logging.exception(f"An error occurred: {str(e)}")
    sys.exit(1)

logging.info("Script completed successfully")

Converting Model: 0%| | 0/172 [00:00<?, ?it/s]Traceback (most recent call last): File "/nfs_share/tpu-training-dutch/train_shard.py", line 58, in <module> model, params = AutoEasyDeLModelForCausalLM.from_pretrained( File "/home/air/.local/lib/python3.10/site-packages/easydel/modules/auto_easydel_model.py", line 588, in from_pretrained return cls._from_torch( File "/home/air/.local/lib/python3.10/site-packages/easydel/modules/auto_easydel_model.py", line 770, in _from_torch params = trf( File "/home/air/.local/lib/python3.10/site-packages/easydel/transform/easydel_transform.py", line 184, in huggingface_to_easydel pt2jax(tensor), dtype File "/home/air/.local/lib/python3.10/site-packages/easydel/transform/utils.py", line 107, in pt2jax return jax.numpy.asarray(x.detach().cpu().numpy()) File "/home/air/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3289, in asarray return array(a, dtype=dtype, copy=bool(copy), order=order) File "/home/air/.local/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 3133, in array return jax.device_put(object) File "/home/air/.local/lib/python3.10/site-packages/jax/_src/api.py", line 2471, in device_put out_flat = dispatch.device_put_p.bind( File "/home/air/.local/lib/python3.10/site-packages/jax/_src/core.py", line 416, in bind return self.bind_with_trace(find_top_trace(args), args, params) File "/home/air/.local/lib/python3.10/site-packages/jax/_src/core.py", line 420, in bind_with_trace out = trace.process_primitive(self, map(trace.full_raise, args), params) File "/home/air/.local/lib/python3.10/site-packages/jax/_src/core.py", line 921, in process_primitive return primitive.impl(*tracers, **params) File "/home/air/.local/lib/python3.10/site-packages/jax/_src/dispatch.py", line 496, in _batched_device_put_impl shard_arg_results = pxla.shard_args(shard_arg_shardings, shard_arg_xs) File "/home/air/.local/lib/python3.10/site-packages/jax/_src/profiler.py", line 335, in wrapper return func(*args, **kwargs) File "/home/air/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 119, in shard_args return shard_arg_handlers[type(arg)]([arg], shardings) File "/home/air/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 170, in _shard_array results.append(batched_device_put(aval, sharding, shards, devices)) File "/home/air/.local/lib/python3.10/site-packages/jax/_src/interpreters/pxla.py", line 195, in batched_device_put return xc.batched_device_put(aval, sharding, xs, list(devices), committed) jaxlib.xla_extension.XlaRuntimeError: INVALID_ARGUMENT: Cannot copy array to non-addressable device TFRT_CPU_0

I think this is a JAX bug, but would like to know if I can prevent this somehow. It does not load after loading the wandb, there is some CPU activity but now much. Using echo 'python3 /nfs_share/tpu-training-dutch/train_shard.py' | podrun -iw

erfanzar commented 1 month ago

do this at the start of importing

import os
os.environ["EASYDEL_AUTO"]="false"
import jax
jax.print_environment_info()

and check if it fix that

s-smits commented 1 month ago

After the weekend, thank you!

s-smits commented 1 month ago

Now it works for Qwen with:

import jax
import easydel as ed
import jax.numpy as jnp
from easydel import (
    AutoEasyDeLModelForCausalLM,
    TrainArguments,
    CausalLanguageModelTrainer,
    EasyDeLOptimizers,
    EasyDeLSchedulers,
    EasyDeLGradientCheckPointers,
    get_modules_by_type,
    AttentionMechanisms
)
from flax.core import FrozenDict
import wandb
import numpy as np
from dataset_utils import load_and_process_dataset

jax.print_environment_info()
# os.environ['WANDB_DISABLED'] = 'true'
# wandb.init(project="EasyDeL-Qwen-Tune", entity="safemantic")

#pretrained_model_name_or_path = "ssmits/Falcon2-5.5B-Dutch"
pretrained_model_name_or_path = "Qwen/Qwen2-7B-Instruct"
max_length = 4096
sharding_axis_dims = (1, -1, 1, 1)
partition_axis = ed.PartitionAxis()
input_shape = (1, max_length)
attn_mechanism = AttentionMechanisms.sharded_vanilla
dtype = jnp.bfloat16

# ed.AttentionModule.test_attentions(axis_dims=sharding_axis_dims) # you can test the attention modules to find the best one which works for you

# Use the new function to load and process the dataset
tokenized_dataset = load_and_process_dataset("ssmits/processed-falcon-dutch-dataset", max_length=max_length)

model, params = AutoEasyDeLModelForCausalLM.from_pretrained(
    pretrained_model_name_or_path,
    device=jax.devices('cpu')[0],
    input_shape=input_shape,
    # device_map = "auto",
    # auto_shard_params=True,
    sharding_axis_dims=sharding_axis_dims,
    verbose_params=True,
    config_kwargs=dict(
        use_scan_mlp=False,
        attn_mechanism=attn_mechanism,
        partition_axis=partition_axis
    ),
    partition_axis=partition_axis,
    param_dtype=dtype,
)

config = model.config
config.freq_max_position_embeddings = config.max_position_embeddings
config.max_position_embeddings = max_length
config.c_max_position_embeddings = config.max_position_embeddings

model.config.add_basic_configurations(
    attn_mechanism=attn_mechanism, shard_attention_computation=True,
)

# model.config.add_basic_configurations(
#     attn_mechanism="flash",  # Using Flash Attention here you can simply just set this to normal or ring
#     block_b=1,
#     block_q=128,
#     block_k=128,
#     block_k_major=128,
# )

# Add the monkey patch
def patched_log_metrics(self, metrics, step):
    wandb_metrics = {}
    for key, value in metrics.items():
        if isinstance(value, (list, tuple, np.ndarray, jnp.ndarray)):
            wandb_metrics[key] = wandb.Histogram(np.array(value))
        else:
            wandb_metrics[key] = value
    wandb.log(wandb_metrics, step=step)

# Apply the monkey patch
TrainArguments.log_metrics = patched_log_metrics

train_args = TrainArguments(
    model_class=get_modules_by_type(model.config.model_type)[1],
    configs_to_initialize_model_class={
        "config": model.config,
        "dtype": dtype,
        "param_dtype": dtype,
        "input_shape": input_shape
    },
    init_input_shape=input_shape,
    dtype=dtype,
    param_dtype=dtype,
    custom_rule=model.config.get_partition_rules(True),
    sharding_array=sharding_axis_dims,
    do_shard_fns=True,
    backend="tpu",
    model_name="Falcon-Tune",
    num_train_epochs=1,
    learning_rate=5e-5,
    learning_rate_end=7e-6,
    warmup_steps=2000,
    optimizer=EasyDeLOptimizers.ADAMW,
    scheduler=EasyDeLSchedulers.WARM_UP_LINEAR,
    weight_decay=0.1,
    z_loss=0.0001,
    label_smoothing_factor=float(0),
    total_batch_size=8,
    save_steps=2000,
    max_training_steps=100000,
    save_total_limit=1,
    do_last_save=True,
    max_sequence_length=max_length,
    gradient_checkpointing=EasyDeLGradientCheckPointers.NOTHING_SAVEABLE,
    gradient_accumulation_steps=4,
    loss_re_mat="",
    force_batch_and_gradient_accumulation_steps_calculation=False,
    step_start_point=0,
    # training_time=None,"180Min",  # Set training limit time to 10 hours you can set this to None
    wandb_entity="safemantic",
    # Read docs for more and better understanding of options
)

# Print 5 most important training arguments
print(f"1. Learning Rate: {train_args.learning_rate}")
print(f"2. Number of Training Epochs: {train_args.num_train_epochs}")
print(f"3. Total Batch Size: {train_args.total_batch_size}")
print(f"4. Max Training Steps: {train_args.max_training_steps}")
print(f"5. Gradient Accumulation Steps: {train_args.gradient_accumulation_steps}")

trainer = CausalLanguageModelTrainer(
    train_args,
    tokenized_dataset.shuffle().shuffle(),
    checkpoint_path=None  # In Case of resuming from a checkpoint you can pass checkpoint path here and simply just
    # don't create and run model and params steps above.
)

model_parameters = FrozenDict({"params": params})

output = trainer.train(
    model_parameters=model_parameters,  # pass this as none in case of resuming from last checkpoint
    state=None
)

saved_model_location = f"{str(train_args.get_path())}/{output.last_save_file_name}"

print("Hey im Here in case you want to load me :", saved_model_location)

However, for Falcon, it still does not work, even with batch size 4 or 1. My suspicion is that Falcon2 is not correctly integrated in EasyDeL or that something goes wrong with sharded_vanilla. I'll try with Falcon-7B and even Falcon-40B to see if Falcon-11B with a newer architecture could be the problem.

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/home/air/workspace/train.py", line 141, in <module>
    output = trainer.train(
  File "/home/air/workspace/EasyDel/src/easydel/trainers/causal_language_model_trainer/causal_language_model_trainer.py", line 826, in train
    ) = self.sharded_train_step_function(sharded_state, batch)
jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: XLA:TPU compile permanent error. Ran out of memory in memory space hbm. Used 63.80G of 30.75G hbm. Exceeded hbm capacity by 33.05G.

Total hbm usage >= 65.05G:
    reserved          1.25G
    program          63.80G
    arguments            0B

Output size 0B; shares 0B with arguments.

Program hbm requirement 63.80G:
    global            1.14M
    scoped            4.56M
    HLO temp         63.79G (100.0% utilization: Unpadded (63.58G) Padded (63.58G), 0.3% fragmentation (210.89M))

  Largest program allocations in hbm:

  1. Size: 2.00G
     Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/0/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
     Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
     Unpadded size: 2.00G
     XLA label: fusion.1265.remat3 = fusion(custom-call.224, custom-call.241, custom-call.22), kind=kOutput, calls=fused_computation.957.clone.clone.clone
     Allocation type: HLO temp
     ==========================

  2. Size: 2.00G
     Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/27/self_attention/sub" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1069
     Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
     Unpadded size: 2.00G
     XLA label: fusion.6346.remat2 = fusion(fusion.1211.remat), kind=kOutput, calls=fused_computation.5587.clone.clone
     Allocation type: HLO temp
     ==========================

  3. Size: 2.00G
     Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/27/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
     Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
     Unpadded size: 2.00G
     XLA label: fusion.1211.remat = fusion(custom-call.240, all-reduce.366, custom-call.38), kind=kOutput, calls=fused_computation.903.clone
     Allocation type: HLO temp
     ==========================

  4. Size: 2.00G
     Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/25/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
     Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
     Unpadded size: 2.00G
     XLA label: fusion.1215.remat3 = fusion(custom-call.238, custom-call.255, custom-call.36), kind=kOutput, calls=fused_computation.907.clone.clone.clone
     Allocation type: HLO temp
     ==========================

  5. Size: 2.00G
     Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/24/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
     Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
     Unpadded size: 2.00G
     XLA label: fusion.1217.remat3 = fusion(custom-call.237, custom-call.254, custom-call.35), kind=kOutput, calls=fused_computation.909.clone.clone.clone
     Allocation type: HLO temp
     ==========================

  6. Size: 2.00G
     Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/23/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
     Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
     Unpadded size: 2.00G
     XLA label: fusion.1219.remat3 = fusion(custom-call.236, custom-call.253, custom-call.34), kind=kOutput, calls=fused_computation.911.clone.clone.clone
     Allocation type: HLO temp
     ==========================

  7. Size: 2.00G
     Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/22/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
     Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
     Unpadded size: 2.00G
     XLA label: fusion.1221.remat3 = fusion(custom-call.235, custom-call.252, custom-call.33), kind=kOutput, calls=fused_computation.913.clone.clone.clone
     Allocation type: HLO temp
     ==========================

  8. Size: 2.00G
     Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/21/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
     Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
     Unpadded size: 2.00G
     XLA label: fusion.1223.remat3 = fusion(custom-call.234, custom-call.251, custom-call.32), kind=kOutput, calls=fused_computation.915.clone.clone.clone
     Allocation type: HLO temp
     ==========================

  9. Size: 2.00G
     Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/20/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
     Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
     Unpadded size: 2.00G
     XLA label: fusion.1225.remat3 = fusion(custom-call.233, custom-call.250, custom-call.31), kind=kOutput, calls=fused_computation.917.clone.clone.clone
     Allocation type: HLO temp
     ==========================

  10. Size: 2.00G
     Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/19/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
     Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
     Unpadded size: 2.00G
     XLA label: fusion.1227.remat3 = fusion(custom-call.232, custom-call.249, custom-call.30), kind=kOutput, calls=fused_computation.919.clone.clone.clone
     Allocation type: HLO temp
     ==========================

  11. Size: 2.00G
     Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/18/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
     Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
     Unpadded size: 2.00G
     XLA label: fusion.1229.remat3 = fusion(custom-call.231, custom-call.248, custom-call.29), kind=kOutput, calls=fused_computation.921.clone.clone.clone
     Allocation type: HLO temp
     ==========================

  12. Size: 2.00G
     Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/17/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
     Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
     Unpadded size: 2.00G
     XLA label: fusion.1231.remat3 = fusion(custom-call.230, custom-call.247, custom-call.28), kind=kOutput, calls=fused_computation.923.clone.clone.clone
     Allocation type: HLO temp
     ==========================

  13. Size: 2.00G
     Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/16/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
     Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
     Unpadded size: 2.00G
     XLA label: fusion.1233.remat3 = fusion(custom-call.229, custom-call.246, custom-call.27), kind=kOutput, calls=fused_computation.925.clone.clone.clone
     Allocation type: HLO temp
     ==========================

  14. Size: 2.00G
     Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/14/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
     Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
     Unpadded size: 2.00G
     XLA label: fusion.1237.remat3 = fusion(custom-call.227, custom-call.245, gte.remat.76), kind=kOutput, calls=fused_computation.929.clone.clone.clone
     Allocation type: HLO temp
     ==========================

  15. Size: 2.00G
     Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/13/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
     Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
     Unpadded size: 2.00G
     XLA label: fusion.1239.remat3 = fusion(copy-done.108, copy-done.125, custom-call.25), kind=kOutput, calls=fused_computation.931.clone.clone.clone
     Allocation type: HLO temp
     ==========================

  16. Size: 2.00G
     Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/12/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
     Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
     Unpadded size: 2.00G
     XLA label: fusion.1241.remat3 = fusion(copy-done.106, all-reduce.171, custom-call.25), kind=kOutput, calls=fused_computation.933.clone.clone.clone
     Allocation type: HLO temp
     ==========================

  17. Size: 2.00G
     Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/11/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
     Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
     Unpadded size: 2.00G
     XLA label: fusion.1243.remat3 = fusion(get-tuple-element.3954, custom-call.244, custom-call.25), kind=kOutput, calls=fused_computation.935.clone.clone.clone
     Allocation type: HLO temp
     ==========================

  18. Size: 2.00G
     Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/10/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
     Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
     Unpadded size: 2.00G
     XLA label: fusion.1245.remat3 = fusion(get-tuple-element.3951, copy-done.121, custom-call.25), kind=kOutput, calls=fused_computation.937.clone.clone.clone
     Allocation type: HLO temp
     ==========================

  19. Size: 2.00G
     Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/9/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
     Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
     Unpadded size: 2.00G
     XLA label: fusion.1247.remat3 = fusion(copy-done.105, copy-done.120, custom-call.25), kind=kOutput, calls=fused_computation.939.clone.clone.clone
     Allocation type: HLO temp
     ==========================

  20. Size: 2.00G
     Operator: op_name="jit(casual_language_model_train_step)/jit(main)/jvp(FlaxFalconForCausalLMModule)/transformer/h/8/self_attention/...qhd,...khd->...hqk/dot_general[dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2))) precision=(Precision.DEFAULT, Precision.DEFAULT) preferred_element_type=bfloat16]" source_file="/home/air/workspace/EasyDel/src/easydel/modules/attention_module.py" source_line=1059
     Shape: f32[32,4096,4096]{1,2,0:T(8,128)}
     Unpadded size: 2.00G
     XLA label: fusion.1249.remat3 = fusion(copy-done.104, copy-done.119, custom-call.25), kind=kOutput, calls=fused_computation.941.clone.clone.clone
     Allocation type: HLO temp
     ==========================
s-smits commented 1 month ago

Falcon-7B also does not work, but the focus for me lies on 11B. Could you take a look at it?

erfanzar commented 1 month ago

sure, im working on that.

s-smits commented 1 month ago

Thank you. Did you manage to find the bug?

erfanzar commented 1 month ago

Yes actually that's fixed but there are still some other issues from new experimental features... they all will be fixed soon but in case that your are not in discord server sorry that i forgot to tell you it's fixed. (Pypi version should work fine)

erfanzar commented 1 month ago

Your partition specs are wrong Replace weight with kernel and dot separator with /

s-smits commented 1 month ago

Will these fixes be implemented in the next version of EasyDeL?

erfanzar commented 1 month ago

actually they are fixed right now

erfanzar commented 4 weeks ago

@s-smits is it working?